Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d3f5528

Browse filesBrowse files
committed
fix: from_json_schema oneof/anyof bug. Closes abetlen#1097
1 parent 8eefdbc commit d3f5528
Copy full SHA for d3f5528

File tree

Expand file treeCollapse file tree

2 files changed

+39
-10
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+39
-10
lines changed

‎llama_cpp/llama_grammar.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_grammar.py
+13-10Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,6 @@ def _add_rule(self, name: str, rule: str):
14321432
return key
14331433

14341434
def visit(self, schema: Dict[str, Any], name: str) -> str:
1435-
schema_type: Optional[str] = schema.get("type") # type: ignore
14361435
rule_name = name or "root"
14371436

14381437
if "$defs" in schema:
@@ -1458,7 +1457,19 @@ def visit(self, schema: Dict[str, Any], name: str) -> str:
14581457
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
14591458
return self._add_rule(rule_name, rule)
14601459

1461-
elif schema_type == "object" and "properties" in schema:
1460+
elif "$ref" in schema:
1461+
ref = schema["$ref"]
1462+
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
1463+
# inline $defs
1464+
def_name = ref[len("#/$defs/") :]
1465+
def_schema = self._defs[def_name]
1466+
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
1467+
1468+
1469+
schema_type: Optional[str] = schema.get("type") # type: ignore
1470+
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
1471+
1472+
if schema_type == "object" and "properties" in schema:
14621473
# TODO: `required` keyword
14631474
prop_order = self._prop_order
14641475
prop_pairs = sorted(
@@ -1489,14 +1500,6 @@ def visit(self, schema: Dict[str, Any], name: str) -> str:
14891500
)
14901501
return self._add_rule(rule_name, rule)
14911502

1492-
elif "$ref" in schema:
1493-
ref = schema["$ref"]
1494-
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
1495-
# inline $defs
1496-
def_name = ref[len("#/$defs/") :]
1497-
def_schema = self._defs[def_name]
1498-
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
1499-
15001503
else:
15011504
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
15021505
return self._add_rule(

‎tests/test_grammar.py

Copy file name to clipboardExpand all lines: tests/test_grammar.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,29 @@ class B(BaseModel):
5050
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
5151

5252
assert grammar.grammar is not None
53+
54+
55+
def test_grammar_anyof():
56+
sch = {
57+
"properties": {
58+
"temperature": {
59+
"description": "The temperature mentioned",
60+
"type": "number",
61+
},
62+
"unit": {
63+
"anyOf": [
64+
{
65+
"description": "Unit for temperature",
66+
"enum": ["celsius", "fahrenheit"],
67+
"type": "string",
68+
},
69+
{"type": "null"},
70+
],
71+
},
72+
},
73+
"type": "object",
74+
}
75+
76+
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
77+
78+
assert grammar.grammar is not None

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.