From 937f3c470f1f79e935bdd67b6b2abefe6d0d1dc6 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Wed, 25 Dec 2024 23:50:42 +0800 Subject: [PATCH] [Feature] Enhance JSON Schema Converter (#134) This PR enhances the JSONSchemaConverter: - Support `oneOf` - Support `allOf` with one option - Support `$ref` with any path (previously only those starting with `#/$def/` is supported) - Support object with `properties` specified but not `type`, and array with `item` specified but not `type` This PR also completes support for pydantic V1. --- cpp/json_schema_converter.cc | 127 ++++++-- python/xgrammar/compiler.py | 11 +- tests/python/test_json_schema_converter.py | 344 +++++++++++++++++---- 3 files changed, 405 insertions(+), 77 deletions(-) diff --git a/cpp/json_schema_converter.cc b/cpp/json_schema_converter.cc index 89c3098..ecc664c 100644 --- a/cpp/json_schema_converter.cc +++ b/cpp/json_schema_converter.cc @@ -189,7 +189,7 @@ class JSONSchemaConverter { std::string VisitRef(const picojson::object& schema, const std::string& rule_name); /*! \brief Get the schema from the URI. */ - picojson::value URIToSchema(const picojson::value& uri); + picojson::value URIToSchema(const std::string& uri); /*! \brief Visit a const schema. */ std::string VisitConst(const picojson::object& schema, const std::string& rule_name); @@ -203,6 +203,11 @@ class JSONSchemaConverter { /*! \brief Visit an anyOf schema. */ std::string VisitAnyOf(const picojson::object& schema, const std::string& rule_name); + picojson::value FuseAllOfSchema(const std::vector& schemas); + + /*! \brief Visit an allOf schema. */ + std::string VisitAllOf(const picojson::object& schema, const std::string& rule_name); + /*! \brief Visit a true schema that can match anything. */ std::string VisitAny(const picojson::value& schema, const std::string& rule_name); @@ -456,7 +461,7 @@ void JSONSchemaConverter::WarnUnsupportedKeywords( return; } - XGRAMMAR_DCHECK(schema.is()); + XGRAMMAR_DCHECK(schema.is()) << "Schema should be an object or bool"; WarnUnsupportedKeywords(schema.get(), keywords, verbose); } @@ -541,15 +546,14 @@ std::string JSONSchemaConverter::VisitSchema( const picojson::value& schema, const std::string& rule_name ) { if (schema.is()) { - XGRAMMAR_CHECK(schema.get()); + XGRAMMAR_CHECK(schema.get()) << "Schema should not be false: it cannot accept any value"; return VisitAny(schema, rule_name); } + XGRAMMAR_CHECK(schema.is()) << "Schema should be an object or bool"; WarnUnsupportedKeywords( schema, { - "allof", - "oneof", "not", "if", "then", @@ -559,8 +563,6 @@ std::string JSONSchemaConverter::VisitSchema( } ); - XGRAMMAR_CHECK(schema.is()); - const auto& schema_obj = schema.get(); if (schema_obj.count("$ref")) { @@ -569,8 +571,10 @@ std::string JSONSchemaConverter::VisitSchema( return VisitConst(schema_obj, rule_name); } else if (schema_obj.count("enum")) { return VisitEnum(schema_obj, rule_name); - } else if (schema_obj.count("anyOf")) { + } else if (schema_obj.count("anyOf") || schema_obj.count("oneOf")) { return VisitAnyOf(schema_obj, rule_name); + } else if (schema_obj.count("allOf")) { + return VisitAllOf(schema_obj, rule_name); } else if (schema_obj.count("type")) { const std::string& type = schema_obj.at("type").get(); if (type == "integer") { @@ -591,6 +595,12 @@ std::string JSONSchemaConverter::VisitSchema( XGRAMMAR_LOG(FATAL) << "Unsupported type " << type << " in schema " << schema.serialize(false); } + } else if (schema_obj.count("properties") || schema_obj.count("additionalProperties") || + schema_obj.count("unevaluatedProperties")) { + return VisitObject(schema_obj, rule_name); + } else if (schema_obj.count("items") || schema_obj.count("prefixItems") || + schema_obj.count("unevaluatedItems")) { + return VisitArray(schema_obj, rule_name); } // If no above keyword is detected, we treat it as any @@ -600,9 +610,11 @@ std::string JSONSchemaConverter::VisitSchema( std::string JSONSchemaConverter::VisitRef( const picojson::object& schema, const std::string& rule_name ) { - XGRAMMAR_CHECK(schema.count("$ref")); - picojson::value new_schema = URIToSchema(schema.at("$ref")); + XGRAMMAR_CHECK(schema.count("$ref") && schema.at("$ref").is()) + << "Schema $ref should be a string"; + picojson::value new_schema = URIToSchema(schema.at("$ref").get()); if (!new_schema.is()) { + XGRAMMAR_CHECK(new_schema.is()) << "Schema should be an object or bool"; picojson::object new_schema_obj = new_schema.get(); for (const auto& [k, v] : schema) { if (k != "$ref") { @@ -614,13 +626,32 @@ std::string JSONSchemaConverter::VisitRef( return VisitSchema(new_schema, rule_name); } -picojson::value JSONSchemaConverter::URIToSchema(const picojson::value& uri) { - if (uri.get().substr(0, 8) == "#/$defs/") { - return json_schema_.get("$defs").get(uri.get().substr(8)); +picojson::value JSONSchemaConverter::URIToSchema(const std::string& uri) { + if (uri.size() < 2 || uri[0] != '#' || uri[1] != '/') { + XGRAMMAR_LOG(WARNING) << "Now only support URI starting with '#/' but got " << uri; + return picojson::value(true); + } + std::vector parts; + std::stringstream ss(uri.substr(2)); + std::string part; + while (std::getline(ss, part, '/')) { + if (!part.empty()) { + parts.push_back(part); + } + } + + picojson::value current = json_schema_; + for (const auto& part : parts) { + XGRAMMAR_CHECK(current.is() && current.contains(part)) + << "Cannot find field " << part << " in " << current.serialize(false); + current = current.get(part); } - XGRAMMAR_LOG(WARNING) << "Now only support URI starting with '#/$defs/' but got " - << uri.serialize(false); - return picojson::value(true); + + // Handle recursive reference + if (current.is() && current.get().count("$ref")) { + return URIToSchema(current.get().at("$ref").get()); + } + return current; } std::string JSONSchemaConverter::VisitConst( @@ -665,10 +696,12 @@ std::string JSONSchemaConverter::JSONStrToPrintableStr(const std::string& json_s std::string JSONSchemaConverter::VisitAnyOf( const picojson::object& schema, const std::string& rule_name ) { - XGRAMMAR_CHECK(schema.count("anyOf")); + XGRAMMAR_CHECK(schema.count("anyOf") || schema.count("oneOf")); std::string result = ""; int idx = 0; - for (auto anyof_schema : schema.at("anyOf").get()) { + auto anyof_schema = schema.count("anyOf") ? schema.at("anyOf") : schema.at("oneOf"); + XGRAMMAR_CHECK(anyof_schema.is()) << "anyOf or oneOf must be an array"; + for (auto anyof_schema : anyof_schema.get()) { if (idx != 0) { result += " | "; } @@ -678,6 +711,29 @@ std::string JSONSchemaConverter::VisitAnyOf( return result; } +picojson::value JSONSchemaConverter::FuseAllOfSchema(const std::vector& schemas) { + picojson::object fused_schema; + XGRAMMAR_LOG(WARNING) << "Support for allOf with multiple options is still ongoing"; + return picojson::value(fused_schema); +} + +std::string JSONSchemaConverter::VisitAllOf( + const picojson::object& schema, const std::string& rule_name +) { + // We support common usecases of AllOf, but not all, because it's impossible to support all + // cases with CFG + XGRAMMAR_CHECK(schema.count("allOf")); + XGRAMMAR_CHECK(schema.at("allOf").is()) << "allOf must be an array"; + auto all_array = schema.at("allOf").get(); + // Case 1: allOf is a single schema + if (all_array.size() == 1) { + return VisitSchema(all_array[0], rule_name + "_case_0"); + } + // Case 2: allOf is a list of schemas, we fuse them into a single schema + auto fused_schema = FuseAllOfSchema(all_array); + return VisitSchema(fused_schema, rule_name); +} + std::string JSONSchemaConverter::VisitAny( const picojson::value& schema, const std::string& rule_name ) { @@ -820,19 +876,36 @@ std::string JSONSchemaConverter::VisitInteger( schema.count("exclusiveMaximum")) { std::optional start, end; if (schema.count("minimum")) { + XGRAMMAR_CHECK(schema.at("minimum").is() || schema.at("minimum").is()) + << "minimum must be a number"; double start_double = schema.at("minimum").get(); + XGRAMMAR_CHECK(start_double == static_cast(start_double)) + << "minimum must be an integer"; start = static_cast(start_double); } if (schema.count("exclusiveMinimum")) { + XGRAMMAR_CHECK( + schema.at("exclusiveMinimum").is() || schema.at("exclusiveMinimum").is() + ) << "exclusiveMinimum must be a number"; double start_double = schema.at("exclusiveMinimum").get(); + XGRAMMAR_CHECK(start_double == static_cast(start_double)) + << "exclusiveMinimum must be an integer"; start = static_cast(start_double); } if (schema.count("maximum")) { + XGRAMMAR_CHECK(schema.at("maximum").is() || schema.at("maximum").is()) + << "maximum must be a number"; double end_double = schema.at("maximum").get(); + XGRAMMAR_CHECK(end_double == static_cast(end_double)) << "maximum must be an integer"; end = static_cast(end_double); } if (schema.count("exclusiveMaximum")) { + XGRAMMAR_CHECK( + schema.at("exclusiveMaximum").is() || schema.at("exclusiveMaximum").is() + ) << "exclusiveMaximum must be a number"; double end_double = schema.at("exclusiveMaximum").get(); + XGRAMMAR_CHECK(end_double == static_cast(end_double)) + << "exclusiveMaximum must be an integer"; end = static_cast(end_double); } range_regex = GenerateRangeRegex(start, end); @@ -902,8 +975,10 @@ std::string JSONSchemaConverter::VisitNull( std::string JSONSchemaConverter::VisitArray( const picojson::object& schema, const std::string& rule_name ) { - XGRAMMAR_CHECK(schema.count("type")); - XGRAMMAR_CHECK(schema.at("type").get() == "array"); + XGRAMMAR_CHECK( + (schema.count("type") && schema.at("type").get() == "array") || + schema.count("items") || schema.count("prefixItems") || schema.count("unevaluatedItems") + ); WarnUnsupportedKeywords( schema, { @@ -922,6 +997,8 @@ std::string JSONSchemaConverter::VisitArray( // 1. Handle prefix items if (schema.count("prefixItems")) { + XGRAMMAR_CHECK(schema.at("prefixItems").is()) + << "prefixItems must be an array"; const auto& prefix_items = schema.at("prefixItems").get(); for (int i = 0; i < static_cast(prefix_items.size()); ++i) { XGRAMMAR_CHECK(prefix_items[i].is()); @@ -1117,8 +1194,11 @@ std::string JSONSchemaConverter::GetPartialRuleForPropertiesContainRequired( std::string JSONSchemaConverter::VisitObject( const picojson::object& schema, const std::string& rule_name ) { - XGRAMMAR_CHECK(schema.count("type")); - XGRAMMAR_CHECK(schema.at("type").get() == "object"); + XGRAMMAR_CHECK( + (schema.count("type") && schema.at("type").get() == "object") || + schema.count("properties") || schema.count("additionalProperties") || + schema.count("unevaluatedProperties") + ); WarnUnsupportedKeywords( schema, { @@ -1140,6 +1220,8 @@ std::string JSONSchemaConverter::VisitObject( // 1. Handle properties std::vector> properties; if (schema.count("properties")) { + XGRAMMAR_CHECK(schema.at("properties").is()) + << "properties must be an object"; auto properties_obj = schema.at("properties").get(); for (const auto& key : properties_obj.ordered_keys()) { properties.push_back({key, properties_obj.at(key)}); @@ -1148,6 +1230,7 @@ std::string JSONSchemaConverter::VisitObject( std::unordered_set required; if (schema.count("required")) { + XGRAMMAR_CHECK(schema.at("required").is()) << "required must be an array"; for (const auto& required_prop : schema.at("required").get()) { required.insert(required_prop.get()); } diff --git a/python/xgrammar/compiler.py b/python/xgrammar/compiler.py index a758961..8b13e81 100644 --- a/python/xgrammar/compiler.py +++ b/python/xgrammar/compiler.py @@ -102,7 +102,16 @@ def compile_json_schema( The compiled grammar. """ if isinstance(schema, type) and issubclass(schema, BaseModel): - schema = json.dumps(schema.model_json_schema()) + if hasattr(schema, "model_json_schema"): + # pydantic 2.x + schema = json.dumps(schema.model_json_schema()) + elif hasattr(schema, "schema_json"): + # pydantic 1.x + schema = json.dumps(schema.schema_json()) + else: + raise ValueError( + "The schema should have a model_json_schema or json_schema method." + ) return CompiledGrammar._create_from_handle( self._handle.compile_json_schema( diff --git a/tests/python/test_json_schema_converter.py b/tests/python/test_json_schema_converter.py index f5700f9..fcd3df7 100644 --- a/tests/python/test_json_schema_converter.py +++ b/tests/python/test_json_schema_converter.py @@ -29,9 +29,9 @@ def check_schema_with_grammar( assert json_schema_ebnf == expected_grammar_ebnf -def check_schema_with_json( +def check_schema_with_instance( schema: Dict[str, Any], - json_str: str, + instance: Union[str, BaseModel, Any], is_accepted: bool = True, any_whitespace: bool = True, indent: Optional[int] = None, @@ -46,29 +46,21 @@ def check_schema_with_json( strict_mode=strict_mode, ) + # instance: pydantic model, json string, or any other object (dumped to json string) + if isinstance(instance, BaseModel): + instance = json.dumps( + instance.model_dump(mode="json", round_trip=True), indent=indent, separators=separators + ) + elif not isinstance(instance, str): + instance = json.dumps(instance, indent=indent, separators=separators) + if is_accepted: - assert _is_grammar_accept_string(json_schema_grammar, json_str) + assert _is_grammar_accept_string(json_schema_grammar, instance) else: - assert not _is_grammar_accept_string(json_schema_grammar, json_str) + assert not _is_grammar_accept_string(json_schema_grammar, instance) -def check_schema_with_instance( - schema: Dict[str, Any], - instance: BaseModel, - is_accepted: bool = True, - any_whitespace: bool = True, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - instance_obj = instance.model_dump(mode="json", round_trip=True) - instance_str = json.dumps(instance_obj, indent=indent, separators=separators) - check_schema_with_json( - schema, instance_str, is_accepted, any_whitespace, indent, separators, strict_mode - ) - - -def test_basic() -> None: +def test_basic(): class MainModel(BaseModel): integer_field: int number_field: float @@ -129,7 +121,7 @@ class MainModel(BaseModel): check_schema_with_instance(schema, instance_empty, is_accepted=False, any_whitespace=False) -def test_indent() -> None: +def test_indent(): class MainModel(BaseModel): array_field: List[str] tuple_field: Tuple[str, int, List[str]] @@ -166,7 +158,7 @@ class MainModel(BaseModel): ) -def test_non_strict() -> None: +def test_non_strict(): class Foo(BaseModel): pass @@ -215,10 +207,12 @@ class MainModel(BaseModel): check_schema_with_grammar( schema, ebnf_grammar, any_whitespace=False, indent=2, strict_mode=False ) - check_schema_with_json(schema, instance_json, any_whitespace=False, indent=2, strict_mode=False) + check_schema_with_instance( + schema, instance_json, any_whitespace=False, indent=2, strict_mode=False + ) -def test_enum_const() -> None: +def test_enum_const(): class Field(Enum): FOO = "foo" BAR = "bar" @@ -254,7 +248,7 @@ class MainModel(BaseModel): check_schema_with_instance(schema, instance, any_whitespace=False) -def test_optional() -> None: +def test_optional(): class MainModel(BaseModel): num: int = 0 opt_bool: Optional[bool] = None @@ -285,12 +279,14 @@ class MainModel(BaseModel): instance = MainModel(size=None) check_schema_with_instance(schema, instance, any_whitespace=False) - check_schema_with_json(schema, '{"size": null}', any_whitespace=False) - check_schema_with_json(schema, '{"size": null, "name": "foo"}', any_whitespace=False) - check_schema_with_json(schema, '{"num": 1, "size": null, "name": "foo"}', any_whitespace=False) + check_schema_with_instance(schema, '{"size": null}', any_whitespace=False) + check_schema_with_instance(schema, '{"size": null, "name": "foo"}', any_whitespace=False) + check_schema_with_instance( + schema, '{"num": 1, "size": null, "name": "foo"}', any_whitespace=False + ) -def test_all_optional() -> None: +def test_all_optional(): class MainModel(BaseModel): size: int = 0 state: bool = False @@ -317,8 +313,8 @@ class MainModel(BaseModel): instance = MainModel(size=42, state=True, num=3.14) check_schema_with_instance(schema, instance, any_whitespace=False) - check_schema_with_json(schema, '{"state": false}', any_whitespace=False) - check_schema_with_json(schema, '{"size": 1, "num": 1.5}', any_whitespace=False) + check_schema_with_instance(schema, '{"state": false}', any_whitespace=False) + check_schema_with_instance(schema, '{"size": 1, "num": 1.5}', any_whitespace=False) ebnf_grammar_non_strict = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] basic_string_sub ::= ("\"" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub) (= [ \n\t]* [,}\]:]) @@ -340,13 +336,13 @@ class MainModel(BaseModel): schema, ebnf_grammar_non_strict, any_whitespace=False, strict_mode=False ) - check_schema_with_json( + check_schema_with_instance( schema, '{"size": 1, "num": 1.5, "other": false}', any_whitespace=False, strict_mode=False ) - check_schema_with_json(schema, '{"other": false}', any_whitespace=False, strict_mode=False) + check_schema_with_instance(schema, '{"other": false}', any_whitespace=False, strict_mode=False) -def test_empty() -> None: +def test_empty(): class MainModel(BaseModel): pass @@ -369,10 +365,10 @@ class MainModel(BaseModel): instance = MainModel() check_schema_with_instance(schema, instance, any_whitespace=False) - check_schema_with_json(schema, '{"tmp": 123}', any_whitespace=False, strict_mode=False) + check_schema_with_instance(schema, '{"tmp": 123}', any_whitespace=False, strict_mode=False) -def test_reference() -> None: +def test_reference(): class Foo(BaseModel): count: int size: Optional[float] = None @@ -413,7 +409,103 @@ class MainModel(BaseModel): check_schema_with_instance(schema, instance, any_whitespace=False) -def test_union() -> None: +def test_reference_schema(): + # Test simple reference with $defs + schema = { + "type": "object", + "properties": {"value": {"$ref": "#/$defs/nested"}}, + "required": ["value"], + "$defs": { + "nested": { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + }, + } + + instance = {"value": {"name": "John", "age": 30}} + instance_rejected = {"value": {"name": "John"}} + + check_schema_with_instance(schema, instance, any_whitespace=False) + check_schema_with_instance(schema, instance_rejected, is_accepted=False, any_whitespace=False) + + # Test simple reference with definitions + schema_def = { + "type": "object", + "properties": {"value": {"$ref": "#/definitions/nested"}}, + "required": ["value"], + "definitions": { + "nested": { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + }, + } + + check_schema_with_instance(schema_def, instance, any_whitespace=False) + check_schema_with_instance( + schema_def, instance_rejected, is_accepted=False, any_whitespace=False + ) + + # Test multi-level reference path + schema_multi = { + "type": "object", + "properties": {"value": {"$ref": "#/$defs/level1/level2/nested"}}, + "required": ["value"], + "$defs": { + "level1": { + "level2": { + "nested": { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + } + } + }, + } + + check_schema_with_instance(schema_multi, instance, any_whitespace=False) + check_schema_with_instance( + schema_multi, instance_rejected, is_accepted=False, any_whitespace=False + ) + + # Test nested reference + schema_nested = { + "type": "object", + "properties": {"value": {"$ref": "#/definitions/node_a"}}, + "required": ["value"], + "definitions": { + "node_a": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "child": {"$ref": "#/definitions/node_b"}, + }, + "required": ["name"], + }, + "node_b": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + }, + "required": ["id"], + }, + }, + } + + instance_nested = {"value": {"name": "first", "child": {"id": 1}}} + instance_nested_rejected = {"value": {"name": "first", "child": {}}} + + check_schema_with_instance(schema_nested, instance_nested, any_whitespace=False) + check_schema_with_instance( + schema_nested, instance_nested_rejected, is_accepted=False, any_whitespace=False + ) + + +def test_union(): class Cat(BaseModel): name: str color: str @@ -447,12 +539,51 @@ class Dog(BaseModel): check_schema_with_instance( model_schema, Dog(name="doggy", breed="bulldog"), any_whitespace=False ) - check_schema_with_json( + check_schema_with_instance( model_schema, '{"name": "kitty", "test": "black"}', False, any_whitespace=False ) -def test_alias() -> None: +def test_anyof_oneof(): + schema = { + "type": "object", + "properties": { + "name": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ], + }, + }, + } + schema_accepted_1 = '{"name": "John"}' + schema_accepted_2 = '{"name": 123}' + schema_rejected = '{"name": {"a": 1}}' + check_schema_with_instance(schema, schema_accepted_1, any_whitespace=False) + check_schema_with_instance(schema, schema_accepted_2, any_whitespace=False) + check_schema_with_instance(schema, schema_rejected, is_accepted=False, any_whitespace=False) + + schema = { + "type": "object", + "properties": { + "name": { + "oneOf": [ + {"type": "string"}, + {"type": "integer"}, + ], + }, + }, + } + + schema_accepted_1 = '{"name": "John"}' + schema_accepted_2 = '{"name": 123}' + schema_rejected = '{"name": {"a": 1}}' + check_schema_with_instance(schema, schema_accepted_1, any_whitespace=False) + check_schema_with_instance(schema, schema_accepted_2, any_whitespace=False) + check_schema_with_instance(schema, schema_rejected, is_accepted=False, any_whitespace=False) + + +def test_alias(): class MainModel(BaseModel): test: str = Field(..., alias="name") @@ -473,12 +604,12 @@ class MainModel(BaseModel): instance = MainModel(name="kitty") instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=False)) - check_schema_with_json( + check_schema_with_instance( MainModel.model_json_schema(by_alias=False), instance_str, any_whitespace=False ) instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=True)) - check_schema_with_json( + check_schema_with_instance( MainModel.model_json_schema(by_alias=True), instance_str, any_whitespace=False ) @@ -508,20 +639,20 @@ class MainModelSpace(BaseModel): instance_space_str = json.dumps( instance_space.model_dump(mode="json", round_trip=True, by_alias=True), ) - check_schema_with_json( + check_schema_with_instance( MainModelSpace.model_json_schema(by_alias=True), instance_space_str, any_whitespace=False ) -def test_restricted_string() -> None: +def test_restricted_string(): class MainModel(BaseModel): restricted_string: str = Field(..., pattern=r"[a-f]") instance = MainModel(restricted_string="a") instance_str = json.dumps(instance.model_dump(mode="json")) - check_schema_with_json(MainModel.model_json_schema(), instance_str, any_whitespace=False) + check_schema_with_instance(MainModel.model_json_schema(), instance_str, any_whitespace=False) - check_schema_with_json( + check_schema_with_instance( MainModel.model_json_schema(), '{"restricted_string": "j"}', is_accepted=False, @@ -529,7 +660,7 @@ class MainModel(BaseModel): ) -def test_complex_restrictions() -> None: +def test_complex_restrictions(): class RestrictedModel(BaseModel): restricted_string: Annotated[str, WithJsonSchema({"type": "string", "pattern": r"[^\"]*"})] restricted_value: Annotated[int, Field(strict=True, ge=0, lt=44)] @@ -537,18 +668,20 @@ class RestrictedModel(BaseModel): # working instance instance = RestrictedModel(restricted_string="abd", restricted_value=42) instance_str = json.dumps(instance.model_dump(mode="json")) - check_schema_with_json(RestrictedModel.model_json_schema(), instance_str, any_whitespace=False) + check_schema_with_instance( + RestrictedModel.model_json_schema(), instance_str, any_whitespace=False + ) instance_err = RestrictedModel(restricted_string='"', restricted_value=42) instance_str = json.dumps(instance_err.model_dump(mode="json")) - check_schema_with_json( + check_schema_with_instance( RestrictedModel.model_json_schema(), instance_str, is_accepted=False, any_whitespace=False, ) - check_schema_with_json( + check_schema_with_instance( RestrictedModel.model_json_schema(), '{"restricted_string": "j", "restricted_value": 45}', is_accepted=False, @@ -556,7 +689,7 @@ class RestrictedModel(BaseModel): ) -def test_dynamic_model() -> None: +def test_dynamic_model(): class MainModel(BaseModel): restricted_string: Annotated[str, WithJsonSchema({"type": "string", "pattern": r"[a-f]"})] @@ -572,10 +705,12 @@ class MainModel(BaseModel): ) instance = CompleteModel(restricted_string="a", restricted_string_dynamic="j") instance_str = json.dumps(instance.model_dump(mode="json")) - check_schema_with_json(CompleteModel.model_json_schema(), instance_str, any_whitespace=False) + check_schema_with_instance( + CompleteModel.model_json_schema(), instance_str, any_whitespace=False + ) -def test_any_whitespace() -> None: +def test_any_whitespace(): class SimpleModel(BaseModel): value: str arr: List[int] @@ -625,7 +760,108 @@ class SimpleModel(BaseModel): '{\t"value"\t:\t"test",\t"arr":\t[1,\t2],\t"obj":\t{"a":\t1}\t}', ] for instance in instances: - check_schema_with_json(schema, instance, any_whitespace=True) + check_schema_with_instance(schema, instance, any_whitespace=True) + + +def test_array_with_only_items_keyword(): + schema = { + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + } + instance_accepted = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] + instance_rejected = [{"name": "John"}] + check_schema_with_instance(schema, instance_accepted, any_whitespace=False) + check_schema_with_instance(schema, instance_rejected, is_accepted=False, any_whitespace=False) + + schema_prefix_items = { + "prefixItems": [ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ], + } + + check_schema_with_instance(schema_prefix_items, instance_accepted, any_whitespace=False) + check_schema_with_instance( + schema_prefix_items, instance_rejected, is_accepted=False, any_whitespace=False + ) + + schema_unevaluated_items = { + "unevaluatedItems": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + } + + check_schema_with_instance(schema_unevaluated_items, instance_accepted, any_whitespace=False) + check_schema_with_instance( + schema_unevaluated_items, instance_rejected, is_accepted=False, any_whitespace=False + ) + + +def test_object_with_only_properties_keyword(): + schema = { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + instance_accepted = {"name": "John", "age": 30} + instance_rejected = {"name": "John"} + check_schema_with_instance(schema, instance_accepted, any_whitespace=False) + check_schema_with_instance(schema, instance_rejected, is_accepted=False, any_whitespace=False) + + schema_additional_properties = { + "additionalProperties": { + "type": "string", + }, + } + instance_accepted = {"name": "John"} + instance_rejected = {"name": "John", "age": 30} + + check_schema_with_instance( + schema_additional_properties, instance_accepted, any_whitespace=False + ) + check_schema_with_instance( + schema_additional_properties, instance_rejected, is_accepted=False, any_whitespace=False + ) + + schema_unevaluated_properties = { + "unevaluatedProperties": { + "type": "string", + }, + } + + check_schema_with_instance( + schema_unevaluated_properties, instance_accepted, any_whitespace=False + ) + check_schema_with_instance( + schema_unevaluated_properties, instance_rejected, is_accepted=False, any_whitespace=False + ) if __name__ == "__main__":