Skip to content

Commit

Permalink
union isjsontype check
Browse files Browse the repository at this point in the history
Signed-off-by: Clemens Vasters <clemens@vasters.com>
  • Loading branch information
clemensv committed Apr 11, 2024
1 parent 420db97 commit 5c6aff7
Showing 1 changed file with 95 additions and 35 deletions.
130 changes: 95 additions & 35 deletions avrotize/avrotocsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,40 +298,7 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b

# emit IsJsonMatch method for System.Text.Json
if self.system_text_json_annotation:
class_definition += f"\n\n{INDENT}public static bool IsJsonMatch(System.Text.Json.JsonElement element)\n{INDENT}{{"
class_definition += f"\n{INDENT*2}return "
field_count = 0
for field in avro_schema.get('fields', []):
if field_count > 0:
class_definition += f" && \n{INDENT*3}"
field_count += 1
field_name = field['name']
if self.is_csharp_reserved_word(field_name):
field_name = f"@{field_name}"
if field_name == class_name:
field_name += "_"
field_type = self.convert_avro_type_to_csharp(
class_name, field_name, field['type'], namespace)
is_optional = field_type[-1] == '?'
field_type = field_type[:-1] if is_optional else field_type
if field_type == 'byte[]':
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.String){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type == 'string':
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.String){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type in ['int', 'long', 'float', 'double', 'decimal', 'short', 'sbyte', 'ushort', 'uint', 'ulong']:
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.Number){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type == 'bool':
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.True || {field_name}.ValueKind == System.Text.Json.JsonValueKind.False){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type.startswith("global::"):
type_kind = self.generated_types[field_type] if field_type in self.generated_types else "class"
if type_kind == "class":
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({f'{field_name}.ValueKind == System.Text.Json.JsonValueKind.Null || ' if is_optional else ''}{field_type}.IsJsonMatch({field_name})))"
elif type_kind == "enum":
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({f'{field_name}.ValueKind == System.Text.Json.JsonValueKind.Null ||' if is_optional else ''}({field_name}.ValueKind == System.Text.Json.JsonValueKind.String && Enum.TryParse<{field_type}>({field_name}.GetString(), true, out _ ))))"
else:
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} true )"
class_definition += f";\n{INDENT}}}"

class_definition += self.create_is_json_match_method(avro_schema, namespace, class_name)
class_definition += "\n"+"}"

if write_file:
Expand All @@ -340,6 +307,92 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b
self.generated_types[ref] = "class"
return ref

def create_is_json_match_method(self, avro_schema, namespace, class_name) -> str:
""" Generates the IsJsonMatch method for System.Text.Json """
class_definition = ''
class_definition += f"\n\n{INDENT}public static bool IsJsonMatch(System.Text.Json.JsonElement element)\n{INDENT}{{"
class_definition += f"\n{INDENT*2}return "
field_count = 0
for field in avro_schema.get('fields', []):
if field_count > 0:
class_definition += f" && \n{INDENT*3}"
field_count += 1
field_name = field['name']
if self.is_csharp_reserved_word(field_name):
field_name = f"@{field_name}"
if field_name == class_name:
field_name += "_"
field_type = self.convert_avro_type_to_csharp(
class_name, field_name, field['type'], namespace)
class_definition += self.get_is_json_match_clause(class_name, field_name, field_type)
class_definition += f";\n{INDENT}}}"
return class_definition

def get_is_json_match_clause(self, class_name, field_name, field_type) -> str:
""" Generates the IsJsonMatch clause for a field """
class_definition = ''
field_name_js = field_name[1:] if field_name[0] == '@' else field_name
is_optional = field_type[-1] == '?'
field_type = field_type[:-1] if is_optional else field_type
if field_type == 'byte[]':
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.String){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type == 'string':
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.String){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type in ['int', 'long', 'float', 'double', 'decimal', 'short', 'sbyte', 'ushort', 'uint', 'ulong']:
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.Number){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type == 'bool':
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({field_name}.ValueKind == System.Text.Json.JsonValueKind.True || {field_name}.ValueKind == System.Text.Json.JsonValueKind.False){f' || {field_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type.startswith("global::"):
type_kind = self.generated_types[field_type] if field_type in self.generated_types else "class"
if type_kind == "class":
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({f'{field_name}.ValueKind == System.Text.Json.JsonValueKind.Null || ' if is_optional else ''}{field_type}.IsJsonMatch({field_name})))"
elif type_kind == "enum":
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({f'{field_name}.ValueKind == System.Text.Json.JsonValueKind.Null ||' if is_optional else ''}({field_name}.ValueKind == System.Text.Json.JsonValueKind.String && Enum.TryParse<{field_type}>({field_name}.GetString(), true, out _ ))))"
else:
is_union = False
field_union = pascal(field_name)+'Union'
if field_type == field_union:
field_union = class_name+"."+pascal(field_name)+'Union'
type_kind = self.generated_types[field_union] if field_union in self.generated_types else "class"
if type_kind == "union":
is_union = True
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} ({f'{field_name}.ValueKind == System.Text.Json.JsonValueKind.Null || ' if is_optional else ''}{field_type}.IsJsonMatch({field_name})))"
if not is_union:
class_definition += f"({'!' if is_optional else ''}element.TryGetProperty(\"{field_name_js}\", out System.Text.Json.JsonElement {field_name}) {'||' if is_optional else '&&'} true )"
return class_definition

