Skip to content

Commit

Permalink
[Feature] Enhance JSON Schema Converter (#134)
Browse files Browse the repository at this point in the history
This PR enhances the JSONSchemaConverter:
- Support `oneOf`
- Support `allOf` with one option
- Support `$ref` with any path (previously only those starting with
`#/$def/` is supported)
- Support object with `properties` specified but not `type`, and array
with `item` specified but not `type`
This PR also completes support for pydantic V1.
  • Loading branch information
Ubospica authored Dec 25, 2024
1 parent 49655f4 commit 937f3c4
Show file tree
Hide file tree
Showing 3 changed files with 405 additions and 77 deletions.
127 changes: 105 additions & 22 deletions cpp/json_schema_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class JSONSchemaConverter {
std::string VisitRef(const picojson::object& schema, const std::string& rule_name);

/*! \brief Get the schema from the URI. */
picojson::value URIToSchema(const picojson::value& uri);
picojson::value URIToSchema(const std::string& uri);

/*! \brief Visit a const schema. */
std::string VisitConst(const picojson::object& schema, const std::string& rule_name);
Expand All @@ -203,6 +203,11 @@ class JSONSchemaConverter {
/*! \brief Visit an anyOf schema. */
std::string VisitAnyOf(const picojson::object& schema, const std::string& rule_name);

picojson::value FuseAllOfSchema(const std::vector<picojson::value>& schemas);

/*! \brief Visit an allOf schema. */
std::string VisitAllOf(const picojson::object& schema, const std::string& rule_name);

/*! \brief Visit a true schema that can match anything. */
std::string VisitAny(const picojson::value& schema, const std::string& rule_name);

Expand Down Expand Up @@ -456,7 +461,7 @@ void JSONSchemaConverter::WarnUnsupportedKeywords(
return;
}

XGRAMMAR_DCHECK(schema.is<picojson::object>());
XGRAMMAR_DCHECK(schema.is<picojson::object>()) << "Schema should be an object or bool";
WarnUnsupportedKeywords(schema.get<picojson::object>(), keywords, verbose);
}

Expand Down Expand Up @@ -541,15 +546,14 @@ std::string JSONSchemaConverter::VisitSchema(
const picojson::value& schema, const std::string& rule_name
) {
if (schema.is<bool>()) {
XGRAMMAR_CHECK(schema.get<bool>());
XGRAMMAR_CHECK(schema.get<bool>()) << "Schema should not be false: it cannot accept any value";
return VisitAny(schema, rule_name);
}
XGRAMMAR_CHECK(schema.is<picojson::object>()) << "Schema should be an object or bool";

WarnUnsupportedKeywords(
schema,
{
"allof",
"oneof",
"not",
"if",
"then",
Expand All @@ -559,8 +563,6 @@ std::string JSONSchemaConverter::VisitSchema(
}
);

XGRAMMAR_CHECK(schema.is<picojson::object>());

const auto& schema_obj = schema.get<picojson::object>();

if (schema_obj.count("$ref")) {
Expand All @@ -569,8 +571,10 @@ std::string JSONSchemaConverter::VisitSchema(
return VisitConst(schema_obj, rule_name);
} else if (schema_obj.count("enum")) {
return VisitEnum(schema_obj, rule_name);
} else if (schema_obj.count("anyOf")) {
} else if (schema_obj.count("anyOf") || schema_obj.count("oneOf")) {
return VisitAnyOf(schema_obj, rule_name);
} else if (schema_obj.count("allOf")) {
return VisitAllOf(schema_obj, rule_name);
} else if (schema_obj.count("type")) {
const std::string& type = schema_obj.at("type").get<std::string>();
if (type == "integer") {
Expand All @@ -591,6 +595,12 @@ std::string JSONSchemaConverter::VisitSchema(
XGRAMMAR_LOG(FATAL) << "Unsupported type " << type << " in schema "
<< schema.serialize(false);
}
} else if (schema_obj.count("properties") || schema_obj.count("additionalProperties") ||
schema_obj.count("unevaluatedProperties")) {
return VisitObject(schema_obj, rule_name);
} else if (schema_obj.count("items") || schema_obj.count("prefixItems") ||
schema_obj.count("unevaluatedItems")) {
return VisitArray(schema_obj, rule_name);
}

// If no above keyword is detected, we treat it as any
Expand All @@ -600,9 +610,11 @@ std::string JSONSchemaConverter::VisitSchema(
std::string JSONSchemaConverter::VisitRef(
const picojson::object& schema, const std::string& rule_name
) {
XGRAMMAR_CHECK(schema.count("$ref"));
picojson::value new_schema = URIToSchema(schema.at("$ref"));
XGRAMMAR_CHECK(schema.count("$ref") && schema.at("$ref").is<std::string>())
<< "Schema $ref should be a string";
picojson::value new_schema = URIToSchema(schema.at("$ref").get<std::string>());
if (!new_schema.is<bool>()) {
XGRAMMAR_CHECK(new_schema.is<picojson::object>()) << "Schema should be an object or bool";
picojson::object new_schema_obj = new_schema.get<picojson::object>();
for (const auto& [k, v] : schema) {
if (k != "$ref") {
Expand All @@ -614,13 +626,32 @@ std::string JSONSchemaConverter::VisitRef(
return VisitSchema(new_schema, rule_name);
}

picojson::value JSONSchemaConverter::URIToSchema(const picojson::value& uri) {
if (uri.get<std::string>().substr(0, 8) == "#/$defs/") {
return json_schema_.get("$defs").get(uri.get<std::string>().substr(8));
picojson::value JSONSchemaConverter::URIToSchema(const std::string& uri) {
if (uri.size() < 2 || uri[0] != '#' || uri[1] != '/') {
XGRAMMAR_LOG(WARNING) << "Now only support URI starting with '#/' but got " << uri;
return picojson::value(true);
}
std::vector<std::string> parts;
std::stringstream ss(uri.substr(2));
std::string part;
while (std::getline(ss, part, '/')) {
if (!part.empty()) {
parts.push_back(part);
}
}

picojson::value current = json_schema_;
for (const auto& part : parts) {
XGRAMMAR_CHECK(current.is<picojson::object>() && current.contains(part))
<< "Cannot find field " << part << " in " << current.serialize(false);
current = current.get(part);
}
XGRAMMAR_LOG(WARNING) << "Now only support URI starting with '#/$defs/' but got "
<< uri.serialize(false);
return picojson::value(true);

// Handle recursive reference
if (current.is<picojson::object>() && current.get<picojson::object>().count("$ref")) {
return URIToSchema(current.get<picojson::object>().at("$ref").get<std::string>());
}
return current;
}

std::string JSONSchemaConverter::VisitConst(
Expand Down Expand Up @@ -665,10 +696,12 @@ std::string JSONSchemaConverter::JSONStrToPrintableStr(const std::string& json_s
std::string JSONSchemaConverter::VisitAnyOf(
const picojson::object& schema, const std::string& rule_name
) {
XGRAMMAR_CHECK(schema.count("anyOf"));
XGRAMMAR_CHECK(schema.count("anyOf") || schema.count("oneOf"));
std::string result = "";
int idx = 0;
for (auto anyof_schema : schema.at("anyOf").get<picojson::array>()) {
auto anyof_schema = schema.count("anyOf") ? schema.at("anyOf") : schema.at("oneOf");
XGRAMMAR_CHECK(anyof_schema.is<picojson::array>()) << "anyOf or oneOf must be an array";
for (auto anyof_schema : anyof_schema.get<picojson::array>()) {
if (idx != 0) {
result += " | ";
}
Expand All @@ -678,6 +711,29 @@ std::string JSONSchemaConverter::VisitAnyOf(
return result;
}

picojson::value JSONSchemaConverter::FuseAllOfSchema(const std::vector<picojson::value>& schemas) {
picojson::object fused_schema;
XGRAMMAR_LOG(WARNING) << "Support for allOf with multiple options is still ongoing";
return picojson::value(fused_schema);
}

std::string JSONSchemaConverter::VisitAllOf(
const picojson::object& schema, const std::string& rule_name
) {
// We support common usecases of AllOf, but not all, because it's impossible to support all
// cases with CFG
XGRAMMAR_CHECK(schema.count("allOf"));
XGRAMMAR_CHECK(schema.at("allOf").is<picojson::array>()) << "allOf must be an array";
auto all_array = schema.at("allOf").get<picojson::array>();
// Case 1: allOf is a single schema
if (all_array.size() == 1) {
return VisitSchema(all_array[0], rule_name + "_case_0");
}
// Case 2: allOf is a list of schemas, we fuse them into a single schema
auto fused_schema = FuseAllOfSchema(all_array);
return VisitSchema(fused_schema, rule_name);
}

std::string JSONSchemaConverter::VisitAny(
const picojson::value& schema, const std::string& rule_name
) {
Expand Down Expand Up @@ -820,19 +876,36 @@ std::string JSONSchemaConverter::VisitInteger(
schema.count("exclusiveMaximum")) {
std::optional<int> start, end;
if (schema.count("minimum")) {
XGRAMMAR_CHECK(schema.at("minimum").is<double>() || schema.at("minimum").is<int64_t>())
<< "minimum must be a number";
double start_double = schema.at("minimum").get<double>();
XGRAMMAR_CHECK(start_double == static_cast<int>(start_double))
<< "minimum must be an integer";
start = static_cast<int>(start_double);
}
if (schema.count("exclusiveMinimum")) {
XGRAMMAR_CHECK(
schema.at("exclusiveMinimum").is<double>() || schema.at("exclusiveMinimum").is<int64_t>()
) << "exclusiveMinimum must be a number";
double start_double = schema.at("exclusiveMinimum").get<double>();
XGRAMMAR_CHECK(start_double == static_cast<int>(start_double))
<< "exclusiveMinimum must be an integer";
start = static_cast<int>(start_double);
}
if (schema.count("maximum")) {
XGRAMMAR_CHECK(schema.at("maximum").is<double>() || schema.at("maximum").is<int64_t>())
<< "maximum must be a number";
double end_double = schema.at("maximum").get<double>();
XGRAMMAR_CHECK(end_double == static_cast<int>(end_double)) << "maximum must be an integer";
end = static_cast<int>(end_double);
}
if (schema.count("exclusiveMaximum")) {
XGRAMMAR_CHECK(
schema.at("exclusiveMaximum").is<double>() || schema.at("exclusiveMaximum").is<int64_t>()
) << "exclusiveMaximum must be a number";
double end_double = schema.at("exclusiveMaximum").get<double>();
XGRAMMAR_CHECK(end_double == static_cast<int>(end_double))
<< "exclusiveMaximum must be an integer";
end = static_cast<int>(end_double);
}
range_regex = GenerateRangeRegex(start, end);
Expand Down Expand Up @@ -902,8 +975,10 @@ std::string JSONSchemaConverter::VisitNull(
std::string JSONSchemaConverter::VisitArray(
const picojson::object& schema, const std::string& rule_name
) {
XGRAMMAR_CHECK(schema.count("type"));
XGRAMMAR_CHECK(schema.at("type").get<std::string>() == "array");
XGRAMMAR_CHECK(
(schema.count("type") && schema.at("type").get<std::string>() == "array") ||
schema.count("items") || schema.count("prefixItems") || schema.count("unevaluatedItems")
);
WarnUnsupportedKeywords(
schema,
{
Expand All @@ -922,6 +997,8 @@ std::string JSONSchemaConverter::VisitArray(

// 1. Handle prefix items
if (schema.count("prefixItems")) {
XGRAMMAR_CHECK(schema.at("prefixItems").is<picojson::array>())
<< "prefixItems must be an array";
const auto& prefix_items = schema.at("prefixItems").get<picojson::array>();
for (int i = 0; i < static_cast<int>(prefix_items.size()); ++i) {
XGRAMMAR_CHECK(prefix_items[i].is<picojson::object>());
Expand Down Expand Up @@ -1117,8 +1194,11 @@ std::string JSONSchemaConverter::GetPartialRuleForPropertiesContainRequired(
std::string JSONSchemaConverter::VisitObject(
const picojson::object& schema, const std::string& rule_name
) {
XGRAMMAR_CHECK(schema.count("type"));
XGRAMMAR_CHECK(schema.at("type").get<std::string>() == "object");
XGRAMMAR_CHECK(
(schema.count("type") && schema.at("type").get<std::string>() == "object") ||
schema.count("properties") || schema.count("additionalProperties") ||
schema.count("unevaluatedProperties")
);
WarnUnsupportedKeywords(
schema,
{
Expand All @@ -1140,6 +1220,8 @@ std::string JSONSchemaConverter::VisitObject(
// 1. Handle properties
std::vector<std::pair<std::string, picojson::value>> properties;
if (schema.count("properties")) {
XGRAMMAR_CHECK(schema.at("properties").is<picojson::object>())
<< "properties must be an object";
auto properties_obj = schema.at("properties").get<picojson::object>();
for (const auto& key : properties_obj.ordered_keys()) {
properties.push_back({key, properties_obj.at(key)});
Expand All @@ -1148,6 +1230,7 @@ std::string JSONSchemaConverter::VisitObject(

std::unordered_set<std::string> required;
if (schema.count("required")) {
XGRAMMAR_CHECK(schema.at("required").is<picojson::array>()) << "required must be an array";
for (const auto& required_prop : schema.at("required").get<picojson::array>()) {
required.insert(required_prop.get<std::string>());
}
Expand Down
11 changes: 10 additions & 1 deletion python/xgrammar/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,16 @@ def compile_json_schema(
The compiled grammar.
"""
if isinstance(schema, type) and issubclass(schema, BaseModel):
schema = json.dumps(schema.model_json_schema())
if hasattr(schema, "model_json_schema"):
# pydantic 2.x
schema = json.dumps(schema.model_json_schema())
elif hasattr(schema, "schema_json"):
# pydantic 1.x
schema = json.dumps(schema.schema_json())
else:
raise ValueError(
"The schema should have a model_json_schema or json_schema method."
)

return CompiledGrammar._create_from_handle(
self._handle.compile_json_schema(
Expand Down
Loading

0 comments on commit 937f3c4

Please sign in to comment.