From 14430bc8129858c54ed7e30fe33ce04adf74e141 Mon Sep 17 00:00:00 2001 From: Mark Neumann Date: Sun, 14 Jan 2024 14:03:28 -0800 Subject: [PATCH] fix model parsing --- llama_cpp/llama_grammar.py | 1 - tests/test_grammar.py | 39 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 0c3b2e0ef..c02e65642 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1433,7 +1433,6 @@ def _add_rule(self, name: str, rule: str): def visit(self, schema: Dict[str, Any], name: str) -> str: schema_type: Optional[str] = schema.get("type") # type: ignore - assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" rule_name = name or "root" if "$defs" in schema: diff --git a/tests/test_grammar.py b/tests/test_grammar.py index 2e24903c2..ef9392b7a 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -1,4 +1,5 @@ import llama_cpp +import json tree = """ leaf ::= "." @@ -6,8 +7,46 @@ root ::= node """ + def test_grammar_from_string(): grammar = llama_cpp.LlamaGrammar.from_string(tree) assert grammar._n_rules == 3 assert grammar._start_rule_index == 2 assert grammar.grammar is not None + + +def test_composed_pydantic_grammar(): + """ + from pydantic import BaseModel + + class A(BaseModel): + a: int + + class B(BaseModel): + a: A + b: int + """ + + # This schema corresponds to the grammar in the comment above. + # We don't use the pydantic models directly to avoid the dependency. + schema = { + "$defs": { + "A": { + "properties": {"a": {"title": "A", "type": "integer"}}, + "required": ["a"], + "title": "A", + "type": "object", + } + }, + "properties": { + "a": {"$ref": "#/$defs/A"}, + "b": {"title": "B", "type": "integer"}, + }, + "required": ["a", "b"], + "title": "B", + "type": "object", + } + + grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) + + assert grammar.grammar is not None