Skip to content

Commit 88d055d

Browse files
authoredJan 31, 2025··
Merge pull request #34 from mlabs-haskell/rmgaray/fix-tagged-variants
fix: union types with multiple tagged variants not being decoded correctly
2 parents ac0ebbb + 1c62d9f commit 88d055d

File tree

5 files changed

+161
-32
lines changed

5 files changed

+161
-32
lines changed
 

‎conway-cddl/codegen/generators/union.ts

+101-21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { SchemaTable } from "..";
44
export type Variant = {
55
tag: number; // number assigned to the variant in the FooKind enum
66
peek_type: string | string[]; // decode this variant if the CBOR type tag equals any of these values
7+
valid_tags?: number[];
78
name: string; // used in Class.new_foo()
89
type: string; // used to do if (reader.getTag() == tag) type.deserialize()
910
kind_name?: string; // name of the variant in the FooKind enum
@@ -13,6 +14,15 @@ export type GenUnionOptions = {
1314
variants: Variant[];
1415
} & CodeGeneratorBaseOptions;
1516

17+
// We say tagged to refer to variants that are encoded with a CBOR tag
18+
type TaggedVariant = {
19+
tag: number; // number assigned to the variant in the FooKind enum
20+
valid_tags: number[];
21+
name: string; // used in Class.new_foo()
22+
type: string; // used to do if (reader.getTag() == tag) type.deserialize()
23+
kind_name?: string; // name of the variant in the FooKind enum
24+
}
25+
1626
export class GenUnion extends CodeGeneratorBase {
1727
variants: Variant[];
1828

@@ -87,33 +97,103 @@ export class GenUnion extends CodeGeneratorBase {
8797
}
8898

8999
generateDeserialize(reader: string, path: string): string {
100+
const constructUntagged = (v: Variant) => {
101+
let out = "";
102+
if (Array.isArray(v.peek_type)) {
103+
for (let t of v.peek_type) {
104+
out += `case "${t}":\n`;
105+
}
106+
} else {
107+
out += `case "${v.peek_type}":\n`;
108+
}
109+
return out +
110+
`
111+
variant = {
112+
kind: ${this.name}Kind.${v.kind_name ?? v.type},
113+
value: ${this.typeUtils.readType(reader, v.type, `[...${path}, '${v.type}(${v.name})']`)}
114+
};
115+
break;
116+
`
117+
}
118+
119+
const constructTagged = (v: TaggedVariant) => {
120+
if(v.valid_tags.length == 0) {
121+
throw new Error("Expected a non-empty 'valid_tags' field because multiple tagged variants exist. These are needed to disambiguate.")
122+
} else {
123+
return `if ([${v.valid_tags.toString()}].includes(tagNumber)) {
124+
variant = {
125+
kind: ${this.name}Kind.${v.kind_name ?? v.type},
126+
value: ${this.typeUtils.readType(reader, v.type, `[...${path}, '${v.type}(${v.name})']`)}
127+
};
128+
break;
129+
}`
130+
}
131+
}
132+
133+
// split variants into tagged and untagged types
134+
let [taggedVariants, untaggedVariants] = this.variants.reduce((acc, v) => {
135+
let [tagged, untagged] = acc;
136+
if (typeof v.peek_type == "string") {
137+
if (v.peek_type == "tagged") {
138+
let tagged_v: TaggedVariant = {
139+
tag: v.tag,
140+
valid_tags: v.valid_tags ? v.valid_tags : [],
141+
name: v.name,
142+
type: v.type,
143+
kind_name: v.kind_name
144+
}
145+
tagged.push(tagged_v);
146+
} else {
147+
untagged.push(v);
148+
}
149+
} else if (v.peek_type.includes("tagged")) {
150+
let untagged_v: Variant = structuredClone(v)
151+
untagged_v.peek_type = v.peek_type.filter((t) => t != "tagged");
152+
153+
let tagged_v: TaggedVariant = {
154+
tag: v.tag,
155+
valid_tags: v.valid_tags ? v.valid_tags : [],
156+
name: v.name,
157+
type: v.type,
158+
kind_name: v.kind_name
159+
}
160+
161+
untagged.push(untagged_v);
162+
tagged.push(tagged_v);
163+
} else {
164+
untagged.push(v);
165+
}
166+
167+
return [tagged, untagged]
168+
}, [[], []] as [TaggedVariant[], Variant[]]);
169+
90170
return `
91171
let tag = ${reader}.peekType(${path});
92172
let variant: ${this.name}Variant;
93173
94174
switch(tag) {
95-
${this.variants
96-
.map((x) => {
97-
let out = "";
98-
if (Array.isArray(x.peek_type)) {
99-
for (let t of x.peek_type) {
100-
out += `case "${t}":\n`;
101-
}
102-
} else {
103-
out += `case "${x.peek_type}":\n`;
104-
}
105-
return (
106-
out +
107-
`
108-
variant = {
109-
kind: ${this.name}Kind.${x.kind_name ?? x.type},
110-
value: ${this.typeUtils.readType(reader, x.type, `[...${path}, '${x.type}(${x.name})']`)}
111-
};
112-
break;
113-
`
114-
);
115-
})
175+
${untaggedVariants
176+
.map(constructUntagged)
116177
.join("\n")}
178+
${taggedVariants.length > 0
179+
? (taggedVariants.length == 1
180+
? `case "tagged":
181+
variant = {
182+
kind: ${this.name}Kind.${taggedVariants[0].kind_name ?? taggedVariants[0].type},
183+
value: ${this.typeUtils.readType(reader, taggedVariants[0].type, `[...${path}, '${taggedVariants[0].type}(${taggedVariants[0].name})']`)}
184+
};
185+
break;
186+
`
187+
: `case "tagged":
188+
const tagNumber = ${reader}.peekTagNumber(${path});
189+
${constructTagged(taggedVariants[0])}
190+
${taggedVariants.slice(1).map((v) => "else " + constructTagged(v)).join("\n")}
191+
else {
192+
throw new Error("Unexpected tag number " + tagNumber + " (at " + ${path}.join("/") + ")")
193+
}
194+
`)
195+
: ''
196+
}
117197
default:
118198
throw new Error("Unexpected subtype for ${this.name}: " + tag + "(at " + ${path}.join("/") + ")");
119199
}

‎conway-cddl/codegen/types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ export const Schema = Type.Intersect([
120120
Type.Object({
121121
tag: Type.Number(),
122122
peek_type: Type.Union([Type.String(), Type.Array(Type.String())]),
123+
valid_tags: Type.Optional(Type.Array(Type.Number())),
123124
name: Type.String(),
124125
type: Type.String(),
125126
kind_name: Type.Optional(Type.String()),

‎conway-cddl/yaml/custom/plutus_data.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ PlutusData:
33
variants:
44
- tag: 0
55
peek_type: "tagged"
6+
valid_tags: [102, 121, 122, 123, 124, 125, 126, 127]
67
name: constr_plutus_data
78
type: ConstrPlutusData
89
- tag: 1
@@ -15,6 +16,7 @@ PlutusData:
1516
type: PlutusList
1617
- tag: 3
1718
peek_type: ["uint", "nint", "tagged"]
19+
valid_tags: [2, 3]
1820
name: integer
1921
type: CSLBigInt
2022
- tag: 4

‎src/generated.ts

+30-11
Original file line numberDiff line numberDiff line change
@@ -10026,16 +10026,6 @@ export class PlutusData {
1002610026
let variant: PlutusDataVariant;
1002710027

1002810028
switch (tag) {
10029-
case "tagged":
10030-
variant = {
10031-
kind: PlutusDataKind.ConstrPlutusData,
10032-
value: ConstrPlutusData.deserialize(reader, [
10033-
...path,
10034-
"ConstrPlutusData(constr_plutus_data)",
10035-
]),
10036-
};
10037-
break;
10038-
1003910029
case "map":
1004010030
variant = {
1004110031
kind: PlutusDataKind.PlutusMap,
@@ -10052,7 +10042,6 @@ export class PlutusData {
1005210042

1005310043
case "uint":
1005410044
case "nint":
10055-
case "tagged":
1005610045
variant = {
1005710046
kind: PlutusDataKind.CSLBigInt,
1005810047
value: CSLBigInt.deserialize(reader, [...path, "CSLBigInt(integer)"]),
@@ -10066,6 +10055,36 @@ export class PlutusData {
1006610055
};
1006710056
break;
1006810057

10058+
case "tagged":
10059+
const tagNumber = reader.peekTagNumber(path);
10060+
if ([102, 121, 122, 123, 124, 125, 126, 127].includes(tagNumber)) {
10061+
variant = {
10062+
kind: PlutusDataKind.ConstrPlutusData,
10063+
value: ConstrPlutusData.deserialize(reader, [
10064+
...path,
10065+
"ConstrPlutusData(constr_plutus_data)",
10066+
]),
10067+
};
10068+
break;
10069+
} else if ([2, 3].includes(tagNumber)) {
10070+
variant = {
10071+
kind: PlutusDataKind.CSLBigInt,
10072+
value: CSLBigInt.deserialize(reader, [
10073+
...path,
10074+
"CSLBigInt(integer)",
10075+
]),
10076+
};
10077+
break;
10078+
} else {
10079+
throw new Error(
10080+
"Unexpected tag number " +
10081+
tagNumber +
10082+
" (at " +
10083+
path.join("/") +
10084+
")",
10085+
);
10086+
}
10087+
1006910088
default:
1007010089
throw new Error(
1007110090
"Unexpected subtype for PlutusData: " +

‎src/lib/cbor/reader.ts

+27
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,33 @@ export class CBORReader {
5858
throw err
5959
}
6060

61+
peekTagNumber(path: string[]): number {
62+
let tag = this.buffer[0];
63+
64+
let len = tag & 0b11111;
65+
66+
// the value of the length field must be between 0x00 and 0x1b
67+
if (!(len >= 0x00 && len <= 0x1b)) {
68+
let err = new CBORInvalidTag(tag);
69+
err.message += ` (at ${path.join("/")})`;
70+
throw err;
71+
}
72+
73+
let slicedBuffer = this.buffer.slice(1);
74+
75+
// if the length field is less than 0x18, then that itself is the value
76+
// (optimization for small values)
77+
if (len < 0x18) {
78+
return Number(BigInt(len));
79+
}
80+
81+
// Else the length is 2^(length - 0x18)
82+
let nBytes = Math.pow(2, len - 0x18);
83+
84+
let x = Number(bigintFromBytes(nBytes, slicedBuffer));
85+
return x;
86+
}
87+
6188
isBreak(): boolean {
6289
return this.buffer[0] == 0xff;
6390
}

0 commit comments

Comments
 (0)
Please sign in to comment.