mirror of
https://github.com/kovidgoyal/kitty.git
synced 2026-07-03 11:12:30 +08:00
Ensure output.Write is not called outside of the stream decompressor function
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
@@ -11,15 +12,30 @@ var _ = fmt.Print
|
||||
|
||||
type StreamDecompressor = func(chunk []byte, is_last bool) error
|
||||
|
||||
type pipe_reader struct {
|
||||
pr *io.PipeReader
|
||||
}
|
||||
|
||||
func (self *pipe_reader) Read(b []byte) (n int, err error) {
|
||||
// ensure the decompressor code never gets a zero byte read with no error
|
||||
for len(b) > 0 {
|
||||
n, err = self.pr.Read(b)
|
||||
if err != nil || n > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap Go's awful decompressor routines to allow feeding them
|
||||
// data in chunks. For example:
|
||||
// sd, err := NewStreamDecompressor(zlib.NewReader, output)
|
||||
// sd := NewStreamDecompressor(zlib.NewReader, output)
|
||||
// sd(chunk, false)
|
||||
// ...
|
||||
// sd(last_chunk, true)
|
||||
// after this call calling sd() further will just return io.EOF
|
||||
// WARNING: output.Write() may be called from a different thread, possibly even after sd()
|
||||
// returns. It will never be called after sd(is_last=True) returns, however.
|
||||
// after this call, calling sd() further will just return io.EOF.
|
||||
// To close the decompressor at any time, call sd(nil, true).
|
||||
// Note: output.Write() may be called from a different thread, but only while the main thread is in sd()
|
||||
func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), output io.Writer) StreamDecompressor {
|
||||
if constructor == nil { // identity decompressor
|
||||
var err error
|
||||
@@ -40,14 +56,16 @@ func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), o
|
||||
}
|
||||
}
|
||||
pipe_r, pipe_w := io.Pipe()
|
||||
pr := pipe_reader{pr: pipe_r}
|
||||
finished := make(chan error, 1)
|
||||
finished_err := errors.New("finished")
|
||||
go func() {
|
||||
var err error
|
||||
defer func() {
|
||||
finished <- err
|
||||
}()
|
||||
var impl io.ReadCloser
|
||||
impl, err = constructor(pipe_r)
|
||||
impl, err = constructor(&pr)
|
||||
if err != nil {
|
||||
pipe_r.CloseWithError(err)
|
||||
return
|
||||
@@ -57,24 +75,44 @@ func NewStreamDecompressor(constructor func(io.Reader) (io.ReadCloser, error), o
|
||||
if err == nil {
|
||||
err = cerr
|
||||
}
|
||||
if err == nil {
|
||||
err = finished_err
|
||||
}
|
||||
pipe_r.CloseWithError(err)
|
||||
}()
|
||||
|
||||
var iter_err error
|
||||
return func(chunk []byte, is_last bool) error {
|
||||
if iter_err != nil {
|
||||
if iter_err == finished_err {
|
||||
iter_err = io.EOF
|
||||
}
|
||||
return iter_err
|
||||
}
|
||||
if len(chunk) > 0 {
|
||||
_, iter_err = pipe_w.Write(chunk)
|
||||
if iter_err != nil {
|
||||
var n int
|
||||
n, iter_err = pipe_w.Write(chunk)
|
||||
if iter_err != nil && iter_err != finished_err {
|
||||
return iter_err
|
||||
}
|
||||
if n < len(chunk) {
|
||||
iter_err = io.ErrShortWrite
|
||||
return iter_err
|
||||
}
|
||||
// wait for output to finish
|
||||
if iter_err == nil {
|
||||
// after a zero byte read, pipe_reader.Read() calls pipe_r.Read() again so
|
||||
// we know it is either blocked waiting for a write to pipe_w or has finished
|
||||
_, iter_err = pipe_w.Write(nil)
|
||||
if iter_err != nil && iter_err != finished_err {
|
||||
return iter_err
|
||||
}
|
||||
}
|
||||
}
|
||||
if is_last {
|
||||
pipe_w.CloseWithError(io.EOF)
|
||||
err := <-finished
|
||||
if err != nil && err != io.EOF {
|
||||
if err != nil && err != io.EOF && err != finished_err {
|
||||
iter_err = err
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -33,4 +33,25 @@ func TestStreamDecompressor(t *testing.T) {
|
||||
if !bytes.Equal(o.Bytes(), input) {
|
||||
t.Fatalf("Roundtripping via zlib failed output (%d) != input (%d)", len(o.Bytes()), len(input))
|
||||
}
|
||||
|
||||
o.Reset()
|
||||
sd = NewStreamDecompressor(zlib.NewReader, &o)
|
||||
err := sd([]byte("abcd"), true)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not get an invalid header error from zlib")
|
||||
}
|
||||
|
||||
o.Reset()
|
||||
sd = NewStreamDecompressor(zlib.NewReader, &o)
|
||||
err = sd(b.Bytes(), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(o.Bytes(), input) {
|
||||
t.Fatalf("Roundtripping via zlib failed output (%d) != input (%d)", len(o.Bytes()), len(input))
|
||||
}
|
||||
err = sd([]byte("extra trailing data"), true)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not get an invalid header error from zlib")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user