From b3819d32268930fd3d64ce95dd732e9a8c164f57 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 25 Jul 2023 06:43:07 +0530 Subject: [PATCH] Ensure output.Write is not called outside of the stream decompressor function --- tools/utils/stream_decompressor.go | 54 +++++++++++++++++++++---- tools/utils/stream_decompressor_test.go | 21 ++++++++++ 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/tools/utils/stream_decompressor.go b/tools/utils/stream_decompressor.go index 9552a8f74..23ba10876 100644 --- a/tools/utils/stream_decompressor.go +++ b/tools/utils/stream_decompressor.go @@ -3,6 +3,7 @@ package utils import ( + "errors" "fmt" "io" ) @@ -11,15 +12,30 @@ var _ = fmt.Print type StreamDecompressor = func(chunk []byte, is_last bool) error +type pipe_reader struct { + pr *io.PipeReader +} + +func (self *pipe_reader) Read(b []byte) (n int, err error) { + // ensure the decompressor code never gets a zero byte read with no error + for len(b) > 0 { + n, err = self.pr.Read(b) + if err != nil || n > 0 { + return + } + } + return +} + // Wrap Go's awful decompressor routines to allow feeding them // data in chunks. For example: -// sd, err := NewStreamDecompressor(zlib.NewReader, output) +// sd := NewStreamDecompressor(zlib.NewReader, output) // sd(chunk, false) // ... // sd(last_chunk, true) -// after this call calling sd() further will just return io.EOF -// WARNING: output.Write() may be called from a different thread, possibly even after sd() -// returns. It will never be called after sd(is_last=True) returns, however. +// after this call, calling sd() further will just return io.EOF. +// To close the decompressor at any time, call sd(nil, true). +// Note: output.Write() may be called from a different thread, but only while the main thread is in sd() func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), output io.Writer) StreamDecompressor { if constructor == nil { // identity decompressor var err error @@ -40,14 +56,16 @@ func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), o } } pipe_r, pipe_w := io.Pipe() + pr := pipe_reader{pr: pipe_r} finished := make(chan error, 1) + finished_err := errors.New("finished") go func() { var err error defer func() { finished <- err }() var impl io.ReadCloser - impl, err = constructor(pipe_r) + impl, err = constructor(&pr) if err != nil { pipe_r.CloseWithError(err) return @@ -57,24 +75,44 @@ func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), o if err == nil { err = cerr } + if err == nil { + err = finished_err + } pipe_r.CloseWithError(err) }() var iter_err error return func(chunk []byte, is_last bool) error { if iter_err != nil { + if iter_err == finished_err { + iter_err = io.EOF + } return iter_err } if len(chunk) > 0 { - _, iter_err = pipe_w.Write(chunk) - if iter_err != nil { + var n int + n, iter_err = pipe_w.Write(chunk) + if iter_err != nil && iter_err != finished_err { return iter_err } + if n < len(chunk) { + iter_err = io.ErrShortWrite + return iter_err + } + // wait for output to finish + if iter_err == nil { + // after a zero byte read, pipe_reader.Read() calls pipe_r.Read() again so + // we know it is either blocked waiting for a write to pipe_w or has finished + _, iter_err = pipe_w.Write(nil) + if iter_err != nil && iter_err != finished_err { + return iter_err + } + } } if is_last { pipe_w.CloseWithError(io.EOF) err := <-finished - if err != nil && err != io.EOF { + if err != nil && err != io.EOF && err != finished_err { iter_err = err return err } diff --git a/tools/utils/stream_decompressor_test.go b/tools/utils/stream_decompressor_test.go index 6473db4ea..1c3c29897 100644 --- a/tools/utils/stream_decompressor_test.go +++ b/tools/utils/stream_decompressor_test.go @@ -33,4 +33,25 @@ func TestStreamDecompressor(t *testing.T) { if !bytes.Equal(o.Bytes(), input) { t.Fatalf("Roundtripping via zlib failed output (%d) != input (%d)", len(o.Bytes()), len(input)) } + + o.Reset() + sd = NewStreamDecompressor(zlib.NewReader, &o) + err := sd([]byte("abcd"), true) + if err == nil { + t.Fatalf("Did not get an invalid header error from zlib") + } + + o.Reset() + sd = NewStreamDecompressor(zlib.NewReader, &o) + err = sd(b.Bytes(), false) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(o.Bytes(), input) { + t.Fatalf("Roundtripping via zlib failed output (%d) != input (%d)", len(o.Bytes()), len(input)) + } + err = sd([]byte("extra trailing data"), true) + if err == nil { + t.Fatalf("Did not get an invalid header error from zlib") + } }