diff --git a/avrotize/avrotocsharp.py b/avrotize/avrotocsharp.py index f20dde1..004935a 100644 --- a/avrotize/avrotocsharp.py +++ b/avrotize/avrotocsharp.py @@ -60,15 +60,19 @@ if (contentType.MediaType.StartsWith("avro/binary") || contentType.MediaType.StartsWith("application/vnd.apache.avro+avro")) { var stream = new System.IO.MemoryStream(); - var writer = new Avro.Specific.SpecificDatumWriter<{type_name}>({type_name}.AvroSchema); - writer.Write(this, new Avro.IO.BinaryEncoder(stream)); + var writer = new global::Avro.Specific.SpecificDatumWriter<{type_name}>({type_name}.AvroSchema); + var encoder = new global::Avro.IO.BinaryEncoder(stream); + writer.Write(this, encoder); + encoder.Flush(); result = stream.ToArray(); } else if (contentType.MediaType.StartsWith("avro/json") || contentType.MediaType.StartsWith("application/vnd.apache.avro+json")) { var stream = new System.IO.MemoryStream(); - var writer = new Avro.Specific.SpecificDatumWriter<{type_name}>({type_name}.AvroSchema); - writer.Write(this, new Avro.IO.JsonEncoder({type_name}.AvroSchema, stream)); + var writer = new global::Avro.Specific.SpecificDatumWriter<{type_name}>({type_name}.AvroSchema); + var encoder = new global::Avro.IO.JsonEncoder({type_name}.AvroSchema, stream); + writer.Write(this, encoder); + encoder.Flush(); result = stream.ToArray(); } """ @@ -104,8 +108,10 @@ }; using (var gzip = new System.IO.Compression.GZipStream(stream, System.IO.Compression.CompressionMode.Decompress)) { - data = new System.IO.MemoryStream(); - gzip.CopyTo((System.IO.MemoryStream)data); + System.IO.MemoryStream memoryStream = new System.IO.MemoryStream(); + gzip.CopyTo(memoryStream); + memoryStream.Position = 0; + data = memoryStream.ToArray(); } } """ @@ -131,6 +137,14 @@ { return ((System.BinaryData)data).ToObjectFromJson<{type_name}>(); } + else if (data is byte[]) + { + return System.Text.Json.JsonSerializer.Deserialize<{type_name}>(new ReadOnlySpan((byte[])data)); + } + else if (data is System.IO.Stream) + { + return System.Text.Json.JsonSerializer.DeserializeAsync<{type_name}>((System.IO.Stream)data).Result; + } } """ @@ -158,19 +172,32 @@ System.IO.Stream s => s, System.BinaryData bd => bd.ToStream(), byte[] bytes => new System.IO.MemoryStream(bytes), _ => throw new NotSupportedException("Data is not of a supported type for conversion to Stream") }; + #pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type. if (contentType.MediaType.StartsWith("avro/binary") || contentType.MediaType.StartsWith("application/vnd.apache.avro+avro")) { - var reader = new Avro.Specific.SpecificDatumReader<{type_name}>({type_name}.AvroSchema, {type_name}.AvroSchema); - return reader.Read(new {type_name}(), new Avro.IO.BinaryDecoder(stream)); + var reader = new global::Avro.Generic.GenericDatumReader({type_name}.AvroSchema, {type_name}.AvroSchema); + return new {type_name}(reader.Read(null, new global::Avro.IO.BinaryDecoder(stream))); } if ( contentType.MediaType.StartsWith("avro/json") || contentType.MediaType.StartsWith("application/vnd.apache.avro+json")) { - var reader = new Avro.Specific.SpecificDatumReader<{type_name}>({type_name}.AvroSchema, {type_name}.AvroSchema); - return reader.Read(new {type_name}(), new Avro.IO.JsonDecoder({type_name}.AvroSchema, stream)); + var reader = new global::Avro.Generic.GenericDatumReader({type_name}.AvroSchema, {type_name}.AvroSchema); + return new {type_name}(reader.Read(null, new global::Avro.IO.JsonDecoder({type_name}.AvroSchema, stream))); } + #pragma warning restore CS8625 } """ +AVRO_CLASS_PREAMBLE = \ +""" +public {type_name}(global::Avro.Generic.GenericRecord obj) +{ + global::Avro.Specific.ISpecificRecord self = this; + for (int i = 0; obj.Schema.Fields.Count > i; ++i) + { + self.Put(i, obj.GetValue(i)); + } +} +""" class AvroToCSharp: """ Converts Avro schema to C# classes """ @@ -281,8 +308,11 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b """ Generates a Class """ class_definition = '' avro_namespace = avro_schema.get('namespace', parent_namespace) + if not 'namespace' in avro_schema: + avro_schema['namespace'] = parent_namespace namespace = pascal(self.concat_namespace(self.base_namespace, avro_namespace)) class_name = pascal(avro_schema['name']) + class_definition += f"/// \n/// { avro_schema.get('doc', class_name ) }\n/// \n" fields_str = [self.generate_property(field, class_name, avro_namespace) for field in avro_schema.get('fields', [])] class_body = "\n".join(fields_str) @@ -290,6 +320,14 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b if self.avro_annotation: class_definition += " : global::Avro.Specific.ISpecificRecord" class_definition += "\n{\n"+class_body + class_definition += f"\n{INDENT}/// \n{INDENT}/// Default constructor\n{INDENT}///\n" + class_definition += f"{INDENT}public {class_name}()\n{INDENT}{{\n{INDENT}}}" + if self.avro_annotation: + class_definition += f"\n\n{INDENT}/// \n{INDENT}/// Constructor from Avro GenericRecord\n{INDENT}///\n" + class_definition += f"{INDENT}public {class_name}(global::Avro.Generic.GenericRecord obj)\n{INDENT}{{\n" + class_definition += f"{INDENT*2}global::Avro.Specific.ISpecificRecord self = this;\n" + class_definition += f"{INDENT*2}for (int i = 0; obj.Schema.Fields.Count > i; ++i)\n{INDENT*2}{{\n" + class_definition += f"{INDENT*3}self.Put(i, obj.GetValue(i));\n{INDENT*2}}}\n{INDENT}}}\n" if self.avro_annotation: avro_schema_json = json.dumps(avro_schema) # wrap schema at 80 characters @@ -299,7 +337,7 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b avro_schema_json = avro_schema_json.replace('ยง', '\\"') class_definition += f"\n\n{INDENT}/// \n{INDENT}/// Avro schema for this class\n{INDENT}/// " class_definition += f"\n{INDENT}public static global::Avro.Schema AvroSchema = global::Avro.Schema.Parse(\n{INDENT}\"{avro_schema_json}\");\n" - class_definition += f"\n{INDENT}Schema global::Avro.Specific.ISpecificRecord.Schema => AvroSchema;\n" + class_definition += f"\n{INDENT}global::Avro.Schema global::Avro.Specific.ISpecificRecord.Schema => AvroSchema;\n" get_method = f"{INDENT}object global::Avro.Specific.ISpecificRecord.Get(int fieldPos)\n" + \ INDENT+"{"+f"\n{INDENT*2}switch (fieldPos)\n{INDENT*2}" + "{" put_method = f"{INDENT}void global::Avro.Specific.ISpecificRecord.Put(int fieldPos, object fieldValue)\n" + \ @@ -313,8 +351,12 @@ def generate_class(self, avro_schema: Dict, parent_namespace: str, write_file: b field_name = pascal(field_name) if field_name == class_name: field_name += "_" - get_method += f"\n{INDENT*3}case {pos}: return this.{field_name};" - put_method += f"\n{INDENT*3}case {pos}: this.{field_name} = ({field_type})fieldValue; break;" + if class_name + '.' + field_type in self.generated_types and self.generated_types[class_name + '.' + field_type] == "union": + get_method += f"\n{INDENT*3}case {pos}: return this.{field_name}?.ToObject();" + put_method += f"\n{INDENT*3}case {pos}: this.{field_name} = new {field_type}((global::Avro.Generic.GenericRecord)fieldValue); break;" + else: + get_method += f"\n{INDENT*3}case {pos}: return this.{field_name};" + put_method += f"\n{INDENT*3}case {pos}: this.{field_name} = ({field_type})fieldValue; break;" get_method += f"\n{INDENT*3}default: throw new global::Avro.AvroRuntimeException($\"Bad index {{fieldPos}} in Get()\");" put_method += f"\n{INDENT*3}default: throw new global::Avro.AvroRuntimeException($\"Bad index {{fieldPos}} in Put()\");" get_method += "\n"+INDENT+INDENT+"}\n"+INDENT+"}" @@ -484,7 +526,9 @@ def generate_enum(self, avro_schema: Dict, parent_namespace: str, write_file: bo def generate_embedded_union_class_system_json_text(self, class_name: str, field_name: str, avro_type: List, parent_namespace: str, write_file: bool) -> str: """ Generates an embedded Union Class """ - class_definition_ctors = class_definition_decls = class_definition_read = class_definition_write = class_definition = '' + class_definition_ctors = class_definition_decls = class_definition_read = '' + class_definition_write = class_definition = class_definition_toobject = '' + class_definition_objctr = class_definition_genericrecordctor = '' list_is_json_match: List [str] = [] union_class_name = pascal(field_name)+'Union' union_types = [self.convert_avro_type_to_csharp(class_name, field_name+"Option"+str(i), t, parent_namespace) for i,t in enumerate(avro_type)] @@ -506,12 +550,16 @@ def generate_embedded_union_class_system_json_text(self, class_name: str, field_ union_type_name = union_type.rsplit('.', 1)[-1] if self.is_csharp_reserved_word(union_type_name): union_type_name = f"@{union_type_name}" + class_definition_objctr += f"{INDENT*3}if (obj is {union_type})\n{INDENT*3}{{\n{INDENT*4}this.{union_type_name} = ({union_type})obj;\n{INDENT*4}return;\n{INDENT*3}}}\n" + class_definition_genericrecordctor += f"{INDENT*3}if (obj.Schema.Fullname == {union_type}.AvroSchema.Fullname)\n{INDENT*3}{{\n{INDENT*4}this.{union_type_name} = new {union_type}(obj);\n{INDENT*4}return;\n{INDENT*3}}}\n" class_definition_ctors += \ f"{INDENT*2}/// \n{INDENT*2}/// Constructor for {union_type_name} values\n{INDENT*2}/// \n" + \ f"{INDENT*2}public {union_class_name}({union_type}? {union_type_name})\n{INDENT*2}{{\n{INDENT*3}this.{union_type_name} = {union_type_name};\n{INDENT*2}}}\n" class_definition_decls += \ f"{INDENT*2}/// \n{INDENT*2}/// Gets the {union_type_name} value\n{INDENT*2}/// \n" + \ - f"{INDENT*2}public {union_type}? {union_type_name} {{ get; private set; }} = null;\n" + f"{INDENT*2}public {union_type}? {union_type_name} {{ get; private set; }} = null;\n" + class_definition_toobject += f"{INDENT*3}if ({union_type_name} != null) {{\n{INDENT*4}return {union_type_name};\n{INDENT*3}}}\n" + if is_dict: class_definition_read += f"{INDENT*3}if (element.ValueKind == JsonValueKind.Object)\n{INDENT*3}{{\n" + \ f"{INDENT*4}var map = System.Text.Json.JsonSerializer.Deserialize<{union_type}>(element, options);\n" + \ @@ -546,9 +594,27 @@ def generate_embedded_union_class_system_json_text(self, class_name: str, field_ f"/// \n/// {class_name}. Type union resolver. \n/// \n" + \ f"public partial class {class_name}\n{{\n{INDENT}[System.Text.Json.Serialization.JsonConverter(typeof({union_class_name}))]\n{INDENT}public sealed class {union_class_name} : System.Text.Json.Serialization.JsonConverter<{union_class_name}>\n{INDENT}{{\n" + \ f"{INDENT*2}/// \n{INDENT*2}/// Default constructor\n{INDENT*2}/// \n" + \ - f"{INDENT*2}public {union_class_name}() {{ }}\n" + \ - class_definition_ctors + \ + f"{INDENT*2}public {union_class_name}() {{ }}\n" + class_definition += class_definition_ctors + if self.avro_annotation: + class_definition += \ + f"{INDENT*2}/// \n{INDENT*2}/// Constructor for Avro decoder\n{INDENT*2}/// \n" + \ + f"{INDENT*2}internal {union_class_name}(object obj)\n{INDENT*2}{{\n" + \ + class_definition_objctr + \ + f"{INDENT*3}throw new NotSupportedException(\"No record type matched the type\");\n" + \ + f"{INDENT*2}}}\n" + class_definition += f"\n{INDENT}/// \n{INDENT}/// Constructor from Avro GenericRecord\n{INDENT}/// \n" + \ + f"{INDENT*2}public {union_class_name}(global::Avro.Generic.GenericRecord obj)\n{INDENT*2}{{\n" + \ + class_definition_genericrecordctor + \ + f"{INDENT*3}throw new NotSupportedException(\"No record type matched the type\");\n" + \ + f"{INDENT*2}}}\n" + class_definition += \ class_definition_decls + \ + f"\n{INDENT*2}/// \n{INDENT*2}/// Yields the current value of the union\n{INDENT*2}/// \n" + \ + f"\n{INDENT*2}public Object ToObject()\n{INDENT*2}{{\n" + \ + class_definition_toobject+ \ + f"{INDENT*3}throw new NotSupportedException(\"No record type is set in the union\");\n" + \ + f"{INDENT*2}}}\n" + \ f"\n{INDENT*2}/// \n{INDENT*2}/// Reads the JSON representation of the object.\n{INDENT*2}/// \n" + \ f"{INDENT*2}public override {union_class_name}? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)\n{INDENT*2}{{\n{INDENT*3}var element = JsonElement.ParseValue(ref reader);\n" + \ class_definition_read + \ @@ -638,8 +704,7 @@ def write_to_file(self, namespace: str, name: str, definition: str): file_content += "using System.Text.Json.Serialization;\n" if self.newtonsoft_json_annotation: file_content += "using Newtonsoft.Json;\n" - if self.avro_annotation: - file_content += "using Avro;\nusing Avro.Specific;\n" + # Namespace declaration with correct indentation for the definition file_content += f"\nnamespace {namespace}\n{{\n" indented_definition = '\n'.join( diff --git a/test/cs/twotypeunion/Program.cs b/test/cs/twotypeunion/Program.cs index 42ebd41..a527adf 100644 --- a/test/cs/twotypeunion/Program.cs +++ b/test/cs/twotypeunion/Program.cs @@ -1,54 +1,168 @@ using System; +using System.Text; using System.Text.Json; using Com.Example.Avro; public class TestProgram { - public static int Main() + const string jsonRecord1 = "{\"document\":{\"field1\":\"Value1\",\"field2\":10,\"field3\":\"hey\",\"field4\":5.5,\"field5\":100,\"fieldB\":\"OptionalB\"}}"; + const string jsonRecord2 = "{\"document\":{\"field1\":\"Value2\",\"field2\":20,\"field3\":false,\"field4\":null,\"field5\":null,\"fieldA\":\"MandatoryA\"}}"; + const string jsonRecord3 = "{\"document\":{\"field1\":\"Value2\",\"field2\":20,\"field3\":false,\"fieldA\":\"MandatoryA\"}}"; + const string jsonInvalid = "{\"document\":{\"invalidField\":\"Value3\"}}"; + + public static void testJsonDeserialize() { + var tl1 = JsonSerializer.Deserialize(jsonRecord1); + if (tl1 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl1.Document.RecordType1 == null) + { + throw new Exception("RecordType1 is null"); + } + var tl2 = JsonSerializer.Deserialize(jsonRecord2); + if (tl2 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl2.Document.RecordType2 == null) + { + throw new Exception("RecordType2 is null"); + } + var tl3 = JsonSerializer.Deserialize(jsonRecord2); + if (tl3 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl3.Document.RecordType2 == null) + { + throw new Exception("RecordType2 is null"); + } try { - string jsonRecord1 = "{\"document\":{\"field1\":\"Value1\",\"field2\":10,\"field3\":\"hey\",\"field4\":5.5,\"field5\":100,\"fieldB\":\"OptionalB\"}}"; - string jsonRecord2 = "{\"document\":{\"field1\":\"Value2\",\"field2\":20,\"field3\":false,\"field4\":null,\"field5\":null,\"fieldA\":\"MandatoryA\"}}"; - string jsonRecord3 = "{\"document\":{\"field1\":\"Value2\",\"field2\":20,\"field3\":false,\"fieldA\":\"MandatoryA\"}}"; - string jsonInvalid = "{\"document\":{\"invalidField\":\"Value3\"}}"; + JsonSerializer.Deserialize(jsonInvalid); + throw new Exception("Expected exception was not thrown"); + } + catch (NotSupportedException) + { + // expected + } + } - var tl1 = JsonSerializer.Deserialize(jsonRecord1); - if (tl1 == null) - { - throw new Exception("Deserialization failed"); - } - if (tl1.Document.RecordType1 == null) - { - throw new Exception("RecordType1 is null"); - } - var tl2 = JsonSerializer.Deserialize(jsonRecord2); - if (tl2 == null) - { - throw new Exception("Deserialization failed"); - } - if (tl2.Document.RecordType2 == null) - { - throw new Exception("RecordType2 is null"); - } - var tl3 = JsonSerializer.Deserialize(jsonRecord2); - if (tl3 == null) - { - throw new Exception("Deserialization failed"); - } - if (tl3.Document.RecordType2 == null) - { - throw new Exception("RecordType2 is null"); - } - try - { - JsonSerializer.Deserialize(jsonInvalid); - throw new Exception("Expected exception was not thrown"); - } - catch (NotSupportedException) - { - // expected - } + public static void testJsonFromData() + { + var tl1 = TopLevelType.FromData(jsonRecord1, "application/json"); + if (tl1 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl1.Document.RecordType1 == null) + { + throw new Exception("RecordType1 is null"); + } + var tl2 = TopLevelType.FromData(jsonRecord2, "application/json"); + if (tl2 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl2.Document.RecordType2 == null) + { + throw new Exception("RecordType2 is null"); + } + var tl3 = TopLevelType.FromData(jsonRecord2, "application/json"); + if (tl3 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl3.Document.RecordType2 == null) + { + throw new Exception("RecordType2 is null"); + } + try + { + TopLevelType.FromData(jsonInvalid, "application/json"); + throw new Exception("Expected exception was not thrown"); + } + catch (NotSupportedException) + { + // expected + } + } + + public static void testReadWrite(string mediaType) + { + byte[] bytes = null; + var tl1 = JsonSerializer.Deserialize(jsonRecord1); + if (tl1 == null) + { + throw new Exception("Deserialization failed"); + } + bytes = tl1.ToByteArray(mediaType); + tl1 = TopLevelType.FromData(bytes, mediaType); + if (tl1 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl1.Document.RecordType1 == null) + { + throw new Exception("RecordType1 is null"); + } + var tl2 = JsonSerializer.Deserialize(jsonRecord2); + if (tl2 == null) + { + throw new Exception("Deserialization failed"); + } + bytes = tl2.ToByteArray(mediaType); + tl2 = TopLevelType.FromData(bytes, mediaType); + if (tl2 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl2.Document.RecordType2 == null) + { + throw new Exception("RecordType2 is null"); + } + var tl3 = JsonSerializer.Deserialize(jsonRecord2); + if (tl3 == null) + { + throw new Exception("Deserialization failed"); + } + bytes = tl3.ToByteArray(mediaType); + tl3 = TopLevelType.FromData(bytes, mediaType); + if (tl3 == null) + { + throw new Exception("Deserialization failed"); + } + if (tl3.Document.RecordType2 == null) + { + throw new Exception("RecordType2 is null"); + } + try + { + bytes = Encoding.UTF8.GetBytes(jsonInvalid); + TopLevelType.FromData(bytes, mediaType); + throw new Exception("Expected exception was not thrown"); + } + catch (Exception) + { + // expected + } + } + + + public static int Main() + { + try + { + testJsonDeserialize(); + testJsonFromData(); + testReadWrite("application/json"); + testReadWrite("application/vnd.apache.avro+avro"); + testReadWrite("application/vnd.apache.avro+json"); + testReadWrite("application/json+gzip"); + testReadWrite("application/vnd.apache.avro+avro+gzip"); + testReadWrite("application/vnd.apache.avro+json+gzip"); return 0; } catch (Exception e) diff --git a/test/cs/twotypeunion/twotypeunion-test.csproj b/test/cs/twotypeunion/twotypeunion-test.csproj index 8d09ab4..290e4be 100644 --- a/test/cs/twotypeunion/twotypeunion-test.csproj +++ b/test/cs/twotypeunion/twotypeunion-test.csproj @@ -12,6 +12,6 @@ - + diff --git a/test/cs/typemapunion/typemapunion-test.csproj b/test/cs/typemapunion/typemapunion-test.csproj index 81c6887..cd65b10 100644 --- a/test/cs/typemapunion/typemapunion-test.csproj +++ b/test/cs/typemapunion/typemapunion-test.csproj @@ -12,6 +12,6 @@ - + diff --git a/test/test_avrotocsharp.py b/test/test_avrotocsharp.py index d193578..2cfe1fe 100644 --- a/test/test_avrotocsharp.py +++ b/test/test_avrotocsharp.py @@ -104,17 +104,19 @@ def test_convert_twotypeunion_avsc_to_csharp(self): """ Test converting an address.avsc file to C# """ cwd = os.getcwd() avro_path = os.path.join(cwd, "test", "avsc", "twotypeunion.avsc") - cs_path = os.path.join(tempfile.gettempdir(), "avrotize", "twotypeunion-cs-test", "twotypeunion-cs") + cs_path = os.path.join(tempfile.gettempdir(), "avrotize", "twotypeunion-cs") cs_test_path = os.path.join(tempfile.gettempdir(), "avrotize", "twotypeunion-cs-test") test_csproj = os.path.join(cwd, "test", "cs", "twotypeunion") if os.path.exists(cs_test_path): shutil.rmtree(cs_test_path, ignore_errors=True) + if os.path.exists(cs_path): + shutil.rmtree(cs_path, ignore_errors=True) os.makedirs(cs_test_path, exist_ok=True) os.makedirs(cs_path, exist_ok=True) copy_tree(test_csproj, cs_test_path) - convert_avro_to_csharp(avro_path, cs_path, system_text_json_annotation=True, pascal_properties=True) + convert_avro_to_csharp(avro_path, cs_path, system_text_json_annotation=True, avro_annotation=True, pascal_properties=True) assert subprocess.check_call( ['dotnet', 'run', '--force'], cwd=cs_test_path, stdout=sys.stdout, stderr=sys.stderr) == 0 @@ -122,12 +124,14 @@ def test_convert_typemapunion_avsc_to_csharp(self): """ Test converting an address.avsc file to C# """ cwd = os.getcwd() avro_path = os.path.join(cwd, "test", "avsc", "typemapunion.avsc") - cs_path = os.path.join(tempfile.gettempdir(), "avrotize", "typemapunion-cs-test", "typemapunion-cs") + cs_path = os.path.join(tempfile.gettempdir(), "avrotize", "typemapunion-cs") cs_test_path = os.path.join(tempfile.gettempdir(), "avrotize", "typemapunion-cs-test") test_csproj = os.path.join(cwd, "test", "cs", "typemapunion") if os.path.exists(cs_test_path): shutil.rmtree(cs_test_path, ignore_errors=True) + if os.path.exists(cs_path): + shutil.rmtree(cs_path, ignore_errors=True) os.makedirs(cs_test_path, exist_ok=True) os.makedirs(cs_path, exist_ok=True)