def get_is_json_match_clause_type(self, element_name, class_name, field_type) -> str:
""" Generates the IsJsonMatch clause for a field """
class_definition = ''
is_optional = field_type[-1] == '?'
field_type = field_type[:-1] if is_optional else field_type
if field_type == 'byte[]':
class_definition += f"({element_name}.ValueKind == System.Text.Json.JsonValueKind.String{f' || {element_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type == 'string':
class_definition += f"({element_name}.ValueKind == System.Text.Json.JsonValueKind.String{f' || {element_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type in ['int', 'long', 'float', 'double', 'decimal', 'short', 'sbyte', 'ushort', 'uint', 'ulong']:
class_definition += f"({element_name}.ValueKind == System.Text.Json.JsonValueKind.Number{f' || {element_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type == 'bool':
class_definition += f"({element_name}.ValueKind == System.Text.Json.JsonValueKind.True || {element_name}.ValueKind == System.Text.Json.JsonValueKind.False{f' || {element_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
elif field_type.startswith("global::"):
type_kind = self.generated_types[field_type] if field_type in self.generated_types else "class"
if type_kind == "class":
class_definition += f"({f'{element_name}.ValueKind == System.Text.Json.JsonValueKind.Null || ' if is_optional else ''}{field_type}.IsJsonMatch({element_name}))"
elif type_kind == "enum":
class_definition += f"({f'{element_name}.ValueKind == System.Text.Json.JsonValueKind.Null ||' if is_optional else ''}({element_name}.ValueKind == System.Text.Json.JsonValueKind.String && Enum.TryParse<{field_type}>({element_name}.GetString(), true, out _ ))))"
else:
is_union = False
field_union = pascal(element_name)+'Union'
if field_type == field_union:
field_union = class_name+"."+pascal(element_name)+'Union'
type_kind = self.generated_types[field_union] if field_union in self.generated_types else "class"
if type_kind == "union":
is_union = True
class_definition += f"({f'{element_name}.ValueKind == System.Text.Json.JsonValueKind.Null || ' if is_optional else ''}{field_type}.IsJsonMatch({element_name})))"
if not is_union:
class_definition += f"({element_name}.ValueKind == System.Text.Json.JsonValueKind.Object{f' || {element_name}.ValueKind == System.Text.Json.JsonValueKind.Null' if is_optional else ''})"
return class_definition

def generate_enum(self, avro_schema: Dict, parent_namespace: str, write_file: bool) -> str:
""" Generates an Enum """
enum_definition = ''
Expand All @@ -362,6 +415,7 @@ def generate_enum(self, avro_schema: Dict, parent_namespace: str, write_file: bo
def generate_embedded_union_class_system_json_text(self, class_name: str, field_name: str, avro_type: List, parent_namespace: str, write_file: bool) -> str:
""" Generates an embedded Union Class """
class_definition_ctors = class_definition_decls = class_definition_read = class_definition_write = class_definition = ''
list_is_json_match: List [str] = []
union_class_name = pascal(field_name)+'Union'
union_types = [self.convert_avro_type_to_csharp(class_name, field_name+"Option"+str(i), t, parent_namespace) for i,t in enumerate(avro_type)]
for i, union_type in enumerate(union_types):
Expand Down Expand Up @@ -410,6 +464,9 @@ def generate_embedded_union_class_system_json_text(self, class_name: str, field_
else:
class_definition_read += f"{INDENT*3}if ({union_type}.IsJsonMatch(element))\n{INDENT*3}{{\n{INDENT*4}return new {union_class_name}({union_type}.FromData(element, System.Net.Mime.MediaTypeNames.Application.Json));\n{INDENT*3}}}\n"
class_definition_write += f"{INDENT*3}{'else ' if i>0 else ''}if (value.{union_type_name} != null)\n{INDENT*3}{{\n{INDENT*4}System.Text.Json.JsonSerializer.Serialize(writer, value.{union_type_name}, options);\n{INDENT*3}}}\n"
gij = self.get_is_json_match_clause_type("element", class_name, union_type)
if gij:
list_is_json_match.append(gij)

class_definition = \
f"public partial class {class_name}\n{{\n{INDENT}[System.Text.Json.Serialization.JsonConverter(typeof({union_class_name}))]\n{INDENT}public sealed class {union_class_name} : System.Text.Json.Serialization.JsonConverter<{union_class_name}>\n{INDENT}{{\n" + \
Expand All @@ -422,10 +479,13 @@ def generate_embedded_union_class_system_json_text(self, class_name: str, field_
f"\n{INDENT*2}public override void Write(Utf8JsonWriter writer, {union_class_name} value, JsonSerializerOptions options)\n{INDENT*2}{{\n" + \
class_definition_write + \
f"{INDENT*3}else\n{INDENT*3}{{\n{INDENT*4}throw new NotSupportedException(\"No record type is set in the union\");\n{INDENT*3}}}\n{INDENT*2}}}\n" + \
f"{INDENT}}}\n}}\n"
f"\n{INDENT*2}public static bool IsJsonMatch(System.Text.Json.JsonElement element)\n{INDENT*2}{{" + \
f"\n{INDENT*3}return "+f"\n{INDENT*3} || ".join(list_is_json_match)+f";\n{INDENT*2}}}\n" + \
f"{INDENT*1}}}\n}}\n"

if write_file:
self.write_to_file(parent_namespace, class_name +"."+union_class_name, class_definition)
self.generated_types[class_name+'.'+union_class_name] = "union" # it doesn't matter if the names clash, we just need to know whether it's a union
return union_class_name

def generate_property(self, field: Dict, class_name: str, parent_namespace: str) -> str:
Expand Down

0 comments on commit 5c6aff7

Please sign in to comment.