Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ObjectId in schema-bson #673

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

build:
runs-on: ubuntu-20.04
timeout-minutes: 30
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -71,4 +71,4 @@ jobs:
PGP_SECRET: ${{ secrets.PGP_SECRET }}
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}


240 changes: 144 additions & 96 deletions zio-schema-bson/src/main/scala/zio/schema/codec/BsonSchemaCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import scala.collection.compat._
import scala.collection.immutable.{ HashMap, ListMap }
import scala.jdk.CollectionConverters._

import org.bson.types.ObjectId
import org.bson.{ BsonDocument, BsonNull, BsonReader, BsonType, BsonValue, BsonWriter }

import zio.bson.BsonBuilder._
Expand Down Expand Up @@ -505,6 +506,10 @@ object BsonSchemaCodec {
directEncoder =>
override def encode(writer: BsonWriter, value: DynamicValue, ctx: BsonEncoder.EncoderContext): Unit =
value match {
case DynamicValue.Record(_, values) if values.headOption.exists(_._1 == bson.ObjectIdTag) =>
val id = values.head._2.toTypedValueOption[String].get
writer.writeObjectId(new ObjectId(id))

case DynamicValue.Record(_, values) =>
val nextCtx = BsonEncoder.EncoderContext.default

Expand Down Expand Up @@ -547,6 +552,11 @@ object BsonSchemaCodec {

override def toBsonValue(value: DynamicValue): BsonValue =
value match {
case DynamicValue.Record(_, values) if values.headOption.exists(_._1 == bson.ObjectIdTag) =>
val id = values.head._2.toTypedValueOption[String].get
val objectId = new ObjectId(id)
objectId.toBsonValue

case DynamicValue.Record(_, values) =>
new BsonDocument(values.view.map {
case (key, value) => element(key, directEncoder.toBsonValue(value))
Expand Down Expand Up @@ -835,7 +845,15 @@ object BsonSchemaCodec {
DynamicValue.Primitive(Chunk.fromArray(bsonValue.asBinary().getData), StandardType.BinaryType)
case BsonType.UNDEFINED => DynamicValue.NoneValue
case BsonType.OBJECT_ID =>
DynamicValue.Primitive(bsonValue.asObjectId().getValue.toHexString, StandardType.StringType)
DynamicValue.Record(
TypeId.Structural,
ListMap(
bson.ObjectIdTag -> DynamicValue.Primitive(
bsonValue.asObjectId().getValue.toHexString,
StandardType.StringType
)
)
)
case BsonType.BOOLEAN => DynamicValue.Primitive(bsonValue.asBoolean().getValue, StandardType.BoolType)
case BsonType.DATE_TIME =>
DynamicValue.Primitive(Instant.ofEpochMilli(bsonValue.asDateTime().getValue), StandardType.InstantType)
Expand Down Expand Up @@ -1146,39 +1164,49 @@ object BsonSchemaCodec {

private val len = nonTransientFields.length

def encode(writer: BsonWriter, value: Z, ctx: BsonEncoder.EncoderContext): Unit = {
val nextCtx = ctx.copy(inlineNextObject = false)
def encode(writer: BsonWriter, value: Z, ctx: BsonEncoder.EncoderContext): Unit =
if (names.size == 1 && names(0) == bson.ObjectIdTag) {
val fieldValue = nonTransientFields(0).get(value)
val id = new ObjectId(fieldValue.toString)
writer.writeObjectId(id)
} else {
val nextCtx = ctx.copy(inlineNextObject = false)

if (!ctx.inlineNextObject) writer.writeStartDocument()

if (!ctx.inlineNextObject) writer.writeStartDocument()
var i = 0

var i = 0
while (i < len) {
val tc = tcs(i)
val fieldValue = nonTransientFields(i).get(value)

while (i < len) {
val tc = tcs(i)
val fieldValue = nonTransientFields(i).get(value)
if (keepNulls || !tc.isAbsent(fieldValue)) {
writer.writeName(names(i))
tc.encode(writer, fieldValue, nextCtx)
}

if (keepNulls || !tc.isAbsent(fieldValue)) {
writer.writeName(names(i))
tc.encode(writer, fieldValue, nextCtx)
i += 1
}

i += 1
if (!ctx.inlineNextObject) writer.writeEndDocument()
}

if (!ctx.inlineNextObject) writer.writeEndDocument()
}

def toBsonValue(value: Z): BsonValue = {
val elements = nonTransientFields.indices.view.flatMap { idx =>
val fieldValue = nonTransientFields(idx).get(value)
val tc = tcs(idx)
def toBsonValue(value: Z): BsonValue =
if (names.size == 1 && names(0) == bson.ObjectIdTag) {
val fieldValue = nonTransientFields(0).get(value)
val id = new ObjectId(fieldValue.toString)
id.toBsonValue
} else {
val elements = nonTransientFields.indices.view.flatMap { idx =>
val fieldValue = nonTransientFields(idx).get(value)
val tc = tcs(idx)

if (keepNulls || !tc.isAbsent(fieldValue)) Some(element(names(idx), tc.toBsonValue(fieldValue)))
else None
}.to(Chunk)
if (keepNulls || !tc.isAbsent(fieldValue)) Some(element(names(idx), tc.toBsonValue(fieldValue)))
else None
}.to(Chunk)

new BsonDocument(elements.asJava)
}
new BsonDocument(elements.asJava)
}
}
}

Expand All @@ -1190,7 +1218,6 @@ object BsonSchemaCodec {
private[codec] def caseClassDecoder[Z](caseClassSchema: Schema.Record[Z]): BsonDecoder[Z] = {
val fields = caseClassSchema.fields
val len: Int = fields.length
Array.ofDim[Any](len)
val fieldNames = fields.map { f =>
f.annotations.collectFirst { case bsonField(n) => n }.getOrElse(f.name.asInstanceOf[String])
}.toArray
Expand All @@ -1210,91 +1237,112 @@ object BsonSchemaCodec {
lazy val tcs: Array[BsonDecoder[Any]] = schemas.map(s => schemaDecoder(s).asInstanceOf[BsonDecoder[Any]])

new BsonDecoder[Z] {
def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z = unsafeCall(trace) {
reader.readStartDocument()

val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
val name = reader.readName()
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).decodeUnsafe(reader, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.decodeUnsafe(reader, nextTrace, nextCtx)
def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z =
if (fieldNames.size == 1 && fieldNames(0) == bson.ObjectIdTag) {
val id = reader.readObjectId.toHexString
Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(Array(id))) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
} else {
unsafeCall(trace) {
reader.readStartDocument()

val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
val name = reader.readName()
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).decodeUnsafe(reader, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.decodeUnsafe(reader, nextTrace, nextCtx)
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name)) {
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
} else reader.skipValue()
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name)) {
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
} else reader.skipValue()
}

var i = 0
while (i < len) {
if (ps(i) == null) {
if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
ps(i) = fields(i).defaultValue.get
} else {
ps(i) = tcs(i).decodeMissingUnsafe(spans(i) :: trace)
var i = 0
while (i < len) {
if (ps(i) == null) {
if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
ps(i) = fields(i).defaultValue.get
} else {
ps(i) = tcs(i).decodeMissingUnsafe(spans(i) :: trace)
}
}
i += 1
}
}
i += 1
}

reader.readEndDocument()
reader.readEndDocument()

Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
}
}
}

def fromBsonValueUnsafe(value: BsonValue, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z =
assumeType(trace)(BsonType.DOCUMENT, value) { value =>
val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

value.asDocument().asScala.foreachEntry { (name, value) =>
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).fromBsonValueUnsafe(value, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.fromBsonValueUnsafe(value, nextTrace, nextCtx)
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name))
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
if (value.getBsonType == BsonType.OBJECT_ID) {
Unsafe.unsafe { implicit u =>
val ps: Array[Any] = Array(value.asObjectId.getValue.toHexString)
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
} else {
assumeType(trace)(BsonType.DOCUMENT, value) { value =>
val nextCtx = BsonDecoder.BsonDecoderContext.default
val ps: Array[Any] = Array.ofDim(len)

value.asDocument().asScala.foreachEntry { (name, value) =>
val idx = indexes.getOrElse(name, -1)

if (idx >= 0) {
val nextTrace = spans(idx) :: trace
val tc = tcs(idx)
if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate")
ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) {
val opt = BsonDecoder.option(tc).fromBsonValueUnsafe(value, nextTrace, nextCtx)
opt.getOrElse(fields(idx).defaultValue.get)
} else {
tc.fromBsonValueUnsafe(value, nextTrace, nextCtx)
}
} else if (noExtra && !ctx.ignoreExtraField.contains(name))
throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.")
}

var i = 0
while (i < len) {
if (ps(i) == null) {
ps(i) = if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
fields(i).defaultValue.get
} else {
tcs(i).decodeMissingUnsafe(spans(i) :: trace)
var i = 0
while (i < len) {
if (ps(i) == null) {
ps(i) = if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) {
fields(i).defaultValue.get
} else {
tcs(i).decodeMissingUnsafe(spans(i) :: trace)
}
}
i += 1
}
i += 1
}

Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
Unsafe.unsafe { implicit u =>
caseClassSchema.construct(Chunk.fromArray(ps)) match {
case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err")
case Right(value) => value
}
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions zio-schema-bson/src/main/scala/zio/schema/codec/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package zio.schema.codec

import org.bson.types.ObjectId

import zio.schema.{ Schema, TypeId }

package object bson {
val ObjectIdTag = "$oid"

implicit val ObjectIdSchema: Schema[ObjectId] =
Schema.CaseClass1[String, ObjectId](
id0 = TypeId.fromTypeName("ObjectId"),
field0 = Schema.Field(
name0 = ObjectIdTag,
schema0 = Schema[String],
get0 = _.toHexString,
set0 = (_, idStr) => new ObjectId(idStr)
),
defaultConstruct0 = new ObjectId(_)
)

}
Loading
Loading