Skip to content

Commit 4e1a285

Browse files
pyrcoyunzheng
andauthored
Be robust against invalid utf-8 byte sequences and surrogateescape them when en- or decoding (#144)
This commit also takes the opportunity to remove Python 2 string compatibility code. It will also remove the final left-over Python 2 compatibility in the test cases. Co-authored-by: Yun Zheng Hu <[email protected]>
1 parent 8d6fe37 commit 4e1a285

16 files changed

+133
-140
lines changed

flow/record/adapter/elastic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def record_to_document(self, record: Record, index: str) -> dict:
106106
}
107107

108108
if self.hash_record:
109-
document["_id"] = hashlib.md5(document["_source"].encode()).hexdigest()
109+
document["_id"] = hashlib.md5(document["_source"].encode(errors="surrogateescape")).hexdigest()
110110

111111
return document
112112

flow/record/adapter/line.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def write(self, rec: Record) -> None:
6969
for key, value in rdict.items():
7070
if rdict_types:
7171
key = f"{key} ({rdict_types[key]})"
72-
self.fp.write(fmt.format(key, value).encode())
72+
self.fp.write(fmt.format(key, value).encode(errors="surrogateescape"))
7373

7474
def flush(self) -> None:
7575
if self.fp:

flow/record/adapter/sqlite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def read_table(self, table_name: str) -> Iterator[Record]:
187187
if value == 0:
188188
row[idx] = None
189189
elif isinstance(value, str):
190-
row[idx] = value.encode("utf-8")
190+
row[idx] = value.encode(errors="surrogateescape")
191191
yield descriptor_cls.init_from_dict(dict(zip(fnames, row)))
192192

193193
def __iter__(self) -> Iterator[Record]:

flow/record/adapter/text.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def write(self, rec):
4141
buf = self.format_spec.format_map(DefaultMissing(rec._asdict()))
4242
else:
4343
buf = repr(rec)
44-
self.fp.write(buf.encode() + b"\n")
44+
self.fp.write(buf.encode(errors="surrogateescape") + b"\n")
4545

4646
# because stdout is usually line buffered we force flush here if wanted
4747
if self.auto_flush:

flow/record/adapter/xlsx.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def sanitize_fieldvalues(values: Iterator[Any]) -> Iterator[Any]:
3636
elif isinstance(value, bytes):
3737
base64_encode = False
3838
try:
39-
new_value = 'b"' + value.decode() + '"'
39+
new_value = 'b"' + value.decode(errors="surrogateescape") + '"'
4040
if ILLEGAL_CHARACTERS_RE.search(new_value):
4141
base64_encode = True
4242
else:
@@ -142,7 +142,7 @@ def __iter__(self):
142142
if field_types[idx] == "bytes":
143143
if value[1] == '"': # If so, we know this is b""
144144
# Cut of the b" at the start and the trailing "
145-
value = value[2:-1].encode()
145+
value = value[2:-1].encode(errors="surrogateescape")
146146
else:
147147
# If not, we know it is base64 encoded (so we cut of the starting 'base64:')
148148
value = b64decode(value[7:])

flow/record/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
from collections import OrderedDict
6363

64-
from .utils import to_native_str, to_str
64+
from .utils import to_str
6565
from .whitelist import WHITELIST, WHITELIST_TREE
6666

