Skip to content

Commit

Permalink
Support ObjectId in schema-bson
Browse files Browse the repository at this point in the history
  • Loading branch information
balx committed Apr 3, 2024
1 parent 28930b3 commit 260ac7a
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 96 deletions.
240 changes: 144 additions & 96 deletions zio-schema-bson/src/main/scala/zio/schema/codec/BsonSchemaCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import scala.collection.compat._
import scala.collection.immutable.{ HashMap, ListMap }
import scala.jdk.CollectionConverters._

import org.bson.types.ObjectId
import org.bson.{ BsonDocument, BsonNull, BsonReader, BsonType, BsonValue, BsonWriter }

import zio.bson.BsonBuilder._
Expand Down Expand Up @@ -504,6 +505,10 @@ object BsonSchemaCodec {
directEncoder =>
override def encode(writer: BsonWriter, value: DynamicValue, ctx: BsonEncoder.EncoderContext): Unit =
value match {
case DynamicValue.Record(_, values) if values.headOption.exists(_._1 == bson.ObjectIdTag) =>
val id = values.head._2.toTypedValueOption[String].get
writer.writeObjectId(new ObjectId(id))

case DynamicValue.Record(_, values) =>
val nextCtx = BsonEncoder.EncoderContext.default

Expand Down Expand Up @@ -546,6 +551,11 @@ object BsonSchemaCodec {

override def toBsonValue(value: DynamicValue): BsonValue =
value match {
case DynamicValue.Record(_, values) if values.headOption.exists(_._1 == bson.ObjectIdTag) =>
val id = values.head._2.toTypedValueOption[String].get
val objectId = new ObjectId(id)
objectId.toBsonValue

case DynamicValue.Record(_, values) =>
new BsonDocument(values.view.map {
case (key, value) => element(key, directEncoder.toBsonValue(value))
Expand Down Expand Up @@ -834,7 +844,15 @@ object BsonSchemaCodec {
DynamicValue.Primitive(Chunk.fromArray(bsonValue.asBinary().getData), StandardType.BinaryType)
case BsonType.UNDEFINED => DynamicValue.NoneValue
case BsonType.OBJECT_ID =>
DynamicValue.Primitive(bsonValue.asObjectId().getValue.toHexString, StandardType.StringType)
DynamicValue.Record(
TypeId.Structural,
ListMap(
bson.ObjectIdTag -> DynamicValue.Primitive(
bsonValue.asObjectId().getValue.toHexString,
StandardType.StringType
)
)
)
case BsonType.BOOLEAN => DynamicValue.Primitive(bsonValue.asBoolean().getValue, StandardType.BoolType)
case BsonType.DATE_TIME =>
DynamicValue.Primitive(Instant.ofEpochMilli(bsonValue.asDateTime().getValue), StandardType.InstantType)
Expand Down Expand Up @@ -1145,39 +1163,49 @@ object BsonSchemaCodec {

private val len = nonTransientFields.length

def encode(writer: BsonWriter, value: Z, ctx: BsonEncoder.EncoderContext): Unit = {
val nextCtx = ctx.copy(inlineNextObject = false)
def encode(writer: BsonWriter, value: Z, ctx: BsonEncoder.EncoderContext): Unit =
if (names.size == 1 && names(0) == bson.ObjectIdTag) {
val fieldValue = nonTransientFields(0).get(value)
val id = new ObjectId(fieldValue.toString)
writer.writeObjectId(id)
} else {
val nextCtx = ctx.copy(inlineNextObject = false)

if (!ctx.inlineNextObject) writer.writeStartDocument()

if (!ctx.inlineNextObject) writer.writeStartDocument()
var i = 0

var i = 0
while (i < len) {
val tc = tcs(i)
val fieldValue = nonTransientFields(i).get(value)

while (i < len) {
val tc = tcs(i)
val fieldValue = nonTransientFields(i).get(value)
if (keepNulls || !tc.isAbsent(fieldValue)) {
writer.writeName(names(i))
tc.encode(writer, fieldValue, nextCtx)
}

if (keepNulls || !tc.isAbsent(fieldValue)) {
writer.writeName(names(i))
tc.encode(writer, fieldValue, nextCtx)
i += 1
}

i += 1
if (!ctx.inlineNextObject) writer.writeEndDocument()
}

if (!ctx.inlineNextObject) writer.writeEndDocument()
}

def toBsonValue(value: Z): BsonValue = {
val elements = nonTransientFields.indices.view.flatMap { idx =>
val fieldValue = nonTransientFields(idx).get(value)
val tc = tcs(idx)
def toBsonValue(value: Z): BsonValue =
if (names.size == 1 && names(0) == bson.ObjectIdTag) {
val fieldValue = nonTransientFields(0).get(value)
val id = new ObjectId(fieldValue.toString)
id.toBsonValue
} else {
val elements = nonTransientFields.indices.view.flatMap { idx =>
val fieldValue = nonTransientFields(idx).get(value)
val tc = tcs(idx)

if (keepNulls || !tc.isAbsent(fieldValue)) Some(element(names(idx), tc.toBsonValue(fieldValue)))
else None
}.to(Chunk)
if (keepNulls || !tc.isAbsent(fieldValue)) Some(element(names(idx), tc.toBsonValue(fieldValue)))
else None
}.to(Chunk)

new BsonDocument(elements.asJava)
}
new BsonDocument(elements.asJava)
}
}
}

Expand All @@ -1189,7 +1217,6 @@ object BsonSchemaCodec {
private[codec] def caseClassDecoder[Z](caseClassSchema: Schema.Record[Z]): BsonDecoder[Z] = {
val fields = caseClassSchema.fields
val len: Int = fields.length
Array.ofDim[Any](len)
val fieldNames = fields.map { f =>
f.annotations.collectFirst { case bsonField(n) => n }.getOrElse(f.name.asInstanceOf[String])
}.toArray
Expand All @@ -1209,91 +1236,112 @@ object BsonSchemaCodec {
lazy val tcs: Array[BsonDecoder[Any]] = schemas.map(s => schemaDecoder(s).asInstanceOf[BsonDecoder[Any]])

new BsonDecoder[Z] {
def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z = unsafeCall(trace) {
reader.readStartDocument()

val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
val name = reader.readName()
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).decodeUnsafe(reader, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.decodeUnsafe(reader, nextTrace, nextCtx)
def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z =
if (fieldNames.size == 1 && fieldNames(0) == bson.ObjectIdTag) {
val id = reader.readObjectId.toHexString
Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(Array(id))) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
} else {
unsafeCall(trace) {
reader.readStartDocument()

val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
val name = reader.readName()
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).decodeUnsafe(reader, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.decodeUnsafe(reader, nextTrace, nextCtx)
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name)) {
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
} else reader.skipValue()
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name)) {
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
} else reader.skipValue()
}

var i = 0
while (i < len) {
if (ps(i) == null) {
if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
ps(i) = fields(i).defaultValue.get
} else {
ps(i) = tcs(i).decodeMissingUnsafe(spans(i) :: trace)
var i = 0
while (i < len) {
if (ps(i) == null) {
if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
ps(i) = fields(i).defaultValue.get
} else {
ps(i) = tcs(i).decodeMissingUnsafe(spans(i) :: trace)
}
}
i += 1
}
}
i += 1
}

reader.readEndDocument()
reader.readEndDocument()

Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
}
}
}

def fromBsonValueUnsafe(value: BsonValue, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z =
assumeType(trace)(BsonType.DOCUMENT, value) { value =>
val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

value.asDocument().asScala.foreachEntry { (name, value) =>
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).fromBsonValueUnsafe(value, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.fromBsonValueUnsafe(value, nextTrace, nextCtx)
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name))
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
if (value.getBsonType == BsonType.OBJECT_ID) {
Unsafe.unsafe { implicit u =>
val ps: Array[Any] = Array(value.asObjectId.getValue.toHexString)
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
} else {
assumeType(trace)(BsonType.DOCUMENT, value) { value =>
val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

value.asDocument().asScala.foreachEntry { (name, value) =>
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).fromBsonValueUnsafe(value, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.fromBsonValueUnsafe(value, nextTrace, nextCtx)
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name))
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
}

var i = 0
while (i < len) {
if (ps(i) == null) {
ps(i) = if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
fields(i).defaultValue.get
} else {
tcs(i).decodeMissingUnsafe(spans(i) :: trace)
var i = 0
while (i < len) {
if (ps(i) == null) {
ps(i) = if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
fields(i).defaultValue.get
} else {
tcs(i).decodeMissingUnsafe(spans(i) :: trace)
}
}
i += 1
}
i += 1
}

Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions zio-schema-bson/src/main/scala/zio/schema/codec/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package zio.schema.codec

import org.bson.types.ObjectId

import zio.schema.{ Schema, TypeId }

package object bson {
val ObjectIdTag = "$oid"

implicit val ObjectIdSchema: Schema[ObjectId] =
Schema.CaseClass1[String, ObjectId](
id0 = TypeId.fromTypeName("ObjectId"),
field0 = Schema.Field(
name0 = ObjectIdTag,
schema0 = Schema[String],
get0 = _.toHexString,
set0 = (_, idStr) => new ObjectId(idStr)
),
defaultConstruct0 = new ObjectId(_)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.bson.codecs.configuration.CodecRegistry
import org.bson.codecs.{ Codec => BCodec, DecoderContext, EncoderContext }
import org.bson.conversions.Bson
import org.bson.io.BasicOutputBuffer
import org.bson.types.ObjectId

import zio.bson.BsonBuilder._
import zio.bson._
Expand Down Expand Up @@ -50,6 +51,34 @@ object BsonSchemaCodecSpec extends ZIOSpecDefault {
implicit lazy val codec: BsonCodec[EnumLike] = BsonSchemaCodec.bsonCodec(schema)
}

case class CustomerId(value: ObjectId) extends AnyVal
case class Customer(id: CustomerId, name: String, age: Int, invitedFriends: List[CustomerId])

object Customer {
implicit lazy val customerIdSchema: Schema[CustomerId] = bson.ObjectIdSchema.transform(CustomerId(_), _.value)

implicit lazy val customerSchema: Schema[Customer] = DeriveSchema.gen[Customer]
implicit lazy val customerCodec: BsonCodec[Customer] = BsonSchemaCodec.bsonCodec(customerSchema)

val example: Customer = Customer(
id = CustomerId(ObjectId.get),
name = "Joseph",
age = 18,
invitedFriends = List(CustomerId(ObjectId.get), CustomerId(ObjectId.get))
)

lazy val genCustomerId: Gen[Any, CustomerId] =
Gen.vectorOfN(12)(Gen.byte).map(bs => new ObjectId(bs.toArray)).map(CustomerId.apply)

def gen: Gen[Sized, Customer] =
for {
id <- genCustomerId
name <- Gen.string
age <- Gen.int
friends <- Gen.listOf(genCustomerId)
} yield Customer(id, name, age, friends)
}

def spec: Spec[TestEnvironment with Scope, Any] = suite("BsonSchemaCodecSpec")(
suite("round trip")(
roundTripTest("SimpleClass")(
Expand All @@ -66,6 +95,18 @@ object BsonSchemaCodecSpec extends ZIOSpecDefault {
Gen.fromIterable(Chunk(EnumLike.A, EnumLike.B)),
EnumLike.A,
str("A")
),
roundTripTest("Customer")(
Customer.gen,
Customer.example,
doc(
"id" -> Customer.example.id.value.toBsonValue,
"name" -> str(Customer.example.name),
"age" -> int(Customer.example.age),
"invitedFriends" -> array(
Customer.example.invitedFriends.map(_.value.toBsonValue): _*
)
)
)
),
suite("configuration")(
Expand Down

0 comments on commit 260ac7a

Please sign in to comment.