Skip to content

Commit

Permalink
get rid of EOS_MARKER
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 26, 2024
1 parent a10c7cd commit e5142ad
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 115 deletions.
27 changes: 16 additions & 11 deletions controllers/derivre/src/regexvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ pub struct RegexVec {
#[derive(Clone, Debug)]
pub struct StateDesc {
pub state: StateID,
pub lowest_accepting: isize, // -1 if no accepting state
pub lowest_accepting: Option<usize>,
pub accepting: SimpleVob,
pub possible: SimpleVob,

possible_lookahead_len: Option<usize>,
lookahead_len: Option<Option<usize>>,
next_byte: Option<NextByte>,
lowest_match: Option<Option<usize>>,
lowest_match: Option<Option<(usize, usize)>>,
}

impl StateDesc {
pub fn is_accepting(&self) -> bool {
self.lowest_accepting != -1
self.lowest_accepting.is_some()
}

pub fn is_dead(&self) -> bool {
Expand Down Expand Up @@ -124,6 +124,10 @@ impl RegexVec {
Ok(Self::new_with_exprset(&exprset, &acc, None))
}

pub fn lazy_regexes(&self) -> &SimpleVob {
&self.lazy
}

pub fn initial_state_all(&mut self) -> StateID {
self.initial_state(&SimpleVob::all_true(self.rx_list.len()))
}
Expand Down Expand Up @@ -155,18 +159,18 @@ impl RegexVec {

pub fn lookahead_len_for_state(&mut self, state: StateID) -> Option<usize> {
let desc = &mut self.state_descs[state.as_usize()];
let idx = desc.lowest_accepting;
if idx < 0 {
if desc.lowest_accepting.is_none() {
return None;
}
let idx = desc.lowest_accepting.unwrap();
if let Some(len) = desc.lookahead_len {
return len;
}
let mut res = None;
let exprs = &self.exprs;
for (idx2, e) in iter_state(&self.rx_sets, state) {
if res.is_none() && exprs.is_nullable(e) {
assert!(idx == idx2 as isize);
assert!(idx == idx2);
res = Some(exprs.lookahead_len(e).unwrap_or(0));
}
}
Expand Down Expand Up @@ -236,7 +240,7 @@ impl RegexVec {
/// Return index of lowest matching regex if any.
/// Lazy regexes match as soon as they accept, while greedy only
/// if they accept and force EOI.
pub fn lowest_match(&mut self, state: StateID) -> Option<usize> {
pub fn lowest_match(&mut self, state: StateID) -> Option<(usize, usize)> {
let desc = &mut self.state_descs[state.as_usize()];
if let Some(lowest_match) = desc.lowest_match {
return lowest_match;
Expand All @@ -247,7 +251,8 @@ impl RegexVec {
continue;
}
if self.lazy[idx] || self.next_byte.next_byte(&self.exprs, e) == NextByte::ForcedEOI {
res = Some(idx);
let len = self.exprs.possible_lookahead_len(e);
res = Some((idx, len));
break;
}
}
Expand Down Expand Up @@ -503,7 +508,7 @@ impl RegexVec {
fn compute_state_desc(&self, state: StateID) -> StateDesc {
let mut res = StateDesc {
state,
lowest_accepting: -1,
lowest_accepting: None,
accepting: SimpleVob::alloc(self.rx_list.len()),
possible: SimpleVob::alloc(self.rx_list.len()),
possible_lookahead_len: None,
Expand All @@ -515,8 +520,8 @@ impl RegexVec {
res.possible.set(idx, true);
if self.exprs().is_nullable(e) {
res.accepting.set(idx, true);
if res.lowest_accepting == -1 {
res.lowest_accepting = idx as isize;
if res.lowest_accepting.is_none() {
res.lowest_accepting = Some(idx);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions controllers/derivre/tests/c_lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ fn c_lexer() {
if new_state == StateID::DEAD {
let desc = rx.state_desc(state);
if desc.is_accepting() {
let lexeme = &patterns[desc.lowest_accepting as usize];
let lexeme = &patterns[desc.lowest_accepting.unwrap()];
println!(
"matched: {:?} {:?}",
// desc,
lexeme,
&C_SAMPLE[start_idx..idx]
);
if patterns[desc.lowest_accepting as usize] == "if" {
if patterns[desc.lowest_accepting.unwrap()] == "if" {
num_if += 1;
}
start_idx = idx;
Expand Down
6 changes: 3 additions & 3 deletions controllers/derivre/tests/sample_multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn multi_matching_sample() {
let state1 = rx.transition_bytes(state0, b"while");
let desc = rx.state_desc(state1);
// first accepting pattern is #1 ("while")
assert!(desc.lowest_accepting == 1);
assert!(desc.lowest_accepting == Some(1));
assert!(desc.accepting[0] == false);
assert!(desc.accepting[1] == true);
// but [a-z]+ is also matching
Expand All @@ -20,7 +20,7 @@ fn multi_matching_sample() {
let state2 = rx.transition_bytes(state0, b"i");
let desc = rx.state_desc(state2);
// after we go through just 'i' from start, we can only match [a-z]+
assert!(desc.lowest_accepting == 2);
assert!(desc.lowest_accepting == Some(2));
assert!(desc.accepting.to_bin_string().as_str() == "001");
// however, the 'if' is still possible (but 'while' is not)
assert!(desc.possible.to_bin_string().as_str() == "101");
Expand All @@ -31,6 +31,6 @@ fn multi_matching_sample() {
let state1 = rx.transition_bytes(state0, b"if");
let desc = rx.state_desc(state1);
// the string matches the identifier rule only
assert!(desc.lowest_accepting == 2);
assert!(desc.lowest_accepting == Some(2));
assert!(desc.accepting.to_bin_string().as_str() == "001");
}
34 changes: 31 additions & 3 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,14 @@ def character_maker2(lm, id, description, valid_weapons):
prompt = ""
grm = optional("A")

grm = "Q: Are dolphins fish?\nA: " + gen("dolphins", regex="Yes|No", max_tokens=10) + \
"\nQ: Are sharks fish?\nA: " + gen("sharks", regex="Yes|No", max_tokens=10)
grm = (
"Q: Are dolphins fish?\nA: "
+ gen("dolphins", regex="Yes|No", max_tokens=10)
+ "\nQ: Are sharks fish?\nA: "
+ gen("sharks", regex="Yes|No", max_tokens=10)
)

grm = one_or_more(gen(regex="[a-z]"))

# grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5)

Expand All @@ -297,6 +303,28 @@ def character_maker2(lm, id, description, valid_weapons):

serialized = grm.ll_serialize()

# with open("tmp/long_json_grammar_req.json", "r") as f:
# with open("tmp/email_regex_grammar.json", "r") as f:
# max_tokens = 2000
# serialized = json.load(f)

x_serialized = {
"grammars": [
{
"greedy_lexer": False,
"nodes": [
{"Join": {"sequence": [1]}},
{"Join": {"sequence": [2, 3]}},
{"Gen": {"body_rx": 0, "stop_rx": "", "temperature": None}},
{"Select": {"among": [4, 5]}},
{"Join": {"sequence": [3, 2]}},
{"String": {"literal": ""}},
],
"rx_nodes": [{"ByteSet": [0, 0, 0, 134217726, 0, 0, 0, 0]}],
}
]
}

x_serialized = {
"grammars": [
{
Expand All @@ -318,7 +346,7 @@ def character_maker2(lm, id, description, valid_weapons):
"greedy_skip_rx": "[\\x20\\x0A\\x0D\\x09]+",
"nodes": [
{"Lexeme": {"rx": "-?(?:0|[1-9][0-9]*)", "contextual": False}}
#{"Lexeme": {"rx": "[ab][ab]", "contextual": False}}
# {"Lexeme": {"rx": "[ab][ab]", "contextual": False}}
],
"rx_nodes": [],
},
Expand Down
73 changes: 21 additions & 52 deletions controllers/llguidance_ctrl/src/earley/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use anyhow::Result;
use derivre::{NextByte, RegexVec, StateDesc};
use std::fmt::Debug;

use super::lexerspec::{LexemeIdx, LexerSpec, EOS_MARKER};
use super::lexerspec::{LexemeIdx, LexerSpec};

const DEBUG: bool = true;

Expand All @@ -18,6 +18,7 @@ macro_rules! debug {
#[derive(Clone)]
pub struct Lexer {
dfa: RegexVec,
#[allow(dead_code)]
spec: LexerSpec,
}

Expand Down Expand Up @@ -69,19 +70,8 @@ impl Lexer {
self.dfa.state_desc(state)
}

pub fn allows_eos(&mut self, state: StateID, allowed_eos_lexemes: &SimpleVob) -> bool {
if allowed_eos_lexemes.is_zero() {
return false;
}

let state = self.dfa.transition_bytes(state, EOS_MARKER);

let accepting = &self.dfa.state_desc(state).accepting;
if accepting.and_is_zero(allowed_eos_lexemes) {
false
} else {
true
}
pub fn allows_eos(&mut self, state: StateID) -> bool {
self.state_info(state).is_accepting()
}

pub fn force_lexeme_end(&self, prev: StateID) -> LexerResult {
Expand All @@ -98,23 +88,16 @@ impl Lexer {
}

pub fn try_lexeme_end(&mut self, prev: StateID) -> LexerResult {
let prev_accepting = self.state_info(prev).accepting.first_bit_set();
let eos_state = self.dfa.transition_bytes(prev, EOS_MARKER);
let eos_accepting = self.state_info(eos_state).accepting.first_bit_set();

let idx = match (prev_accepting, eos_accepting) {
(Some(p), Some(e)) if p < e => p,
(_, Some(e)) => e,
(Some(p), None) => p,
(None, None) => return LexerResult::Error,
};

LexerResult::Lexeme(PreLexeme {
idx: LexemeIdx::new(idx),
byte: None,
byte_next_row: false,
hidden_len: 0,
})
if let Some(idx) = self.state_info(prev).lowest_accepting {
LexerResult::Lexeme(PreLexeme {
idx: LexemeIdx::new(idx),
byte: None,
byte_next_row: false,
hidden_len: 0,
})
} else {
LexerResult::Error
}
}

pub fn check_for_single_byte_lexeme(&mut self, state: StateID, b: u8) -> Option<PreLexeme> {
Expand All @@ -139,39 +122,32 @@ impl Lexer {
if enable_logging {
let info = self.state_info(state);
debug!(
"lex: {:?} -{:?}-> {:?}, acpt={}",
"lex: {:?} -{:?}-> {:?}, acpt={:?}",
prev, byte as char, state, info.lowest_accepting
);
}

if state.is_dead() {
if !self.spec.greedy {
return LexerResult::Error;
}

let info = self.dfa.state_desc(prev);
// we take the first token that matched
// (eg., "while" will match both keyword and identifier, but keyword is first)
if info.is_accepting() {
if let Some(idx) = info.lowest_accepting {
LexerResult::Lexeme(PreLexeme {
idx: LexemeIdx::from_state_desc(info),
idx: LexemeIdx::new(idx),
byte: Some(byte),
byte_next_row: true,
hidden_len: self.dfa.possible_lookahead_len(prev),
hidden_len: 0,
})
} else {
LexerResult::Error
}
} else {
let can_stop_now =
!self.spec.greedy || self.dfa.next_byte(state) == NextByte::ForcedEOI;
let info = self.state_info(state);
if can_stop_now && info.is_accepting() {
if let Some((idx, hidden_len)) = self.dfa.lowest_match(state) {
LexerResult::Lexeme(PreLexeme {
idx: LexemeIdx::from_state_desc(info),
idx: LexemeIdx::new(idx),
byte: Some(byte),
byte_next_row: false,
hidden_len: self.dfa.possible_lookahead_len(state),
hidden_len,
})
} else {
LexerResult::State(state, byte)
Expand All @@ -180,13 +156,6 @@ impl Lexer {
}
}

impl LexemeIdx {
fn from_state_desc(desc: &StateDesc) -> Self {
assert!(desc.lowest_accepting >= 0);
LexemeIdx::new(desc.lowest_accepting as usize)
}
}

impl LexerResult {
#[inline(always)]
pub fn is_error(&self) -> bool {
Expand Down
Loading

0 comments on commit e5142ad

Please sign in to comment.