Ensure output.Write is not called outside of the stream decompressor function

This commit is contained in:
Kovid Goyal
2023-07-25 06:43:07 +05:30
parent 301f309444
commit b3819d3226
2 changed files with 67 additions and 8 deletions

View File

@@ -3,6 +3,7 @@
package utils
import (
"errors"
"fmt"
"io"
)
@@ -11,15 +12,30 @@ var _ = fmt.Print
type StreamDecompressor = func(chunk []byte, is_last bool) error
type pipe_reader struct {
pr *io.PipeReader
}
func (self *pipe_reader) Read(b []byte) (n int, err error) {
// ensure the decompressor code never gets a zero byte read with no error
for len(b) > 0 {
n, err = self.pr.Read(b)
if err != nil || n > 0 {
return
}
}
return
}
// Wrap Go's awful decompressor routines to allow feeding them
// data in chunks. For example:
// sd, err := NewStreamDecompressor(zlib.NewReader, output)
// sd := NewStreamDecompressor(zlib.NewReader, output)
// sd(chunk, false)
// ...
// sd(last_chunk, true)
// after this call calling sd() further will just return io.EOF
// WARNING: output.Write() may be called from a different thread, possibly even after sd()
// returns. It will never be called after sd(is_last=True) returns, however.
// after this call, calling sd() further will just return io.EOF.
// To close the decompressor at any time, call sd(nil, true).
// Note: output.Write() may be called from a different thread, but only while the main thread is in sd()
func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), output io.Writer) StreamDecompressor {
if constructor == nil { // identity decompressor
var err error
@@ -40,14 +56,16 @@ func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), o
}
}
pipe_r, pipe_w := io.Pipe()
pr := pipe_reader{pr: pipe_r}
finished := make(chan error, 1)
finished_err := errors.New("finished")
go func() {
var err error
defer func() {
finished <- err
}()
var impl io.ReadCloser
impl, err = constructor(pipe_r)
impl, err = constructor(&pr)
if err != nil {
pipe_r.CloseWithError(err)
return
@@ -57,24 +75,44 @@ func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), o
if err == nil {
err = cerr
}
if err == nil {
err = finished_err
}
pipe_r.CloseWithError(err)
}()
var iter_err error
return func(chunk []byte, is_last bool) error {
if iter_err != nil {
if iter_err == finished_err {
iter_err = io.EOF
}
return iter_err
}
if len(chunk) > 0 {
_, iter_err = pipe_w.Write(chunk)
if iter_err != nil {
var n int
n, iter_err = pipe_w.Write(chunk)
if iter_err != nil && iter_err != finished_err {
return iter_err
}
if n < len(chunk) {
iter_err = io.ErrShortWrite
return iter_err
}
// wait for output to finish
if iter_err == nil {
// after a zero byte read, pipe_reader.Read() calls pipe_r.Read() again so
// we know it is either blocked waiting for a write to pipe_w or has finished
_, iter_err = pipe_w.Write(nil)
if iter_err != nil && iter_err != finished_err {
return iter_err
}
}
}
if is_last {
pipe_w.CloseWithError(io.EOF)
err := <-finished
if err != nil && err != io.EOF {
if err != nil && err != io.EOF && err != finished_err {
iter_err = err
return err
}

View File

@@ -33,4 +33,25 @@ func TestStreamDecompressor(t *testing.T) {
if !bytes.Equal(o.Bytes(), input) {
t.Fatalf("Roundtripping via zlib failed output (%d) != input (%d)", len(o.Bytes()), len(input))
}
o.Reset()
sd = NewStreamDecompressor(zlib.NewReader, &o)
err := sd([]byte("abcd"), true)
if err == nil {
t.Fatalf("Did not get an invalid header error from zlib")
}
o.Reset()
sd = NewStreamDecompressor(zlib.NewReader, &o)
err = sd(b.Bytes(), false)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(o.Bytes(), input) {
t.Fatalf("Roundtripping via zlib failed output (%d) != input (%d)", len(o.Bytes()), len(input))
}
err = sd([]byte("extra trailing data"), true)
if err == nil {
t.Fatalf("Did not get an invalid header error from zlib")
}
}