-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathapi.rs
351 lines (305 loc) · 11.5 KB
/
api.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
use std::fmt::{Debug, Display};
use anyhow::{bail, Result};
use derivre::RegexAst;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::{
earley::{lexerspec::LexerSpec, Grammar},
lark::lark_regex_quote,
};
/// This represents a collection of grammars, with a designated
/// "start" grammar at first position.
/// Grammars can refer to each other via GenGrammar nodes.
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct TopLevelGrammar {
pub grammars: Vec<GrammarWithLexer>,
pub max_tokens: Option<usize>,
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
pub enum GrammarInit {
Serialized(TopLevelGrammar),
Internal(Grammar, LexerSpec),
}
/// cbindgen:ignore
pub const DEFAULT_CONTEXTUAL: bool = true;
/// In lark syntax, this can be specified as JSON object after '%llguidance' declaration in the grammar.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LLGuidanceOptions {
/// Normally, when a sequence of bytes is forced by grammar, it is tokenized
/// canonically and forced as tokens.
/// With `no_forcing`, we let the model decide on tokenization.
/// This generally reduces both quality and speed, so should not be used
/// outside of testing.
#[serde(default)]
pub no_forcing: bool,
/// If set, the grammar will allow invalid utf8 byte sequences.
/// Any Unicode regex will cause an error.
/// This is very unlikely what you need.
#[serde(default)]
pub allow_invalid_utf8: bool,
}
impl LLGuidanceOptions {
pub fn apply(&mut self, other: &LLGuidanceOptions) {
if other.no_forcing {
self.no_forcing = true;
}
if other.allow_invalid_utf8 {
self.allow_invalid_utf8 = true;
}
}
}
#[derive(Serialize, Deserialize, Clone, Default)]
pub struct GrammarWithLexer {
/// The name of this grammar, can be used in GenGrammar nodes.
pub name: Option<String>,
/// The JSON schema that the grammar should generate.
/// When this is set, nodes and rx_nodes must be empty.
pub json_schema: Option<Value>,
/// The Lark grammar that the grammar should generate.
/// When this is set, nodes and rx_nodes must be empty.
pub lark_grammar: Option<String>,
// #[serde(flatten)]
// pub options: LLGuidanceOptions,
}
impl Debug for GrammarWithLexer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"GrammarWithLexer [{}]",
if self.lark_grammar.is_some() {
"lark"
} else {
"json"
}
)
}
}
// /// If false, all other lexemes are excluded when this lexeme is recognized.
// /// This is normal behavior for keywords in programming languages.
// /// Set to true for eg. a JSON schema with both `/"type"/` and `/"[^"]*"/` as lexemes,
// /// or for "get"/"set" contextual keywords in C#.
// /// Default value set in GrammarWithLexer.
// contextual: Option<bool>,
// /// It lists the allowed escape sequences, typically one of:
// /// "nrbtf\\\"u" - to allow all JSON escapes, including \u00XX for control characters
// /// this is the default
// /// "nrbtf\\\"" - to disallow \u00XX control characters
// /// "nrt\\\"" - to also disallow unusual escapes (\f and \b)
// /// "" - to disallow all escapes
// /// Note that \uXXXX for non-control characters (code points above U+001F) are never allowed,
// /// as they never have to be quoted in JSON.
// json_allowed_escapes: Option<String>,
/// Optional fields allowed on any Node
#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Eq)]
pub struct NodeProps {
pub max_tokens: Option<usize>,
pub name: Option<String>,
pub capture_name: Option<String>,
}
#[derive(Clone)]
pub struct GenOptions {
/// Regular expression matching the body of generation.
pub body_rx: RegexAst,
/// The whole generation must match `body_rx + stop_rx`.
/// Whatever matched `stop_rx` is discarded.
/// If `stop_rx` is empty, it's assumed to be EOS.
pub stop_rx: RegexAst,
/// When set, the string matching `stop_rx` will be output as a capture
/// with the given name.
pub stop_capture_name: Option<String>,
/// Lazy gen()s take the shortest match. Non-lazy take the longest.
/// If not specified, the gen() is lazy if stop_rx is non-empty.
pub lazy: Option<bool>,
/// Treat stop_rx as suffix, i.e., do not hide it from the LLM
/// (but do not include it in the capture).
pub is_suffix: Option<bool>,
/// Override sampling temperature.
pub temperature: Option<f32>,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct GenGrammarOptions {
pub grammar: GrammarId,
/// Override sampling temperature.
pub temperature: Option<f32>,
}
#[derive(Serialize, Deserialize, Hash, PartialEq, Eq, Clone, Debug)]
#[serde(untagged)]
pub enum GrammarId {
Name(String),
}
impl Display for GrammarId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GrammarId::Name(s) => write!(f, "@{}", s),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct RegexExt {
/// The lexeme should accept any (possibly empty) contiguous sequence of these chunks.
pub substring_chunks: Option<Vec<String>>,
/// Similar to `substring_chunks: s.split(/\s+/)`
pub substring_words: Option<String>,
/// Similar to `substring_chunks: s.split('')`
pub substring_chars: Option<String>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum StopReason {
/// Parser has not emitted stop() yet.
NotStopped,
/// max_tokens limit on the total number of tokens has been reached.
MaxTokensTotal,
/// max_tokens limit on the number of tokens in the top-level parser has been reached. (no longer used)
MaxTokensParser,
/// Top-level parser indicates that no more bytes can be added.
NoExtension,
/// Top-level parser indicates that no more bytes can be added, however it was recognized late.
NoExtensionBias,
/// Top-level parser allowed EOS (as it was in an accepting state), and EOS was generated.
EndOfSentence,
/// Something went wrong with creating a nested parser.
InternalError,
/// The lexer is too complex
LexerTooComplex,
/// The parser is too complex
ParserTooComplex,
}
impl Display for StopReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
serde_json::to_value(self).unwrap().as_str().unwrap()
)
}
}
impl StopReason {
pub fn is_ok(&self) -> bool {
matches!(
self,
StopReason::NotStopped
| StopReason::EndOfSentence
| StopReason::NoExtension
| StopReason::NoExtensionBias
)
}
}
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(default)]
#[repr(C)]
pub struct ParserLimits {
/// For non-ambiguous grammars, this is the maximum "branching factor" of the grammar.
/// For ambiguous grammars, this might get hit much quicker.
/// Default: 2000
pub max_items_in_row: usize,
/// How much "fuel" are we willing to spend to build initial lexer regex AST nodes.
/// Default: 1_000_000
/// Speed: 50k/ms
pub initial_lexer_fuel: u64,
/// Maximum lexer fuel for computation of the whole token mask.
/// Default: 200_000
/// Speed: 14k/ms
pub step_lexer_fuel: u64,
/// Number of Earley items created for the whole token mask.
/// Default: 50_000
/// Speed: 20k/ms
pub step_max_items: usize,
/// Maximum number of lexer states.
/// Affects memory consumption, but not the speed for the most part.
/// Default: 250_000
/// Speed: ~1-2kB of memory per state
pub max_lexer_states: usize,
/// Maximum size of the grammar (symbols in productions)
/// Default: 500_000 (a few megabytes of JSON)
pub max_grammar_size: usize,
/// If true, we'll run any extremely large regexes against the whole
/// trie of the tokenizer while constructing the lexer.
/// This reduces future mask computation time, but increases
/// the time it takes to construct the lexer.
/// Default: true
pub precompute_large_lexemes: bool,
}
impl Default for ParserLimits {
fn default() -> Self {
Self {
max_items_in_row: 2000,
initial_lexer_fuel: 1_000_000, // fhir schema => 500k
step_lexer_fuel: 200_000, //
max_lexer_states: 250_000, //
max_grammar_size: 500_000, // fhir schema => 200k
step_max_items: 50_000, //
precompute_large_lexemes: true,
}
}
}
impl TopLevelGrammar {
pub fn from_lark_or_grammar_list(s: &str) -> Result<Self> {
let first_non_whitespace = s.chars().find(|c| !c.is_whitespace());
if first_non_whitespace.is_none() {
bail!("Empty grammar");
}
if first_non_whitespace == Some('{') {
Ok(serde_json::from_str(s)?)
} else {
Ok(TopLevelGrammar::from_lark(s.to_string()))
}
}
pub fn from_regex(rx: &str) -> Self {
Self::from_grammar(GrammarWithLexer::from_regex(rx))
}
pub fn from_lark(lark_grammar: String) -> Self {
Self::from_grammar(GrammarWithLexer::from_lark(lark_grammar))
}
pub fn from_json_schema(json_schema: Value) -> Self {
Self::from_grammar(GrammarWithLexer::from_json_schema(json_schema))
}
pub fn from_grammar(grammar: GrammarWithLexer) -> Self {
TopLevelGrammar {
grammars: vec![grammar],
max_tokens: None,
}
}
/// The data is of different format, depending on tag:
/// - "regex" - data is regular expression in rust regex format
/// see https://docs.rs/regex/latest/regex/#syntax
/// - "json" or "json_schema" - data is (stringifed) JSON schema
/// see https://github.com/guidance-ai/llguidance/blob/main/docs/json_schema.md
/// - "json_object" - equivalent to JSON schema: {"type":"object"}
/// - "lark" - data is grammar in a variant of Lark syntax
/// see https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md
/// - "llguidance" or "guidance" - data is a list of Lark or JSON schemas in JSON format
pub fn from_tagged_str(tag: &str, data: &str) -> Result<Self> {
match tag {
"regex" => Ok(Self::from_regex(data)),
"json" | "json_schema" => Ok(Self::from_json_schema(serde_json::from_str(data)?)),
"json_object" => Ok(Self::from_json_schema(json!({"type": "object"}))),
"lark" => Ok(Self::from_lark(data.to_string())),
"llguidance" | "guidance" => Self::from_lark_or_grammar_list(data),
_ => bail!("unknown constraint type: {tag}"),
}
}
}
impl GrammarWithLexer {
pub fn from_lark(lark_grammar: String) -> Self {
GrammarWithLexer {
name: Some("lark_grammar".to_string()),
lark_grammar: Some(lark_grammar),
..GrammarWithLexer::default()
}
}
pub fn from_json_schema(json_schema: Value) -> Self {
GrammarWithLexer {
name: Some("json_schema".to_string()),
json_schema: Some(json_schema),
..GrammarWithLexer::default()
}
}
pub fn from_regex(rx: &str) -> Self {
let rx = lark_regex_quote(rx);
let mut r = Self::from_lark(format!("start: /{}/", rx));
r.name = Some("regex".to_string());
r
}
}