Transformers-CFG is a fork of the Hugging Face Transformers library adding support for Context-Free Grammar based Constrained Generation methods. The library is trying to keep up-to-date with the main branch of the 🤗 Transformers library. The library tries to offer a compatible interface to llama-cpp project.
pip install git+https://github.com/epfl-dlab/transformers-CFG
The below example can be found in examples/pytorch/text-geenration/grammar_constrained_generation.py
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.grammar_utils import IncrementalGrammarConstraint
from transformers.generation.logits_process import GrammarConstrainedLogitsProcessor
if __name__ == "__main__":
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Load json grammar
with open("examples/grammars/json.gbnf", "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
# Generate
prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"
input_ids = tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
output = model.generate(
input_ids,
do_sample=False,
max_length=50,
num_beams=2,
logits_processor=[grammar_processor],
repetition_penalty=5.0,
num_return_sequences=1,
)
# decode output
generations = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)
"""
'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
"""
We provide a collection of grammars in the examples/grammars
folder, which are mostly identical to the grammars in llama-cpp project.
We try to keep the grammars up-to-date with the original grammars from llama-cpp project.
But up to now, we can not yet guarantee that all grammars from llama-cpp project can be directly used in transformers-CFG.
The list of grammars contains:
json.gbnf
: A grammar for generating valid json objects.c.gbnf
: A grammar for generating valid C programs.chess.gbnf
: A grammar for generating valid chess moves.arithmetic.gbnf
: A grammar for generating valid arithmetic expressions.
- We offer the same grammar interface as llama-cpp project, allowing you to drop-in replace llama-cpp with transformers-CFG.
- We allow you to use any of the models in the 🤗 Transformers library, including the ones that are not supported by llama-cpp.