6767
log = logging.getLogger(__package__)
@@ -513,7 +513,7 @@ def __init__(self, name: str, fields: Optional[Sequence[tuple[str, str]]] = None
513513
name, fields = parse_def(name)
514514

515515
self.name = name
516-
self._field_tuples = tuple([(to_native_str(k), to_str(v)) for k, v in fields])
516+
self._field_tuples = tuple([(to_str(k), to_str(v)) for k, v in fields])
517517
self.recordType = _generate_record_class(name, self._field_tuples)
518518
self.recordType._desc = self
519519

flow/record/fieldtypes/__init__.py

+2-27
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from flow.record.base import FieldType
2929

3030
RE_NORMALIZE_PATH = re.compile(r"[\\/]+")
31-
NATIVE_UNICODE = isinstance("", str)
3231

3332
UTC = timezone.utc
3433

@@ -207,10 +206,7 @@ def _pack(self):
207206
class string(string_type, FieldType):
208207
def __new__(cls, value):
209208
if isinstance(value, bytes_type):
210-
value = cls._decode(value, "utf-8")
211-
if isinstance(value, bytes_type):
212-
# Still bytes, so decoding failed (Python 2)
213-
return bytes(value)
209+
value = value.decode(errors="surrogateescape")
214210
return super().__new__(cls, value)
215211

216212
def _pack(self):
@@ -221,27 +217,6 @@ def __format__(self, spec):
221217
return defang(self)
222218
return str.__format__(self, spec)
223219

224-
@classmethod
225-
def _decode(cls, data, encoding):
226-
"""Decode a byte-string into a unicode-string.
227-
228-
Python 3: When `data` contains invalid unicode characters a `UnicodeDecodeError` is raised.
229-
Python 2: When `data` contains invalid unicode characters the original byte-string is returned.
230-
"""
231-
if NATIVE_UNICODE:
232-
# Raises exception on decode error
233-
return data.decode(encoding)
234-
try:
235-
return data.decode(encoding)
236-
except UnicodeDecodeError:
237-
# Fallback to bytes (Python 2 only)
238-
preview = data[:16].encode("hex_codec") + (".." if len(data) > 16 else "")
239-
warnings.warn(
240-
"Got binary data in string field (hex: {}). Compatibility is not guaranteed.".format(preview),
241-
RuntimeWarning,
242-
)
243-
return data
244-
245220

246221
# Alias for backwards compatibility
247222
wstring = string
@@ -278,7 +253,7 @@ def __new__(cls, *args, **kwargs):
278253
if len(args) == 1 and not kwargs:
279254
arg = args[0]
280255
if isinstance(arg, bytes_type):
281-
arg = arg.decode("utf-8")
256+
arg = arg.decode(errors="surrogateescape")
282257
if isinstance(arg, string_type):
283258
# If we are on Python 3.11 or newer, we can use fromisoformat() to parse the string (fast path)
284259
#

flow/record/fieldtypes/net/ipv4.py

-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44

55
from flow.record import FieldType
6-
from flow.record.utils import to_native_str
76

87

98
def addr_long(s):
@@ -45,9 +44,6 @@ def __init__(self, addr, netmask=None):
4544
DeprecationWarning,
4645
stacklevel=5,
4746
)
48-
if isinstance(addr, type("")):
49-
addr = to_native_str(addr)
50-
5147
if not isinstance(addr, str):
5248
raise TypeError("Subnet() argument 1 must be string, not {}".format(type(addr).__name__))
5349

@@ -67,9 +63,6 @@ def __contains__(self, addr):
6763
if addr is None:
6864
return False
6965

70-
if isinstance(addr, type("")):
71-
addr = to_native_str(addr)
72-
7366
if isinstance(addr, str):
7467
addr = addr_long(addr)
7568

flow/record/jsonpacker.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,8 @@ def pack_obj(self, obj):
4747
serial["_recorddescriptor"] = obj._desc.identifier
4848

4949
for field_type, field_name in obj._desc.get_field_tuples():
50-
# PYTHON2: Because "bytes" are also "str" we have to handle this here
51-
if field_type == "bytes" and isinstance(serial[field_name], str):
52-
serial[field_name] = base64.b64encode(serial[field_name]).decode()
53-
5450
# Boolean field types should be cast to a bool instead of staying ints
55-
elif field_type == "boolean" and isinstance(serial[field_name], int):
51+
if field_type == "boolean" and isinstance(serial[field_name], int):
5652
serial[field_name] = bool(serial[field_name])
5753

5854
return serial

flow/record/utils.py

+18-22
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33
import base64
44
import os
55
import sys
6+
import warnings
67
from functools import wraps
78
from typing import BinaryIO, TextIO
89

9-
_native = str
10-
_unicode = type("")
11-
_bytes = type(b"")
12-
1310

