From e04e3b886fb33817c808f2614175fbfc92bfe8cb Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Fri, 13 Dec 2024 06:02:10 +0800 Subject: [PATCH] [Examples] Update examples (#129) --- .../bench_grammar_compile_mask_gen.py | 20 ++++++++++--------- .../hf_transformers/transformers_example.py | 3 +-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/benchmark/bench_grammar_compile_mask_gen.py b/examples/benchmark/bench_grammar_compile_mask_gen.py index d7e26b8..afa3238 100644 --- a/examples/benchmark/bench_grammar_compile_mask_gen.py +++ b/examples/benchmark/bench_grammar_compile_mask_gen.py @@ -12,8 +12,9 @@ build_token_enforcer_tokenizer_data, ) from outlines.fsm.guide import Guide, RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema, convert_json_schema_to_str +from outlines.fsm.json_schema import convert_json_schema_to_str from outlines.generate.generator import bias_logits +from outlines.generate.json import build_regex_from_schema from outlines.models import TransformerTokenizer from tqdm import tqdm from transformers import AutoTokenizer @@ -23,18 +24,18 @@ wrong_data_indices = [1] -def xgrammar_build(schema: str, tokenizer_info: TokenizerInfo): - grammar = BuiltinGrammar.json_schema(schema, strict_mode=False) - matcher = GrammarMatcher(grammar, tokenizer_info) +def xgrammar_build(schema: str, grammar_compiler: xgr.GrammarCompiler): + grammar = grammar_compiler.compile_json_schema(schema) + matcher = xgr.GrammarMatcher(grammar) return matcher def xgrammar_exec( - matcher: GrammarMatcher, logits: torch.Tensor, bitmask: torch.Tensor, token_id: int + matcher: xgr.GrammarMatcher, logits: torch.Tensor, bitmask: torch.Tensor, token_id: int ): # Logits processing matcher.fill_next_token_bitmask(bitmask) - matcher.apply_token_bitmask_inplace(logits, bitmask) + xgr.apply_token_bitmask_inplace(logits, bitmask) # Update state assert matcher.accept_token(token_id) return @@ -93,7 +94,8 @@ def lmformatenforcer_exec(token_enforcer: TokenEnforcer, logits: torch.Tensor, t hf_model_path = "meta-llama/Llama-3.1-8B-Instruct" hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_path) - xgrammar_tokenizer_info = TokenizerInfo.from_huggingface(hf_tokenizer) + xgrammar_tokenizer_info = xgr.TokenizerInfo.from_huggingface(hf_tokenizer) + xgrammar_grammar_compiler = xgr.GrammarCompiler(xgrammar_tokenizer_info) outlines_tokenizer = TransformerTokenizer(hf_tokenizer) lmformatenforcer_tokenizer = build_token_enforcer_tokenizer_data(hf_tokenizer) @@ -137,8 +139,8 @@ def lmformatenforcer_exec(token_enforcer: TokenEnforcer, logits: torch.Tensor, t start = time.perf_counter() try: if backend == "xgrammar": - worker = xgrammar_build(schema, xgrammar_tokenizer_info) - bitmask = GrammarMatcher.allocate_token_bitmask(worker.vocab_size) + worker = xgrammar_build(schema, xgrammar_grammar_compiler) + bitmask = xgr.allocate_token_bitmask(worker.vocab_size) elif backend == "outlines": worker = outlines_build(schema, outlines_tokenizer) elif backend == "lmformatenforcer": diff --git a/examples/hf_transformers/transformers_example.py b/examples/hf_transformers/transformers_example.py index efc948e..caf7e53 100644 --- a/examples/hf_transformers/transformers_example.py +++ b/examples/hf_transformers/transformers_example.py @@ -67,5 +67,4 @@ ] responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) for response in responses: - print(response) - print() + print(response, end="\n\n")