Skip to content

Commit

Permalink
Non empty collection schemas (#717) (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil authored Aug 14, 2024
1 parent b178062 commit ce3a315
Show file tree
Hide file tree
Showing 17 changed files with 433 additions and 106 deletions.
3 changes: 1 addition & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ inThisBuild(
ThisBuild / publishTo := sonatypePublishToBundle.value
scalacOptions ++= Seq("-scalajs")

addCommandAlias("prepare", "fix; fmt")
addCommandAlias("fmt", "all scalafmtSbt scalafmtAll")
addCommandAlias("fmt", "all scalafmtSbt scalafmtAll;fix")
addCommandAlias("fmtCheck", "all scalafmtSbtCheck scalafmtCheckAll")
addCommandAlias("fix", "scalafixAll")
addCommandAlias("fixCheck", "scalafixAll --check")
Expand Down
2 changes: 2 additions & 0 deletions tests/shared/src/test/scala/zio/schema/DynamicValueGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ object DynamicValueGen {
case Schema.Enum22(_, case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22, _) => anyDynamicValueOfEnum(Chunk(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22))
case Schema.EnumN(_, cases, _) => anyDynamicValueOfEnum(Chunk.fromIterable(cases.toSeq))
case Schema.Sequence(schema, _, _, _, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.NonEmptySequence(schema, _, _, _, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.Map(ks, vs, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.NonEmptyMap(ks, vs, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.Set(schema, _) => Gen.setOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.SetValue(_))
case Schema.Optional(schema, _) => Gen.oneOf(anyDynamicSomeValueOfSchema(schema), Gen.const(DynamicValue.NoneValue))
case Schema.Tuple2(left, right, _) => anyDynamicTupleValue(left, right)
Expand Down
30 changes: 26 additions & 4 deletions zio-schema-avro/src/main/scala/zio/schema/codec/AvroCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.apache.avro.io.{ DecoderFactory, EncoderFactory }
import org.apache.avro.util.Utf8
import org.apache.avro.{ Conversions, LogicalTypes, Schema => SchemaAvro }

import zio.prelude.NonEmptyMap
import zio.schema.{ Fallback, FieldSet, Schema, StandardType, TypeId }
import zio.stream.ZPipeline
import zio.{ Chunk, Unsafe, ZIO }
Expand Down Expand Up @@ -201,9 +202,20 @@ object AvroCodec {
case record: Schema.Record[_] => decodeRecord(raw, record).map(_.asInstanceOf[A])
case Schema.Sequence(element, f, _, _, _) =>
decodeSequence(raw, element.asInstanceOf[Schema[Any]]).map(f.asInstanceOf[Chunk[Any] => A])
case nes @ Schema.NonEmptySequence(element, _, _, _, _) =>
decodeSequence(raw, element.asInstanceOf[Schema[Any]]).map(nes.fromChunk.asInstanceOf[Chunk[Any] => A])
case Schema.Set(element, _) => decodeSequence(raw, element.asInstanceOf[Schema[Any]]).map(_.toSet.asInstanceOf[A])
case mapSchema: Schema.Map[_, _] =>
decodeMap(raw, mapSchema.asInstanceOf[Schema.Map[Any, Any]]).map(_.asInstanceOf[A])
case mapSchema: Schema.NonEmptyMap[_, _] =>
decodeMap(
raw,
Schema.Map(
mapSchema.keySchema.asInstanceOf[Schema[Any]],
mapSchema.valueSchema.asInstanceOf[Schema[Any]],
mapSchema.annotations
)
).map(mapSchema.asInstanceOf[Schema.NonEmptyMap[Any, Any]].fromMap(_).asInstanceOf[A])
case Schema.Transform(schema, f, _, _, _) =>
decodeValue(raw, schema).flatMap(
a => f(a).left.map(msg => DecodeError.MalformedFieldWithPath(Chunk.single("Error"), msg))
Expand Down Expand Up @@ -662,12 +674,22 @@ object AvroCodec {
c21,
c22
)
case Schema.GenericRecord(typeId, structure, _) => encodeGenericRecord(a, typeId, structure)
case Schema.Primitive(standardType, _) => encodePrimitive(a, standardType)
case Schema.Sequence(element, _, g, _, _) => encodeSequence(element, g(a))
case Schema.Set(element, _) => encodeSet(element, a)
case Schema.GenericRecord(typeId, structure, _) => encodeGenericRecord(a, typeId, structure)
case Schema.Primitive(standardType, _) => encodePrimitive(a, standardType)
case Schema.Sequence(element, _, g, _, _) => encodeSequence(element, g(a))
case Schema.NonEmptySequence(element, _, g, _, _) => encodeSequence(element, g(a))
case Schema.Set(element, _) => encodeSet(element, a)
case mapSchema: Schema.Map[_, _] =>
encodeMap(mapSchema.asInstanceOf[Schema.Map[Any, Any]], a.asInstanceOf[scala.collection.immutable.Map[Any, Any]])
case mapSchema: Schema.NonEmptyMap[_, _] =>
encodeMap(
Schema.Map(
mapSchema.keySchema.asInstanceOf[Schema[Any]],
mapSchema.valueSchema.asInstanceOf[Schema[Any]],
mapSchema.annotations
),
a.asInstanceOf[NonEmptyMap[Any, Any]].toMap
)
case Schema.Transform(schema, _, g, _, _) =>
g(a).map(encodeValue(_, schema)).getOrElse(throw new Exception("Transform failed."))
case Schema.Optional(schema, _) => encodeOption(schema, a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,14 @@ object AvroSchemaCodec extends AvroSchemaCodec {

private def toAvroSchema(schema: Schema[_]): scala.util.Either[String, SchemaAvro] = {
schema match {
case e: Enum[_] => toAvroEnum(e)
case record: Record[_] => toAvroRecord(record)
case map: Schema.Map[_, _] => toAvroMap(map)
case seq: Schema.Sequence[_, _, _] => toAvroSchema(seq.elementSchema).map(SchemaAvro.createArray)
case set: Schema.Set[_] => toAvroSchema(set.elementSchema).map(SchemaAvro.createArray)
case Transform(codec, _, _, _, _) => toAvroSchema(codec)
case e: Enum[_] => toAvroEnum(e)
case record: Record[_] => toAvroRecord(record)
case map: Schema.Map[_, _] => toAvroMap(map)
case map: Schema.NonEmptyMap[_, _] => toAvroMap(map)
case seq: Schema.Sequence[_, _, _] => toAvroSchema(seq.elementSchema).map(SchemaAvro.createArray)
case seq: Schema.NonEmptySequence[_, _, _] => toAvroSchema(seq.elementSchema).map(SchemaAvro.createArray)
case set: Schema.Set[_] => toAvroSchema(set.elementSchema).map(SchemaAvro.createArray)
case Transform(codec, _, _, _, _) => toAvroSchema(codec)
case Primitive(standardType, _) =>
standardType match {
case StandardType.UnitType => Right(SchemaAvro.create(SchemaAvro.Type.NULL))
Expand Down Expand Up @@ -624,6 +626,18 @@ object AvroSchemaCodec extends AvroSchemaCodec {
toAvroSchema(tupleSchema).map(SchemaAvro.createArray)
}

private[codec] def toAvroMap(map: NonEmptyMap[_, _]): scala.util.Either[String, SchemaAvro] =
map.keySchema match {
case p: Schema.Primitive[_] if p.standardType == StandardType.StringType =>
toAvroSchema(map.valueSchema).map(SchemaAvro.createMap)
case _ =>
val tupleSchema = Schema
.Tuple2(map.keySchema, map.valueSchema)
.annotate(AvroAnnotations.name("Tuple"))
.annotate(AvroAnnotations.namespace("scala"))
toAvroSchema(tupleSchema).map(SchemaAvro.createArray)
}

private[codec] def toAvroDecimal(schema: Schema[_]): scala.util.Either[String, SchemaAvro] = {
val scale = schema.annotations.collectFirst { case AvroAnnotations.scale(s) => s }
.getOrElse(AvroAnnotations.scale().scale)
Expand Down Expand Up @@ -820,7 +834,9 @@ object AvroSchemaCodec extends AvroSchemaCodec {
case c: Dynamic => Right(c)
case c: GenericRecord => Right(c)
case c: Map[_, _] => Right(c)
case c: NonEmptyMap[_, _] => Right(c)
case c: Sequence[_, _, _] => Right(c)
case c: NonEmptySequence[_, _, _] => Right(c)
case c: Set[_] => Right(c)
case c: Fail[_] => Right(c)
case c: Lazy[_] => Right(c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1918,10 +1918,7 @@ object AssertionHelper {
def recordFields(assertion: Assertion[Iterable[Schema.Field[_, _]]]): Assertion[Schema.Record[_]] =
Assertion.assertionRec[Schema.Record[_], Chunk[Field[_, _]]]("hasRecordField")(
assertion
) {
case r: Schema.Record[_] => Some(r.fields)
case _ => None
}
)((r: Schema.Record[_]) => Some(r.fields))

def hasSequenceElementSchema[A](assertion: Assertion[Schema[A]]): Assertion[Schema.Sequence[_, A, _]] =
Assertion.hasField("schemaA", _.elementSchema, assertion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,22 +457,24 @@ object BsonSchemaCodec {
//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaEncoder[A](schema: Schema[A]): BsonEncoder[A] =
schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => chunkEncoder(schemaEncoder(schema)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs)
case Schema.Set(s, _) => chunkEncoder(schemaEncoder(s)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, _, _) => transformEncoder(c, g)
case Schema.Tuple2(l, r, _) => tuple2Encoder(schemaEncoder(l), schemaEncoder(r))
case Schema.Optional(schema, _) => BsonEncoder.option(schemaEncoder(schema))
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case Schema.GenericRecord(_, structure, _) => genericRecordEncoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherEncoder(schemaEncoder(left), schemaEncoder(right))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left), schemaEncoder(right))
case l @ Schema.Lazy(_) => schemaEncoder(l.schema)
case r: Schema.Record[A] => caseClassEncoder(r)
case e: Schema.Enum[A] => enumEncoder(e, e.cases)
case d @ Schema.Dynamic(_) => dynamicEncoder(d)
case null => throw new Exception(s"A captured schema is null, most likely due to wrong field initialization order")
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => chunkEncoder(schemaEncoder(schema)).contramap(g)
case Schema.NonEmptySequence(schema, _, g, _, _) => chunkEncoder(schemaEncoder(schema)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs)
case Schema.NonEmptyMap(ks, vs, _) => mapEncoder(ks, vs).contramap(_.toMap)
case Schema.Set(s, _) => chunkEncoder(schemaEncoder(s)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, _, _) => transformEncoder(c, g)
case Schema.Tuple2(l, r, _) => tuple2Encoder(schemaEncoder(l), schemaEncoder(r))
case Schema.Optional(schema, _) => BsonEncoder.option(schemaEncoder(schema))
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case Schema.GenericRecord(_, structure, _) => genericRecordEncoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherEncoder(schemaEncoder(left), schemaEncoder(right))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left), schemaEncoder(right))
case l @ Schema.Lazy(_) => schemaEncoder(l.schema)
case r: Schema.Record[A] => caseClassEncoder(r)
case e: Schema.Enum[A] => enumEncoder(e, e.cases)
case d @ Schema.Dynamic(_) => dynamicEncoder(d)
case null => throw new Exception(s"A captured schema is null, most likely due to wrong field initialization order")
}
//scalafmt: { maxColumn = 120, optIn.configStyleArguments = true }

Expand Down Expand Up @@ -773,22 +775,24 @@ object BsonSchemaCodec {

//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaDecoder[A](schema: Schema[A]): BsonDecoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => BsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple2(left, right, _) => tuple2Decoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => chunkDecoder(schemaDecoder(codec)).map(f)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.Set(s, _) => chunkDecoder(schemaDecoder(s)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherDecoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Fallback(left, right, _, _) => fallbackDecoder(schemaDecoder(left), schemaDecoder(right))
case l @ Schema.Lazy(_) => schemaDecoder(l.schema)
case s: Schema.Record[A] => caseClassDecoder(s)
case e: Schema.Enum[A] => enumDecoder(e)
case d @ Schema.Dynamic(_) => dynamicDecoder(d)
case null => throw new Exception(s"Missing a handler for decoding of schema $schema.")
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => BsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple2(left, right, _) => tuple2Decoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => chunkDecoder(schemaDecoder(codec)).map(f)
case s @ Schema.NonEmptySequence(codec, _, _, _, _) => chunkDecoder(schemaDecoder(codec)).map(s.fromChunk)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case s @ Schema.NonEmptyMap(ks, vs, _) => mapDecoder(ks, vs).map(s.fromMap)
case Schema.Set(s, _) => chunkDecoder(schemaDecoder(s)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherDecoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Fallback(left, right, _, _) => fallbackDecoder(schemaDecoder(left), schemaDecoder(right))
case l @ Schema.Lazy(_) => schemaDecoder(l.schema)
case s: Schema.Record[A] => caseClassDecoder(s)
case e: Schema.Enum[A] => enumDecoder(e)
case d @ Schema.Dynamic(_) => dynamicDecoder(d)
case _ => throw new Exception(s"Missing a handler for decoding of schema $schema.")
}
//scalafmt: { maxColumn = 120, optIn.configStyleArguments = true }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ object BsonSchemaCodecSpec extends ZIOSpecDefault {
implicit lazy val schema: Schema[Tree] = DeriveSchema.gen
implicit lazy val codec: BsonCodec[Tree] = BsonSchemaCodec.bsonCodec(schema)

private val genLeaf = Gen.int.map(Leaf)
private val genLeaf = Gen.int.map(Leaf.apply)

lazy val gen: Gen[Any, Tree] = Gen.sized { i =>
if (i >= 2) Gen.oneOf(genLeaf, Gen.suspend(gen.zipWith(gen)(Branch)).resize(i / 2))
if (i >= 2) Gen.oneOf(genLeaf, Gen.suspend(gen.zipWith(gen)(Branch.apply)).resize(i / 2))
else genLeaf
}
}
Expand Down
Loading

0 comments on commit ce3a315

Please sign in to comment.