Simplify streaming base64 decoder by using the streaming API of libbase64

This commit is contained in:
Kovid Goyal
2024-07-29 20:08:28 +05:30
parent eb1bb493a7
commit 4ba9fcaf37
4 changed files with 21 additions and 57 deletions

View File

@@ -279,10 +279,6 @@ class WriteRequest:
self.currently_writing_mime = mime
self.write_base64_data(data)
@property
def current_leftover_bytes(self) -> memoryview:
return self.decoder.leftover_bytes()
def flush_base64_data(self) -> None:
if self.currently_writing_mime:
self.decoder.flush()

View File

@@ -108,8 +108,8 @@ pybase64_decode(PyObject UNUSED *self, PyObject *args) {
typedef struct StreamingBase64Decoder {
PyObject_HEAD
PyObject *output;
size_t output_sz, output_capacity, num_leftover_bytes, initial_capacity;
unsigned char leftover_bytes[8];
size_t output_sz, output_capacity, initial_capacity;
struct base64_state state;
} StreamingBase64Decoder;
static int
@@ -119,6 +119,7 @@ StreamingBase64Decoder_init(PyObject *s, PyObject *args, PyObject *kwds) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|k", kwlist, &initial_capacity)) return -1;
StreamingBase64Decoder *self = (StreamingBase64Decoder*)s;
self->initial_capacity = initial_capacity;
base64_stream_decode_init(&self->state, 0);
return 0;
}
@@ -130,16 +131,17 @@ StreamingBase64Decoder_dealloc(PyObject *self) {
}
static bool
write_base64_data(StreamingBase64Decoder *self, const void *data, size_t len) {
write_base64_data(StreamingBase64Decoder *self, const void *data, const size_t len) {
if (!len) return true;
size_t sz = required_buffer_size_for_base64_decode(len);
if ((self->output_sz + sz) > self->output_capacity) {
size_t cap = MAX(self->output_capacity * 2, self->output_sz + sz + self->initial_capacity);
size_t cap = MAX(self->output_capacity * 2, MAX(self->output_sz + sz, self->initial_capacity));
if (self->output) { if (_PyBytes_Resize(&self->output, cap) != 0) return false; }
else { self->output = PyBytes_FromStringAndSize(NULL, cap); if (!self->output) return false; }
self->output_capacity = cap;
}
if (!base64_decode8(data, len, (unsigned char*)(PyBytes_AS_STRING(self->output) + self->output_sz), &sz)) {
bool ok = base64_stream_decode(&self->state, data, len, PyBytes_AS_STRING(self->output) + self->output_sz, &sz);
if (!ok) {
PyErr_SetString(PyExc_ValueError, "Invalid base64 input data");
return false;
}
@@ -147,55 +149,22 @@ write_base64_data(StreamingBase64Decoder *self, const void *data, size_t len) {
return true;
}
static bool
write_saving_leftover_bytes(StreamingBase64Decoder *self, const unsigned char *data, size_t len) {
size_t extra = len % 4;
if (!write_base64_data(self, data, len - extra)) return false;
self->num_leftover_bytes = extra;
if (extra) memcpy(self->leftover_bytes, data + len - extra, extra);
return true;
}
static PyObject*
StreamingBase64Decoder_add(StreamingBase64Decoder *self, PyObject *a) {
RAII_PY_BUFFER(data);
if (PyObject_GetBuffer(a, &data, PyBUF_SIMPLE) != 0) return NULL;
if (!data.buf || !data.len) return PyLong_FromLong(0);
unsigned char *d = data.buf; size_t dlen = data.len;
size_t before = self->output_sz;
if (self->num_leftover_bytes) {
size_t extra = 4 - self->num_leftover_bytes;
if (dlen >= extra) {
memcpy(self->leftover_bytes + self->num_leftover_bytes, d, extra);
if (!write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes + extra)) return NULL;
self->num_leftover_bytes = 0;
d += extra; dlen -= extra;
if (!write_saving_leftover_bytes(self, d, dlen)) return NULL;
} else {
memcpy(self->leftover_bytes + self->num_leftover_bytes, d, dlen);
self->num_leftover_bytes += dlen;
}
} else if (!write_saving_leftover_bytes(self, d, dlen)) return NULL;
if (!write_base64_data(self, data.buf, data.len)) return NULL;
return PyLong_FromSize_t(self->output_sz - before);
}
static Py_ssize_t
StreamingBase64Decoder_len(PyObject *s) { return ((StreamingBase64Decoder*)s)->output_sz; }
static PyObject*
StreamingBase64Decoder_leftover_bytes(StreamingBase64Decoder *self, PyObject *a UNUSED) {
return PyMemoryView_FromMemory((char*)self->leftover_bytes, self->num_leftover_bytes, PyBUF_READ);
}
static PyObject*
StreamingBase64Decoder_flush(StreamingBase64Decoder *self, PyObject *args UNUSED) {
size_t padding = 4 - self->num_leftover_bytes;
switch(padding) {
case 1: self->leftover_bytes[self->num_leftover_bytes++] = '='; break;
case 2: self->leftover_bytes[self->num_leftover_bytes++] = '='; self->leftover_bytes[self->num_leftover_bytes++] = '='; break;
}
write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes);
self->num_leftover_bytes = 0;
base64_stream_decode_init(&self->state, 0);
Py_RETURN_NONE;
}
@@ -207,11 +176,9 @@ StreamingBase64Decoder_copy_output(StreamingBase64Decoder *self, PyObject *args
static PyObject*
StreamingBase64Decoder_take_output(StreamingBase64Decoder *self, PyObject *args UNUSED) {
if (!self->output_sz) return PyBytes_FromStringAndSize(NULL, 0);
RAII_PyObject(newbuf, PyBytes_FromStringAndSize(NULL, self->initial_capacity));
if (!newbuf) return NULL;
if (_PyBytes_Resize(&self->output, self->output_sz) != 0) return NULL;
PyObject *ans = self->output;
self->output = Py_NewRef(newbuf); self->output_sz = 0; self->output_capacity = self->initial_capacity;
self->output = NULL; self->output_sz = 0; self->output_capacity = 0;
return ans;
}
@@ -227,7 +194,6 @@ static PyTypeObject StreamingBase64Decoder_Type = {
{"flush", (PyCFunction)StreamingBase64Decoder_flush, METH_NOARGS, ""},
{"take_output", (PyCFunction)StreamingBase64Decoder_take_output, METH_NOARGS, ""},
{"copy_output", (PyCFunction)StreamingBase64Decoder_copy_output, METH_NOARGS, ""},
{"leftover_bytes", (PyCFunction)StreamingBase64Decoder_leftover_bytes, METH_NOARGS, ""},
{NULL, NULL, 0, NULL},
},
.tp_new = PyType_GenericNew,

View File

@@ -1711,7 +1711,6 @@ class StreamingBase64Decoder:
def take_output(self) -> bytes: ... # take the output so far. The decoder no longer references this output
def copy_output(self) -> bytes: ... # copy the output so far
def __len__(self) -> int: ... # return the length of the current output
def leftover_bytes(self) -> memoryview: ... # return the currently leftover bytes that will be consumed by flush()
class DiskCache:

View File

@@ -10,14 +10,17 @@ from . import BaseTest
class TestClipboard(BaseTest):
def test_clipboard_write_request(self):
wr = WriteRequest(max_size=64)
wr.add_base64_data('bGlnaHQgd29yaw')
self.ae(bytes(wr.current_leftover_bytes), b'aw')
wr.flush_base64_data()
self.ae(wr.data_for(), b'light work')
wr = WriteRequest(max_size=64)
wr.add_base64_data('bGlnaHQgd29yaw==')
self.ae(wr.data_for(), b'light work')
def t(data, expected):
wr = WriteRequest(max_size=64)
wr.add_base64_data(data)
self.ae(wr.data_for(), expected)
t('dGl0bGU=', b'title')
t('dGl0bGU', b'title')
t('dGl0bG', b'titl')
t('dGl0bG==', b'titl')
t('dGl0b', b'tit')
t('bGlnaHQgd29yaw', b'light work')
t('bGlnaHQgd29yaw==', b'light work')
wr = WriteRequest(max_size=64)
wr.add_base64_data('bGlnaHQgd29')
for x in b'y', b'a', b'y', b'4', b'=':