diff --git a/cli/collection_create.go b/cli/collection_create.go index 994911a14c..f4c36fbd53 100644 --- a/cli/collection_create.go +++ b/cli/collection_create.go @@ -17,14 +17,25 @@ import ( "github.com/spf13/cobra" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/internal/db" ) func MakeCollectionCreateCommand() *cobra.Command { var file string + var shouldEncrypt bool var cmd = &cobra.Command{ - Use: "create [-i --identity] ", + Use: "create [-i --identity] [-e --encrypt] ", Short: "Create a new document.", Long: `Create a new document. + +Options: + -i, --identity + Marks the document as private and set the identity as the owner. The access to the document + and permissions are controlled by ACP (Access Control Policy). + + -e, --encrypt + Encrypt flag specified if the document needs to be encrypted. If set, DefraDB will generate a + symmetric key for encryption using AES-GCM. Example: create from string: defradb client collection create --name User '{ "name": "Bob" }' @@ -69,6 +80,9 @@ Example: create from stdin: return cmd.Usage() } + txn, _ := db.TryGetContextTxn(cmd.Context()) + setContextDocEncryption(cmd, shouldEncrypt, txn) + if client.IsJSONArray(docData) { docs, err := client.NewDocsFromJSON(docData, col.Definition()) if err != nil { @@ -84,6 +98,8 @@ Example: create from stdin: return col.Create(cmd.Context(), doc) }, } + cmd.PersistentFlags().BoolVarP(&shouldEncrypt, "encrypt", "e", false, + "Flag to enable encryption of the document") cmd.Flags().StringVarP(&file, "file", "f", "", "File containing document(s)") return cmd } diff --git a/cli/utils.go b/cli/utils.go index d1ee09962b..b2d4c076bc 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -25,8 +25,10 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/http" "github.com/sourcenetwork/defradb/internal/db" + "github.com/sourcenetwork/defradb/internal/encryption" "github.com/sourcenetwork/defradb/keyring" ) @@ -160,6 +162,19 @@ func setContextIdentity(cmd *cobra.Command, privateKeyHex string) error { return nil } +// setContextDocEncryption sets doc encryption for the current command context. +func setContextDocEncryption(cmd *cobra.Command, shouldEncrypt bool, txn datastore.Txn) { + if !shouldEncrypt { + return + } + ctx := cmd.Context() + if txn != nil { + ctx = encryption.ContextWithStore(ctx, txn) + } + ctx = encryption.SetContextConfig(ctx, encryption.DocEncConfig{IsEncrypted: true}) + cmd.SetContext(ctx) +} + // setContextRootDir sets the rootdir for the current command context. func setContextRootDir(cmd *cobra.Command) error { rootdir, err := cmd.Root().PersistentFlags().GetString("rootdir") diff --git a/client/document.go b/client/document.go index ada47cc8f9..f8d427c349 100644 --- a/client/document.go +++ b/client/document.go @@ -118,7 +118,7 @@ func NewDocFromMap(data map[string]any, collectionDefinition CollectionDefinitio return doc, nil } -var jsonArrayPattern = regexp.MustCompile(`^\s*\[.*\]\s*$`) +var jsonArrayPattern = regexp.MustCompile(`(?s)^\s*\[.*\]\s*$`) // IsJSONArray returns true if the given byte array is a JSON Array. func IsJSONArray(obj []byte) bool { diff --git a/client/document_test.go b/client/document_test.go index b15c7b019a..a11a6a67c8 100644 --- a/client/document_test.go +++ b/client/document_test.go @@ -206,3 +206,66 @@ func TestNewFromJSON_WithInvalidJSONFieldValueSimpleString_Error(t *testing.T) { _, err := NewDocFromJSON(objWithJSONField, def) require.ErrorContains(t, err, "invalid JSON payload. Payload: blah") } + +func TestIsJSONArray(t *testing.T) { + tests := []struct { + name string + input []byte + expected bool + }{ + { + name: "Valid JSON Array", + input: []byte(`[{"name":"John","age":21},{"name":"Islam","age":33}]`), + expected: true, + }, + { + name: "Valid Empty JSON Array", + input: []byte(`[]`), + expected: true, + }, + { + name: "Valid JSON Object", + input: []byte(`{"name":"John","age":21}`), + expected: false, + }, + { + name: "Invalid JSON String", + input: []byte(`{"name":"John","age":21`), + expected: false, + }, + { + name: "Non-JSON String", + input: []byte(`Hello, World!`), + expected: false, + }, + { + name: "Array of Primitives", + input: []byte(`[1, 2, 3, 4]`), + expected: true, + }, + { + name: "Nested JSON Array", + input: []byte(`[[1, 2], [3, 4]]`), + expected: true, + }, + { + name: "Valid JSON Array with Whitespace", + input: []byte(` + [ + { "name": "John", "age": 21 }, + { "name": "Islam", "age": 33 } + ] + `), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := IsJSONArray(tt.input) + if actual != tt.expected { + t.Errorf("IsJSONArray(%s) = %v; expected %v", tt.input, actual, tt.expected) + } + }) + } +} diff --git a/client/request/consts.go b/client/request/consts.go index 1a1d653a25..cba609b788 100644 --- a/client/request/consts.go +++ b/client/request/consts.go @@ -21,10 +21,13 @@ const ( Cid = "cid" Input = "input" + Inputs = "inputs" FieldName = "field" FieldIDName = "fieldId" ShowDeleted = "showDeleted" + EncryptArgName = "encrypt" + FilterClause = "filter" GroupByClause = "groupBy" LimitClause = "limit" diff --git a/client/request/mutation.go b/client/request/mutation.go index 81fcc823c9..70d0bed1d9 100644 --- a/client/request/mutation.go +++ b/client/request/mutation.go @@ -41,6 +41,15 @@ type ObjectMutation struct { // // This is ignored for [DeleteObjects] mutations. Input map[string]any + + // Inputs is the array of json representations of the fieldName-value pairs of document + // properties to mutate. + // + // This is ignored for [DeleteObjects] mutations. + Inputs []map[string]any + + // Encrypt is a boolean flag that indicates whether the input data should be encrypted. + Encrypt bool } // ToSelect returns a basic Select object, with the same Name, Alias, and Fields as diff --git a/datastore/mocks/txn.go b/datastore/mocks/txn.go index f29c045dcd..41606260ea 100644 --- a/datastore/mocks/txn.go +++ b/datastore/mocks/txn.go @@ -195,6 +195,53 @@ func (_c *Txn_Discard_Call) RunAndReturn(run func(context.Context)) *Txn_Discard return _c } +// Encstore provides a mock function with given fields: +func (_m *Txn) Encstore() datastore.DSReaderWriter { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Encstore") + } + + var r0 datastore.DSReaderWriter + if rf, ok := ret.Get(0).(func() datastore.DSReaderWriter); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(datastore.DSReaderWriter) + } + } + + return r0 +} + +// Txn_Encstore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Encstore' +type Txn_Encstore_Call struct { + *mock.Call +} + +// Encstore is a helper method to define mock.On call +func (_e *Txn_Expecter) Encstore() *Txn_Encstore_Call { + return &Txn_Encstore_Call{Call: _e.mock.On("Encstore")} +} + +func (_c *Txn_Encstore_Call) Run(run func()) *Txn_Encstore_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Txn_Encstore_Call) Return(_a0 datastore.DSReaderWriter) *Txn_Encstore_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Txn_Encstore_Call) RunAndReturn(run func() datastore.DSReaderWriter) *Txn_Encstore_Call { + _c.Call.Return(run) + return _c +} + // Headstore provides a mock function with given fields: func (_m *Txn) Headstore() datastore.DSReaderWriter { ret := _m.Called() diff --git a/datastore/multi.go b/datastore/multi.go index a70a24a60d..f863924d5d 100644 --- a/datastore/multi.go +++ b/datastore/multi.go @@ -23,11 +23,13 @@ var ( headStoreKey = rootStoreKey.ChildString("heads") blockStoreKey = rootStoreKey.ChildString("blocks") peerStoreKey = rootStoreKey.ChildString("ps") + encStoreKey = rootStoreKey.ChildString("enc") ) type multistore struct { root DSReaderWriter data DSReaderWriter + enc DSReaderWriter head DSReaderWriter peer DSBatching system DSReaderWriter @@ -43,6 +45,7 @@ func MultiStoreFrom(rootstore ds.Datastore) MultiStore { ms := &multistore{ root: rootRW, data: prefix(rootRW, dataStoreKey), + enc: prefix(rootRW, encStoreKey), head: prefix(rootRW, headStoreKey), peer: namespace.Wrap(rootstore, peerStoreKey), system: prefix(rootRW, systemStoreKey), @@ -57,6 +60,11 @@ func (ms multistore) Datastore() DSReaderWriter { return ms.data } +// Encstore implements MultiStore. +func (ms multistore) Encstore() DSReaderWriter { + return ms.enc +} + // Headstore implements MultiStore. func (ms multistore) Headstore() DSReaderWriter { return ms.head diff --git a/datastore/store.go b/datastore/store.go index 66501270d1..516bfe0b65 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -34,26 +34,26 @@ type Rootstore interface { type MultiStore interface { Rootstore() DSReaderWriter - // Datastore is a wrapped root DSReaderWriter - // under the /data namespace + // Datastore is a wrapped root DSReaderWriter under the /data namespace Datastore() DSReaderWriter - // Headstore is a wrapped root DSReaderWriter - // under the /head namespace + // Encstore is a wrapped root DSReaderWriter under the /enc namespace + // This store is used for storing symmetric encryption keys for doc encryption. + // The store keys are comprised of docID + field name. + Encstore() DSReaderWriter + + // Headstore is a wrapped root DSReaderWriter under the /head namespace Headstore() DSReaderWriter - // Peerstore is a wrapped root DSReaderWriter - // as a ds.Batching, embedded into a DSBatching + // Peerstore is a wrapped root DSReaderWriter as a ds.Batching, embedded into a DSBatching // under the /peers namespace Peerstore() DSBatching - // Blockstore is a wrapped root DSReaderWriter - // as a Blockstore, embedded into a Blockstore + // Blockstore is a wrapped root DSReaderWriter as a Blockstore, embedded into a Blockstore // under the /blocks namespace Blockstore() Blockstore - // Headstore is a wrapped root DSReaderWriter - // under the /system namespace + // Headstore is a wrapped root DSReaderWriter under the /system namespace Systemstore() DSReaderWriter } diff --git a/docs/website/references/cli/defradb_client_collection_create.md b/docs/website/references/cli/defradb_client_collection_create.md index 425be82753..5425e3f860 100644 --- a/docs/website/references/cli/defradb_client_collection_create.md +++ b/docs/website/references/cli/defradb_client_collection_create.md @@ -5,6 +5,15 @@ Create a new document. ### Synopsis Create a new document. + +Options: + -i, --identity + Marks the document as private and set the identity as the owner. The access to the document + and permissions are controlled by ACP (Access Control Policy). + + -e, --encrypt + Encrypt flag specified if the document needs to be encrypted. If set, DefraDB will generate a + symmetric key for encryption using AES-GCM. Example: create from string: defradb client collection create --name User '{ "name": "Bob" }' @@ -24,12 +33,13 @@ Example: create from stdin: ``` -defradb client collection create [-i --identity] [flags] +defradb client collection create [-i --identity] [-e --encrypt] [flags] ``` ### Options ``` + -e, --encrypt Flag to enable encryption of the document -f, --file string File containing document(s) -h, --help help for create ``` diff --git a/http/client.go b/http/client.go index 2082604599..6e5cc21276 100644 --- a/http/client.go +++ b/http/client.go @@ -349,12 +349,16 @@ func (c *Client) ExecRequest( result.GQL.Errors = []error{err} return result } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, methodURL.String(), bytes.NewBuffer(body)) if err != nil { result.GQL.Errors = []error{err} return result } err = c.http.setDefaultHeaders(req) + + setDocEncryptionFlagIfNeeded(ctx, req) + if err != nil { result.GQL.Errors = []error{err} return result diff --git a/http/client_collection.go b/http/client_collection.go index ee614c1dba..8df094f5fc 100644 --- a/http/client_collection.go +++ b/http/client_collection.go @@ -24,6 +24,7 @@ import ( sse "github.com/vito/go-sse/sse" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/internal/encryption" ) var _ client.Collection = (*Collection)(nil) @@ -78,6 +79,8 @@ func (c *Collection) Create( return err } + setDocEncryptionFlagIfNeeded(ctx, req) + _, err = c.http.request(req) if err != nil { return err @@ -114,6 +117,8 @@ func (c *Collection) CreateMany( return err } + setDocEncryptionFlagIfNeeded(ctx, req) + _, err = c.http.request(req) if err != nil { return err @@ -125,6 +130,15 @@ func (c *Collection) CreateMany( return nil } +func setDocEncryptionFlagIfNeeded(ctx context.Context, req *http.Request) { + encConf := encryption.GetContextConfig(ctx) + if encConf.HasValue() && encConf.Value().IsEncrypted { + q := req.URL.Query() + q.Set(docEncryptParam, "true") + req.URL.RawQuery = q.Encode() + } +} + func (c *Collection) Update( ctx context.Context, doc *client.Document, diff --git a/http/client_tx.go b/http/client_tx.go index a804b934f1..5b99f5aaad 100644 --- a/http/client_tx.go +++ b/http/client_tx.go @@ -91,6 +91,10 @@ func (c *Transaction) Datastore() datastore.DSReaderWriter { panic("client side transaction") } +func (c *Transaction) Encstore() datastore.DSReaderWriter { + panic("client side transaction") +} + func (c *Transaction) Headstore() datastore.DSReaderWriter { panic("client side transaction") } diff --git a/http/handler_collection.go b/http/handler_collection.go index 60c18b3442..412f486602 100644 --- a/http/handler_collection.go +++ b/http/handler_collection.go @@ -21,8 +21,11 @@ import ( "github.com/go-chi/chi/v5" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/internal/encryption" ) +const docEncryptParam = "encrypt" + type collectionHandler struct{} type CollectionDeleteRequest struct { @@ -43,6 +46,11 @@ func (s *collectionHandler) Create(rw http.ResponseWriter, req *http.Request) { return } + ctx := req.Context() + if req.URL.Query().Get(docEncryptParam) == "true" { + ctx = encryption.SetContextConfig(ctx, encryption.DocEncConfig{IsEncrypted: true}) + } + switch { case client.IsJSONArray(data): docList, err := client.NewDocsFromJSON(data, col.Definition()) @@ -51,7 +59,7 @@ func (s *collectionHandler) Create(rw http.ResponseWriter, req *http.Request) { return } - if err := col.CreateMany(req.Context(), docList); err != nil { + if err := col.CreateMany(ctx, docList); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } @@ -62,7 +70,7 @@ func (s *collectionHandler) Create(rw http.ResponseWriter, req *http.Request) { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - if err := col.Create(req.Context(), doc); err != nil { + if err := col.Create(ctx, doc); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } diff --git a/internal/core/block/block.go b/internal/core/block/block.go index 8482a23d91..d2caa610f7 100644 --- a/internal/core/block/block.go +++ b/internal/core/block/block.go @@ -103,6 +103,9 @@ type Block struct { Delta crdt.CRDT // Links are the links to other blocks in the DAG. Links []DAGLink + // IsEncrypted is a flag that indicates if the block's delta is encrypted. + // It needs to be a pointer so that it can be translated from and to `optional Bool` in the IPLD schema. + IsEncrypted *bool } // IPLDSchemaBytes returns the IPLD schema representation for the block. @@ -111,8 +114,9 @@ type Block struct { func (b Block) IPLDSchemaBytes() []byte { return []byte(` type Block struct { - delta CRDT - links [ DAGLink ] + delta CRDT + links [ DAGLink ] + isEncrypted optional Bool }`) } @@ -143,19 +147,9 @@ func New(delta core.Delta, links []DAGLink, heads ...cid.Cid) *Block { blockLinks = append(blockLinks, links...) - var crdtDelta crdt.CRDT - switch delta := delta.(type) { - case *crdt.LWWRegDelta: - crdtDelta = crdt.CRDT{LWWRegDelta: delta} - case *crdt.CompositeDAGDelta: - crdtDelta = crdt.CRDT{CompositeDAGDelta: delta} - case *crdt.CounterDelta: - crdtDelta = crdt.CRDT{CounterDelta: delta} - } - return &Block{ Links: blockLinks, - Delta: crdtDelta, + Delta: crdt.NewCRDT(delta), } } diff --git a/internal/core/crdt/ipld_union.go b/internal/core/crdt/ipld_union.go index 361a41b150..95023f28b2 100644 --- a/internal/core/crdt/ipld_union.go +++ b/internal/core/crdt/ipld_union.go @@ -19,6 +19,19 @@ type CRDT struct { CounterDelta *CounterDelta } +// NewCRDT returns a new CRDT. +func NewCRDT(delta core.Delta) CRDT { + switch d := delta.(type) { + case *LWWRegDelta: + return CRDT{LWWRegDelta: d} + case *CompositeDAGDelta: + return CRDT{CompositeDAGDelta: d} + case *CounterDelta: + return CRDT{CounterDelta: d} + } + return CRDT{} +} + // IPLDSchemaBytes returns the IPLD schema representation for the CRDT. // // This needs to match the [CRDT] struct or [mustSetSchema] will panic on init. @@ -96,6 +109,39 @@ func (c CRDT) GetSchemaVersionID() string { return "" } +// Clone returns a clone of the CRDT. +func (c CRDT) Clone() CRDT { + var cloned CRDT + switch { + case c.LWWRegDelta != nil: + cloned.LWWRegDelta = &LWWRegDelta{ + DocID: c.LWWRegDelta.DocID, + FieldName: c.LWWRegDelta.FieldName, + Priority: c.LWWRegDelta.Priority, + SchemaVersionID: c.LWWRegDelta.SchemaVersionID, + Data: c.LWWRegDelta.Data, + } + case c.CompositeDAGDelta != nil: + cloned.CompositeDAGDelta = &CompositeDAGDelta{ + DocID: c.CompositeDAGDelta.DocID, + FieldName: c.CompositeDAGDelta.FieldName, + Priority: c.CompositeDAGDelta.Priority, + SchemaVersionID: c.CompositeDAGDelta.SchemaVersionID, + Status: c.CompositeDAGDelta.Status, + } + case c.CounterDelta != nil: + cloned.CounterDelta = &CounterDelta{ + DocID: c.CounterDelta.DocID, + FieldName: c.CounterDelta.FieldName, + Priority: c.CounterDelta.Priority, + SchemaVersionID: c.CounterDelta.SchemaVersionID, + Nonce: c.CounterDelta.Nonce, + Data: c.CounterDelta.Data, + } + } + return cloned +} + // GetStatus returns the status of the delta. // // Currently only implemented for CompositeDAGDelta. @@ -107,15 +153,24 @@ func (c CRDT) GetStatus() uint8 { } // GetData returns the data of the delta. -// -// Currently only implemented for LWWRegDelta. func (c CRDT) GetData() []byte { if c.LWWRegDelta != nil { return c.LWWRegDelta.Data + } else if c.CounterDelta != nil { + return c.CounterDelta.Data } return nil } +// SetData sets the data of the delta. +func (c CRDT) SetData(data []byte) { + if c.LWWRegDelta != nil { + c.LWWRegDelta.Data = data + } else if c.CounterDelta != nil { + c.CounterDelta.Data = data + } +} + // IsComposite returns true if the CRDT is a composite CRDT. func (c CRDT) IsComposite() bool { return c.CompositeDAGDelta != nil diff --git a/internal/core/crdt/lwwreg_test.go b/internal/core/crdt/lwwreg_test.go index 2083a5b800..136d5cd09d 100644 --- a/internal/core/crdt/lwwreg_test.go +++ b/internal/core/crdt/lwwreg_test.go @@ -31,7 +31,7 @@ func setupLWWRegister() LWWRegister { return NewLWWRegister(store, core.CollectionSchemaVersionKey{}, key, "") } -func setupLoadedLWWRegster(ctx context.Context) LWWRegister { +func setupLoadedLWWRegister(ctx context.Context) LWWRegister { lww := setupLWWRegister() addDelta := lww.Set([]byte("test")) addDelta.SetPriority(1) @@ -73,7 +73,7 @@ func TestLWWRegisterInitialMerge(t *testing.T) { func TestLWWReisterFollowupMerge(t *testing.T) { ctx := context.Background() - lww := setupLoadedLWWRegster(ctx) + lww := setupLoadedLWWRegister(ctx) addDelta := lww.Set([]byte("test2")) addDelta.SetPriority(2) lww.Merge(ctx, addDelta) @@ -90,7 +90,7 @@ func TestLWWReisterFollowupMerge(t *testing.T) { func TestLWWRegisterOldMerge(t *testing.T) { ctx := context.Background() - lww := setupLoadedLWWRegster(ctx) + lww := setupLoadedLWWRegister(ctx) addDelta := lww.Set([]byte("test-1")) addDelta.SetPriority(0) lww.Merge(ctx, addDelta) diff --git a/internal/core/key.go b/internal/core/key.go index d087c43af8..efc2d73017 100644 --- a/internal/core/key.go +++ b/internal/core/key.go @@ -71,7 +71,7 @@ type DataStoreKey struct { CollectionRootID uint32 InstanceType InstanceType DocID string - FieldId string + FieldID string } var _ Key = (*DataStoreKey)(nil) @@ -238,7 +238,7 @@ func NewDataStoreKey(key string) (DataStoreKey, error) { dataStoreKey.InstanceType = InstanceType(elements[1]) dataStoreKey.DocID = elements[2] if numberOfElements == 4 { - dataStoreKey.FieldId = elements[3] + dataStoreKey.FieldID = elements[3] } return dataStoreKey, nil @@ -429,21 +429,21 @@ func (k DataStoreKey) WithDocID(docID string) DataStoreKey { func (k DataStoreKey) WithInstanceInfo(key DataStoreKey) DataStoreKey { newKey := k newKey.DocID = key.DocID - newKey.FieldId = key.FieldId + newKey.FieldID = key.FieldID newKey.InstanceType = key.InstanceType return newKey } func (k DataStoreKey) WithFieldId(fieldId string) DataStoreKey { newKey := k - newKey.FieldId = fieldId + newKey.FieldID = fieldId return newKey } func (k DataStoreKey) ToHeadStoreKey() HeadStoreKey { return HeadStoreKey{ DocID: k.DocID, - FieldId: k.FieldId, + FieldId: k.FieldID, } } @@ -477,8 +477,8 @@ func (k DataStoreKey) ToString() string { if k.DocID != "" { result = result + "/" + k.DocID } - if k.FieldId != "" { - result = result + "/" + k.FieldId + if k.FieldID != "" { + result = result + "/" + k.FieldID } return result @@ -495,7 +495,7 @@ func (k DataStoreKey) ToDS() ds.Key { func (k DataStoreKey) Equal(other DataStoreKey) bool { return k.CollectionRootID == other.CollectionRootID && k.DocID == other.DocID && - k.FieldId == other.FieldId && + k.FieldID == other.FieldID && k.InstanceType == other.InstanceType } @@ -769,8 +769,8 @@ func (k HeadStoreKey) ToDS() ds.Key { func (k DataStoreKey) PrefixEnd() DataStoreKey { newKey := k - if k.FieldId != "" { - newKey.FieldId = string(bytesPrefixEnd([]byte(k.FieldId))) + if k.FieldID != "" { + newKey.FieldID = string(bytesPrefixEnd([]byte(k.FieldID))) return newKey } if k.DocID != "" { @@ -789,12 +789,12 @@ func (k DataStoreKey) PrefixEnd() DataStoreKey { return newKey } -// FieldID extracts the Field Identifier from the Key. -// In a Primary index, the last key path is the FieldID. +// FieldIDAsUint extracts the Field Identifier from the Key. +// In a Primary index, the last key path is the FieldIDAsUint. // This may be different in Secondary Indexes. // An error is returned if it can't correct convert the field to a uint32. -func (k DataStoreKey) FieldID() (uint32, error) { - fieldID, err := strconv.Atoi(k.FieldId) +func (k DataStoreKey) FieldIDAsUint() (uint32, error) { + fieldID, err := strconv.Atoi(k.FieldID) if err != nil { return 0, NewErrFailedToGetFieldIdOfKey(err) } @@ -814,3 +814,34 @@ func bytesPrefixEnd(b []byte) []byte { // maximal byte string (i.e. already \xff...). return b } + +// EncStoreDocKey is a key for the encryption store. +type EncStoreDocKey struct { + DocID string + FieldID uint32 +} + +var _ Key = (*EncStoreDocKey)(nil) + +// NewEncStoreDocKey creates a new EncStoreDocKey from a docID and fieldID. +func NewEncStoreDocKey(docID string, fieldID uint32) EncStoreDocKey { + return EncStoreDocKey{ + DocID: docID, + FieldID: fieldID, + } +} + +func (k EncStoreDocKey) ToString() string { + if k.FieldID == 0 { + return k.DocID + } + return fmt.Sprintf("%s/%d", k.DocID, k.FieldID) +} + +func (k EncStoreDocKey) Bytes() []byte { + return []byte(k.ToString()) +} + +func (k EncStoreDocKey) ToDS() ds.Key { + return ds.NewKey(k.ToString()) +} diff --git a/internal/core/key_test.go b/internal/core/key_test.go index c5e34073a3..90bd122d6f 100644 --- a/internal/core/key_test.go +++ b/internal/core/key_test.go @@ -54,7 +54,7 @@ func TestNewDataStoreKey_ReturnsCollectionIdAndIndexIdAndDocIDAndFieldIdAndInsta DataStoreKey{ CollectionRootID: collectionRootID, DocID: docID, - FieldId: fieldID, + FieldID: fieldID, InstanceType: InstanceType(instanceType)}, result) assert.Equal(t, fmt.Sprintf("/%v/%s/%s/%s", collectionRootID, instanceType, docID, fieldID), resultString) diff --git a/internal/db/collection.go b/internal/db/collection.go index 7e20f0da8f..64f90960cc 100644 --- a/internal/db/collection.go +++ b/internal/db/collection.go @@ -657,7 +657,7 @@ func (c *collection) save( return cid.Undef, err } - link, _, err := merkleCRDT.Save(ctx, val) + link, _, err := merkleCRDT.Save(ctx, &merklecrdt.DocField{DocID: primaryKey.DocID, FieldValue: val}) if err != nil { return cid.Undef, err } @@ -905,7 +905,7 @@ func (c *collection) getDataStoreKeyFromDocID(docID client.DocID) core.DataStore } func (c *collection) tryGetFieldKey(primaryKey core.PrimaryDataStoreKey, fieldName string) (core.DataStoreKey, bool) { - fieldId, hasField := c.tryGetFieldID(fieldName) + fieldID, hasField := c.tryGetFieldID(fieldName) if !hasField { return core.DataStoreKey{}, false } @@ -913,7 +913,7 @@ func (c *collection) tryGetFieldKey(primaryKey core.PrimaryDataStoreKey, fieldNa return core.DataStoreKey{ CollectionRootID: c.Description().RootID, DocID: primaryKey.DocID, - FieldId: strconv.FormatUint(uint64(fieldId), 10), + FieldID: strconv.FormatUint(uint64(fieldID), 10), }, true } diff --git a/internal/db/context.go b/internal/db/context.go index 88019af323..8ad51c86ce 100644 --- a/internal/db/context.go +++ b/internal/db/context.go @@ -17,6 +17,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/internal/encryption" ) // txnContextKey is the key type for transaction context values. @@ -62,6 +63,7 @@ func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (con if err != nil { return nil, txn, err } + ctx = encryption.ContextWithStore(ctx, txn) return SetContextTxn(ctx, txn), txn, nil } @@ -87,9 +89,6 @@ func SetContextTxn(ctx context.Context, txn datastore.Txn) context.Context { return context.WithValue(ctx, txnContextKey{}, txn) } -// TryGetContextTxn returns an identity and a bool indicating if the -// identity was retrieved from the given context. - // GetContextIdentity returns the identity from the given context. // // If an identity does not exist `NoIdentity` is returned. diff --git a/internal/db/fetcher/fetcher.go b/internal/db/fetcher/fetcher.go index bfaed9d871..06e3255e8c 100644 --- a/internal/db/fetcher/fetcher.go +++ b/internal/db/fetcher/fetcher.go @@ -351,7 +351,7 @@ func (df *DocumentFetcher) nextKey(ctx context.Context, seekNext bool) (spanDone if seekNext { curKey := df.kv.Key - curKey.FieldId = "" // clear field so prefixEnd applies to docID + curKey.FieldID = "" // clear field so prefixEnd applies to docID seekKey := curKey.PrefixEnd().ToString() spanDone, df.kv, err = df.seekKV(seekKey) // handle any internal errors @@ -504,7 +504,7 @@ func (df *DocumentFetcher) processKV(kv *keyValue) error { } } - if kv.Key.FieldId == core.DATASTORE_DOC_VERSION_FIELD_ID { + if kv.Key.FieldID == core.DATASTORE_DOC_VERSION_FIELD_ID { df.doc.schemaVersionID = string(kv.Value) return nil } @@ -515,7 +515,7 @@ func (df *DocumentFetcher) processKV(kv *keyValue) error { } // extract the FieldID and update the encoded doc properties map - fieldID, err := kv.Key.FieldID() + fieldID, err := kv.Key.FieldIDAsUint() if err != nil { return err } diff --git a/internal/db/fetcher/versioned.go b/internal/db/fetcher/versioned.go index 892b84e329..0ff58c4eeb 100644 --- a/internal/db/fetcher/versioned.go +++ b/internal/db/fetcher/versioned.go @@ -415,7 +415,7 @@ func (vf *VersionedFetcher) processBlock( vf.mCRDTs[crdtIndex] = mcrdt } - err = mcrdt.Clock().ProcessBlock(vf.ctx, block, blockLink) + err = mcrdt.Clock().ProcessBlock(vf.ctx, block, blockLink, false) return err } diff --git a/internal/db/merge.go b/internal/db/merge.go index bbfedd98d8..e588cb60a4 100644 --- a/internal/db/merge.go +++ b/internal/db/merge.go @@ -227,11 +227,15 @@ func (mp *mergeProcessor) loadComposites( func (mp *mergeProcessor) mergeComposites(ctx context.Context) error { for e := mp.composites.Front(); e != nil; e = e.Next() { block := e.Value.(*coreblock.Block) + var onlyHeads bool + if block.IsEncrypted != nil && *block.IsEncrypted { + onlyHeads = true + } link, err := block.GenerateLink() if err != nil { return err } - err = mp.processBlock(ctx, block, link) + err = mp.processBlock(ctx, block, link, onlyHeads) if err != nil { return err } @@ -240,10 +244,12 @@ func (mp *mergeProcessor) mergeComposites(ctx context.Context) error { } // processBlock merges the block and its children to the datastore and sets the head accordingly. +// If onlyHeads is true, it will skip merging and update only the heads. func (mp *mergeProcessor) processBlock( ctx context.Context, block *coreblock.Block, blockLink cidlink.Link, + onlyHeads bool, ) error { crdt, err := mp.initCRDTForType(block.Delta.GetFieldName()) if err != nil { @@ -256,7 +262,7 @@ func (mp *mergeProcessor) processBlock( return nil } - err = crdt.Clock().ProcessBlock(ctx, block, blockLink) + err = crdt.Clock().ProcessBlock(ctx, block, blockLink, onlyHeads) if err != nil { return err } @@ -276,7 +282,7 @@ func (mp *mergeProcessor) processBlock( return err } - if err := mp.processBlock(ctx, childBlock, link.Link); err != nil { + if err := mp.processBlock(ctx, childBlock, link.Link, onlyHeads); err != nil { return err } } diff --git a/internal/encryption/aes.go b/internal/encryption/aes.go new file mode 100644 index 0000000000..e3a7feb563 --- /dev/null +++ b/internal/encryption/aes.go @@ -0,0 +1,79 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" +) + +// EncryptAES encrypts data using AES-GCM with a provided key. +func EncryptAES(plainText, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + nonce, err := generateNonceFunc() + if err != nil { + return nil, err + } + + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + cipherText := aesGCM.Seal(nonce, nonce, plainText, nil) + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(cipherText))) + base64.StdEncoding.Encode(buf, cipherText) + + return buf, nil +} + +// DecryptAES decrypts AES-GCM encrypted data with a provided key. +func DecryptAES(cipherTextBase64, key []byte) ([]byte, error) { + cipherText := make([]byte, base64.StdEncoding.DecodedLen(len(cipherTextBase64))) + n, err := base64.StdEncoding.Decode(cipherText, []byte(cipherTextBase64)) + + if err != nil { + return nil, err + } + + cipherText = cipherText[:n] + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + if len(cipherText) < nonceLength { + return nil, fmt.Errorf("cipherText too short") + } + + nonce := cipherText[:nonceLength] + cipherText = cipherText[nonceLength:] + + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + plainText, err := aesGCM.Open(nil, nonce, cipherText, nil) + if err != nil { + return nil, err + } + + return plainText, nil +} diff --git a/internal/encryption/config.go b/internal/encryption/config.go new file mode 100644 index 0000000000..ddb4a3815a --- /dev/null +++ b/internal/encryption/config.go @@ -0,0 +1,16 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +// DocEncConfig is the configuration for document encryption. +type DocEncConfig struct { + IsEncrypted bool +} diff --git a/internal/encryption/context.go b/internal/encryption/context.go new file mode 100644 index 0000000000..10a03c89c1 --- /dev/null +++ b/internal/encryption/context.go @@ -0,0 +1,73 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "context" + + "github.com/sourcenetwork/immutable" + + "github.com/sourcenetwork/defradb/datastore" +) + +// docEncContextKey is the key type for document encryption context values. +type docEncContextKey struct{} + +// configContextKey is the key type for encryption context values. +type configContextKey struct{} + +// TryGetContextDocEnc returns a document encryption and a bool indicating if +// it was retrieved from the given context. +func TryGetContextEncryptor(ctx context.Context) (*DocEncryptor, bool) { + enc, ok := ctx.Value(docEncContextKey{}).(*DocEncryptor) + if ok { + checkKeyGenerationFlag(ctx, enc) + } + return enc, ok +} + +func checkKeyGenerationFlag(ctx context.Context, enc *DocEncryptor) { + encConfig := GetContextConfig(ctx) + if encConfig.HasValue() && encConfig.Value().IsEncrypted { + enc.EnableKeyGeneration() + } +} + +func ensureContextWithDocEnc(ctx context.Context) (context.Context, *DocEncryptor) { + enc, ok := TryGetContextEncryptor(ctx) + if !ok { + enc = newDocEncryptor(ctx) + ctx = context.WithValue(ctx, docEncContextKey{}, enc) + } + return ctx, enc +} + +// ContextWithStore sets the store on the doc encryptor in the context. +// If the doc encryptor is not present, it will be created. +func ContextWithStore(ctx context.Context, txn datastore.Txn) context.Context { + ctx, encryptor := ensureContextWithDocEnc(ctx) + encryptor.SetStore(txn.Encstore()) + return ctx +} + +// GetContextConfig returns the doc encryption config from the given context. +func GetContextConfig(ctx context.Context) immutable.Option[DocEncConfig] { + encConfig, ok := ctx.Value(configContextKey{}).(DocEncConfig) + if ok { + return immutable.Some(encConfig) + } + return immutable.None[DocEncConfig]() +} + +// SetContextConfig returns a new context with the doc encryption config set. +func SetContextConfig(ctx context.Context, encConfig DocEncConfig) context.Context { + return context.WithValue(ctx, configContextKey{}, encConfig) +} diff --git a/internal/encryption/encryptor.go b/internal/encryption/encryptor.go new file mode 100644 index 0000000000..596e9f9903 --- /dev/null +++ b/internal/encryption/encryptor.go @@ -0,0 +1,127 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "context" + "crypto/rand" + "errors" + "io" + + ds "github.com/ipfs/go-datastore" + + "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/internal/core" +) + +var generateEncryptionKeyFunc = generateEncryptionKey + +const keyLength = 32 // 32 bytes for AES-256 + +const testEncryptionKey = "examplekey1234567890examplekey12" + +// generateEncryptionKey generates a random AES key. +func generateEncryptionKey() ([]byte, error) { + key := make([]byte, keyLength) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return nil, err + } + return key, nil +} + +// generateTestEncryptionKey generates a deterministic encryption key for testing. +func generateTestEncryptionKey() ([]byte, error) { + return []byte(testEncryptionKey), nil +} + +type DocEncryptor struct { + shouldGenerateKey bool + ctx context.Context + store datastore.DSReaderWriter +} + +func newDocEncryptor(ctx context.Context) *DocEncryptor { + return &DocEncryptor{ctx: ctx} +} + +func (d *DocEncryptor) EnableKeyGeneration() { + d.shouldGenerateKey = true +} + +func (d *DocEncryptor) SetStore(store datastore.DSReaderWriter) { + d.store = store +} + +func (d *DocEncryptor) Encrypt(docID string, fieldID uint32, plainText []byte) ([]byte, error) { + encryptionKey, storeKey, err := d.fetchEncryptionKey(docID, fieldID) + if err != nil { + return nil, err + } + + if len(encryptionKey) == 0 { + if !d.shouldGenerateKey { + return plainText, nil + } + + encryptionKey, err = generateEncryptionKeyFunc() + if err != nil { + return nil, err + } + + err = d.store.Put(d.ctx, storeKey.ToDS(), encryptionKey) + if err != nil { + return nil, err + } + } + return EncryptAES(plainText, encryptionKey) +} + +func (d *DocEncryptor) Decrypt(docID string, fieldID uint32, cipherText []byte) ([]byte, error) { + encKey, _, err := d.fetchEncryptionKey(docID, fieldID) + if err != nil { + return nil, err + } + if len(encKey) == 0 { + return nil, nil + } + return DecryptAES(cipherText, encKey) +} + +// fetchEncryptionKey fetches the encryption key for the given docID and fieldID. +// If the key is not found, it returns an empty key. +func (d *DocEncryptor) fetchEncryptionKey(docID string, fieldID uint32) ([]byte, core.EncStoreDocKey, error) { + storeKey := core.NewEncStoreDocKey(docID, fieldID) + if d.store == nil { + return nil, core.EncStoreDocKey{}, ErrNoStorageProvided + } + encryptionKey, err := d.store.Get(d.ctx, storeKey.ToDS()) + isNotFound := errors.Is(err, ds.ErrNotFound) + if err != nil && !isNotFound { + return nil, core.EncStoreDocKey{}, err + } + return encryptionKey, storeKey, nil +} + +func EncryptDoc(ctx context.Context, docID string, fieldID uint32, plainText []byte) ([]byte, error) { + enc, ok := TryGetContextEncryptor(ctx) + if !ok { + return nil, nil + } + return enc.Encrypt(docID, fieldID, plainText) +} + +func DecryptDoc(ctx context.Context, docID string, fieldID uint32, cipherText []byte) ([]byte, error) { + enc, ok := TryGetContextEncryptor(ctx) + if !ok { + return nil, nil + } + return enc.Decrypt(docID, fieldID, cipherText) +} diff --git a/internal/encryption/encryptor_test.go b/internal/encryption/encryptor_test.go new file mode 100644 index 0000000000..10abd1f062 --- /dev/null +++ b/internal/encryption/encryptor_test.go @@ -0,0 +1,175 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "context" + "errors" + "testing" + + ds "github.com/ipfs/go-datastore" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/sourcenetwork/defradb/datastore/mocks" + "github.com/sourcenetwork/defradb/internal/core" +) + +var testErr = errors.New("test error") + +var docID = "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3" + +func getPlainText() []byte { + return []byte("test") +} + +func getCipherText(t *testing.T) []byte { + cipherText, err := EncryptAES(getPlainText(), []byte(testEncryptionKey)) + assert.NoError(t, err) + return cipherText +} + +func TestEncryptorEncrypt_IfStorageReturnsError_Error(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, testErr) + + _, err := enc.Encrypt(docID, 0, []byte("test")) + + assert.ErrorIs(t, err, testErr) +} + +func TestEncryptorEncrypt_IfNoKeyFoundInStorage_ShouldGenerateKeyStoreItAndReturnCipherText(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.EnableKeyGeneration() + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound) + + storeKey := core.NewEncStoreDocKey(docID, 0) + + st.EXPECT().Put(mock.Anything, storeKey.ToDS(), []byte(testEncryptionKey)).Return(nil) + + cipherText, err := enc.Encrypt(docID, 0, getPlainText()) + + assert.NoError(t, err) + assert.Equal(t, getCipherText(t), cipherText) +} + +func TestEncryptorEncrypt_IfKeyFoundInStorage_ShouldUseItToReturnCipherText(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.EnableKeyGeneration() + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return([]byte(testEncryptionKey), nil) + + cipherText, err := enc.Encrypt(docID, 0, getPlainText()) + + assert.NoError(t, err) + assert.Equal(t, getCipherText(t), cipherText) +} + +func TestEncryptorEncrypt_IfStorageFailsToStoreEncryptionKey_ReturnError(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.EnableKeyGeneration() + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound) + + st.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(testErr) + + _, err := enc.Encrypt(docID, 0, getPlainText()) + + assert.ErrorIs(t, err, testErr) +} + +func TestEncryptorEncrypt_IfKeyGenerationIsNotEnabled_ShouldReturnPlainText(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + // we don call enc.EnableKeyGeneration() + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound) + + cipherText, err := enc.Encrypt(docID, 0, getPlainText()) + + assert.NoError(t, err) + assert.Equal(t, getPlainText(), cipherText) +} + +func TestEncryptorEncrypt_IfNoStorageProvided_Error(t *testing.T) { + enc := newDocEncryptor(context.Background()) + enc.EnableKeyGeneration() + // we don call enc.SetStore(st) + + _, err := enc.Encrypt(docID, 0, getPlainText()) + + assert.ErrorIs(t, err, ErrNoStorageProvided) +} + +func TestEncryptorDecrypt_IfNoStorageProvided_Error(t *testing.T) { + enc := newDocEncryptor(context.Background()) + enc.EnableKeyGeneration() + // we don call enc.SetStore(st) + + _, err := enc.Decrypt(docID, 0, getPlainText()) + + assert.ErrorIs(t, err, ErrNoStorageProvided) +} + +func TestEncryptorDecrypt_IfStorageReturnsError_Error(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, testErr) + + _, err := enc.Decrypt(docID, 0, []byte("test")) + + assert.ErrorIs(t, err, testErr) +} + +func TestEncryptorDecrypt_IfKeyFoundInStorage_ShouldUseItToReturnPlainText(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.EnableKeyGeneration() + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return([]byte(testEncryptionKey), nil) + + plainText, err := enc.Decrypt(docID, 0, getCipherText(t)) + + assert.NoError(t, err) + assert.Equal(t, getPlainText(), plainText) +} + +func TestEncryptorDecrypt_IfNoKeyFoundInStorage_ShouldGenerateKeyStoreItAndReturnCipherText(t *testing.T) { + enc := newDocEncryptor(context.Background()) + st := mocks.NewDSReaderWriter(t) + enc.EnableKeyGeneration() + enc.SetStore(st) + + st.EXPECT().Get(mock.Anything, mock.Anything).Return(nil, ds.ErrNotFound) + + storeKey := core.NewEncStoreDocKey(docID, 0) + + st.EXPECT().Put(mock.Anything, storeKey.ToDS(), []byte(testEncryptionKey)).Return(nil) + + cipherText, err := enc.Encrypt(docID, 0, getPlainText()) + + assert.NoError(t, err) + assert.Equal(t, getCipherText(t), cipherText) +} diff --git a/internal/encryption/errors.go b/internal/encryption/errors.go new file mode 100644 index 0000000000..6a443ad834 --- /dev/null +++ b/internal/encryption/errors.go @@ -0,0 +1,23 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "github.com/sourcenetwork/defradb/errors" +) + +const ( + errNoStorageProvided string = "no storage provided" +) + +var ( + ErrNoStorageProvided = errors.New(errNoStorageProvided) +) diff --git a/internal/encryption/nonce.go b/internal/encryption/nonce.go new file mode 100644 index 0000000000..67a5467a4e --- /dev/null +++ b/internal/encryption/nonce.go @@ -0,0 +1,53 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "crypto/rand" + "errors" + "io" + "os" + "strings" +) + +const nonceLength = 12 + +var generateNonceFunc = generateNonce + +func generateNonce() ([]byte, error) { + nonce := make([]byte, nonceLength) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return nonce, nil +} + +// generateTestNonce generates a deterministic nonce for testing. +func generateTestNonce() ([]byte, error) { + nonce := []byte("deterministic nonce for testing") + + if len(nonce) < nonceLength { + return nil, errors.New("nonce length is longer than available deterministic nonce") + } + + return nonce[:nonceLength], nil +} + +func init() { + arg := os.Args[0] + // If the binary is a test binary, use a deterministic nonce. + // TODO: We should try to find a better way to detect this https://github.com/sourcenetwork/defradb/issues/2801 + if strings.HasSuffix(arg, ".test") || strings.Contains(arg, "/defradb/tests/") { + generateNonceFunc = generateTestNonce + generateEncryptionKeyFunc = generateTestEncryptionKey + } +} diff --git a/internal/merkle/clock/clock.go b/internal/merkle/clock/clock.go index d16d1d6a5b..1cb79ed756 100644 --- a/internal/merkle/clock/clock.go +++ b/internal/merkle/clock/clock.go @@ -16,6 +16,7 @@ package clock import ( "context" + cid "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime/linking" cidlink "github.com/ipld/go-ipld-prime/linking/cid" @@ -24,6 +25,7 @@ import ( "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/internal/core" coreblock "github.com/sourcenetwork/defradb/internal/core/block" + "github.com/sourcenetwork/defradb/internal/encryption" ) var ( @@ -34,9 +36,8 @@ var ( type MerkleClock struct { headstore datastore.DSReaderWriter blockstore datastore.Blockstore - // dagSyncer - headset *heads - crdt core.ReplicatedData + headset *heads + crdt core.ReplicatedData } // NewMerkleClock returns a new MerkleClock. @@ -86,7 +87,24 @@ func (mc *MerkleClock) AddDelta( block := coreblock.New(delta, links, heads...) // Write the new block to the dag store. - link, err := mc.putBlock(ctx, block) + isEncrypted, err := mc.checkIfBlockEncryptionEnabled(ctx, heads) + if err != nil { + return cidlink.Link{}, nil, err + } + + dagBlock := block + if isEncrypted { + if !block.Delta.IsComposite() { + dagBlock, err = encryptBlock(ctx, block) + if err != nil { + return cidlink.Link{}, nil, err + } + } else { + dagBlock.IsEncrypted = &isEncrypted + } + } + + link, err := mc.putBlock(ctx, dagBlock) if err != nil { return cidlink.Link{}, nil, err } @@ -96,12 +114,13 @@ func (mc *MerkleClock) AddDelta( ctx, block, link, + false, ) if err != nil { return cidlink.Link{}, nil, err } - b, err := block.Marshal() + b, err := dagBlock.Marshal() if err != nil { return cidlink.Link{}, nil, err } @@ -109,19 +128,68 @@ func (mc *MerkleClock) AddDelta( return link, b, err } +func (mc *MerkleClock) checkIfBlockEncryptionEnabled( + ctx context.Context, + heads []cid.Cid, +) (bool, error) { + encConf := encryption.GetContextConfig(ctx) + if encConf.HasValue() && encConf.Value().IsEncrypted { + return true, nil + } + + for _, headCid := range heads { + bytes, err := mc.blockstore.AsIPLDStorage().Get(ctx, headCid.KeyString()) + if err != nil { + return false, NewErrCouldNotFindBlock(headCid, err) + } + prevBlock, err := coreblock.GetFromBytes(bytes) + if err != nil { + return false, err + } + if prevBlock.IsEncrypted != nil && *prevBlock.IsEncrypted { + return true, nil + } + } + + return false, nil +} + +func encryptBlock(ctx context.Context, block *coreblock.Block) (*coreblock.Block, error) { + clonedCRDT := block.Delta.Clone() + bytes, err := encryption.EncryptDoc(ctx, string(clonedCRDT.GetDocID()), 0, clonedCRDT.GetData()) + if err != nil { + return nil, err + } + clonedCRDT.SetData(bytes) + isEncrypted := true + return &coreblock.Block{Delta: clonedCRDT, Links: block.Links, IsEncrypted: &isEncrypted}, nil +} + // ProcessBlock merges the delta CRDT and updates the state accordingly. +// If onlyHeads is true, it will skip merging and update only the heads. func (mc *MerkleClock) ProcessBlock( ctx context.Context, block *coreblock.Block, blockLink cidlink.Link, + onlyHeads bool, ) error { - priority := block.Delta.GetPriority() - - err := mc.crdt.Merge(ctx, block.Delta.GetDelta()) - if err != nil { - return NewErrMergingDelta(blockLink.Cid, err) + if !onlyHeads { + err := mc.crdt.Merge(ctx, block.Delta.GetDelta()) + if err != nil { + return NewErrMergingDelta(blockLink.Cid, err) + } } + return mc.updateHeads(ctx, block, blockLink) +} + +func (mc *MerkleClock) updateHeads( + ctx context.Context, + block *coreblock.Block, + blockLink cidlink.Link, +) error { + priority := block.Delta.GetPriority() + // check if we have any HEAD links hasHeads := false for _, l := range block.Links { diff --git a/internal/merkle/crdt/counter.go b/internal/merkle/crdt/counter.go index 2553dcfd2f..1ff6874b08 100644 --- a/internal/merkle/crdt/counter.go +++ b/internal/merkle/crdt/counter.go @@ -49,11 +49,11 @@ func NewMerkleCounter( // Save the value of the Counter to the DAG. func (mc *MerkleCounter) Save(ctx context.Context, data any) (cidlink.Link, []byte, error) { - value, ok := data.(*client.FieldValue) + value, ok := data.(*DocField) if !ok { return cidlink.Link{}, nil, NewErrUnexpectedValueType(mc.reg.CType(), &client.FieldValue{}, data) } - bytes, err := value.Bytes() + bytes, err := value.FieldValue.Bytes() if err != nil { return cidlink.Link{}, nil, err } diff --git a/internal/merkle/crdt/errors.go b/internal/merkle/crdt/errors.go index 9e828df5dc..58ee8b6bc4 100644 --- a/internal/merkle/crdt/errors.go +++ b/internal/merkle/crdt/errors.go @@ -1,4 +1,4 @@ -// Copyright 2023 Democratized Data Foundation +// Copyright 2024 Democratized Data Foundation // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt. diff --git a/internal/merkle/crdt/field.go b/internal/merkle/crdt/field.go new file mode 100644 index 0000000000..6426165f49 --- /dev/null +++ b/internal/merkle/crdt/field.go @@ -0,0 +1,22 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package merklecrdt + +import "github.com/sourcenetwork/defradb/client" + +// DocField is a struct that holds the document ID and the field value. +// This is used to have a link between the document and the field value. +// For example, to check if the field value needs to be encrypted depending on the document-level +// encryption is enabled or not. +type DocField struct { + DocID string + FieldValue *client.FieldValue +} diff --git a/internal/merkle/crdt/lwwreg.go b/internal/merkle/crdt/lwwreg.go index b8132ccad5..11e73089bf 100644 --- a/internal/merkle/crdt/lwwreg.go +++ b/internal/merkle/crdt/lwwreg.go @@ -47,14 +47,15 @@ func NewMerkleLWWRegister( // Save the value of the register to the DAG. func (mlwwreg *MerkleLWWRegister) Save(ctx context.Context, data any) (cidlink.Link, []byte, error) { - value, ok := data.(*client.FieldValue) + value, ok := data.(*DocField) if !ok { return cidlink.Link{}, nil, NewErrUnexpectedValueType(client.LWW_REGISTER, &client.FieldValue{}, data) } - bytes, err := value.Bytes() + bytes, err := value.FieldValue.Bytes() if err != nil { return cidlink.Link{}, nil, err } + // Set() call on underlying LWWRegister CRDT // persist/publish delta delta := mlwwreg.reg.Set(bytes) diff --git a/internal/merkle/crdt/merklecrdt.go b/internal/merkle/crdt/merklecrdt.go index fc3019b05c..c7733be778 100644 --- a/internal/merkle/crdt/merklecrdt.go +++ b/internal/merkle/crdt/merklecrdt.go @@ -47,7 +47,10 @@ type MerkleClock interface { delta core.Delta, links ...coreblock.DAGLink, ) (cidlink.Link, []byte, error) - ProcessBlock(context.Context, *coreblock.Block, cidlink.Link) error + // ProcessBlock processes a block and updates the CRDT state. + // The bool argument indicates whether only heads need to be updated. It is needed in case + // merge should be skipped for example if the block is encrypted. + ProcessBlock(context.Context, *coreblock.Block, cidlink.Link, bool) error } // baseMerkleCRDT handles the MerkleCRDT overhead functions that aren't CRDT specific like the mutations and state diff --git a/internal/planner/commit.go b/internal/planner/commit.go index 3a5bec39f9..bbb5fdc09c 100644 --- a/internal/planner/commit.go +++ b/internal/planner/commit.go @@ -112,7 +112,7 @@ func (n *dagScanNode) Spans(spans core.Spans) { } for i, span := range headSetSpans.Value { - if span.Start().FieldId != fieldId { + if span.Start().FieldID != fieldId { headSetSpans.Value[i] = core.NewSpan(span.Start().WithFieldId(fieldId), core.DataStoreKey{}) } } diff --git a/internal/planner/create.go b/internal/planner/create.go index 21a36fcc24..b03f2c1765 100644 --- a/internal/planner/create.go +++ b/internal/planner/create.go @@ -15,6 +15,7 @@ import ( "github.com/sourcenetwork/defradb/client/request" "github.com/sourcenetwork/defradb/internal/core" "github.com/sourcenetwork/defradb/internal/db/base" + "github.com/sourcenetwork/defradb/internal/encryption" "github.com/sourcenetwork/defradb/internal/planner/mapper" ) @@ -36,13 +37,12 @@ type createNode struct { collection client.Collection // input map of fields and values - input map[string]any - doc *client.Document + input []map[string]any + docs []*client.Document - err error + didCreate bool - returned bool - results planNode + results planNode execInfo createExecInfo } @@ -56,76 +56,63 @@ func (n *createNode) Kind() string { return "createNode" } func (n *createNode) Init() error { return nil } -func (n *createNode) Start() error { - doc, err := client.NewDocFromMap(n.input, n.collection.Definition()) - if err != nil { - n.err = err - return err +func docIDsToSpans(ids []string, desc client.CollectionDescription) core.Spans { + spans := make([]core.Span, len(ids)) + for i, id := range ids { + docID := base.MakeDataStoreKeyWithCollectionAndDocID(desc, id) + spans[i] = core.NewSpan(docID, docID.PrefixEnd()) } - n.doc = doc - return nil + return core.NewSpans(spans...) } -// Next only returns once. -func (n *createNode) Next() (bool, error) { - n.execInfo.iterations++ - - if n.err != nil { - return false, n.err +func documentsToDocIDs(docs []*client.Document) []string { + docIDs := make([]string, len(docs)) + for i, doc := range docs { + docIDs[i] = doc.ID().String() } + return docIDs +} - if n.returned { - return false, nil - } +func (n *createNode) Start() error { + n.docs = make([]*client.Document, len(n.input)) - if err := n.collection.Create( - n.p.ctx, - n.doc, - ); err != nil { - return false, err + for i, input := range n.input { + doc, err := client.NewDocFromMap(input, n.collection.Definition()) + if err != nil { + return err + } + n.docs[i] = doc } - currentValue := n.documentMapping.NewDoc() + return nil +} - currentValue.SetID(n.doc.ID().String()) - for i, value := range n.doc.Values() { - if len(n.documentMapping.IndexesByName[i.Name()]) > 0 { - n.documentMapping.SetFirstOfName(¤tValue, i.Name(), value.Value()) - } else if aliasName := i.Name() + request.RelatedObjectID; len(n.documentMapping.IndexesByName[aliasName]) > 0 { - n.documentMapping.SetFirstOfName(¤tValue, aliasName, value.Value()) - } else { - return false, client.NewErrFieldNotExist(i.Name()) - } - } +func (n *createNode) Next() (bool, error) { + n.execInfo.iterations++ - n.returned = true - n.currentValue = currentValue + if !n.didCreate { + err := n.collection.CreateMany(n.p.ctx, n.docs) + if err != nil { + return false, err + } - desc := n.collection.Description() - docID := base.MakeDataStoreKeyWithCollectionAndDocID(desc, currentValue.GetID()) - n.results.Spans(core.NewSpans(core.NewSpan(docID, docID.PrefixEnd()))) + n.results.Spans(docIDsToSpans(documentsToDocIDs(n.docs), n.collection.Description())) - err := n.results.Init() - if err != nil { - return false, err - } + err = n.results.Init() + if err != nil { + return false, err + } - err = n.results.Start() - if err != nil { - return false, err + err = n.results.Start() + if err != nil { + return false, err + } + n.didCreate = true } - // get the next result based on our point lookup next, err := n.results.Next() - if err != nil { - return false, err - } - if !next { - return false, nil - } - n.currentValue = n.results.Value() - return true, nil + return next, err } func (n *createNode) Spans(spans core.Spans) { /* no-op */ } @@ -155,7 +142,7 @@ func (n *createNode) Explain(explainType request.ExplainType) (map[string]any, e } } -func (p *Planner) CreateDoc(parsed *mapper.Mutation) (planNode, error) { +func (p *Planner) CreateDocs(parsed *mapper.Mutation) (planNode, error) { results, err := p.Select(&parsed.Select) if err != nil { return nil, err @@ -164,10 +151,17 @@ func (p *Planner) CreateDoc(parsed *mapper.Mutation) (planNode, error) { // create a mutation createNode. create := &createNode{ p: p, - input: parsed.Input, + input: parsed.Inputs, results: results, docMapper: docMapper{parsed.DocumentMapping}, } + if parsed.Input != nil { + create.input = []map[string]any{parsed.Input} + } + + if parsed.Encrypt { + p.ctx = encryption.SetContextConfig(p.ctx, encryption.DocEncConfig{IsEncrypted: true}) + } // get collection col, err := p.db.GetCollectionByName(p.ctx, parsed.Name) diff --git a/internal/planner/explain.go b/internal/planner/explain.go index f6d3f57209..34c3b3b644 100644 --- a/internal/planner/explain.go +++ b/internal/planner/explain.go @@ -342,7 +342,7 @@ func collectExecuteExplainInfo(executedPlan planNode) (map[string]any, error) { // Note: This function only fails if the collection of the datapoints goes wrong, otherwise // even if plan execution fails this function would return the collected datapoints. func (p *Planner) executeAndExplainRequest( - ctx context.Context, + _ context.Context, plan planNode, ) ([]map[string]any, error) { executionSuccess := false diff --git a/internal/planner/mapper/mapper.go b/internal/planner/mapper/mapper.go index be52066b54..858ef0e0ae 100644 --- a/internal/planner/mapper/mapper.go +++ b/internal/planner/mapper/mapper.go @@ -1165,9 +1165,11 @@ func ToMutation(ctx context.Context, store client.Store, mutationRequest *reques } return &Mutation{ - Select: *underlyingSelect, - Type: MutationType(mutationRequest.Type), - Input: mutationRequest.Input, + Select: *underlyingSelect, + Type: MutationType(mutationRequest.Type), + Input: mutationRequest.Input, + Inputs: mutationRequest.Inputs, + Encrypt: mutationRequest.Encrypt, }, nil } diff --git a/internal/planner/mapper/mutation.go b/internal/planner/mapper/mutation.go index a38444e01c..251d01298f 100644 --- a/internal/planner/mapper/mutation.go +++ b/internal/planner/mapper/mutation.go @@ -29,16 +29,10 @@ type Mutation struct { // Input is the map of fields and values used for the mutation. Input map[string]any -} -func (m *Mutation) CloneTo(index int) Requestable { - return m.cloneTo(index) -} + // Inputs is the array of maps of fields and values used for the mutation. + Inputs []map[string]any -func (m *Mutation) cloneTo(index int) *Mutation { - return &Mutation{ - Select: *m.Select.cloneTo(index), - Type: m.Type, - Input: m.Input, - } + // Encrypt is a flag to indicate if the input data should be encrypted. + Encrypt bool } diff --git a/internal/planner/multi.go b/internal/planner/multi.go index 27d6886d7c..de220c43e5 100644 --- a/internal/planner/multi.go +++ b/internal/planner/multi.go @@ -131,7 +131,7 @@ func (p *parallelNode) Next() (bool, error) { return orNext, nil } -func (p *parallelNode) nextMerge(index int, plan planNode) (bool, error) { +func (p *parallelNode) nextMerge(_ int, plan planNode) (bool, error) { if next, err := plan.Next(); !next { return false, err } diff --git a/internal/planner/planner.go b/internal/planner/planner.go index f7a875af70..db7b0510ab 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -163,7 +163,7 @@ func (p *Planner) newPlan(stmt any) (planNode, error) { func (p *Planner) newObjectMutationPlan(stmt *mapper.Mutation) (planNode, error) { switch stmt.Type { case mapper.CreateObjects: - return p.CreateDoc(stmt) + return p.CreateDocs(stmt) case mapper.UpdateObjects: return p.UpdateDocs(stmt) @@ -528,7 +528,7 @@ func walkAndFindPlanType[T planNode](planNode planNode) (T, bool) { // executeRequest executes the plan graph that represents the request that was made. func (p *Planner) executeRequest( - ctx context.Context, + _ context.Context, planNode planNode, ) ([]map[string]any, error) { if err := planNode.Start(); err != nil { diff --git a/internal/request/graphql/parser/mutation.go b/internal/request/graphql/parser/mutation.go index 92071b6e93..2ed4ebf539 100644 --- a/internal/request/graphql/parser/mutation.go +++ b/internal/request/graphql/parser/mutation.go @@ -102,6 +102,18 @@ func parseMutation(schema gql.Schema, parent *gql.Object, field *ast.Field) (*re if prop == request.Input { // parse input raw := argument.Value.(*ast.ObjectValue) mut.Input = parseMutationInputObject(raw) + } else if prop == request.Inputs { + raw := argument.Value.(*ast.ListValue) + + mut.Inputs = make([]map[string]any, len(raw.Values)) + + for i, val := range raw.Values { + doc, ok := val.(*ast.ObjectValue) + if !ok { + return nil, client.NewErrUnexpectedType[*ast.ObjectValue]("doc array element", val) + } + mut.Inputs[i] = parseMutationInputObject(doc) + } } else if prop == request.FilterClause { // parse filter obj := argument.Value.(*ast.ObjectValue) filterType, ok := getArgumentType(fieldDef, request.FilterClause) @@ -128,6 +140,8 @@ func parseMutation(schema gql.Schema, parent *gql.Object, field *ast.Field) (*re ids[i] = id.Value } mut.DocIDs = immutable.Some(ids) + } else if prop == request.EncryptArgName { + mut.Encrypt = argument.Value.(*ast.BooleanValue).Value } } diff --git a/internal/request/graphql/schema/descriptions.go b/internal/request/graphql/schema/descriptions.go index 6f89932c0a..281002366a 100644 --- a/internal/request/graphql/schema/descriptions.go +++ b/internal/request/graphql/schema/descriptions.go @@ -155,5 +155,10 @@ Indicates as to whether or not this document has been deleted. ` versionFieldDescription string = ` Returns the head commit for this document. +` + + encryptArgDescription string = ` +Encrypt flag specified if the input document(s) needs to be encrypted. If set, DefraDB will generate a +symmetric key for encryption using AES-GCM. ` ) diff --git a/internal/request/graphql/schema/generate.go b/internal/request/graphql/schema/generate.go index 5fd6b5ecf6..82ff15d057 100644 --- a/internal/request/graphql/schema/generate.go +++ b/internal/request/graphql/schema/generate.go @@ -27,6 +27,12 @@ import ( // create a fully DefraDB complaint GraphQL schema using a "code-first" dynamic // approach +const ( + filterInputNameSuffix = "FilterArg" + mutationInputNameSuffix = "MutationInputArg" + mutationInputsNameSuffix = "MutationInputsArg" +) + // Generator creates all the necessary typed schema definitions from an AST Document // and adds them to the Schema via the SchemaManager type Generator struct { @@ -171,7 +177,7 @@ func (g *Generator) generate(ctx context.Context, collections []client.Collectio for name, aggregateTarget := range def.Args { expandedField := &gql.InputObjectFieldConfig{ Description: aggregateFilterArgDescription, - Type: g.manager.schema.TypeMap()[name+"FilterArg"], + Type: g.manager.schema.TypeMap()[name+filterInputNameSuffix], } aggregateTarget.Type.(*gql.InputObject).AddFieldConfig(request.FilterClause, expandedField) } @@ -308,7 +314,7 @@ func (g *Generator) createExpandedFieldAggregate( target := aggregateTarget.Name() var filterTypeName string if target == request.GroupFieldName { - filterTypeName = obj.Name() + "FilterArg" + filterTypeName = obj.Name() + filterInputNameSuffix } else { if targeted := obj.Fields()[target]; targeted != nil { if list, isList := targeted.Type.(*gql.List); isList && gql.IsLeafType(list.OfType) { @@ -319,10 +325,10 @@ func (g *Generator) createExpandedFieldAggregate( // underlying name like this if it is a nullable type. filterTypeName = fmt.Sprintf("NotNull%sFilterArg", notNull.OfType.Name()) } else { - filterTypeName = genTypeName(list.OfType, "FilterArg") + filterTypeName = genTypeName(list.OfType, filterInputNameSuffix) } } else { - filterTypeName = targeted.Type.Name() + "FilterArg" + filterTypeName = targeted.Type.Name() + filterInputNameSuffix } } else { return NewErrAggregateTargetNotFound(obj.Name(), target) @@ -353,7 +359,7 @@ func (g *Generator) createExpandedFieldSingle( Type: t, Args: gql.FieldConfigArgument{ "filter": schemaTypes.NewArgConfig( - g.manager.schema.TypeMap()[typeName+"FilterArg"], + g.manager.schema.TypeMap()[typeName+filterInputNameSuffix], singleFieldFilterArgDescription, ), }, @@ -375,7 +381,7 @@ func (g *Generator) createExpandedFieldList( request.DocIDArgName: schemaTypes.NewArgConfig(gql.String, docIDArgDescription), request.DocIDsArgName: schemaTypes.NewArgConfig(gql.NewList(gql.NewNonNull(gql.String)), docIDsArgDescription), "filter": schemaTypes.NewArgConfig( - g.manager.schema.TypeMap()[typeName+"FilterArg"], + g.manager.schema.TypeMap()[typeName+filterInputNameSuffix], listFieldFilterArgDescription, ), "groupBy": schemaTypes.NewArgConfig( @@ -540,7 +546,7 @@ func (g *Generator) buildMutationInputTypes(collections []client.CollectionDefin // will be reassigned before the thunk is run // TODO remove when Go 1.22 collection := c - mutationInputName := collection.Description.Name.Value() + "MutationInputArg" + mutationInputName := collection.Description.Name.Value() + mutationInputNameSuffix // check if mutation input type exists if _, ok := g.manager.schema.TypeMap()[mutationInputName]; ok { @@ -1027,13 +1033,14 @@ func (g *Generator) GenerateMutationInputForGQLType(obj *gql.Object) ([]*gql.Fie return nil, obj.Error() } - filterInputName := genTypeName(obj, "FilterArg") - mutationInputName := genTypeName(obj, "MutationInputArg") + filterInputName := genTypeName(obj, filterInputNameSuffix) + mutationInputName := genTypeName(obj, mutationInputNameSuffix) filterInput, ok := g.manager.schema.TypeMap()[filterInputName].(*gql.InputObject) if !ok { return nil, NewErrTypeNotFound(filterInputName) } + mutationInput, ok := g.manager.schema.TypeMap()[mutationInputName] if !ok { return nil, NewErrTypeNotFound(mutationInputName) @@ -1044,7 +1051,9 @@ func (g *Generator) GenerateMutationInputForGQLType(obj *gql.Object) ([]*gql.Fie Description: createDocumentDescription, Type: obj, Args: gql.FieldConfigArgument{ - "input": schemaTypes.NewArgConfig(mutationInput, "Create field values"), + "input": schemaTypes.NewArgConfig(mutationInput, "Create a "+obj.Name()+" document"), + "inputs": schemaTypes.NewArgConfig(gql.NewList(mutationInput), "Create "+obj.Name()+" documents"), + "encrypt": schemaTypes.NewArgConfig(gql.Boolean, encryptArgDescription), }, } @@ -1092,7 +1101,7 @@ func (g *Generator) genTypeFilterArgInput(obj *gql.Object) *gql.InputObject { var selfRefType *gql.InputObject inputCfg := gql.InputObjectConfig{ - Name: genTypeName(obj, "FilterArg"), + Name: genTypeName(obj, filterInputNameSuffix), } fieldThunk := (gql.InputObjectConfigFieldMapThunk)( func() (gql.InputObjectConfigFieldMap, error) { @@ -1136,7 +1145,7 @@ func (g *Generator) genTypeFilterArgInput(obj *gql.Object) *gql.InputObject { // We want the FilterArg for the object, not the list of objects. fieldType = l.OfType } - filterType, isFilterable := g.manager.schema.TypeMap()[genTypeName(fieldType, "FilterArg")] + filterType, isFilterable := g.manager.schema.TypeMap()[genTypeName(fieldType, filterInputNameSuffix)] if !isFilterable { filterType = &gql.InputObjectField{} } @@ -1169,7 +1178,7 @@ func (g *Generator) genLeafFilterArgInput(obj gql.Type) *gql.InputObject { } inputCfg := gql.InputObjectConfig{ - Name: fmt.Sprintf("%s%s", filterTypeName, "FilterArg"), + Name: fmt.Sprintf("%s%s", filterTypeName, filterInputNameSuffix), } var fieldThunk gql.InputObjectConfigFieldMapThunk = func() (gql.InputObjectConfigFieldMap, error) { diff --git a/tests/bench/query/planner/utils.go b/tests/bench/query/planner/utils.go index 5a842222f5..b2a6e3c0d6 100644 --- a/tests/bench/query/planner/utils.go +++ b/tests/bench/query/planner/utils.go @@ -134,6 +134,7 @@ type dummyTxn struct{} func (*dummyTxn) Rootstore() datastore.DSReaderWriter { return nil } func (*dummyTxn) Datastore() datastore.DSReaderWriter { return nil } +func (*dummyTxn) Encstore() datastore.DSReaderWriter { return nil } func (*dummyTxn) Headstore() datastore.DSReaderWriter { return nil } func (*dummyTxn) Peerstore() datastore.DSBatching { return nil } func (*dummyTxn) Blockstore() datastore.Blockstore { return nil } diff --git a/tests/clients/cli/wrapper_collection.go b/tests/clients/cli/wrapper_collection.go index 62458dae99..f26142c8e9 100644 --- a/tests/clients/cli/wrapper_collection.go +++ b/tests/clients/cli/wrapper_collection.go @@ -21,6 +21,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/http" + "github.com/sourcenetwork/defradb/internal/encryption" ) var _ client.Collection = (*Collection)(nil) @@ -65,6 +66,11 @@ func (c *Collection) Create( args := []string{"client", "collection", "create"} args = append(args, "--name", c.Description().Name.Value()) + encConf := encryption.GetContextConfig(ctx) + if encConf.HasValue() && encConf.Value().IsEncrypted { + args = append(args, "--encrypt") + } + document, err := doc.String() if err != nil { return err @@ -90,21 +96,22 @@ func (c *Collection) CreateMany( args := []string{"client", "collection", "create"} args = append(args, "--name", c.Description().Name.Value()) - docMapList := make([]map[string]any, len(docs)) + encConf := encryption.GetContextConfig(ctx) + if encConf.HasValue() && encConf.Value().IsEncrypted { + args = append(args, "--encrypt") + } + + docStrings := make([]string, len(docs)) for i, doc := range docs { - docMap, err := doc.ToMap() + docStr, err := doc.String() if err != nil { return err } - docMapList[i] = docMap + docStrings[i] = docStr } - documents, err := json.Marshal(docMapList) - if err != nil { - return err - } - args = append(args, string(documents)) + args = append(args, "["+strings.Join(docStrings, ",")+"]") - _, err = c.cmd.execute(ctx, args) + _, err := c.cmd.execute(ctx, args) if err != nil { return err } diff --git a/tests/clients/cli/wrapper_tx.go b/tests/clients/cli/wrapper_tx.go index 0330b8d47e..46aefd000d 100644 --- a/tests/clients/cli/wrapper_tx.go +++ b/tests/clients/cli/wrapper_tx.go @@ -75,6 +75,10 @@ func (w *Transaction) Datastore() datastore.DSReaderWriter { return w.tx.Datastore() } +func (w *Transaction) Encstore() datastore.DSReaderWriter { + return w.tx.Encstore() +} + func (w *Transaction) Headstore() datastore.DSReaderWriter { return w.tx.Headstore() } diff --git a/tests/clients/http/wrapper_tx.go b/tests/clients/http/wrapper_tx.go index 133d3bc1d3..e4b838a2e9 100644 --- a/tests/clients/http/wrapper_tx.go +++ b/tests/clients/http/wrapper_tx.go @@ -69,6 +69,10 @@ func (w *TxWrapper) Datastore() datastore.DSReaderWriter { return w.server.Datastore() } +func (w *TxWrapper) Encstore() datastore.DSReaderWriter { + return w.server.Encstore() +} + func (w *TxWrapper) Headstore() datastore.DSReaderWriter { return w.server.Headstore() } diff --git a/tests/integration/encryption/commit_test.go b/tests/integration/encryption/commit_test.go new file mode 100644 index 0000000000..6a94621b3a --- /dev/null +++ b/tests/integration/encryption/commit_test.go @@ -0,0 +1,356 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "testing" + + "github.com/sourcenetwork/defradb/internal/encryption" + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func encrypt(plaintext []byte) []byte { + val, _ := encryption.EncryptAES(plaintext, []byte("examplekey1234567890examplekey12")) + return val +} + +func TestDocEncryption_WithEncryptionOnLWWCRDT_ShouldStoreCommitsDeltaEncrypted(t *testing.T) { + const docID = "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3" + + test := testUtils.TestCase{ + Actions: []any{ + updateUserCollectionSchema(), + testUtils.CreateDoc{ + Doc: `{ + "name": "John", + "age": 21 + }`, + IsEncrypted: true, + }, + testUtils.Request{ + Request: ` + query { + commits { + cid + collectionID + delta + docID + fieldId + fieldName + height + links { + cid + name + } + } + } + `, + Results: []map[string]any{ + { + "cid": "bafyreih7ry7ef26xn3lm2rhxusf2rbgyvl535tltrt6ehpwtvdnhlmptiu", + "collectionID": int64(1), + "delta": encrypt(testUtils.CBORValue(21)), + "docID": docID, + "fieldId": "1", + "fieldName": "age", + "height": int64(1), + "links": []map[string]any{}, + }, + { + "cid": "bafyreifusejlwidaqswasct37eorazlfix6vyyn5af42pmjvktilzj5cty", + "collectionID": int64(1), + "delta": encrypt(testUtils.CBORValue("John")), + "docID": docID, + "fieldId": "2", + "fieldName": "name", + "height": int64(1), + "links": []map[string]any{}, + }, + { + "cid": "bafyreicvxlfxeqghmc3gy56rp5rzfejnbng4nu77x5e3wjinfydl6wvycq", + "collectionID": int64(1), + "delta": nil, + "docID": docID, + "fieldId": "C", + "fieldName": nil, + "height": int64(1), + "links": []map[string]any{ + { + "cid": "bafyreifusejlwidaqswasct37eorazlfix6vyyn5af42pmjvktilzj5cty", + "name": "name", + }, + { + "cid": "bafyreih7ry7ef26xn3lm2rhxusf2rbgyvl535tltrt6ehpwtvdnhlmptiu", + "name": "age", + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryption_UponUpdateOnLWWCRDT_ShouldEncryptCommitDelta(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + updateUserCollectionSchema(), + testUtils.CreateDoc{ + Doc: `{ + "name": "John", + "age": 21 + }`, + IsEncrypted: true, + }, + testUtils.UpdateDoc{ + Doc: `{ + "age": 22 + }`, + }, + testUtils.Request{ + Request: ` + query { + commits(fieldId: "1") { + delta + } + } + `, + Results: []map[string]any{ + { + "delta": encrypt(testUtils.CBORValue(22)), + }, + { + "delta": encrypt(testUtils.CBORValue(21)), + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryption_WithMultipleDocsUponUpdate_ShouldEncryptOnlyRelevantDocs(t *testing.T) { + const johnDocID = "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3" + const islamDocID = "bae-d55bd956-1cc4-5d26-aa71-b98807ad49d6" + + test := testUtils.TestCase{ + Actions: []any{ + updateUserCollectionSchema(), + testUtils.CreateDoc{ + Doc: `{ + "name": "John", + "age": 21 + }`, + IsEncrypted: true, + }, + testUtils.CreateDoc{ + Doc: `{ + "name": "Islam", + "age": 33 + }`, + }, + testUtils.UpdateDoc{ + DocID: 0, + Doc: `{ + "age": 22 + }`, + }, + testUtils.UpdateDoc{ + DocID: 1, + Doc: `{ + "age": 34 + }`, + }, + testUtils.Request{ + Request: ` + query { + commits(fieldId: "1") { + delta + docID + } + } + `, + Results: []map[string]any{ + { + "delta": encrypt(testUtils.CBORValue(22)), + "docID": johnDocID, + }, + { + "delta": encrypt(testUtils.CBORValue(21)), + "docID": johnDocID, + }, + { + "delta": testUtils.CBORValue(34), + "docID": islamDocID, + }, + { + "delta": testUtils.CBORValue(33), + "docID": islamDocID, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryption_WithEncryptionOnCounterCRDT_ShouldStoreCommitsDeltaEncrypted(t *testing.T) { + const docID = "bae-d3cc98b4-38d5-5c50-85a3-d3045d44094e" + + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + points: Int @crdt(type: "pcounter") + } + `}, + testUtils.CreateDoc{ + Doc: `{ "points": 5 }`, + IsEncrypted: true, + }, + testUtils.Request{ + Request: ` + query { + commits { + cid + delta + docID + } + } + `, + Results: []map[string]any{ + { + "cid": "bafyreieb6owsoljj4vondkx35ngxmhliauwvphicz4edufcy7biexij7mu", + "delta": encrypt(testUtils.CBORValue(5)), + "docID": docID, + }, + { + "cid": "bafyreif2lejhvdja2rmo237lrwpj45usrm55h6gzr4ewl6gajq3cl4ppsi", + "delta": nil, + "docID": docID, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryption_UponUpdateOnCounterCRDT_ShouldEncryptedCommitDelta(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + points: Int @crdt(type: "pcounter") + } + `}, + testUtils.CreateDoc{ + Doc: `{ "points": 5 }`, + IsEncrypted: true, + }, + testUtils.UpdateDoc{ + Doc: `{ + "points": 3 + }`, + }, + testUtils.Request{ + Request: ` + query { + commits(fieldId: "1") { + delta + } + } + `, + Results: []map[string]any{ + { + "delta": encrypt(testUtils.CBORValue(3)), + }, + { + "delta": encrypt(testUtils.CBORValue(5)), + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryption_UponEncryptionSeveralDocs_ShouldStoreAllCommitsDeltaEncrypted(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + updateUserCollectionSchema(), + testUtils.CreateDoc{ + Doc: `[{ + "name": "John", + "age": 21 + }, + { + "name": "Islam", + "age": 33 + }]`, + IsEncrypted: true, + }, + testUtils.Request{ + Request: ` + query { + commits { + cid + delta + docID + } + } + `, + Results: []map[string]any{ + { + "cid": "bafyreih7ry7ef26xn3lm2rhxusf2rbgyvl535tltrt6ehpwtvdnhlmptiu", + "delta": encrypt(testUtils.CBORValue(21)), + "docID": testUtils.NewDocIndex(0, 0), + }, + { + "cid": "bafyreifusejlwidaqswasct37eorazlfix6vyyn5af42pmjvktilzj5cty", + "delta": encrypt(testUtils.CBORValue("John")), + "docID": testUtils.NewDocIndex(0, 0), + }, + { + "cid": "bafyreicvxlfxeqghmc3gy56rp5rzfejnbng4nu77x5e3wjinfydl6wvycq", + "delta": nil, + "docID": testUtils.NewDocIndex(0, 0), + }, + { + "cid": "bafyreibe24bo67owxewoso3ekinera2bhusguij5qy2ahgyufaq3fbvaxa", + "delta": encrypt(testUtils.CBORValue(33)), + "docID": testUtils.NewDocIndex(0, 1), + }, + { + "cid": "bafyreie2fddpidgc62fhd2fjrsucq3spgh2mgvto2xwolcdmdhb5pdeok4", + "delta": encrypt(testUtils.CBORValue("Islam")), + "docID": testUtils.NewDocIndex(0, 1), + }, + { + "cid": "bafyreifulxdkf4m3wmmdxjg43l4mw7uuxl5il27eabklc22nptilrh64sa", + "delta": nil, + "docID": testUtils.NewDocIndex(0, 1), + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/encryption/peer_test.go b/tests/integration/encryption/peer_test.go new file mode 100644 index 0000000000..6d9c937278 --- /dev/null +++ b/tests/integration/encryption/peer_test.go @@ -0,0 +1,147 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "testing" + + "github.com/sourcenetwork/immutable" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestDocEncryptionPeer_IfPeerHasNoKey_ShouldNotFetch(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + updateUserCollectionSchema(), + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: `{ + "name": "John", + "age": 21 + }`, + IsEncrypted: true, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: `query { + Users { + age + } + }`, + Results: []map[string]any{}, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryptionPeer_UponSync_ShouldSyncEncryptedDAG(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.RandomNetworkingConfig(), + testUtils.RandomNetworkingConfig(), + updateUserCollectionSchema(), + testUtils.ConnectPeers{ + SourceNodeID: 1, + TargetNodeID: 0, + }, + testUtils.SubscribeToCollection{ + NodeID: 1, + CollectionIDs: []int{0}, + }, + testUtils.CreateDoc{ + NodeID: immutable.Some(0), + Doc: `{ + "name": "John", + "age": 21 + }`, + IsEncrypted: true, + }, + testUtils.WaitForSync{}, + testUtils.Request{ + NodeID: immutable.Some(1), + Request: ` + query { + commits { + cid + collectionID + delta + docID + fieldId + fieldName + height + links { + cid + name + } + } + } + `, + Results: []map[string]any{ + { + "cid": "bafyreih7ry7ef26xn3lm2rhxusf2rbgyvl535tltrt6ehpwtvdnhlmptiu", + "collectionID": int64(1), + "delta": encrypt(testUtils.CBORValue(21)), + "docID": "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3", + "fieldId": "1", + "fieldName": "age", + "height": int64(1), + "links": []map[string]any{}, + }, + { + "cid": "bafyreifusejlwidaqswasct37eorazlfix6vyyn5af42pmjvktilzj5cty", + "collectionID": int64(1), + "delta": encrypt(testUtils.CBORValue("John")), + "docID": "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3", + "fieldId": "2", + "fieldName": "name", + "height": int64(1), + "links": []map[string]any{}, + }, + { + "cid": "bafyreicvxlfxeqghmc3gy56rp5rzfejnbng4nu77x5e3wjinfydl6wvycq", + "collectionID": int64(1), + "delta": nil, + "docID": "bae-c9fb0fa4-1195-589c-aa54-e68333fb90b3", + "fieldId": "C", + "fieldName": nil, + "height": int64(1), + "links": []map[string]any{ + { + "cid": "bafyreifusejlwidaqswasct37eorazlfix6vyyn5af42pmjvktilzj5cty", + "name": "name", + }, + { + "cid": "bafyreih7ry7ef26xn3lm2rhxusf2rbgyvl535tltrt6ehpwtvdnhlmptiu", + "name": "age", + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/encryption/query_test.go b/tests/integration/encryption/query_test.go new file mode 100644 index 0000000000..32d9bd2c94 --- /dev/null +++ b/tests/integration/encryption/query_test.go @@ -0,0 +1,110 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestDocEncryption_WithEncryption_ShouldFetchDecrypted(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + name: String + age: Int + } + `}, + testUtils.CreateDoc{ + Doc: `{ + "name": "John", + "age": 21 + }`, + IsEncrypted: true, + }, + testUtils.Request{ + Request: ` + query { + Users { + _docID + name + age + } + }`, + Results: []map[string]any{ + { + "_docID": testUtils.NewDocIndex(0, 0), + "name": "John", + "age": int64(21), + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestDocEncryption_WithEncryptionOnCounterCRDT_ShouldFetchDecrypted(t *testing.T) { + const query = ` + query { + Users { + name + points + } + }` + + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + name: String + points: Int @crdt(type: "pcounter") + } + `}, + testUtils.CreateDoc{ + Doc: `{ + "name": "John", + "points": 5 + }`, + IsEncrypted: true, + }, + testUtils.Request{ + Request: query, + Results: []map[string]any{ + { + "name": "John", + "points": 5, + }, + }, + }, + testUtils.UpdateDoc{ + DocID: 0, + Doc: `{ "points": 3 }`, + }, + testUtils.Request{ + Request: query, + Results: []map[string]any{ + { + "name": "John", + "points": 8, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/encryption/utils.go b/tests/integration/encryption/utils.go new file mode 100644 index 0000000000..400a0d34c3 --- /dev/null +++ b/tests/integration/encryption/utils.go @@ -0,0 +1,31 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package encryption + +import ( + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +// we explicitly set LWW CRDT type because we want to test encryption with this specific CRDT type +// and we don't wat to rely on the default behavior +const userCollectionGQLSchema = (` + type Users { + name: String + age: Int @crdt(type: "lww") + verified: Boolean + } +`) + +func updateUserCollectionSchema() testUtils.SchemaUpdate { + return testUtils.SchemaUpdate{ + Schema: userCollectionGQLSchema, + } +} diff --git a/tests/integration/events/simple/with_update_test.go b/tests/integration/events/simple/with_update_test.go index 0b49486aa4..2d5fbfc5fe 100644 --- a/tests/integration/events/simple/with_update_test.go +++ b/tests/integration/events/simple/with_update_test.go @@ -49,17 +49,18 @@ func TestEventsSimpleWithUpdate(t *testing.T) { "Users": []func(c client.Collection){ func(c client.Collection) { err = c.Save(context.Background(), doc1) - assert.Nil(t, err) + assert.NoError(t, err) }, func(c client.Collection) { err = c.Save(context.Background(), doc2) - assert.Nil(t, err) + assert.NoError(t, err) }, func(c client.Collection) { // Update John - doc1.Set("name", "Johnnnnn") + err = doc1.Set("name", "Johnnnnn") + assert.NoError(t, err) err = c.Save(context.Background(), doc1) - assert.Nil(t, err) + assert.NoError(t, err) }, }, }, diff --git a/tests/integration/explain/default/create_test.go b/tests/integration/explain/default/create_test.go index dc57671bdd..3fdccc8b44 100644 --- a/tests/integration/explain/default/create_test.go +++ b/tests/integration/explain/default/create_test.go @@ -52,11 +52,11 @@ func TestDefaultExplainMutationRequestWithCreate(t *testing.T) { TargetNodeName: "createNode", IncludeChildNodes: false, ExpectedAttributes: dataMap{ - "input": dataMap{ + "input": []dataMap{{ "age": int32(27), "name": "Shahzad Lone", "verified": true, - }, + }}, }, }, }, @@ -90,10 +90,10 @@ func TestDefaultExplainMutationRequestDoesNotCreateDocGivenDuplicate(t *testing. TargetNodeName: "createNode", IncludeChildNodes: false, ExpectedAttributes: dataMap{ - "input": dataMap{ + "input": []dataMap{{ "age": int32(27), "name": "Shahzad Lone", - }, + }}, }, }, }, diff --git a/tests/integration/explain/execute/create_test.go b/tests/integration/explain/execute/create_test.go index 58736edb90..54876b57f5 100644 --- a/tests/integration/explain/execute/create_test.go +++ b/tests/integration/explain/execute/create_test.go @@ -42,10 +42,10 @@ func TestExecuteExplainMutationRequestWithCreate(t *testing.T) { "iterations": uint64(2), "selectTopNode": dataMap{ "selectNode": dataMap{ - "iterations": uint64(1), + "iterations": uint64(2), "filterMatches": uint64(1), "scanNode": dataMap{ - "iterations": uint64(1), + "iterations": uint64(2), "docFetches": uint64(1), "fieldFetches": uint64(1), "indexFetches": uint64(0), diff --git a/tests/integration/gql.go b/tests/integration/gql.go index 22a368adf7..1f6cd26d6e 100644 --- a/tests/integration/gql.go +++ b/tests/integration/gql.go @@ -14,15 +14,27 @@ import ( "encoding/json" "fmt" "strings" + + "github.com/sourcenetwork/defradb/client" ) -// jsonToGql transforms a json doc string to a gql string. +// jsonToGQL transforms a json doc string to a gql string. func jsonToGQL(val string) (string, error) { - var doc map[string]any - if err := json.Unmarshal([]byte(val), &doc); err != nil { - return "", err + bytes := []byte(val) + + if client.IsJSONArray(bytes) { + var doc []map[string]any + if err := json.Unmarshal(bytes, &doc); err != nil { + return "", err + } + return arrayToGQL(doc) + } else { + var doc map[string]any + if err := json.Unmarshal(bytes, &doc); err != nil { + return "", err + } + return mapToGQL(doc) } - return mapToGQL(doc) } // valueToGQL transforms a value to a gql string. @@ -41,7 +53,7 @@ func valueToGQL(val any) (string, error) { return string(out), nil } -// mapToGql transforms a map to a gql string. +// mapToGQL transforms a map to a gql string. func mapToGQL(val map[string]any) (string, error) { var entries []string for k, v := range val { @@ -66,3 +78,16 @@ func sliceToGQL(val []any) (string, error) { } return fmt.Sprintf("[%s]", strings.Join(entries, ",")), nil } + +// arrayToGQL transforms an array of maps to a gql string. +func arrayToGQL(val []map[string]any) (string, error) { + var entries []string + for _, v := range val { + out, err := mapToGQL(v) + if err != nil { + return "", err + } + entries = append(entries, out) + } + return fmt.Sprintf("[%s]", strings.Join(entries, ",")), nil +} diff --git a/tests/integration/mutation/create/simple_create_many_test.go b/tests/integration/mutation/create/simple_create_many_test.go new file mode 100644 index 0000000000..5f1e425549 --- /dev/null +++ b/tests/integration/mutation/create/simple_create_many_test.go @@ -0,0 +1,70 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package create + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestMutationCreateMany(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple create many mutation", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + name: String + age: Int + } + `, + }, + testUtils.CreateDoc{ + Doc: `[ + { + "name": "John", + "age": 27 + }, + { + "name": "Islam", + "age": 33 + } + ]`, + }, + testUtils.Request{ + Request: ` + query { + Users { + _docID + name + age + } + } + `, + Results: []map[string]any{ + { + "_docID": "bae-48339725-ed14-55b1-8e63-3fda5f590725", + "name": "Islam", + "age": int64(33), + }, + { + "_docID": "bae-8c89a573-c287-5d8c-8ba6-c47c814c594d", + "name": "John", + "age": int64(27), + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/test_case.go b/tests/integration/test_case.go index 4536c0cd0a..948ae1838e 100644 --- a/tests/integration/test_case.go +++ b/tests/integration/test_case.go @@ -230,6 +230,9 @@ type CreateDoc struct { // created document(s) will be owned by this Identity. Identity immutable.Option[acpIdentity.Identity] + // Specifies whether the document should be encrypted. + IsEncrypted bool + // The collection in which this document should be created. CollectionID int diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index fab8cc5ed9..5012642f22 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -12,6 +12,7 @@ package tests import ( "context" + "encoding/json" "fmt" "os" "reflect" @@ -28,12 +29,14 @@ import ( "github.com/stretchr/testify/require" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/client/request" "github.com/sourcenetwork/defradb/crypto" "github.com/sourcenetwork/defradb/datastore" badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/event" "github.com/sourcenetwork/defradb/internal/db" + "github.com/sourcenetwork/defradb/internal/encryption" "github.com/sourcenetwork/defradb/internal/request/graphql" "github.com/sourcenetwork/defradb/net" changeDetector "github.com/sourcenetwork/defradb/tests/change_detector" @@ -105,6 +108,7 @@ func init() { if value, ok := os.LookupEnv(skipNetworkTestsEnvName); ok { skipNetworkTests, _ = strconv.ParseBool(value) } + mutationType = GQLRequestMutationType } // AssertPanic asserts that the code inside the specified PanicTestFunc panics. @@ -882,7 +886,8 @@ func refreshDocuments( continue } - ctx := db.SetContextIdentity(s.ctx, action.Identity) + ctx := makeContextForDocCreate(s.ctx, &action) + // The document may have been mutated by other actions, so to be sure we have the latest // version without having to worry about the individual update mechanics we fetch it. doc, err = collection.Get(ctx, doc.ID(), false) @@ -1169,7 +1174,7 @@ func createDoc( substituteRelations(s, action) } - var mutation func(*state, CreateDoc, client.DB, []client.Collection) (*client.Document, error) + var mutation func(*state, CreateDoc, client.DB, []client.Collection) ([]*client.Document, error) switch mutationType { case CollectionSaveMutationType: @@ -1183,7 +1188,7 @@ func createDoc( } var expectedErrorRaised bool - var doc *client.Document + var docs []*client.Document actionNodes := getNodes(action.NodeID, s.nodes) for nodeID, collections := range getNodeCollections(action.NodeID, s.collections) { err := withRetry( @@ -1191,7 +1196,7 @@ func createDoc( nodeID, func() error { var err error - doc, err = mutation(s, action, actionNodes[nodeID], collections) + docs, err = mutation(s, action, actionNodes[nodeID], collections) return err }, ) @@ -1204,7 +1209,7 @@ func createDoc( // Expand the slice if required, so that the document can be accessed by collection index s.documents = append(s.documents, make([][]*client.Document, action.CollectionID-len(s.documents)+1)...) } - s.documents[action.CollectionID] = append(s.documents[action.CollectionID], doc) + s.documents[action.CollectionID] = append(s.documents[action.CollectionID], docs...) } func createDocViaColSave( @@ -1212,13 +1217,21 @@ func createDocViaColSave( action CreateDoc, node client.DB, collections []client.Collection, -) (*client.Document, error) { - var err error +) ([]*client.Document, error) { + var docs []*client.Document var doc *client.Document + var err error if action.DocMap != nil { doc, err = client.NewDocFromMap(action.DocMap, collections[action.CollectionID].Definition()) + docs = []*client.Document{doc} } else { - doc, err = client.NewDocFromJSON([]byte(action.Doc), collections[action.CollectionID].Definition()) + bytes := []byte(action.Doc) + if client.IsJSONArray(bytes) { + docs, err = client.NewDocsFromJSON(bytes, collections[action.CollectionID].Definition()) + } else { + doc, err = client.NewDocFromJSON(bytes, collections[action.CollectionID].Definition()) + docs = []*client.Document{doc} + } } if err != nil { return nil, err @@ -1226,10 +1239,23 @@ func createDocViaColSave( txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) - ctx := db.SetContextTxn(s.ctx, txn) - ctx = db.SetContextIdentity(ctx, action.Identity) + ctx := makeContextForDocCreate(db.SetContextTxn(s.ctx, txn), &action) - return doc, collections[action.CollectionID].Save(ctx, doc) + for _, doc := range docs { + err = collections[action.CollectionID].Save(ctx, doc) + if err != nil { + return nil, err + } + } + return docs, nil +} + +func makeContextForDocCreate(ctx context.Context, action *CreateDoc) context.Context { + ctx = db.SetContextIdentity(ctx, action.Identity) + if action.IsEncrypted { + ctx = encryption.SetContextConfig(ctx, encryption.DocEncConfig{IsEncrypted: true}) + } + return ctx } func createDocViaColCreate( @@ -1237,13 +1263,20 @@ func createDocViaColCreate( action CreateDoc, node client.DB, collections []client.Collection, -) (*client.Document, error) { - var err error +) ([]*client.Document, error) { + var docs []*client.Document var doc *client.Document + var err error if action.DocMap != nil { doc, err = client.NewDocFromMap(action.DocMap, collections[action.CollectionID].Definition()) + docs = []*client.Document{doc} } else { - doc, err = client.NewDocFromJSON([]byte(action.Doc), collections[action.CollectionID].Definition()) + if client.IsJSONArray([]byte(action.Doc)) { + docs, err = client.NewDocsFromJSON([]byte(action.Doc), collections[action.CollectionID].Definition()) + } else { + doc, err = client.NewDocFromJSON([]byte(action.Doc), collections[action.CollectionID].Definition()) + docs = []*client.Document{doc} + } } if err != nil { return nil, err @@ -1251,10 +1284,15 @@ func createDocViaColCreate( txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) - ctx := db.SetContextTxn(s.ctx, txn) - ctx = db.SetContextIdentity(ctx, action.Identity) + ctx := makeContextForDocCreate(db.SetContextTxn(s.ctx, txn), &action) - return doc, collections[action.CollectionID].Create(ctx, doc) + if len(docs) > 1 { + err = collections[action.CollectionID].CreateMany(ctx, docs) + } else { + err = collections[action.CollectionID].Create(ctx, doc) + } + + return docs, err } func createDocViaGQL( @@ -1262,37 +1300,47 @@ func createDocViaGQL( action CreateDoc, node client.DB, collections []client.Collection, -) (*client.Document, error) { +) ([]*client.Document, error) { collection := collections[action.CollectionID] - var err error var input string + paramName := request.Input + + var err error if action.DocMap != nil { input, err = valueToGQL(action.DocMap) + } else if client.IsJSONArray([]byte(action.Doc)) { + var docMaps []map[string]any + err = json.Unmarshal([]byte(action.Doc), &docMaps) + require.NoError(s.t, err) + paramName = request.Inputs + input, err = arrayToGQL(docMaps) } else { input, err = jsonToGQL(action.Doc) } require.NoError(s.t, err) - request := fmt.Sprintf( + params := paramName + ": " + input + + if action.IsEncrypted { + params = params + ", " + request.EncryptArgName + ": true" + } + + req := fmt.Sprintf( `mutation { - create_%s(input: %s) { + create_%s(%s) { _docID } }`, collection.Name().Value(), - input, + params, ) txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) - ctx := db.SetContextTxn(s.ctx, txn) - ctx = db.SetContextIdentity(ctx, action.Identity) + ctx := makeContextForDocCreate(db.SetContextTxn(s.ctx, txn), &action) - result := node.ExecRequest( - ctx, - request, - ) + result := node.ExecRequest(ctx, req) if len(result.GQL.Errors) > 0 { return nil, result.GQL.Errors[0] } @@ -1302,14 +1350,20 @@ func createDocViaGQL( return nil, nil } - docIDString := resultantDocs[0]["_docID"].(string) - docID, err := client.NewDocIDFromString(docIDString) - require.NoError(s.t, err) + docs := make([]*client.Document, len(resultantDocs)) - doc, err := collection.Get(ctx, docID, false) - require.NoError(s.t, err) + for i, docMap := range resultantDocs { + docIDString := docMap[request.DocIDFieldName].(string) + docID, err := client.NewDocIDFromString(docIDString) + require.NoError(s.t, err) + + doc, err := collection.Get(ctx, docID, false) + require.NoError(s.t, err) + + docs[i] = doc + } - return doc, nil + return docs, nil } // substituteRelations scans the fields defined in [action.DocMap], if any are of type [DocIndex]