1
- #include " mlc/base/traits_dtype.h"
2
1
#include < algorithm>
3
2
#include < cmath>
4
3
#include < cstdint>
@@ -290,7 +289,7 @@ static const std::array<uint8_t, 256> kBase64DecTable = []() {
290
289
291
290
Str Base64Encode (const uint8_t *data, int64_t len) {
292
291
constexpr int BITS_PER_CHAR = 6 ;
293
- Str ret (::mlc::core::StrPad::Allocator::NewWithPad<uint8_t >(((len + 2 ) / 3 ) * 4 , 0 ));
292
+ Str ret (::mlc::core::StrPad::Allocator::NewWithPad<uint8_t >(((len + 2 ) / 3 ) * 4 + 1 , 0 ));
294
293
uint8_t *out = reinterpret_cast <uint8_t *>(ret.get ()->::MLCStr::data);
295
294
int64_t &out_len = ret.get ()->::MLCStr::length;
296
295
for (int64_t i = 0 ; i < len; i += 3 ) {
@@ -323,14 +322,16 @@ Str Base64Encode(const uint8_t *data, int64_t len) {
323
322
}
324
323
}
325
324
}
325
+ out[out_len] = ' \0 ' ;
326
326
return ret;
327
327
}
328
328
329
329
Str Base64Decode (const uint8_t *data, int64_t len) {
330
330
if (len % 4 != 0 ) {
331
- MLC_THROW (ValueError) << " Base64Decode: Input length not multiple of 4." ;
331
+ MLC_THROW (ValueError) << " Base64Decode: Input length not multiple of 4: length = " << len
332
+ << " , data = " << reinterpret_cast <const char *>(data);
332
333
}
333
- Str ret (::mlc::core::StrPad::Allocator::NewWithPad<uint8_t >((len / 4 ) * 3 , 0 ));
334
+ Str ret (::mlc::core::StrPad::Allocator::NewWithPad<uint8_t >((len / 4 ) * 3 + 1 , 0 ));
334
335
uint8_t *out = reinterpret_cast <uint8_t *>(ret.get ()->::MLCStr::data);
335
336
int64_t &result_len = ret.get ()->::MLCStr::length;
336
337
for (int64_t i = 0 ; i < len; i += 4 ) {
@@ -359,6 +360,7 @@ Str Base64Decode(const uint8_t *data, int64_t len) {
359
360
out[result_len++] = byte_val;
360
361
}
361
362
}
363
+ out[result_len] = ' \0 ' ;
362
364
return ret;
363
365
}
364
366
@@ -1155,7 +1157,7 @@ Str TensorToBytes(const DLTensor *src) {
1155
1157
int64_t numel = ::mlc::core::ShapeToNumel (ndim, src->shape );
1156
1158
int32_t elem_size = ::mlc::base::DataTypeSize (src->dtype );
1157
1159
int64_t total_bytes = 8 + 4 + 4 + 8 * ndim + numel * elem_size;
1158
- Str ret (::mlc::core::StrPad::Allocator::NewWithPad<uint8_t >(total_bytes, total_bytes));
1160
+ Str ret (::mlc::core::StrPad::Allocator::NewWithPad<uint8_t >(total_bytes + 1 , total_bytes));
1159
1161
uint8_t *data_ptr = reinterpret_cast <uint8_t *>(ret->data ());
1160
1162
int64_t tail = 0 ;
1161
1163
WriteElem<8 >(data_ptr, &tail, static_cast <uint64_t >(kMLCTensorMagic ));
@@ -1165,6 +1167,10 @@ Str TensorToBytes(const DLTensor *src) {
1165
1167
WriteElem<8 >(data_ptr, &tail, src->shape [i]);
1166
1168
}
1167
1169
WriteElemMany (data_ptr, &tail, static_cast <uint8_t *>(src->data ), elem_size, numel);
1170
+ data_ptr[tail] = ' \0 ' ;
1171
+ if (tail != total_bytes) {
1172
+ MLC_THROW (InternalError) << " SaveDLPack: Internal error in serialization." ;
1173
+ }
1168
1174
return ret;
1169
1175
}
1170
1176
@@ -1302,8 +1308,9 @@ inline mlc::Str Serialize(Any any) {
1302
1308
};
1303
1309
1304
1310
std::unordered_map<Object *, int32_t > topo_indices;
1311
+ std::vector<TensorObj *> tensors;
1305
1312
std::ostringstream os;
1306
- auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os,
1313
+ auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, &tensors,
1307
1314
is_first_object = true ](Object *object, MLCTypeInfo *type_info) mutable -> void {
1308
1315
int32_t &topo_index = topo_indices[object];
1309
1316
if (topo_index == 0 ) {
@@ -1331,10 +1338,11 @@ inline mlc::Str Serialize(Any any) {
1331
1338
emitter (nullptr , &kv.first );
1332
1339
emitter (nullptr , &kv.second );
1333
1340
}
1341
+ } else if (TensorObj *tensor = object->TryCast <TensorObj>()) {
1342
+ (*os) << " , " << tensors.size ();
1343
+ tensors.push_back (tensor);
1334
1344
} else if (object->IsInstance <FuncObj>() || object->IsInstance <ErrorObj>()) {
1335
1345
MLC_THROW (TypeError) << " Unserializable type: " << object->GetTypeKey ();
1336
- } else if (object->IsInstance <TensorObj>()) {
1337
- // TODO: tensors
1338
1346
} else if (object->IsInstance <OpaqueObj>()) {
1339
1347
MLC_THROW (TypeError) << " Cannot serialize `mlc.Opaque` of type: " << object->Cast <OpaqueObj>()->opaque_type_name ;
1340
1348
} else {
@@ -1375,11 +1383,24 @@ inline mlc::Str Serialize(Any any) {
1375
1383
}
1376
1384
os << ' "' << type_keys[i] << ' \" ' ;
1377
1385
}
1378
- os << " ]}" ;
1386
+ os << " ]" ;
1387
+ if (!tensors.empty ()) {
1388
+ os << " , \" tensors\" : [" ;
1389
+ for (size_t i = 0 ; i < tensors.size (); ++i) {
1390
+ if (i > 0 ) {
1391
+ os << " , " ;
1392
+ }
1393
+ Str b64 = tensors[i]->ToBase64 ();
1394
+ os << ' "' << b64->data () << ' "' ;
1395
+ }
1396
+ os << " ]" ;
1397
+ }
1398
+ os << " }" ;
1379
1399
return os.str ();
1380
1400
}
1381
1401
1382
1402
inline Any Deserialize (const char *json_str, int64_t json_str_len) {
1403
+ int32_t json_type_index_tensor = -1 ;
1383
1404
// Step 0. Parse JSON string
1384
1405
UDict json_obj = JSONLoads (json_str, json_str_len);
1385
1406
// Step 1. type_key => constructors
@@ -1388,7 +1409,12 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
1388
1409
constructors.reserve (type_keys.size ());
1389
1410
for (Str type_key : type_keys) {
1390
1411
int32_t type_index = Lib::GetTypeIndex (type_key->data ());
1391
- FuncObj *func = Lib::_init (type_index);
1412
+ FuncObj *func = nullptr ;
1413
+ if (type_index != kMLCTensor ) {
1414
+ func = Lib::_init (type_index);
1415
+ } else {
1416
+ json_type_index_tensor = static_cast <int32_t >(constructors.size ());
1417
+ }
1392
1418
constructors.push_back (func);
1393
1419
}
1394
1420
auto invoke_init = [&constructors](UList args) {
@@ -1398,13 +1424,29 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
1398
1424
&ret);
1399
1425
return ret;
1400
1426
};
1401
- // Step 2. Translate JSON object to objects
1427
+ // Step 2. Handle tensors
1428
+ std::vector<Tensor> tensors;
1429
+ if (json_obj->count (" tensors" )) {
1430
+ UList tensors_b64 = json_obj->at (" tensors" );
1431
+ while (!tensors_b64->empty ()) {
1432
+ Tensor tensor = Tensor::FromBase64 (tensors_b64->back ());
1433
+ tensors.push_back (tensor);
1434
+ tensors_b64->pop_back ();
1435
+ }
1436
+ json_obj->erase (" tensors" );
1437
+ std::reverse (tensors.begin (), tensors.end ());
1438
+ }
1439
+ // Step 3. Translate JSON object to objects
1402
1440
UList values = json_obj->at (" values" );
1403
1441
for (int64_t i = 0 ; i < values->size (); ++i) {
1404
1442
Any obj = values[i];
1405
1443
if (obj.type_index == kMLCList ) {
1406
1444
UList list = obj.operator UList ();
1407
1445
int32_t json_type_index = list[0 ];
1446
+ if (json_type_index == json_type_index_tensor) {
1447
+ values[i] = tensors[list[1 ].operator int32_t ()];
1448
+ continue ;
1449
+ }
1408
1450
for (int64_t j = 1 ; j < list.size (); ++j) {
1409
1451
Any arg = list[j];
1410
1452
if (arg.type_index == kMLCInt ) {
@@ -1487,23 +1529,38 @@ Any JSONDeserialize(AnyView json_str) {
1487
1529
1488
1530
Str JSONSerialize (AnyView source) { return ::mlc::Serialize (source); }
1489
1531
1490
- Tensor TensorFromBytes (const StrObj *src) {
1491
- return ::mlc::TensorFromBytes (reinterpret_cast <const uint8_t *>(src->::MLCStr::data), src->length ());
1492
- }
1493
-
1494
1532
Str TensorToBytes (const TensorObj *src) {
1495
1533
return ::mlc::TensorToBytes (&src->tensor ); //
1496
1534
}
1497
1535
1498
- Tensor TensorFromBase64 (const StrObj *src) {
1499
- Str bytes = ::mlc::Base64Decode (reinterpret_cast <const uint8_t *>(src->data ()), src->size ());
1500
- return ::mlc::TensorFromBytes (reinterpret_cast <const uint8_t *>(bytes->data ()), bytes->size ());
1501
- }
1502
-
1503
1536
Str TensorToBase64 (const TensorObj *src) {
1504
1537
Str bytes = ::mlc::TensorToBytes (&src->tensor );
1505
1538
return ::mlc::Base64Encode (reinterpret_cast <uint8_t *>(bytes->data ()), bytes->size ());
1506
1539
}
1507
1540
1541
+ Tensor TensorFromBytes (AnyView any) {
1542
+ if (any.type_index == kMLCRawStr ) {
1543
+ const char *src = any;
1544
+ int64_t len = std::strlen (src);
1545
+ return ::mlc::TensorFromBytes (reinterpret_cast <const uint8_t *>(src), len);
1546
+ } else {
1547
+ Str src = any;
1548
+ return ::mlc::TensorFromBytes (reinterpret_cast <const uint8_t *>(src->::MLCStr::data), src->length ());
1549
+ }
1550
+ }
1551
+
1552
+ Tensor TensorFromBase64 (AnyView any) {
1553
+ if (any.type_index == kMLCRawStr ) {
1554
+ const char *src = any;
1555
+ int64_t len = std::strlen (src);
1556
+ Str bytes = ::mlc::Base64Decode (reinterpret_cast <const uint8_t *>(src), len);
1557
+ return ::mlc::TensorFromBytes (reinterpret_cast <const uint8_t *>(bytes->data ()), bytes->size ());
1558
+ } else {
1559
+ Str src = any;
1560
+ Str bytes = ::mlc::Base64Decode (reinterpret_cast <const uint8_t *>(src->::MLCStr::data), src->length ());
1561
+ return ::mlc::TensorFromBytes (reinterpret_cast <const uint8_t *>(bytes->data ()), bytes->size ());
1562
+ }
1563
+ }
1564
+
1508
1565
} // namespace registry
1509
1566
} // namespace mlc
0 commit comments