Skip to content

Commit fe809a8

Browse files
authored
feat(core): JSON Serialization w/ Tensors (#16)
This PR introduces JSON serialization for tensor types. Any MLC dataclass that contains Tensors can now be serialized and deserialized with this PR.
1 parent 4f1f81e commit fe809a8

File tree

4 files changed

+129
-32
lines changed

4 files changed

+129
-32
lines changed

cpp/registry.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ Any CopyShallow(AnyView root);
2727
Any CopyDeep(AnyView root);
2828
Str DocToPythonScript(mlc::printer::Node node, mlc::printer::PrinterConfig cfg);
2929
UDict BuildInfo();
30-
Tensor TensorFromBytes(const StrObj *);
31-
Str TensorToBytes(const TensorObj *);
32-
Tensor TensorFromBase64(const StrObj *);
33-
Str TensorToBase64(const TensorObj *);
30+
31+
Str TensorToBytes(const TensorObj *src);
32+
Str TensorToBase64(const TensorObj *src);
33+
Tensor TensorFromBytes(AnyView any);
34+
Tensor TensorFromBase64(AnyView any);
3435

3536
struct DSOLibrary {
3637
~DSOLibrary() { Unload(); }

cpp/structure.cc

+77-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "mlc/base/traits_dtype.h"
21
#include <algorithm>
32
#include <cmath>
43
#include <cstdint>
@@ -290,7 +289,7 @@ static const std::array<uint8_t, 256> kBase64DecTable = []() {
290289

291290
Str Base64Encode(const uint8_t *data, int64_t len) {
292291
constexpr int BITS_PER_CHAR = 6;
293-
Str ret(::mlc::core::StrPad::Allocator::NewWithPad<uint8_t>(((len + 2) / 3) * 4, 0));
292+
Str ret(::mlc::core::StrPad::Allocator::NewWithPad<uint8_t>(((len + 2) / 3) * 4 + 1, 0));
294293
uint8_t *out = reinterpret_cast<uint8_t *>(ret.get()->::MLCStr::data);
295294
int64_t &out_len = ret.get()->::MLCStr::length;
296295
for (int64_t i = 0; i < len; i += 3) {
@@ -323,14 +322,16 @@ Str Base64Encode(const uint8_t *data, int64_t len) {
323322
}
324323
}
325324
}
325+
out[out_len] = '\0';
326326
return ret;
327327
}
328328

329329
Str Base64Decode(const uint8_t *data, int64_t len) {
330330
if (len % 4 != 0) {
331-
MLC_THROW(ValueError) << "Base64Decode: Input length not multiple of 4.";
331+
MLC_THROW(ValueError) << "Base64Decode: Input length not multiple of 4: length = " << len
332+
<< ", data = " << reinterpret_cast<const char *>(data);
332333
}
333-
Str ret(::mlc::core::StrPad::Allocator::NewWithPad<uint8_t>((len / 4) * 3, 0));
334+
Str ret(::mlc::core::StrPad::Allocator::NewWithPad<uint8_t>((len / 4) * 3 + 1, 0));
334335
uint8_t *out = reinterpret_cast<uint8_t *>(ret.get()->::MLCStr::data);
335336
int64_t &result_len = ret.get()->::MLCStr::length;
336337
for (int64_t i = 0; i < len; i += 4) {
@@ -359,6 +360,7 @@ Str Base64Decode(const uint8_t *data, int64_t len) {
359360
out[result_len++] = byte_val;
360361
}
361362
}
363+
out[result_len] = '\0';
362364
return ret;
363365
}
364366

@@ -1155,7 +1157,7 @@ Str TensorToBytes(const DLTensor *src) {
11551157
int64_t numel = ::mlc::core::ShapeToNumel(ndim, src->shape);
11561158
int32_t elem_size = ::mlc::base::DataTypeSize(src->dtype);
11571159
int64_t total_bytes = 8 + 4 + 4 + 8 * ndim + numel * elem_size;
1158-
Str ret(::mlc::core::StrPad::Allocator::NewWithPad<uint8_t>(total_bytes, total_bytes));
1160+
Str ret(::mlc::core::StrPad::Allocator::NewWithPad<uint8_t>(total_bytes + 1, total_bytes));
11591161
uint8_t *data_ptr = reinterpret_cast<uint8_t *>(ret->data());
11601162
int64_t tail = 0;
11611163
WriteElem<8>(data_ptr, &tail, static_cast<uint64_t>(kMLCTensorMagic));
@@ -1165,6 +1167,10 @@ Str TensorToBytes(const DLTensor *src) {
11651167
WriteElem<8>(data_ptr, &tail, src->shape[i]);
11661168
}
11671169
WriteElemMany(data_ptr, &tail, static_cast<uint8_t *>(src->data), elem_size, numel);
1170+
data_ptr[tail] = '\0';
1171+
if (tail != total_bytes) {
1172+
MLC_THROW(InternalError) << "SaveDLPack: Internal error in serialization.";
1173+
}
11681174
return ret;
11691175
}
11701176

@@ -1302,8 +1308,9 @@ inline mlc::Str Serialize(Any any) {
13021308
};
13031309

13041310
std::unordered_map<Object *, int32_t> topo_indices;
1311+
std::vector<TensorObj *> tensors;
13051312
std::ostringstream os;
1306-
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os,
1313+
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, &tensors,
13071314
is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void {
13081315
int32_t &topo_index = topo_indices[object];
13091316
if (topo_index == 0) {
@@ -1331,10 +1338,11 @@ inline mlc::Str Serialize(Any any) {
13311338
emitter(nullptr, &kv.first);
13321339
emitter(nullptr, &kv.second);
13331340
}
1341+
} else if (TensorObj *tensor = object->TryCast<TensorObj>()) {
1342+
(*os) << ", " << tensors.size();
1343+
tensors.push_back(tensor);
13341344
} else if (object->IsInstance<FuncObj>() || object->IsInstance<ErrorObj>()) {
13351345
MLC_THROW(TypeError) << "Unserializable type: " << object->GetTypeKey();
1336-
} else if (object->IsInstance<TensorObj>()) {
1337-
// TODO: tensors
13381346
} else if (object->IsInstance<OpaqueObj>()) {
13391347
MLC_THROW(TypeError) << "Cannot serialize `mlc.Opaque` of type: " << object->Cast<OpaqueObj>()->opaque_type_name;
13401348
} else {
@@ -1375,11 +1383,24 @@ inline mlc::Str Serialize(Any any) {
13751383
}
13761384
os << '"' << type_keys[i] << '\"';
13771385
}
1378-
os << "]}";
1386+
os << "]";
1387+
if (!tensors.empty()) {
1388+
os << ", \"tensors\": [";
1389+
for (size_t i = 0; i < tensors.size(); ++i) {
1390+
if (i > 0) {
1391+
os << ", ";
1392+
}
1393+
Str b64 = tensors[i]->ToBase64();
1394+
os << '"' << b64->data() << '"';
1395+
}
1396+
os << "]";
1397+
}
1398+
os << "}";
13791399
return os.str();
13801400
}
13811401

13821402
inline Any Deserialize(const char *json_str, int64_t json_str_len) {
1403+
int32_t json_type_index_tensor = -1;
13831404
// Step 0. Parse JSON string
13841405
UDict json_obj = JSONLoads(json_str, json_str_len);
13851406
// Step 1. type_key => constructors
@@ -1388,7 +1409,12 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
13881409
constructors.reserve(type_keys.size());
13891410
for (Str type_key : type_keys) {
13901411
int32_t type_index = Lib::GetTypeIndex(type_key->data());
1391-
FuncObj *func = Lib::_init(type_index);
1412+
FuncObj *func = nullptr;
1413+
if (type_index != kMLCTensor) {
1414+
func = Lib::_init(type_index);
1415+
} else {
1416+
json_type_index_tensor = static_cast<int32_t>(constructors.size());
1417+
}
13921418
constructors.push_back(func);
13931419
}
13941420
auto invoke_init = [&constructors](UList args) {
@@ -1398,13 +1424,29 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
13981424
&ret);
13991425
return ret;
14001426
};
1401-
// Step 2. Translate JSON object to objects
1427+
// Step 2. Handle tensors
1428+
std::vector<Tensor> tensors;
1429+
if (json_obj->count("tensors")) {
1430+
UList tensors_b64 = json_obj->at("tensors");
1431+
while (!tensors_b64->empty()) {
1432+
Tensor tensor = Tensor::FromBase64(tensors_b64->back());
1433+
tensors.push_back(tensor);
1434+
tensors_b64->pop_back();
1435+
}
1436+
json_obj->erase("tensors");
1437+
std::reverse(tensors.begin(), tensors.end());
1438+
}
1439+
// Step 3. Translate JSON object to objects
14021440
UList values = json_obj->at("values");
14031441
for (int64_t i = 0; i < values->size(); ++i) {
14041442
Any obj = values[i];
14051443
if (obj.type_index == kMLCList) {
14061444
UList list = obj.operator UList();
14071445
int32_t json_type_index = list[0];
1446+
if (json_type_index == json_type_index_tensor) {
1447+
values[i] = tensors[list[1].operator int32_t()];
1448+
continue;
1449+
}
14081450
for (int64_t j = 1; j < list.size(); ++j) {
14091451
Any arg = list[j];
14101452
if (arg.type_index == kMLCInt) {
@@ -1487,23 +1529,38 @@ Any JSONDeserialize(AnyView json_str) {
14871529

14881530
Str JSONSerialize(AnyView source) { return ::mlc::Serialize(source); }
14891531

1490-
Tensor TensorFromBytes(const StrObj *src) {
1491-
return ::mlc::TensorFromBytes(reinterpret_cast<const uint8_t *>(src->::MLCStr::data), src->length());
1492-
}
1493-
14941532
Str TensorToBytes(const TensorObj *src) {
14951533
return ::mlc::TensorToBytes(&src->tensor); //
14961534
}
14971535

1498-
Tensor TensorFromBase64(const StrObj *src) {
1499-
Str bytes = ::mlc::Base64Decode(reinterpret_cast<const uint8_t *>(src->data()), src->size());
1500-
return ::mlc::TensorFromBytes(reinterpret_cast<const uint8_t *>(bytes->data()), bytes->size());
1501-
}
1502-
15031536
Str TensorToBase64(const TensorObj *src) {
15041537
Str bytes = ::mlc::TensorToBytes(&src->tensor);
15051538
return ::mlc::Base64Encode(reinterpret_cast<uint8_t *>(bytes->data()), bytes->size());
15061539
}
15071540

1541+
Tensor TensorFromBytes(AnyView any) {
1542+
if (any.type_index == kMLCRawStr) {
1543+
const char *src = any;
1544+
int64_t len = std::strlen(src);
1545+
return ::mlc::TensorFromBytes(reinterpret_cast<const uint8_t *>(src), len);
1546+
} else {
1547+
Str src = any;
1548+
return ::mlc::TensorFromBytes(reinterpret_cast<const uint8_t *>(src->::MLCStr::data), src->length());
1549+
}
1550+
}
1551+
1552+
Tensor TensorFromBase64(AnyView any) {
1553+
if (any.type_index == kMLCRawStr) {
1554+
const char *src = any;
1555+
int64_t len = std::strlen(src);
1556+
Str bytes = ::mlc::Base64Decode(reinterpret_cast<const uint8_t *>(src), len);
1557+
return ::mlc::TensorFromBytes(reinterpret_cast<const uint8_t *>(bytes->data()), bytes->size());
1558+
} else {
1559+
Str src = any;
1560+
Str bytes = ::mlc::Base64Decode(reinterpret_cast<const uint8_t *>(src->::MLCStr::data), src->length());
1561+
return ::mlc::TensorFromBytes(reinterpret_cast<const uint8_t *>(bytes->data()), bytes->size());
1562+
}
1563+
}
1564+
15081565
} // namespace registry
15091566
} // namespace mlc

include/mlc/core/tensor.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@ struct TensorObj : public MLCTensor {
2727
~TensorObj() { delete[] this->tensor.shape; }
2828

2929
Str ToBytes() const {
30-
static auto func = ::mlc::base::GetGlobalFuncCall<2>("mlc.core.TensorToBytes");
30+
static auto func = ::mlc::base::GetGlobalFuncCall<1>("mlc.core.TensorToBytes");
3131
return func({this});
3232
}
3333

3434
static Ref<TensorObj> FromBytes(const Str &source) {
35-
static auto func = ::mlc::base::GetGlobalFuncCall<2>("mlc.core.TensorFromBytes");
35+
static auto func = ::mlc::base::GetGlobalFuncCall<1>("mlc.core.TensorFromBytes");
3636
return func({source});
3737
}
3838

3939
Str ToBase64() const {
40-
static auto func = ::mlc::base::GetGlobalFuncCall<2>("mlc.core.TensorToBase64");
40+
static auto func = ::mlc::base::GetGlobalFuncCall<1>("mlc.core.TensorToBase64");
4141
return func({this});
4242
}
4343

4444
static Ref<TensorObj> FromBase64(const Str &source) {
45-
static auto func = ::mlc::base::GetGlobalFuncCall<2>("mlc.core.TensorFromBase64");
45+
static auto func = ::mlc::base::GetGlobalFuncCall<1>("mlc.core.TensorFromBase64");
4646
return func({source});
4747
}
4848

@@ -125,7 +125,7 @@ struct TensorObj : public MLCTensor {
125125
if (shape[i] == 0) {
126126
return true;
127127
}
128-
if (strides[i] != stride) {
128+
if (shape[i] > 1 && strides[i] != stride) {
129129
return false;
130130
}
131131
stride *= shape[i];
@@ -171,8 +171,8 @@ struct TensorObj::Allocator {
171171

172172
struct Tensor : public ObjectRef {
173173
explicit Tensor(DLManagedTensor *tensor) : ObjectRef(TensorObj::Allocator::New(tensor)) {}
174-
Tensor FromBytes(const Str &source) { return Tensor(TensorObj::FromBytes(source)); }
175-
Tensor FromBase64(const Str &source) { return Tensor(TensorObj::FromBase64(source)); }
174+
static Tensor FromBytes(const Str &source) { return Tensor(TensorObj::FromBytes(source)); }
175+
static Tensor FromBase64(const Str &source) { return Tensor(TensorObj::FromBase64(source)); }
176176

177177
const void *data() const { return this->get()->tensor.data; }
178178
DLDevice device() const { return this->get()->tensor.device; }

tests/python/test_core_tensor.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,28 @@ def test_opaque_from_torch_cuda(cxx_func: mlc.Func) -> None:
7676
assert torch.equal(a, b.torch())
7777

7878

79-
def test_tensor_serialize() -> None:
79+
def test_tensor_base64_int16() -> None:
8080
a = mlc.Tensor(np.arange(24, dtype=np.int16).reshape(2, 3, 4))
81+
assert (
82+
a.base64()
83+
== "P6G0lvBAXt0DAAAAABABAAIAAAAAAAAAAwAAAAAAAAAEAAAAAAAAAAAAAQACAAMABAAFAAYABwAIAAkACgALAAwADQAOAA8AEAARABIAEwAUABUAFgAXAA=="
84+
)
85+
b = mlc.Tensor.from_base64(a.base64())
86+
assert a.ndim == b.ndim
87+
assert a.shape == b.shape
88+
assert a.dtype == b.dtype
89+
assert a.device == b.device
90+
assert a.strides == b.strides
91+
assert a.byte_offset == b.byte_offset
92+
assert a.base64() == b.base64()
93+
94+
assert np.array_equal(a.numpy(), b.numpy())
95+
assert torch.equal(a.torch(), b.torch())
96+
97+
98+
def test_tensor_base64_float16() -> None:
99+
a = mlc.Tensor(np.array([3.0, 10.0, 20.0, 30.0, 35.50], dtype=np.float16))
100+
assert a.base64() == "P6G0lvBAXt0BAAAAAhABAAUAAAAAAAAAAEIASQBNgE9wUA=="
81101
b = mlc.Tensor.from_base64(a.base64())
82102
assert a.ndim == b.ndim
83103
assert a.shape == b.shape
@@ -89,3 +109,22 @@ def test_tensor_serialize() -> None:
89109

90110
assert np.array_equal(a.numpy(), b.numpy())
91111
assert torch.equal(a.torch(), b.torch())
112+
113+
114+
def test_torch_strides() -> None:
115+
a = torch.empty(4, 1, 6, 1, 10, dtype=torch.int16)
116+
a = torch.from_dlpack(torch.to_dlpack(a))
117+
b = mlc.Tensor(a)
118+
assert b.strides is None
119+
120+
121+
def test_tensor_serialize() -> None:
122+
a = mlc.Tensor(np.arange(24, dtype=np.int16).reshape(2, 3, 4))
123+
a_json = mlc.List([a, a]).json()
124+
b = mlc.List.from_json(a_json)
125+
assert isinstance(b, mlc.List)
126+
assert len(b) == 2
127+
assert isinstance(b[0], mlc.Tensor)
128+
assert isinstance(b[1], mlc.Tensor)
129+
assert b[0].eq_ptr(b[1])
130+
assert np.array_equal(a.numpy(), b[0].numpy())

0 commit comments

Comments
 (0)