1411
def get_stdout(binary: bool = False) -> TextIO | BinaryIO:
1512
"""Return the stdout stream as binary or text stream.
@@ -50,33 +47,32 @@ def is_stdout(fp: TextIO | BinaryIO) -> bool:
5047

5148
def to_bytes(value):
5249
"""Convert a value to a byte string."""
53-
if value is None or isinstance(value, _bytes):
50+
if value is None or isinstance(value, bytes):
5451
return value
55-
if isinstance(value, _unicode):
56-
return value.encode("utf-8")
57-
return _bytes(value)
52+
if isinstance(value, str):
53+
return value.encode(errors="surrogateescape")
54+
return bytes(value)
5855

5956

6057
def to_str(value):
6158
"""Convert a value to a unicode string."""
62-
if value is None or isinstance(value, _unicode):
59+
if value is None or isinstance(value, str):
6360
return value
64-
if isinstance(value, _bytes):
65-
return value.decode("utf-8")
66-
return _unicode(value)
61+
if isinstance(value, bytes):
62+
return value.decode(errors="surrogateescape")
63+
return str(value)
6764

6865

6966
def to_native_str(value):
70-
"""Convert a value to a native `str`."""
71-
if value is None or isinstance(value, _native):
72-
return value
73-
if isinstance(value, _unicode):
74-
# Python 2: unicode -> str
75-
return value.encode("utf-8")
76-
if isinstance(value, _bytes):
77-
# Python 3: bytes -> str
78-
return value.decode("utf-8")
79-
return _native(value)
67+
warnings.warn(
68+
(
69+
"The to_native_str() function is deprecated, "
70+
"this function will be removed in flow.record 3.20, "
71+
"use to_str() instead"
72+
),
73+
DeprecationWarning,
74+
)
75+
return to_str(value)
8076

8177

8278
def to_base64(value):

tests/test_adapter_line.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from io import BytesIO
2+
3+
from flow.record import RecordDescriptor
4+
from flow.record.adapter.line import LineWriter
5+
6+
7+
def test_line_writer_write_surrogateescape():
8+
output = BytesIO()
9+
10+
lw = LineWriter(
11+
path=output,
12+
fields="name",
13+
)
14+
15+
TestRecord = RecordDescriptor(
16+
"test/string",
17+
[
18+
("string", "name"),
19+
],
20+
)
21+
22+
# construct from 'bytes' but with invalid unicode bytes
23+
record = TestRecord(b"R\xc3\xa9\xeamy")
24+
lw.write(record)
25+
26+
output.seek(0)
27+
data = output.read()
28+
29+
assert data == b"--[ RECORD 1 ]--\nname = R\xc3\xa9\xeamy\n"

tests/test_adapter_text.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from io import BytesIO
2+
3+
from flow.record import RecordDescriptor
4+
from flow.record.adapter.text import TextWriter
5+
6+
7+
def test_text_writer_write_surrogateescape():
8+
output = BytesIO()
9+
10+
tw = TextWriter(
11+
path=output,
12+
)
13+
14+
TestRecord = RecordDescriptor(
15+
"test/string",
16+
[
17+
("string", "name"),
18+
],
19+
)
20+
21+
# construct from 'bytes' but with invalid unicode bytes
22+
record = TestRecord(b"R\xc3\xa9\xeamy")
23+
tw.write(record)
24+
25+
output.seek(0)
26+
data = output.read()
27+
28+
assert data == b"<test/string name='R\xc3\xa9\\udceamy'>\n"

tests/test_fieldtypes.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,8 @@ def test_string():
213213
assert r.name == "Rémy"
214214

215215
# construct from 'bytes' but with invalid unicode bytes
216-
if isinstance("", str):
217-
# Python 3
218-
with pytest.raises(UnicodeDecodeError):
219-
TestRecord(b"R\xc3\xa9\xeamy")
220-
else:
221-
# Python 2
222-
with pytest.warns(RuntimeWarning):
223-
r = TestRecord(b"R\xc3\xa9\xeamy")
224-
assert r.name
216+
r = TestRecord(b"R\xc3\xa9\xeamy")
217+
assert r.name == "Ré\udceamy"
225218

226219

227220
def test_wstring():

tests/test_json_packer.py

+20
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,23 @@ def test_record_pack_bool_regression() -> None:
9090

9191
# pack the json string back to a record and make sure it is the same as before
9292
assert packer.unpack(data) == record
93+
94+
95+
def test_record_pack_surrogateescape() -> None:
96+
TestRecord = RecordDescriptor(
97+
"test/string",
98+
[
99+
("string", "name"),
100+
],
101+
)
102+
103+
record = TestRecord(b"R\xc3\xa9\xeamy")
104+
packer = JsonRecordPacker()
105+
106+
data = packer.pack(record)
107+
108+
# pack to json string and check if the 3rd and 4th byte are properly surrogate escaped
109+
assert data.startswith('{"name": "R\\u00e9\\udceamy",')
110+
111+
# pack the json string back to a record and make sure it is the same as before
112+
assert packer.unpack(data) == record

tests/test_record.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import inspect
23
import os
34
import sys
45
from unittest.mock import patch
@@ -27,8 +28,6 @@
2728
from flow.record.exceptions import RecordDescriptorError
2829
from flow.record.stream import RecordFieldRewriter
2930

30-
from . import utils_inspect as inspect
31-
3231

3332
def test_record_creation():
3433
TestRecord = RecordDescriptor(
@@ -288,8 +287,30 @@ def isatty():
288287
writer.write(record)
289288

290289
out, err = capsys.readouterr()
291-
modifier = "" if isinstance("", str) else "u"
292-
expected = "<test/a a_string={u}'hello' common={u}'world' a_count=10>\n".format(u=modifier)
290+
expected = "<test/a a_string='hello' common='world' a_count=10>\n"
291+
assert out == expected
292+
293+
294+
def test_record_printer_stdout_surrogateescape(capsys):
295+
Record = RecordDescriptor(
296+
"test/a",
297+
[
298+
("string", "name"),
299+
],
300+
)
301+
record = Record(b"R\xc3\xa9\xeamy")
302+
303+
# fake capsys to be a tty.
304+
def isatty():
305+
return True
306+
307+
capsys._capture.out.tmpfile.isatty = isatty
308+
309+
writer = RecordPrinter(getattr(sys.stdout, "buffer", sys.stdout))
310+
writer.write(record)
311+
312+
out, err = capsys.readouterr()
313+
expected = "<test/a name='Ré\\udceamy'>\n"
293314
assert out == expected
294315

295316

0 commit comments

Comments
 (0)