Skip to content

Commit

Permalink
Additional CBOR types (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
remade authored Oct 22, 2024
1 parent 08ad39f commit f98f67e
Show file tree
Hide file tree
Showing 14 changed files with 568 additions and 115 deletions.
10 changes: 5 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func Query[TResult any](db *DB, sql string, vars map[string]interface{}) (*[]Que
return res.Result, nil
}

func Create[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
func Create[TResult any, TWhat TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
var res connection.RPCResponse[TResult]
if err := db.con.Send(&res, "create", what, data); err != nil {
return nil, err
Expand All @@ -210,7 +210,7 @@ func Create[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data in
return res.Result, nil
}

func Select[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat) (*TResult, error) {
func Select[TResult any, TWhat TableOrRecord](db *DB, what TWhat) (*TResult, error) {
var res connection.RPCResponse[TResult]

if err := db.con.Send(&res, "select", what); err != nil {
Expand All @@ -226,16 +226,16 @@ func Patch(db *DB, what interface{}, patches []PatchData) (*[]PatchData, error)
return patchRes.Result, err
}

func Delete[TWhat models.TableOrRecord](db *DB, what TWhat) error {
func Delete[TWhat TableOrRecord](db *DB, what TWhat) error {
return db.con.Send(nil, "delete", what)
}

func Upsert[TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) error {
func Upsert[TWhat TableOrRecord](db *DB, what TWhat, data interface{}) error {
return db.con.Send(nil, "upsert", what, data)
}

// Update a table or record in the database like a PUT request.
func Update[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
func Update[TResult any, TWhat TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) {
var res connection.RPCResponse[TResult]
if err := db.con.Send(&res, "update", what, data); err != nil {
return nil, err
Expand Down
86 changes: 47 additions & 39 deletions pkg/models/cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,68 @@ package models
import (
"io"
"reflect"
"time"

"github.com/fxamacker/cbor/v2"
"github.com/surrealdb/surrealdb.go/internal/codec"
)

type CustomCBORTag uint64

var (
NoneTag CustomCBORTag = 6
TableNameTag CustomCBORTag = 7
RecordIDTag CustomCBORTag = 8
UUIDStringTag CustomCBORTag = 9
DecimalStringTag CustomCBORTag = 10
DateTimeCompactString CustomCBORTag = 12
DurationStringTag CustomCBORTag = 13
DurationCompactTag CustomCBORTag = 14
BinaryUUIDTag CustomCBORTag = 37
GeometryPointTag CustomCBORTag = 88
GeometryLineTag CustomCBORTag = 89
GeometryPolygonTag CustomCBORTag = 90
GeometryMultiPointTag CustomCBORTag = 91
GeometryMultiLineTag CustomCBORTag = 92
GeometryMultiPolygonTag CustomCBORTag = 93
GeometryCollectionTag CustomCBORTag = 94
TagNone uint64 = 6
TagTable uint64 = 7
TagRecordID uint64 = 8
TagCustomDatetime uint64 = 12
TagCustomDuration uint64 = 14
TagFuture uint64 = 15

TagStringUUID uint64 = 9
TagStringDecimal uint64 = 10
TagStringDuration uint64 = 13

TagSpecBinaryUUID uint64 = 37

TagRange uint64 = 49
TagBoundIncluded uint64 = 50
TagBoundExcluded uint64 = 51

TagGeometryPoint uint64 = 88
TagGeometryLine uint64 = 89
TagGeometryPolygon uint64 = 90
TagGeometryMultiPoint uint64 = 91
TagGeometryMultiLine uint64 = 92
TagGeometryMultiPolygon uint64 = 93
TagGeometryCollection uint64 = 94
)

func registerCborTags() cbor.TagSet {
customTags := map[CustomCBORTag]interface{}{
GeometryPointTag: GeometryPoint{},
GeometryLineTag: GeometryLine{},
GeometryPolygonTag: GeometryPolygon{},
GeometryMultiPointTag: GeometryMultiPoint{},
GeometryMultiLineTag: GeometryMultiLine{},
GeometryMultiPolygonTag: GeometryMultiPolygon{},
GeometryCollectionTag: GeometryCollection{},

TableNameTag: Table(""),
//UUIDStringTag: UUID(""),
DecimalStringTag: Decimal(""),
BinaryUUIDTag: UUID{},
NoneTag: CustomNil{},

DateTimeCompactString: CustomDateTime(time.Now()),
DurationStringTag: CustomDurationStr("2w"),
//DurationCompactTag: CustomDuration(0),
customTags := map[uint64]interface{}{
TagNone: CustomNil{},
TagTable: Table(""),
TagRecordID: RecordID{},

TagCustomDatetime: CustomDateTime{},
TagCustomDuration: CustomDuration{},
TagFuture: Future{},

TagStringUUID: UUIDString(""),
TagStringDecimal: DecimalString(""),
TagStringDuration: CustomDurationString(""),

TagSpecBinaryUUID: UUID{},

TagGeometryPoint: GeometryPoint{},
TagGeometryLine: GeometryLine{},
TagGeometryPolygon: GeometryPolygon{},
TagGeometryMultiPoint: GeometryMultiPoint{},
TagGeometryMultiLine: GeometryMultiLine{},
TagGeometryMultiPolygon: GeometryMultiPolygon{},
TagGeometryCollection: GeometryCollection{},
}

tags := cbor.NewTagSet()
for tag, customType := range customTags {
err := tags.Add(
cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired},
reflect.TypeOf(customType),
uint64(tag),
tag,
)
if err != nil {
panic(err)
Expand Down
131 changes: 123 additions & 8 deletions pkg/models/cbor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ func TestForRequestPayload(t *testing.T) {
params := []interface{}{
"SELECT marketing, count() FROM $tb GROUP BY marketing",
map[string]interface{}{
"tb": Table("person"),
"line": GeometryLine{NewGeometryPoint(11.11, 22.22), NewGeometryPoint(33.33, 44.44)},
"datetime": time.Now(),
"testNone": None,
"testNil": nil,
"duration": time.Duration(340),
// "custom_duration": CustomDuration(340),
"custom_datetime": CustomDateTime(time.Now()),
"tb": Table("person"),
"line": GeometryLine{NewGeometryPoint(11.11, 22.22), NewGeometryPoint(33.33, 44.44)},
"datetime": time.Now(),
"testNone": None,
"testNil": nil,
"duration": time.Duration(340),
"custom_duration": CustomDuration{340},
"custom_datetime": CustomDateTime{time.Now()},
},
}

Expand All @@ -94,3 +94,118 @@ func TestForRequestPayload(t *testing.T) {

fmt.Println(diagStr)
}

func TestRange_GetJoinString(t *testing.T) {
t.Run("begin excluded, end excluded", func(s *testing.T) {
r := &Range[int, BoundExcluded[int], BoundExcluded[int]]{
Begin: &BoundExcluded[int]{0},
End: &BoundExcluded[int]{10},
}
assert.Equal(t, ">..", r.GetJoinString())
})

t.Run("begin excluded, end included", func(t *testing.T) {
r := Range[int, BoundExcluded[int], BoundIncluded[int]]{
Begin: &BoundExcluded[int]{0},
End: &BoundIncluded[int]{10},
}
assert.Equal(t, ">..=", r.GetJoinString())
})

t.Run("begin included, end excluded", func(t *testing.T) {
r := Range[int, BoundIncluded[int], BoundExcluded[int]]{
Begin: &BoundIncluded[int]{0},
End: &BoundExcluded[int]{10},
}
assert.Equal(t, "..", r.GetJoinString())
})

t.Run("begin included, end included", func(t *testing.T) {
r := Range[int, BoundIncluded[int], BoundIncluded[int]]{
Begin: &BoundIncluded[int]{0},
End: &BoundIncluded[int]{10},
}
assert.Equal(t, "..=", r.GetJoinString())
})
}

func TestRange_Bounds(t *testing.T) {
em := getCborEncoder()
dm := getCborDecoder()

t.Run("bound included should be marshaled and unmarshaled properly", func(t *testing.T) {
bi := BoundIncluded[int]{10}
encoded, err := em.Marshal(bi)
assert.NoError(t, err)

var decoded BoundIncluded[int]
err = dm.Unmarshal(encoded, &decoded)
assert.NoError(t, err)
assert.Equal(t, bi, decoded)
})

t.Run("bound excluded should be marshaled and unmarshaled properly", func(t *testing.T) {
be := BoundExcluded[int]{10}
encoded, err := em.Marshal(be)
assert.NoError(t, err)

var decoded BoundExcluded[int]
err = dm.Unmarshal(encoded, &decoded)
assert.NoError(t, err)
assert.Equal(t, be, decoded)
})
}

func TestRange_CODEC(t *testing.T) {
em := getCborEncoder()
dm := getCborDecoder()

r := Range[int, BoundIncluded[int], BoundExcluded[int]]{
Begin: &BoundIncluded[int]{0},
End: &BoundExcluded[int]{10},
}

encoded, err := em.Marshal(r)
assert.NoError(t, err)

var decoded Range[int, BoundIncluded[int], BoundExcluded[int]]
err = dm.Unmarshal(encoded, &decoded)
assert.NoError(t, err)
assert.Equal(t, r, decoded)
}

func TestCustomDateTime_String(t *testing.T) {
time1, err := time.Parse("2006-01-02 15:04:05", "2024-10-30 12:05:00")
assert.NoError(t, err)

cd := CustomDateTime{time1}
assert.Equal(t, "2024-10-30T12:05:00Z", cd.String())
}

func TestTable_String(t *testing.T) {
table := Table("mytesttable")
assert.Equal(t, "mytesttable", table.String())
}

func TestCustomDuration_String(t *testing.T) {
cd := CustomDuration{time.Duration(33333333333000000)}
assert.Equal(t, "1y2w6d19h15m33s333ms", cd.String())
}

func TestRecordID_String(t *testing.T) {
rid := RecordID{Table: "mytesttable", ID: "121212121"}
assert.Equal(t, "mytesttable:121212121", rid.String())
}

func TestFormatDurationAndParseDuration(t *testing.T) {
durationStr := "1y2w6d19h15m33s333ms"

ns, _ := ParseDuration(durationStr)
d := FormatDuration(ns)
assert.Equal(t, durationStr, d)
}

func TestFormatDuration(t *testing.T) {
d := FormatDuration(33333333333000000)
assert.Equal(t, "1y2w6d19h15m33s333ms", d)
}
54 changes: 54 additions & 0 deletions pkg/models/datetime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package models

import (
"fmt"
"time"

"github.com/fxamacker/cbor/v2"
"github.com/surrealdb/surrealdb.go/pkg/constants"
)

// CustomDateTime embeds time.Time
type CustomDateTime struct {
time.Time
}

func (d *CustomDateTime) MarshalCBOR() ([]byte, error) {
enc := getCborEncoder()

totalNS := d.Nanosecond()

s := totalNS / constants.OneSecondToNanoSecond
ns := totalNS % constants.OneSecondToNanoSecond

return enc.Marshal(cbor.Tag{
Number: TagCustomDatetime,
Content: [2]int64{int64(s), int64(ns)},
})
}

func (d *CustomDateTime) UnmarshalCBOR(data []byte) error {
dec := getCborDecoder()

var temp [2]int64
err := dec.Unmarshal(data, &temp)
if err != nil {
return err
}

s := temp[0]
ns := temp[1]

*d = CustomDateTime{time.Unix(s, ns)}

return nil
}

func (d *CustomDateTime) String() string {
layout := "2006-01-02T15:04:05Z"
return d.Format(layout)
}

func (d *CustomDateTime) SurrealString() string {
return fmt.Sprintf("<datetime> '%s'", d.String())
}
Loading

0 comments on commit f98f67e

Please sign in to comment.