diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1f3062f..6b5f75c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -5,29 +5,42 @@ name: fabric-chaincode-go on: workflow_dispatch: + workflow_call: pull_request: branches: - main - - release-* jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - - uses: actions/setup-go@v3 - with: - go-version: 1.21.x - - uses: actions/checkout@v2 - with: - fetch-depth: 1 - clean: true - - name: install Tools - run: | - pushd ci/tools - go install golang.org/x/lint/golint - go install golang.org/x/tools/cmd/goimports - popd - - name: Vet and lint - run: ci/scripts/lint.sh - - name: Run tests - run: go test -race ./... + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.21" + - name: install Tools + working-directory: ci/tools + run: | + go install golang.org/x/lint/golint + go install golang.org/x/tools/cmd/goimports + - name: Vet and lint + run: ci/scripts/lint.sh + - name: Run tests + run: go test -race ./... + + build-v2: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.21" + - name: Staticcheck + run: make staticcheck + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + working-directory: v2 + - name: Unit test + run: make unit-test diff --git a/.github/workflows/schedule.yml b/.github/workflows/schedule.yml new file mode 100644 index 0000000..791e377 --- /dev/null +++ b/.github/workflows/schedule.yml @@ -0,0 +1,10 @@ +name: Scheduled build + +on: + schedule: + - cron: "42 2 * * 0" + workflow_dispatch: + +jobs: + main: + uses: ./.github/workflows/build.yml diff --git a/.github/workflows/vulnerability-scan.yml b/.github/workflows/vulnerability-scan.yml new file mode 100644 index 0000000..a74cd0c --- /dev/null +++ b/.github/workflows/vulnerability-scan.yml @@ -0,0 +1,25 @@ +name: "Security vulnerability scan" + +on: + schedule: + - cron: "27 3 * * *" + workflow_dispatch: + +jobs: + scan: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + target: + - v1 + - v2 + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: stable + check-latest: true + - name: Scan + run: make scan-${{ matrix.target }} diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..300d540 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,18 @@ +# See https://golangci-lint.run/usage/configuration/ + +run: + timeout: 5m + +linters: + disable-all: true + enable: + - errcheck + - gofmt + - goimports + - gosec + - gosimple + - govet + - ineffassign + - misspell + - typecheck + - unused diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..f4913fe --- /dev/null +++ b/Makefile @@ -0,0 +1,53 @@ +# +# SPDX-License-Identifier: Apache-2.0 +# + +base_dir := $(patsubst %/,%,$(dir $(realpath $(lastword $(MAKEFILE_LIST))))) + +v2_dir := $(base_dir)/v2 + +go_bin_dir := $(shell go env GOPATH)/bin + +.PHONY: unit-test +unit-test: + cd '$(v2_dir)' && \ + go test -timeout 10s -race -coverprofile=cover.out ./... + +.PHONY: generate +generate: + go install github.com/maxbrunsfeld/counterfeiter/v6@latest + cd '$(v2_dir)' && \ + go generate ./... + +.PHONY: lint +lint: staticcheck golangci-lint + +.PHONY: staticcheck +staticcheck: + go install honnef.co/go/tools/cmd/staticcheck@latest + cd '$(v2_dir)' && \ + staticcheck -f stylish ./... + +.PHONY: install-golangci-lint +install-golangci-lint: + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b '$(go_bin_dir)' + +$(go_bin_dir)/golangci-lint: + $(MAKE) install-golangci-lint + +.PHONY: golangci-lint +golangci-lint: $(go_bin_dir)/golangci-lint + cd '$(v2_dir)' && \ + golangci-lint run + +.PHONY: scan-v2 +scan-v2: + go install golang.org/x/vuln/cmd/govulncheck@latest + cd '$(v2_dir)' && \ + govulncheck ./... + +.PHONY: scan-v1 +scan-v1: + go install golang.org/x/vuln/cmd/govulncheck@latest + cd '$(base_dir)' && \ + govulncheck ./... diff --git a/v2/.gitignore b/v2/.gitignore new file mode 100644 index 0000000..b51c70b --- /dev/null +++ b/v2/.gitignore @@ -0,0 +1 @@ +cover.out diff --git a/v2/doc.go b/v2/doc.go new file mode 100644 index 0000000..a4ca35e --- /dev/null +++ b/v2/doc.go @@ -0,0 +1,6 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package chaincode contains the code necessary for chaincode to interact +// with a Hyperledger Fabric peer. +package chaincode diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..060709b --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,22 @@ +module github.com/hyperledger/fabric-chaincode-go/v2 + +go 1.21 + +require ( + github.com/hyperledger/fabric-protos-go-apiv2 v0.3.3 + github.com/stretchr/testify v1.9.0 + google.golang.org/grpc v1.64.0 + google.golang.org/protobuf v1.34.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.23.0 // indirect + golang.org/x/sys v0.19.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..558c0f4 --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,34 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hyperledger/fabric-protos-go-apiv2 v0.3.3 h1:Xpd6fzG/KjAOHJsq7EQXY2l+qi/y8muxBaY7R6QWABk= +github.com/hyperledger/fabric-protos-go-apiv2 v0.3.3/go.mod h1:2pq0ui6ZWA0cC8J+eCErgnMDCS1kPOEYVY+06ZAK0qE= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 h1:Zy9XzmMEflZ/MAaA7vNcoebnRAld7FsPW1EeBB7V0m8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/v2/pkg/attrmgr/attrmgr.go b/v2/pkg/attrmgr/attrmgr.go new file mode 100644 index 0000000..48b686c --- /dev/null +++ b/v2/pkg/attrmgr/attrmgr.go @@ -0,0 +1,245 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package attrmgr contains utilities for managing attributes. +// Attributes are added to an X509 certificate as an extension. +package attrmgr + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/json" + "errors" + "fmt" + + "github.com/hyperledger/fabric-protos-go-apiv2/msp" + "google.golang.org/protobuf/proto" +) + +var ( + // AttrOID is the ASN.1 object identifier for an attribute extension in an + // X509 certificate + AttrOID = asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6, 7, 8, 1} + // AttrOIDString is the string version of AttrOID + AttrOIDString = "1.2.3.4.5.6.7.8.1" +) + +// Attribute is a name/value pair +type Attribute interface { + // GetName returns the name of the attribute + GetName() string + // GetValue returns the value of the attribute + GetValue() string +} + +// AttributeRequest is a request for an attribute +type AttributeRequest interface { + // GetName returns the name of an attribute + GetName() string + // IsRequired returns true if the attribute is required + IsRequired() bool +} + +// New constructs an attribute manager +func New() *Mgr { return &Mgr{} } + +// Mgr is the attribute manager and is the main object for this package +type Mgr struct{} + +// ProcessAttributeRequestsForCert add attributes to an X509 certificate, given +// attribute requests and attributes. +func (mgr *Mgr) ProcessAttributeRequestsForCert(requests []AttributeRequest, attributes []Attribute, cert *x509.Certificate) error { + attrs, err := mgr.ProcessAttributeRequests(requests, attributes) + if err != nil { + return err + } + return mgr.AddAttributesToCert(attrs, cert) +} + +// ProcessAttributeRequests takes an array of attribute requests and an identity's attributes +// and returns an Attributes object containing the requested attributes. +func (mgr *Mgr) ProcessAttributeRequests(requests []AttributeRequest, attributes []Attribute) (*Attributes, error) { + attrsMap := map[string]string{} + attrs := &Attributes{Attrs: attrsMap} + missingRequiredAttrs := []string{} + // For each of the attribute requests + for _, req := range requests { + // Get the attribute + name := req.GetName() + attr := getAttrByName(name, attributes) + if attr == nil { + if req.IsRequired() { + // Didn't find attribute and it was required; return error below + missingRequiredAttrs = append(missingRequiredAttrs, name) + } + // Skip attribute requests which aren't required + continue + } + attrsMap[name] = attr.GetValue() + } + if len(missingRequiredAttrs) > 0 { + return nil, fmt.Errorf("the following required attributes are missing: %+v", + missingRequiredAttrs) + } + return attrs, nil +} + +// AddAttributesToCert adds public attribute info to an X509 certificate. +func (mgr *Mgr) AddAttributesToCert(attrs *Attributes, cert *x509.Certificate) error { + buf, err := json.Marshal(attrs) + if err != nil { + return fmt.Errorf("failed to marshal attributes: %s", err) + } + ext := pkix.Extension{ + Id: AttrOID, + Critical: false, + Value: buf, + } + cert.Extensions = append(cert.Extensions, ext) + return nil +} + +// GetAttributesFromCert gets the attributes from a certificate. +func (mgr *Mgr) GetAttributesFromCert(cert *x509.Certificate) (*Attributes, error) { + // Get certificate attributes from the certificate if it exists + buf, err := getAttributesFromCert(cert) + if err != nil { + return nil, err + } + // Unmarshal into attributes object + attrs := &Attributes{} + if buf != nil { + err := json.Unmarshal(buf, attrs) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal attributes from certificate: %s", err) + } + } + return attrs, nil +} + +// GetAttributesFromIdemix ... +func (mgr *Mgr) GetAttributesFromIdemix(creator []byte) (*Attributes, error) { + if creator == nil { + return nil, errors.New("creator is nil") + } + + sid := &msp.SerializedIdentity{} + err := proto.Unmarshal(creator, sid) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal transaction invoker's identity: %s", err) + } + idemixID := &msp.SerializedIdemixIdentity{} + err = proto.Unmarshal(sid.IdBytes, idemixID) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal transaction invoker's idemix identity: %s", err) + } + // Unmarshal into attributes object + attrs := &Attributes{ + Attrs: make(map[string]string), + } + + ou := &msp.OrganizationUnit{} + err = proto.Unmarshal(idemixID.Ou, ou) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal transaction invoker's ou: %s", err) + } + attrs.Attrs["ou"] = ou.OrganizationalUnitIdentifier + + role := &msp.MSPRole{} + err = proto.Unmarshal(idemixID.Role, role) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal transaction invoker's role: %s", err) + } + var roleStr string + switch role.Role { + case 0: + roleStr = "member" + case 1: + roleStr = "admin" + case 2: + roleStr = "client" + case 3: + roleStr = "peer" + } + attrs.Attrs["role"] = roleStr + + return attrs, nil +} + +// Attributes contains attribute names and values +type Attributes struct { + Attrs map[string]string `json:"attrs"` +} + +// Names returns the names of the attributes +func (a *Attributes) Names() []string { + i := 0 + names := make([]string, len(a.Attrs)) + for name := range a.Attrs { + names[i] = name + i++ + } + return names +} + +// Contains returns true if the named attribute is found +func (a *Attributes) Contains(name string) bool { + _, ok := a.Attrs[name] + return ok +} + +// Value returns an attribute's value +func (a *Attributes) Value(name string) (string, bool, error) { + attr, ok := a.Attrs[name] + return attr, ok, nil +} + +// True returns nil if the value of attribute 'name' is true; +// otherwise, an appropriate error is returned. +func (a *Attributes) True(name string) error { + val, ok, err := a.Value(name) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("Attribute '%s' was not found", name) + } + if val != "true" { + return fmt.Errorf("Attribute '%s' is not true", name) + } + return nil +} + +// Get the attribute info from a certificate extension, or return nil if not found +func getAttributesFromCert(cert *x509.Certificate) ([]byte, error) { + for _, ext := range cert.Extensions { + if isAttrOID(ext.Id) { + return ext.Value, nil + } + } + return nil, nil +} + +// Is the object ID equal to the attribute info object ID? +func isAttrOID(oid asn1.ObjectIdentifier) bool { + if len(oid) != len(AttrOID) { + return false + } + for idx, val := range oid { + if val != AttrOID[idx] { + return false + } + } + return true +} + +// Get an attribute from 'attrs' by its name, or nil if not found +func getAttrByName(name string, attrs []Attribute) Attribute { + for _, attr := range attrs { + if attr.GetName() == name { + return attr + } + } + return nil +} diff --git a/v2/pkg/attrmgr/attrmgr_test.go b/v2/pkg/attrmgr/attrmgr_test.go new file mode 100644 index 0000000..4aee51b --- /dev/null +++ b/v2/pkg/attrmgr/attrmgr_test.go @@ -0,0 +1,125 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package attrmgr_test + +import ( + "crypto/x509" + "encoding/base64" + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/pkg/attrmgr" + "github.com/stretchr/testify/assert" +) + +const creator = `CgxpZGVtaXhNU1BJRDISgQgKIAZM+v2JgGPuCod5T3RGBdeSUGGAgpu1W1TMwOeEn1sJEiCBvWZYvM0Q7Vpz498M1KlsILTZ5jk6pGihIfWaeGV+0RpCCgxpZGVtaXhNU1BJRDISEG9yZzEuZGVwYXJ0bWVudDEaIIQ6XWpZn5NGEMPdfoKXn262cOdbyiKjTLa+4nXEc0wyIg4KDGlkZW1peE1TUElEMirmBgpECiDUyAZaFx3+OBClul07XsuS1Kh6VKxAkp8CYWGylozr5BIgIxzFAuzglE95JvJYbzUo16mYsiLwLUA7KuDK0lgyYogSRAogILB8Pu98YqrMYURrsftwFtHzWQiZtdwQImcNuPhBA1QSIHrgGLSNFqGHXxC5nOqfDqySyfwYEKLaxWyuO0tMqy8xGkQKIEP2aKh/YLIKMc6vqz8kCIAtHON2iC/TFAcTKo0B8gMAEiDHnrLuVSWUZzRe1iwUh2rsK6UMTnlF7nFPXC/NE2EhNSIg9JqjO+vb3iU0YXdbLlh3vCU1b8hkGkFxd1r91B8ZyL0qIFs7ajZtYPU/gc4x8j/95ujxavBM2CY9+aWo0HHMq5AyMiDkDCZAYRico3+k5UMUyOb/dr2EkO/1Hay8jjZpUGazQzogcxIUhnyP/Jkfmce0KTClAwK4EWYjqSsPYJ9OMKI5R+hCIM7tzJGcK324QYiFCwGLCdIRcf4b0iX2q9+RSsCJmVuuSiC/p2ZvXGKN8HeCzJVbGB8qVE1G1/vx0zCNJ/vqMSdKsFIglOQPuVIHAF6kwVE7Fid5Me4bolxJml2h44aoWXR2slZSILyd0LbL0uwUksqzZ10WqwVbuQ2D69E5e5ItB2CVIF99WiD94PNz3TBMERm7ZPouFYRtw/mhnlNh0T+j5w5R8+BXhGJECiAGTPr9iYBj7gqHeU90RgXXklBhgIKbtVtUzMDnhJ9bCRIggb1mWLzNEO1ac+PfDNSpbCC02eY5OqRooSH1mnhlftFqIJc15atDPZQ+S4ARmu375M/8NuYAUXtwFCViRvzWOuf+cogBCiD+DDNQtMlsIChWD1d8KJE6zhxTmhK/hDzSJha2icCe+xIgTqZgV3OKwFTbWuHGN9gTuSTdeOKH0DWJ0mntNKN+aisaIHAgRufFQqOzdncNdRJOPlHvyyR1jWFYSOkJtIG+3Cf/IiAFVOO804jCkELupkkpfrKfi0y+gIIamLPgEoERSq0Em3pnMGUCMQCgFofNfUeO+uc8wNdqOpwt4dHn/8AggYMNwZD7gY2om71ZrCXDpmznw2eSmaHb2K8CMEk0d4Y29f2xBv2XLMsC0JrkiXjEo0YakZn66FACO02lEBku2/aGBKokDLRfofA1d4ABAYoBAA==` + +// TestAttrs tests attributes +func TestAttrs(t *testing.T) { + mgr := attrmgr.New() + attrs := []attrmgr.Attribute{ + &Attribute{Name: "attr1", Value: "val1"}, + &Attribute{Name: "attr2", Value: "val2"}, + &Attribute{Name: "attr3", Value: "val3"}, + &Attribute{Name: "boolAttr", Value: "true"}, + } + reqs := []attrmgr.AttributeRequest{ + &AttributeRequest{Name: "attr1", Require: false}, + &AttributeRequest{Name: "attr2", Require: true}, + &AttributeRequest{Name: "boolAttr", Require: true}, + &AttributeRequest{Name: "noattr1", Require: false}, + } + cert := &x509.Certificate{} + + // Verify that the certificate has no attributes + at, err := mgr.GetAttributesFromCert(cert) + if err != nil { + t.Fatalf("Failed to GetAttributesFromCert: %s", err) + } + numAttrs := len(at.Names()) + assert.True(t, numAttrs == 0, "expecting 0 attributes but found %d", numAttrs) + + // Add attributes to certificate + err = mgr.ProcessAttributeRequestsForCert(reqs, attrs, cert) + if err != nil { + t.Fatalf("Failed to ProcessAttributeRequestsForCert: %s", err) + } + + // Get attributes from the certificate and verify the count is correct + at, err = mgr.GetAttributesFromCert(cert) + if err != nil { + t.Fatalf("Failed to GetAttributesFromCert: %s", err) + } + numAttrs = len(at.Names()) + assert.True(t, numAttrs == 3, "expecting 3 attributes but found %d", numAttrs) + + // Check individual attributes + checkAttr(t, "attr1", "val1", at) + checkAttr(t, "attr2", "val2", at) + checkAttr(t, "attr3", "", at) + checkAttr(t, "noattr1", "", at) + assert.NoError(t, at.True("boolAttr")) + + // Negative test case: add required attributes which don't exist + reqs = []attrmgr.AttributeRequest{ + &AttributeRequest{Name: "noattr1", Require: true}, + } + err = mgr.ProcessAttributeRequestsForCert(reqs, attrs, cert) + assert.Error(t, err) +} + +func TestIdemixAttrs(t *testing.T) { + mgr := attrmgr.New() + + _, err := mgr.GetAttributesFromIdemix(nil) + assert.Error(t, err, "Should fail, if nil passed for creator") + + creatorBytes, err := base64.StdEncoding.DecodeString(creator) + assert.NoError(t, err, "Failed to base64 decode creator string") + + attrs, err := mgr.GetAttributesFromIdemix(creatorBytes) + assert.NoError(t, err, "GetAttributesFromIdemix") + numAttrs := len(attrs.Names()) + assert.True(t, numAttrs == 2, "expecting 2 attributes but found %d", numAttrs) + checkAttr(t, "ou", "org1.department1", attrs) + checkAttr(t, "role", "member", attrs) + checkAttr(t, "id", "", attrs) +} + +func checkAttr(t *testing.T, name, val string, attrs *attrmgr.Attributes) { + v, ok, err := attrs.Value(name) + assert.NoError(t, err) + if val == "" { + assert.False(t, attrs.Contains(name), "contains attribute '%s'", name) + assert.False(t, ok, "attribute '%s' was found", name) + } else { + assert.True(t, attrs.Contains(name), "does not contain attribute '%s'", name) + assert.True(t, ok, "attribute '%s' was not found", name) + assert.True(t, v == val, "incorrect value for '%s'; expected '%s' but found '%s'", name, val, v) + } +} + +type Attribute struct { + Name, Value string +} + +func (a *Attribute) GetName() string { + return a.Name +} + +func (a *Attribute) GetValue() string { + return a.Value +} + +type AttributeRequest struct { + Name string + Require bool +} + +func (ar *AttributeRequest) GetName() string { + return ar.Name +} + +func (ar *AttributeRequest) IsRequired() bool { + return ar.Require +} diff --git a/v2/pkg/cid/README.md b/v2/pkg/cid/README.md new file mode 100644 index 0000000..6c763ad --- /dev/null +++ b/v2/pkg/cid/README.md @@ -0,0 +1,235 @@ +# Client Identity Chaincode Library + +The client identity chaincode library enables you to write chaincode which +makes access control decisions based on the identity of the client +(i.e. the invoker of the chaincode). In particular, you may make access +control decisions based on any or a combination of the following information associated with +the client: + +* the client identity's MSP (Membership Service Provider) ID +* an attribute associated with the client identity +* an OU (Organizational Unit) value associated with the client identity + +Attributes are simply name and value pairs associated with an identity. +For example, `email=me@gmail.com` indicates an identity has the `email` +attribute with a value of `me@gmail.com`. + +## Using the client identity chaincode library + +This section describes how to use the client identity chaincode library. + +All code samples below assume two things: + +1. The type of the `stub` variable is `ChaincodeStubInterface` as passed + to your chaincode. +2. You have added the following import statement to your chaincode. + + ```golang + import "github.com/hyperledger/fabric-chaincode-go/v2/pkg/cid" + ``` + +### Getting the client's ID + +The following demonstrates how to get an ID for the client which is guaranteed +to be unique within the MSP: + +```golang +id, err := cid.GetID(stub) +``` + +### Getting the MSP ID + +The following demonstrates how to get the MSP ID of the client's identity: + +```golang +mspid, err := cid.GetMSPID(stub) +``` + +### Getting an attribute value + +The following demonstrates how to get the value of the *attr1* attribute: + +```golang +val, ok, err := cid.GetAttributeValue(stub, "attr1") +if err != nil { + // There was an error trying to retrieve the attribute +} +if !ok { + // The client identity does not possess the attribute +} +// Do something with the value of 'val' +``` + +### Asserting an attribute value + +Often all you want to do is to make an access control decision based on the value +of an attribute, i.e. to assert the value of an attribute. For example, the following +will return an error if the client does not have the `myapp.admin` attribute +with a value of `true`: + +```golang +err := cid.AssertAttributeValue(stub, "myapp.admin", "true") +if err != nil { + // Return an error +} +``` + +This is effectively using attributes to implement role-based access control, +or RBAC for short. + +### Checking for a specific OU value + +```golang +found, err := cid.HasOUValue(stub, "myapp.admin") +if err != nil { + // Return an error +} +if !found { + // The client identity is not part of the Organizational Unit + // Return an error +} +``` + +### Getting the client's X509 certificate + +The following demonstrates how to get the X509 certificate of the client, or +nil if the client's identity was not based on an X509 certificate: + +```golang +cert, err := cid.GetX509Certificate(stub) +``` + +Note that both `cert` and `err` may be nil as will be the case if the identity +is not using an X509 certificate. + +### Performing multiple operations more efficiently + +Sometimes you may need to perform multiple operations in order to make an access +decision. For example, the following demonstrates how to grant access to +identities with MSP *org1MSP* and *attr1* OR with MSP *org1MSP* and *attr2*. + +```golang +// Get the Client ID object +id, err := cid.New(stub) +if err != nil { + // Handle error +} +mspid, err := id.GetMSPID() +if err != nil { + // Handle error +} +switch mspid { + case "org1MSP": + err = id.AssertAttributeValue("attr1", "true") + case "org2MSP": + err = id.AssertAttributeValue("attr2", "true") + default: + err = errors.New("Wrong MSP") +} +``` + +Although it is not required, it is more efficient to make the `cid.New` call +to get the ClientID object if you need to perform multiple operations, +as demonstrated above. + +## Adding Attributes to Identities + +This section describes how to add custom attributes to certificates when +using Hyperledger Fabric CA as well as when using an external CA. + +### Managing attributes with Fabric CA + +There are two methods of adding attributes to an enrollment certificate +with fabric-ca: + + 1. When you register an identity, you can specify that an enrollment certificate + issued for the identity should by default contain an attribute. This behavior + can be overridden at enrollment time, but this is useful for establishing + default behavior and, assuming registration occurs outside of your application, + does not require any application change. + + The following shows how to register *user1* with two attributes: + *app1Admin* and *email*. + The ":ecert" suffix causes the *appAdmin* attribute to be inserted into user1's + enrollment certificate by default. The *email* attribute is not added + to the enrollment certificate by default. + + ```bash + fabric-ca-client register --id.name user1 --id.secret user1pw --id.type user --id.affiliation org1 --id.attrs 'app1Admin=true:ecert,email=user1@gmail.com' + ``` + + 2. When you enroll an identity, you may request that one or more attributes + be added to the certificate. + For each attribute requested, you may specify whether the attribute is + optional or not. If it is not optional but does not exist for the identity, + enrollment fails. + + The following shows how to enroll *user1* with the *email* attribute, + without the *app1Admin* attribute and optionally with the *phone* attribute + (if the user possesses *phone* attribute). + + ```bash + fabric-ca-client enroll -u http://user1:user1pw@localhost:7054 --enrollment.attrs "email,phone:opt" + ``` + +#### Attribute format in a certificate + +Attributes are stored inside an X509 certificate as an extension with an +ASN.1 OID (Abstract Syntax Notation Object IDentifier) +of `1.2.3.4.5.6.7.8.1`. The value of the extension is a JSON string of the +form `{"attrs":{: 0 { + s += "," + } + for j, tv := range rdn { + if j > 0 { + s += "+" + } + typeString := tv.Type.String() + typeName, ok := attributeTypeNames[typeString] + if !ok { + derBytes, err := asn1.Marshal(tv.Value) + if err == nil { + s += typeString + "=#" + hex.EncodeToString(derBytes) + continue // No value escaping necessary. + } + typeName = typeString + } + valueString := fmt.Sprint(tv.Value) + escaped := "" + begin := 0 + for idx, c := range valueString { + if (idx == 0 && (c == ' ' || c == '#')) || + (idx == len(valueString)-1 && c == ' ') { + escaped += valueString[begin:idx] + escaped += "\\" + string(c) + begin = idx + 1 + continue + } + switch c { + case ',', '+', '"', '\\', '<', '>', ';': + escaped += valueString[begin:idx] + escaped += "\\" + string(c) + begin = idx + 1 + } + } + escaped += valueString[begin:] + s += typeName + "=" + escaped + } + } + return s +} + +var attributeTypeNames = map[string]string{ + "2.5.4.6": "C", + "2.5.4.10": "O", + "2.5.4.11": "OU", + "2.5.4.3": "CN", + "2.5.4.5": "SERIALNUMBER", + "2.5.4.7": "L", + "2.5.4.8": "ST", + "2.5.4.9": "STREET", + "2.5.4.17": "POSTALCODE", +} diff --git a/v2/pkg/cid/cid_test.go b/v2/pkg/cid/cid_test.go new file mode 100644 index 0000000..04b9db1 --- /dev/null +++ b/v2/pkg/cid/cid_test.go @@ -0,0 +1,193 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cid_test + +import ( + "encoding/base64" + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/pkg/cid" + "github.com/hyperledger/fabric-protos-go-apiv2/msp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" +) + +const certWithOutAttrs = `-----BEGIN CERTIFICATE----- +MIICXTCCAgSgAwIBAgIUeLy6uQnq8wwyElU/jCKRYz3tJiQwCgYIKoZIzj0EAwIw +eTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDVNh +biBGcmFuY2lzY28xGTAXBgNVBAoTEEludGVybmV0IFdpZGdldHMxDDAKBgNVBAsT +A1dXVzEUMBIGA1UEAxMLZXhhbXBsZS5jb20wHhcNMTcwOTA4MDAxNTAwWhcNMTgw +OTA4MDAxNTAwWjBdMQswCQYDVQQGEwJVUzEXMBUGA1UECBMOTm9ydGggQ2Fyb2xp +bmExFDASBgNVBAoTC0h5cGVybGVkZ2VyMQ8wDQYDVQQLEwZGYWJyaWMxDjAMBgNV +BAMTBWFkbWluMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEFq/90YMuH4tWugHa +oyZtt4Mbwgv6CkBSDfYulVO1CVInw1i/k16DocQ/KSDTeTfgJxrX1Ree1tjpaodG +1wWyM6OBhTCBgjAOBgNVHQ8BAf8EBAMCB4AwDAYDVR0TAQH/BAIwADAdBgNVHQ4E +FgQUhKs/VJ9IWJd+wer6sgsgtZmxZNwwHwYDVR0jBBgwFoAUIUd4i/sLTwYWvpVr +TApzcT8zv/kwIgYDVR0RBBswGYIXQW5pbHMtTWFjQm9vay1Qcm8ubG9jYWwwCgYI +KoZIzj0EAwIDRwAwRAIgCoXaCdU8ZiRKkai0QiXJM/GL5fysLnmG2oZ6XOIdwtsC +IEmCsI8Mhrvx1doTbEOm7kmIrhQwUVDBNXCWX1t3kJVN +-----END CERTIFICATE----- +` +const certWithAttrs = `-----BEGIN CERTIFICATE----- +MIIB6TCCAY+gAwIBAgIUHkmY6fRP0ANTvzaBwKCkMZZPUnUwCgYIKoZIzj0EAwIw +GzEZMBcGA1UEAxMQZmFicmljLWNhLXNlcnZlcjAeFw0xNzA5MDgwMzQyMDBaFw0x +ODA5MDgwMzQyMDBaMB4xHDAaBgNVBAMTE015VGVzdFVzZXJXaXRoQXR0cnMwWTAT +BgcqhkjOPQIBBggqhkjOPQMBBwNCAATmB1r3CdWvOOP3opB3DjJnW3CnN8q1ydiR +dzmuA6A2rXKzPIltHvYbbSqISZJubsy8gVL6GYgYXNdu69RzzFF5o4GtMIGqMA4G +A1UdDwEB/wQEAwICBDAMBgNVHRMBAf8EAjAAMB0GA1UdDgQWBBTYKLTAvJJK08OM +VGwIhjMQpo2DrjAfBgNVHSMEGDAWgBTEs/52DeLePPx1+65VhgTwu3/2ATAiBgNV +HREEGzAZghdBbmlscy1NYWNCb29rLVByby5sb2NhbDAmBggqAwQFBgcIAQQaeyJh +dHRycyI6eyJhdHRyMSI6InZhbDEifX0wCgYIKoZIzj0EAwIDSAAwRQIhAPuEqWUp +svTTvBqLR5JeQSctJuz3zaqGRqSs2iW+QB3FAiAIP0mGWKcgSGRMMBvaqaLytBYo +9v3hRt1r8j8vN0pMcg== +-----END CERTIFICATE----- +` + +// #nosec G101 +const idemixCred = `CiAGTPr9iYBj7gqHeU90RgXXklBhgIKbtVtUzMDnhJ9bCRIggb1mWLzNEO1ac+PfDNSpbCC02eY5OqRooSH1mnhlftEaQgoMaWRlbWl4TVNQSUQyEhBvcmcxLmRlcGFydG1lbnQxGiCEOl1qWZ+TRhDD3X6Cl59utnDnW8oio0y2vuJ1xHNMMiIOCgxpZGVtaXhNU1BJRDIq5gYKRAog1MgGWhcd/jgQpbpdO17LktSoelSsQJKfAmFhspaM6+QSICMcxQLs4JRPeSbyWG81KNepmLIi8C1AOyrgytJYMmKIEkQKICCwfD7vfGKqzGFEa7H7cBbR81kImbXcECJnDbj4QQNUEiB64Bi0jRahh18QuZzqnw6sksn8GBCi2sVsrjtLTKsvMRpECiBD9miof2CyCjHOr6s/JAiALRzjdogv0xQHEyqNAfIDABIgx56y7lUllGc0XtYsFIdq7CulDE55Re5xT1wvzRNhITUiIPSaozvr294lNGF3Wy5Yd7wlNW/IZBpBcXda/dQfGci9KiBbO2o2bWD1P4HOMfI//ebo8WrwTNgmPfmlqNBxzKuQMjIg5AwmQGEYnKN/pOVDFMjm/3a9hJDv9R2svI42aVBms0M6IHMSFIZ8j/yZH5nHtCkwpQMCuBFmI6krD2CfTjCiOUfoQiDO7cyRnCt9uEGIhQsBiwnSEXH+G9Il9qvfkUrAiZlbrkogv6dmb1xijfB3gsyVWxgfKlRNRtf78dMwjSf76jEnSrBSIJTkD7lSBwBepMFROxYneTHuG6JcSZpdoeOGqFl0drJWUiC8ndC2y9LsFJLKs2ddFqsFW7kNg+vROXuSLQdglSBffVog/eDzc90wTBEZu2T6LhWEbcP5oZ5TYdE/o+cOUfPgV4RiRAogBkz6/YmAY+4Kh3lPdEYF15JQYYCCm7VbVMzA54SfWwkSIIG9Zli8zRDtWnPj3wzUqWwgtNnmOTqkaKEh9Zp4ZX7RaiCXNeWrQz2UPkuAEZrt++TP/DbmAFF7cBQlYkb81jrn/nKIAQog/gwzULTJbCAoVg9XfCiROs4cU5oSv4Q80iYWtonAnvsSIE6mYFdzisBU21rhxjfYE7kk3Xjih9A1idJp7TSjfmorGiBwIEbnxUKjs3Z3DXUSTj5R78skdY1hWEjpCbSBvtwn/yIgBVTjvNOIwpBC7qZJKX6yn4tMvoCCGpiz4BKBEUqtBJt6ZzBlAjEAoBaHzX1HjvrnPMDXajqcLeHR5//AIIGDDcGQ+4GNqJu9Wawlw6Zs58Nnkpmh29ivAjBJNHeGNvX9sQb9lyzLAtCa5Il4xKNGGpGZ+uhQAjtNpRAZLtv2hgSqJAy0X6HwNXeAAQGKAQA=` + +func TestClient(t *testing.T) { + stub, err := getMockStub() + assert.NoError(t, err, "Failed to get mock submitter") + sinfo, err := cid.New(stub) + assert.NoError(t, err, "Error getting submitter of the transaction") + id, err := cid.GetID(stub) + assert.NoError(t, err, "Error getting ID of the submitter of the transaction") + assert.NotEmpty(t, id, "Transaction submitter ID should not be empty") + t.Logf("The client's ID is: %s", id) + cert, err := cid.GetX509Certificate(stub) + assert.NoError(t, err, "Error getting X509 certificate of the submitter of the transaction") + assert.NotNil(t, cert, "Transaction submitter certificate should not be nil") + mspid, err := cid.GetMSPID(stub) + assert.NoError(t, err, "Error getting MSP ID of the submitter of the transaction") + assert.NotEmpty(t, mspid, "Transaction submitter MSP ID should not be empty") + _, found, err := sinfo.GetAttributeValue("foo") + assert.NoError(t, err, "Error getting Unique ID of the submitter of the transaction") + assert.False(t, found, "Attribute 'foo' should not be found in the submitter cert") + err = cid.AssertAttributeValue(stub, "foo", "") + assert.Error(t, err, "AssertAttributeValue should have returned an error with no attribute") + found, err = cid.HasOUValue(stub, "Fabric") + assert.NoError(t, err, "Error getting X509 cert of the submitter of the transaction") + assert.True(t, found) + found, err = cid.HasOUValue(stub, "foo") + assert.NoError(t, err, "HasOUValue") + assert.False(t, found, "OU 'foo' should not be found in the submitter cert") + + stub, err = getMockStubWithAttrs() + assert.NoError(t, err, "Failed to get mock submitter") + sinfo, err = cid.New(stub) + assert.NoError(t, err, "Failed to new client") + attrVal, found, err := sinfo.GetAttributeValue("attr1") + assert.NoError(t, err, "Error getting Unique ID of the submitter of the transaction") + assert.True(t, found, "Attribute 'attr1' should be found in the submitter cert") + assert.Equal(t, attrVal, "val1", "Value of attribute 'attr1' should be 'val1'") + attrVal, found, err = cid.GetAttributeValue(stub, "attr1") + assert.NoError(t, err, "Error getting Unique ID of the submitter of the transaction") + assert.True(t, found, "Attribute 'attr1' should be found in the submitter cert") + assert.Equal(t, attrVal, "val1", "Value of attribute 'attr1' should be 'val1'") + err = cid.AssertAttributeValue(stub, "attr1", "val1") + assert.NoError(t, err, "Error in AssertAttributeValue") + err = cid.AssertAttributeValue(stub, "attr1", "val2") + assert.Error(t, err, "Assert should have failed; value was val1, not val2") + found, err = cid.HasOUValue(stub, "foo") + assert.NoError(t, err, "Error getting X509 cert of the submitter of the transaction") + assert.False(t, found, "HasOUValue") + + // Error case1 + stub, err = getMockStubWithNilCreator() + assert.NoError(t, err, "Failed to get mock submitter") + _, err = cid.New(stub) + assert.Error(t, err, "NewSubmitterInfo should have returned an error when submitter with nil creator is passed") + + // Error case2 + stub, err = getMockStubWithFakeCreator() + assert.NoError(t, err, "Failed to get mock submitter") + _, err = cid.New(stub) + assert.Error(t, err, "NewSubmitterInfo should have returned an error when submitter with fake creator is passed") +} + +func TestIdemix(t *testing.T) { + stub, err := getIdemixMockStubWithAttrs() + assert.NoError(t, err, "Failed to get mock idemix stub") + sinfo, err := cid.New(stub) + assert.NoError(t, err, "Failed to new client") + cert, err := sinfo.GetX509Certificate() + assert.Nil(t, cert, "Idemix can't get x509 type of cert") + assert.NoError(t, err, "Err for this func is nil") + id, err := cid.GetID(stub) + assert.Error(t, err, "Cannot determine identity") + assert.Equal(t, id, "", "Id should be empty when Idemix") + attrVal, found, err := sinfo.GetAttributeValue("ou") + assert.NoError(t, err, "Error getting 'ou' of the submitter of the transaction") + assert.True(t, found, "Attribute 'ou' should be found in the submitter cert") + assert.Equal(t, attrVal, "org1.department1", "Value of attribute 'attr1' should be 'val1'") + attrVal, found, err = sinfo.GetAttributeValue("role") + assert.NoError(t, err, "Error getting 'role' of the submitter of the transaction") + assert.True(t, found, "Attribute 'role' should be found in the submitter cert") + assert.Equal(t, attrVal, "member", "Value of attribute 'attr1' should be 'val1'") + _, found, err = sinfo.GetAttributeValue("id") + assert.NoError(t, err, "GetAttributeValue") + assert.False(t, found, "Attribute 'id' should not be found in the submitter cert") +} + +func getMockStub() (cid.ChaincodeStubInterface, error) { + stub := &mockStub{} + sid := &msp.SerializedIdentity{Mspid: "SampleOrg", + IdBytes: []byte(certWithOutAttrs)} + b, err := proto.Marshal(sid) + if err != nil { + return nil, err + } + stub.creator = b + return stub, nil +} + +func getMockStubWithAttrs() (cid.ChaincodeStubInterface, error) { + stub := &mockStub{} + sid := &msp.SerializedIdentity{Mspid: "SampleOrg", + IdBytes: []byte(certWithAttrs)} + b, err := proto.Marshal(sid) + if err != nil { + return nil, err + } + stub.creator = b + return stub, nil +} + +func getIdemixMockStubWithAttrs() (cid.ChaincodeStubInterface, error) { + stub := &mockStub{} + idBytes, err := base64.StdEncoding.DecodeString(idemixCred) + if err != nil { + return nil, err + } + sid := &msp.SerializedIdentity{Mspid: "idemixOrg", + IdBytes: idBytes, + } + b, err := proto.Marshal(sid) + if err != nil { + return nil, err + } + stub.creator = b + return stub, nil +} + +func getMockStubWithNilCreator() (cid.ChaincodeStubInterface, error) { + c := &mockStub{} + c.creator = nil + return c, nil +} + +func getMockStubWithFakeCreator() (cid.ChaincodeStubInterface, error) { + c := &mockStub{} + c.creator = []byte("foo") + return c, nil +} + +type mockStub struct { + creator []byte +} + +func (s *mockStub) GetCreator() ([]byte, error) { + return s.creator, nil +} diff --git a/v2/pkg/cid/interfaces.go b/v2/pkg/cid/interfaces.go new file mode 100644 index 0000000..140ffe9 --- /dev/null +++ b/v2/pkg/cid/interfaces.go @@ -0,0 +1,42 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cid + +import "crypto/x509" + +// ChaincodeStubInterface is used by deployable chaincode apps to get identity +// of the agent (or user) submitting the transaction. +type ChaincodeStubInterface interface { + // GetCreator returns `SignatureHeader.Creator` (e.g. an identity) + // of the `SignedProposal`. This is the identity of the agent (or user) + // submitting the transaction. + GetCreator() ([]byte, error) +} + +// ClientIdentity represents information about the identity that submitted the +// transaction +type ClientIdentity interface { + + // GetID returns the ID associated with the invoking identity. This ID + // is guaranteed to be unique within the MSP. + GetID() (string, error) + + // Return the MSP ID of the client + GetMSPID() (string, error) + + // GetAttributeValue returns the value of the client's attribute named `attrName`. + // If the client possesses the attribute, `found` is true and `value` equals the + // value of the attribute. + // If the client does not possess the attribute, `found` is false and `value` + // equals "". + GetAttributeValue(attrName string) (value string, found bool, err error) + + // AssertAttributeValue verifies that the client has the attribute named `attrName` + // with a value of `attrValue`; otherwise, an error is returned. + AssertAttributeValue(attrName, attrValue string) error + + // GetX509Certificate returns the X509 certificate associated with the client, + // or nil if it was not identified by an X509 certificate. + GetX509Certificate() (*x509.Certificate, error) +} diff --git a/v2/pkg/statebased/interfaces.go b/v2/pkg/statebased/interfaces.go new file mode 100644 index 0000000..7f410af --- /dev/null +++ b/v2/pkg/statebased/interfaces.go @@ -0,0 +1,51 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package statebased + +import "fmt" + +// RoleType of an endorsement policy's identity +type RoleType string + +const ( + // RoleTypeMember identifies an org's member identity + RoleTypeMember = RoleType("MEMBER") + // RoleTypePeer identifies an org's peer identity + RoleTypePeer = RoleType("PEER") +) + +// RoleTypeDoesNotExistError is returned by function AddOrgs of +// KeyEndorsementPolicy if a role type that does not match one +// specified above is passed as an argument. +type RoleTypeDoesNotExistError struct { + RoleType RoleType +} + +func (r *RoleTypeDoesNotExistError) Error() string { + return fmt.Sprintf("role type %s does not exist", r.RoleType) +} + +// KeyEndorsementPolicy provides a set of convenience methods to create and +// modify a state-based endorsement policy. Endorsement policies created by +// this convenience layer will always be a logical AND of ".peer" +// principals for one or more ORGs specified by the caller. +type KeyEndorsementPolicy interface { + // Policy returns the endorsement policy as bytes + Policy() ([]byte, error) + + // AddOrgs adds the specified orgs to the list of orgs that are required + // to endorse. All orgs MSP role types will be set to the role that is + // specified in the first parameter. Among other aspects the desired role + // depends on the channel's configuration: if it supports node OUs, it is + // likely going to be the PEER role, while the MEMBER role is the suited + // one if it does not. + AddOrgs(roleType RoleType, organizations ...string) error + + // DelOrgs deletes the specified channel orgs from the existing key-level endorsement + // policy for this KVS key. + DelOrgs(organizations ...string) + + // ListOrgs returns an array of channel orgs that are required to endorse changes. + ListOrgs() []string +} diff --git a/v2/pkg/statebased/statebasedimpl.go b/v2/pkg/statebased/statebasedimpl.go new file mode 100644 index 0000000..ae21152 --- /dev/null +++ b/v2/pkg/statebased/statebasedimpl.go @@ -0,0 +1,143 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package statebased + +import ( + "fmt" + "sort" + + "github.com/hyperledger/fabric-protos-go-apiv2/common" + "github.com/hyperledger/fabric-protos-go-apiv2/msp" + "google.golang.org/protobuf/proto" +) + +// stateEP implements the KeyEndorsementPolicy +type stateEP struct { + orgs map[string]msp.MSPRole_MSPRoleType +} + +// NewStateEP constructs a state-based endorsement policy from a given +// serialized EP byte array. If the byte array is empty, a new EP is created. +func NewStateEP(policy []byte) (KeyEndorsementPolicy, error) { + s := &stateEP{orgs: make(map[string]msp.MSPRole_MSPRoleType)} + if policy != nil { + spe := &common.SignaturePolicyEnvelope{} + if err := proto.Unmarshal(policy, spe); err != nil { + return nil, fmt.Errorf("Error unmarshaling to SignaturePolicy: %s", err) + } + + err := s.setMSPIDsFromSP(spe) + if err != nil { + return nil, err + } + } + return s, nil +} + +// Policy returns the endorsement policy as bytes. +func (s *stateEP) Policy() ([]byte, error) { + spe, err := s.policyFromMSPIDs() + if err != nil { + return nil, err + } + spBytes, err := proto.Marshal(spe) + if err != nil { + return nil, err + } + return spBytes, nil +} + +// AddOrgs adds the specified channel orgs to the existing key-level EP. +func (s *stateEP) AddOrgs(role RoleType, neworgs ...string) error { + var mspRole msp.MSPRole_MSPRoleType + switch role { + case RoleTypeMember: + mspRole = msp.MSPRole_MEMBER + case RoleTypePeer: + mspRole = msp.MSPRole_PEER + default: + return &RoleTypeDoesNotExistError{RoleType: role} + } + + // add new orgs + for _, addorg := range neworgs { + s.orgs[addorg] = mspRole + } + + return nil +} + +// DelOrgs delete the specified channel orgs from the existing key-level EP. +func (s *stateEP) DelOrgs(delorgs ...string) { + for _, delorg := range delorgs { + delete(s.orgs, delorg) + } +} + +// ListOrgs returns an array of channel orgs that are required to endorse changes. +func (s *stateEP) ListOrgs() []string { + orgNames := make([]string, 0, len(s.orgs)) + for mspid := range s.orgs { + orgNames = append(orgNames, mspid) + } + return orgNames +} + +func (s *stateEP) setMSPIDsFromSP(sp *common.SignaturePolicyEnvelope) error { + // iterate over the identities in this envelope + for _, identity := range sp.Identities { + // this imlementation only supports the ROLE type + if identity.PrincipalClassification == msp.MSPPrincipal_ROLE { + msprole := &msp.MSPRole{} + err := proto.Unmarshal(identity.Principal, msprole) + if err != nil { + return fmt.Errorf("error unmarshaling msp principal: %s", err) + } + s.orgs[msprole.GetMspIdentifier()] = msprole.GetRole() + } + } + return nil +} + +func (s *stateEP) policyFromMSPIDs() (*common.SignaturePolicyEnvelope, error) { + mspids := s.ListOrgs() + sort.Strings(mspids) + principals := make([]*msp.MSPPrincipal, len(mspids)) + sigspolicy := make([]*common.SignaturePolicy, len(mspids)) + for i, id := range mspids { + principal, err := proto.Marshal( + &msp.MSPRole{ + Role: s.orgs[id], + MspIdentifier: id, + }, + ) + if err != nil { + return nil, err + } + principals[i] = &msp.MSPPrincipal{ + PrincipalClassification: msp.MSPPrincipal_ROLE, + Principal: principal, + } + sigspolicy[i] = &common.SignaturePolicy{ + Type: &common.SignaturePolicy_SignedBy{ + SignedBy: int32(i), + }, + } + } + + // create the policy: it requires exactly 1 signature from all of the principals + p := &common.SignaturePolicyEnvelope{ + Version: 0, + Rule: &common.SignaturePolicy{ + Type: &common.SignaturePolicy_NOutOf_{ + NOutOf: &common.SignaturePolicy_NOutOf{ + N: int32(len(mspids)), + Rules: sigspolicy, + }, + }, + }, + Identities: principals, + } + return p, nil +} diff --git a/v2/pkg/statebased/statebasedimpl_test.go b/v2/pkg/statebased/statebasedimpl_test.go new file mode 100644 index 0000000..fd0e820 --- /dev/null +++ b/v2/pkg/statebased/statebasedimpl_test.go @@ -0,0 +1,113 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package statebased_test + +import ( + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/pkg/statebased" + "github.com/hyperledger/fabric-protos-go-apiv2/common" + "github.com/hyperledger/fabric-protos-go-apiv2/msp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" +) + +func TestAddOrg(t *testing.T) { + // add an org + ep, err := statebased.NewStateEP(nil) + assert.NoError(t, err) + err = ep.AddOrgs(statebased.RoleTypePeer, "Org1") + assert.NoError(t, err) + + // bad role type + err = ep.AddOrgs("unknown", "Org1") + assert.Equal(t, &statebased.RoleTypeDoesNotExistError{RoleType: statebased.RoleType("unknown")}, err) + assert.EqualError(t, err, "role type unknown does not exist") + + epBytes, err := ep.Policy() + assert.NoError(t, err) + expectedEP := signedByMspPeer("Org1", t) + expectedEPBytes, err := proto.Marshal(expectedEP) + assert.NoError(t, err) + assert.Equal(t, expectedEPBytes, epBytes) +} + +func TestListOrgs(t *testing.T) { + expectedEP := signedByMspPeer("Org1", t) + expectedEPBytes, err := proto.Marshal(expectedEP) + assert.NoError(t, err) + + // retrieve the orgs + ep, err := statebased.NewStateEP(expectedEPBytes) + assert.NoError(t, err, "NewStateEP") + orgs := ep.ListOrgs() + assert.Equal(t, []string{"Org1"}, orgs) +} + +func TestDelAddOrg(t *testing.T) { + expectedEP := signedByMspPeer("Org1", t) + expectedEPBytes, err := proto.Marshal(expectedEP) + assert.NoError(t, err) + ep, err := statebased.NewStateEP(expectedEPBytes) + assert.NoError(t, err) + + // retrieve the orgs + orgs := ep.ListOrgs() + assert.Equal(t, []string{"Org1"}, orgs) + + // mod the endorsement policy + err = ep.AddOrgs(statebased.RoleTypePeer, "Org2") + assert.NoError(t, err) + ep.DelOrgs("Org1") + + // check whether what is stored is correct + epBytes, err := ep.Policy() + assert.NoError(t, err) + expectedEP = signedByMspPeer("Org2", t) + expectedEPBytes, err = proto.Marshal(expectedEP) + assert.NoError(t, err) + assert.Equal(t, expectedEPBytes, epBytes) +} + +// SignedByMspPeer creates a SignaturePolicyEnvelope +// requiring 1 signature from any peer of the specified MSP +func signedByMspPeer(mspId string, t *testing.T) *common.SignaturePolicyEnvelope { + // specify the principal: it's a member of the msp we just found + principal, err := proto.Marshal( + &msp.MSPRole{ + Role: msp.MSPRole_PEER, + MspIdentifier: mspId, + }, + ) + if err != nil { + t.Fatalf("failed to marshal principal: %s", err) + } + + // create the policy: it requires exactly 1 signature from the first (and only) principal + p := &common.SignaturePolicyEnvelope{ + Version: 0, + Rule: &common.SignaturePolicy{ + Type: &common.SignaturePolicy_NOutOf_{ + NOutOf: &common.SignaturePolicy_NOutOf{ + N: 1, + Rules: []*common.SignaturePolicy{ + { + Type: &common.SignaturePolicy_SignedBy{ + SignedBy: 0, + }, + }, + }, + }, + }, + }, + Identities: []*msp.MSPPrincipal{ + { + PrincipalClassification: msp.MSPPrincipal_ROLE, + Principal: principal, + }, + }, + } + + return p +} diff --git a/v2/shim/chaincodeserver.go b/v2/shim/chaincodeserver.go new file mode 100644 index 0000000..455ec4a --- /dev/null +++ b/v2/shim/chaincodeserver.go @@ -0,0 +1,79 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "crypto/tls" + "errors" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + + "google.golang.org/grpc/keepalive" +) + +// TLSProperties passed to ChaincodeServer +type TLSProperties struct { + //Disabled forces default to be TLS enabled + Disabled bool + Key []byte + Cert []byte + // ClientCACerts set if connecting peer should be verified + ClientCACerts []byte +} + +// ChaincodeServer encapsulates basic properties needed for a chaincode server +type ChaincodeServer struct { + // CCID should match chaincode's package name on peer + CCID string + // Addesss is the listen address of the chaincode server + Address string + // CC is the chaincode that handles Init and Invoke + CC Chaincode + // TLSProps is the TLS properties passed to chaincode server + TLSProps TLSProperties + // KaOpts keepalive options, sensible defaults provided if nil + KaOpts *keepalive.ServerParameters +} + +// Connect the bidi stream entry point called by chaincode to register with the Peer. +func (cs *ChaincodeServer) Connect(stream peer.Chaincode_ConnectServer) error { + return chatWithPeer(cs.CCID, stream, cs.CC) +} + +// Start the server +func (cs *ChaincodeServer) Start() error { + if cs.CCID == "" { + return errors.New("ccid must be specified") + } + + if cs.Address == "" { + return errors.New("address must be specified") + } + + if cs.CC == nil { + return errors.New("chaincode must be specified") + } + + var tlsCfg *tls.Config + var err error + if !cs.TLSProps.Disabled { + tlsCfg, err = internal.LoadTLSConfig(true, cs.TLSProps.Key, cs.TLSProps.Cert, cs.TLSProps.ClientCACerts) + if err != nil { + return err + } + } + + // create listener and grpc server + server, err := internal.NewServer(cs.Address, tlsCfg, cs.KaOpts) + if err != nil { + return err + } + + // register the server with grpc ... + peer.RegisterChaincodeServer(server.Server, cs) + + // ... and start + return server.Start() +} diff --git a/v2/shim/handler.go b/v2/shim/handler.go new file mode 100644 index 0000000..5b4e217 --- /dev/null +++ b/v2/shim/handler.go @@ -0,0 +1,708 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "errors" + "fmt" + "sync" + + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/protobuf/proto" +) + +type state string + +const ( + created state = "created" // start state + established state = "established" // connection established + ready state = "ready" // ready for requests +) + +// PeerChaincodeStream is the common stream interface for Peer - chaincode communication. +// Both chaincode-as-server and chaincode-as-client patterns need to support this +type PeerChaincodeStream interface { + Send(*peer.ChaincodeMessage) error + Recv() (*peer.ChaincodeMessage, error) +} + +// ClientStream supports the (original) chaincode-as-client interaction pattern +type ClientStream interface { + PeerChaincodeStream + CloseSend() error +} + +// Handler handler implementation for shim side of chaincode. +type Handler struct { + // serialLock is used to prevent concurrent calls to Send on the + // PeerChaincodeStream. This is required by gRPC. + serialLock sync.Mutex + // chatStream is the client used to access the chaincode support server on + // the peer. + chatStream PeerChaincodeStream + + // cc is the chaincode associated with this handler. + cc Chaincode + // state holds the current state of this handler. + state state + + // Multiple queries (and one transaction) with different txids can be executing in parallel for this chaincode + // responseChannels is the channel on which responses are communicated by the shim to the chaincodeStub. + // need lock to protect chaincode from attempting + // concurrent requests to the peer + responseChannelsMutex sync.Mutex + responseChannels map[string]chan *peer.ChaincodeMessage +} + +func shorttxid(txid string) string { + if len(txid) < 8 { + return txid + } + return txid[0:8] +} + +// serialSend serializes calls to Send on the gRPC client. +func (h *Handler) serialSend(msg *peer.ChaincodeMessage) error { + h.serialLock.Lock() + defer h.serialLock.Unlock() + + return h.chatStream.Send(msg) +} + +// serialSendAsync sends the provided message asynchronously in a separate +// goroutine. The result of the send is communicated back to the caller via +// errc. +func (h *Handler) serialSendAsync(msg *peer.ChaincodeMessage, errc chan<- error) { + go func() { + errc <- h.serialSend(msg) + }() +} + +// transactionContextID builds a transaction context identifier by +// concatenating a channel ID and a transaction ID. +func transactionContextID(chainID, txid string) string { + return chainID + txid +} + +func (h *Handler) createResponseChannel(channelID, txid string) (<-chan *peer.ChaincodeMessage, error) { + h.responseChannelsMutex.Lock() + defer h.responseChannelsMutex.Unlock() + + if h.responseChannels == nil { + return nil, fmt.Errorf("[%s] cannot create response channel", shorttxid(txid)) + } + + txCtxID := transactionContextID(channelID, txid) + if h.responseChannels[txCtxID] != nil { + return nil, fmt.Errorf("[%s] channel exists", shorttxid(txCtxID)) + } + + responseChan := make(chan *peer.ChaincodeMessage) + h.responseChannels[txCtxID] = responseChan + return responseChan, nil +} + +func (h *Handler) deleteResponseChannel(channelID, txid string) { + h.responseChannelsMutex.Lock() + defer h.responseChannelsMutex.Unlock() + if h.responseChannels != nil { + txCtxID := transactionContextID(channelID, txid) + delete(h.responseChannels, txCtxID) + } +} + +func (h *Handler) handleResponse(msg *peer.ChaincodeMessage) error { + h.responseChannelsMutex.Lock() + defer h.responseChannelsMutex.Unlock() + + if h.responseChannels == nil { + return fmt.Errorf("[%s] Cannot send message response channel", shorttxid(msg.Txid)) + } + + txCtxID := transactionContextID(msg.ChannelId, msg.Txid) + responseCh := h.responseChannels[txCtxID] + if responseCh == nil { + return fmt.Errorf("[%s] responseChannel does not exist", shorttxid(msg.Txid)) + } + responseCh <- msg + return nil +} + +// sendReceive sends msg to the peer and waits for the response to arrive on +// the provided responseChan. On success, the response message will be +// returned. An error will be returned msg was not successfully sent to the +// peer. +func (h *Handler) sendReceive(msg *peer.ChaincodeMessage, responseChan <-chan *peer.ChaincodeMessage) (*peer.ChaincodeMessage, error) { + err := h.serialSend(msg) + if err != nil { + return &peer.ChaincodeMessage{}, err + } + + outmsg := <-responseChan + return outmsg, nil +} + +// NewChaincodeHandler returns a new instance of the shim side handler. +func newChaincodeHandler(peerChatStream PeerChaincodeStream, chaincode Chaincode) *Handler { + return &Handler{ + chatStream: peerChatStream, + cc: chaincode, + responseChannels: map[string]chan *peer.ChaincodeMessage{}, + state: created, + } +} + +type stubHandlerFunc func(*peer.ChaincodeMessage) (*peer.ChaincodeMessage, error) + +func (h *Handler) handleStubInteraction(handler stubHandlerFunc, msg *peer.ChaincodeMessage, errc chan<- error) { + resp, err := handler(msg) + if err != nil { + resp = &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_ERROR, Payload: []byte(err.Error()), Txid: msg.Txid, ChannelId: msg.ChannelId} + } + h.serialSendAsync(resp, errc) +} + +// handleInit calls the Init function of the associated chaincode. +func (h *Handler) handleInit(msg *peer.ChaincodeMessage) (*peer.ChaincodeMessage, error) { + // Get the function and args from Payload + input := &peer.ChaincodeInput{} + err := proto.Unmarshal(msg.Payload, input) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal input: %s", err) + } + + // Create the ChaincodeStub which the chaincode can use to callback + stub, err := newChaincodeStub(h, msg.ChannelId, msg.Txid, input, msg.Proposal) + if err != nil { + return nil, fmt.Errorf("failed to create new ChaincodeStub: %s", err) + } + + res := h.cc.Init(stub) + if res.Status >= ERROR { + return &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_ERROR, Payload: []byte(res.Message), Txid: msg.Txid, ChaincodeEvent: stub.chaincodeEvent, ChannelId: msg.ChannelId}, nil + } + + resBytes, err := proto.Marshal(res) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %s", err) + } + + return &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_COMPLETED, Payload: resBytes, Txid: msg.Txid, ChaincodeEvent: stub.chaincodeEvent, ChannelId: stub.ChannelID}, nil +} + +// handleTransaction calls Invoke on the associated chaincode. +func (h *Handler) handleTransaction(msg *peer.ChaincodeMessage) (*peer.ChaincodeMessage, error) { + // Get the function and args from Payload + input := &peer.ChaincodeInput{} + err := proto.Unmarshal(msg.Payload, input) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal input: %s", err) + } + + // Create the ChaincodeStub which the chaincode can use to callback + stub, err := newChaincodeStub(h, msg.ChannelId, msg.Txid, input, msg.Proposal) + if err != nil { + return nil, fmt.Errorf("failed to create new ChaincodeStub: %s", err) + } + + res := h.cc.Invoke(stub) + + // Endorser will handle error contained in Response. + resBytes, err := proto.Marshal(res) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %s", err) + } + + return &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_COMPLETED, Payload: resBytes, Txid: msg.Txid, ChaincodeEvent: stub.chaincodeEvent, ChannelId: stub.ChannelID}, nil +} + +// callPeerWithChaincodeMsg sends a chaincode message to the peer for the given +// txid and channel and receives the response. +func (h *Handler) callPeerWithChaincodeMsg(msg *peer.ChaincodeMessage, channelID, txid string) (*peer.ChaincodeMessage, error) { + // Create the channel on which to communicate the response from the peer + respChan, err := h.createResponseChannel(channelID, txid) + if err != nil { + return &peer.ChaincodeMessage{}, err + } + defer h.deleteResponseChannel(channelID, txid) + + return h.sendReceive(msg, respChan) +} + +// handleGetState communicates with the peer to fetch the requested state information from the ledger. +func (h *Handler) handleGetState(collection string, key string, channelID string, txid string) ([]byte, error) { + // Construct payload for GET_STATE + payloadBytes := marshalOrPanic(&peer.GetState{Collection: collection, Key: key}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_GET_STATE, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return nil, fmt.Errorf("[%s] error sending %s: %s", shorttxid(txid), peer.ChaincodeMessage_GET_STATE, err) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + return responseMsg.Payload, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("[%s] incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleGetPrivateDataHash(collection string, key string, channelID string, txid string) ([]byte, error) { + // Construct payload for GET_PRIVATE_DATA_HASH + payloadBytes := marshalOrPanic(&peer.GetState{Collection: collection, Key: key}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_GET_PRIVATE_DATA_HASH, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return nil, fmt.Errorf("[%s] error sending %s: %s", shorttxid(txid), peer.ChaincodeMessage_GET_PRIVATE_DATA_HASH, err) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + return responseMsg.Payload, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("[%s] incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleGetStateMetadata(collection string, key string, channelID string, txID string) (map[string][]byte, error) { + // Construct payload for GET_STATE_METADATA + payloadBytes := marshalOrPanic(&peer.GetStateMetadata{Collection: collection, Key: key}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_GET_STATE_METADATA, Payload: payloadBytes, Txid: txID, ChannelId: channelID} + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txID) + if err != nil { + return nil, fmt.Errorf("[%s] error sending %s: %s", shorttxid(txID), peer.ChaincodeMessage_GET_STATE_METADATA, err) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + var mdResult peer.StateMetadataResult + err := proto.Unmarshal(responseMsg.Payload, &mdResult) + if err != nil { + return nil, errors.New("could not unmarshal metadata response") + } + metadata := make(map[string][]byte) + for _, md := range mdResult.Entries { + metadata[md.Metakey] = md.Value + } + + return metadata, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("[%s] incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +// handlePutState communicates with the peer to put state information into the ledger. +func (h *Handler) handlePutState(collection string, key string, value []byte, channelID string, txid string) error { + // Construct payload for PUT_STATE + payloadBytes := marshalOrPanic(&peer.PutState{Collection: collection, Key: key, Value: value}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_PUT_STATE, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + + // Execute the request and get response + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return fmt.Errorf("[%s] error sending %s: %s", msg.Txid, peer.ChaincodeMessage_PUT_STATE, err) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + return nil + } + + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return fmt.Errorf("[%s] incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handlePutStateMetadataEntry(collection string, key string, metakey string, metadata []byte, channelID string, txID string) error { + // Construct payload for PUT_STATE_METADATA + md := &peer.StateMetadata{Metakey: metakey, Value: metadata} + payloadBytes := marshalOrPanic(&peer.PutStateMetadata{Collection: collection, Key: key, Metadata: md}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_PUT_STATE_METADATA, Payload: payloadBytes, Txid: txID, ChannelId: channelID} + // Execute the request and get response + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txID) + if err != nil { + return fmt.Errorf("[%s] error sending %s: %s", msg.Txid, peer.ChaincodeMessage_PUT_STATE_METADATA, err) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + return nil + } + + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return fmt.Errorf("[%s]incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +// handleDelState communicates with the peer to delete a key from the state in the ledger. +func (h *Handler) handleDelState(collection string, key string, channelID string, txid string) error { + payloadBytes := marshalOrPanic(&peer.DelState{Collection: collection, Key: key}) + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_DEL_STATE, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + // Execute the request and get response + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_DEL_STATE) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + return nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return fmt.Errorf("[%s] incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +// handlerPurgeState communicates with the peer to purge a state from private data +func (h *Handler) handlePurgeState(collection string, key string, channelID string, txid string) error { + payloadBytes := marshalOrPanic(&peer.DelState{Collection: collection, Key: key}) + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_PURGE_PRIVATE_DATA, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + // Execute the request and get response + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_DEL_STATE) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + return nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return fmt.Errorf("[%s] incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleGetStateByRange(collection, startKey, endKey string, metadata []byte, + channelID string, txid string) (*peer.QueryResponse, error) { + // Send GET_STATE_BY_RANGE message to peer chaincode support + payloadBytes := marshalOrPanic(&peer.GetStateByRange{Collection: collection, StartKey: startKey, EndKey: endKey, Metadata: metadata}) + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_GET_STATE_BY_RANGE, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return nil, fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_GET_STATE_BY_RANGE) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + rangeQueryResponse := &peer.QueryResponse{} + err = proto.Unmarshal(responseMsg.Payload, rangeQueryResponse) + if err != nil { + return nil, fmt.Errorf("[%s] GetStateByRangeResponse unmarshall error", shorttxid(responseMsg.Txid)) + } + + return rangeQueryResponse, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("incorrect chaincode message %s received. Expecting %s or %s", responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleQueryStateNext(id, channelID, txid string) (*peer.QueryResponse, error) { + // Create the channel on which to communicate the response from validating peer + respChan, err := h.createResponseChannel(channelID, txid) + if err != nil { + return nil, err + } + defer h.deleteResponseChannel(channelID, txid) + + // Send QUERY_STATE_NEXT message to peer chaincode support + payloadBytes := marshalOrPanic(&peer.QueryStateNext{Id: id}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_QUERY_STATE_NEXT, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + + var responseMsg *peer.ChaincodeMessage + + if responseMsg, err = h.sendReceive(msg, respChan); err != nil { + return nil, fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_QUERY_STATE_NEXT) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + queryResponse := &peer.QueryResponse{} + if err = proto.Unmarshal(responseMsg.Payload, queryResponse); err != nil { + return nil, fmt.Errorf("[%s] unmarshal error", shorttxid(responseMsg.Txid)) + } + + return queryResponse, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("incorrect chaincode message %s received. Expecting %s or %s", responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleQueryStateClose(id, channelID, txid string) (*peer.QueryResponse, error) { + // Create the channel on which to communicate the response from validating peer + respChan, err := h.createResponseChannel(channelID, txid) + if err != nil { + return nil, err + } + defer h.deleteResponseChannel(channelID, txid) + + // Send QUERY_STATE_CLOSE message to peer chaincode support + payloadBytes := marshalOrPanic(&peer.QueryStateClose{Id: id}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_QUERY_STATE_CLOSE, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + + var responseMsg *peer.ChaincodeMessage + + if responseMsg, err = h.sendReceive(msg, respChan); err != nil { + return nil, fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_QUERY_STATE_CLOSE) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + queryResponse := &peer.QueryResponse{} + if err = proto.Unmarshal(responseMsg.Payload, queryResponse); err != nil { + return nil, fmt.Errorf("[%s] unmarshal error", shorttxid(responseMsg.Txid)) + } + + return queryResponse, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("incorrect chaincode message %s received. Expecting %s or %s", responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleGetQueryResult(collection string, query string, metadata []byte, + channelID string, txid string) (*peer.QueryResponse, error) { + // Send GET_QUERY_RESULT message to peer chaincode support + payloadBytes := marshalOrPanic(&peer.GetQueryResult{Collection: collection, Query: query, Metadata: metadata}) + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_GET_QUERY_RESULT, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + responseMsg, err := h.callPeerWithChaincodeMsg(msg, channelID, txid) + if err != nil { + return nil, fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_GET_QUERY_RESULT) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + executeQueryResponse := &peer.QueryResponse{} + if err = proto.Unmarshal(responseMsg.Payload, executeQueryResponse); err != nil { + return nil, fmt.Errorf("[%s] unmarshal error", shorttxid(responseMsg.Txid)) + } + + return executeQueryResponse, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("incorrect chaincode message %s received. Expecting %s or %s", responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) handleGetHistoryForKey(key string, channelID string, txid string) (*peer.QueryResponse, error) { + // Create the channel on which to communicate the response from validating peer + respChan, err := h.createResponseChannel(channelID, txid) + if err != nil { + return nil, err + } + defer h.deleteResponseChannel(channelID, txid) + + // Send GET_HISTORY_FOR_KEY message to peer chaincode support + payloadBytes := marshalOrPanic(&peer.GetHistoryForKey{Key: key}) + + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_GET_HISTORY_FOR_KEY, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + var responseMsg *peer.ChaincodeMessage + + if responseMsg, err = h.sendReceive(msg, respChan); err != nil { + return nil, fmt.Errorf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_GET_HISTORY_FOR_KEY) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + getHistoryForKeyResponse := &peer.QueryResponse{} + if err = proto.Unmarshal(responseMsg.Payload, getHistoryForKeyResponse); err != nil { + return nil, fmt.Errorf("[%s] unmarshal error", shorttxid(responseMsg.Txid)) + } + + return getHistoryForKeyResponse, nil + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return nil, fmt.Errorf("%s", responseMsg.Payload[:]) + } + + // Incorrect chaincode message received + return nil, fmt.Errorf("incorrect chaincode message %s received. Expecting %s or %s", responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR) +} + +func (h *Handler) createResponse(status int32, payload []byte) *peer.Response { + return &peer.Response{Status: status, Payload: payload} +} + +// handleInvokeChaincode communicates with the peer to invoke another chaincode. +func (h *Handler) handleInvokeChaincode(chaincodeName string, args [][]byte, channelID string, txid string) *peer.Response { + payloadBytes := marshalOrPanic(&peer.ChaincodeSpec{ChaincodeId: &peer.ChaincodeID{Name: chaincodeName}, Input: &peer.ChaincodeInput{Args: args}}) + + // Create the channel on which to communicate the response from validating peer + respChan, err := h.createResponseChannel(channelID, txid) + if err != nil { + return h.createResponse(ERROR, []byte(err.Error())) + } + defer h.deleteResponseChannel(channelID, txid) + + // Send INVOKE_CHAINCODE message to peer chaincode support + msg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_INVOKE_CHAINCODE, Payload: payloadBytes, Txid: txid, ChannelId: channelID} + + var responseMsg *peer.ChaincodeMessage + + if responseMsg, err = h.sendReceive(msg, respChan); err != nil { + errStr := fmt.Sprintf("[%s] error sending %s", shorttxid(msg.Txid), peer.ChaincodeMessage_INVOKE_CHAINCODE) + return h.createResponse(ERROR, []byte(errStr)) + } + + if responseMsg.Type == peer.ChaincodeMessage_RESPONSE { + // Success response + respMsg := &peer.ChaincodeMessage{} + if err := proto.Unmarshal(responseMsg.Payload, respMsg); err != nil { + return h.createResponse(ERROR, []byte(err.Error())) + } + if respMsg.Type == peer.ChaincodeMessage_COMPLETED { + // Success response + res := &peer.Response{} + if err = proto.Unmarshal(respMsg.Payload, res); err != nil { + return h.createResponse(ERROR, []byte(err.Error())) + } + return res + } + return h.createResponse(ERROR, responseMsg.Payload) + } + if responseMsg.Type == peer.ChaincodeMessage_ERROR { + // Error response + return h.createResponse(ERROR, responseMsg.Payload) + } + + // Incorrect chaincode message received + return h.createResponse(ERROR, []byte(fmt.Sprintf("[%s] Incorrect chaincode message %s received. Expecting %s or %s", shorttxid(responseMsg.Txid), responseMsg.Type, peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR))) +} + +// handleReady handles messages received from the peer when the handler is in the "ready" state. +func (h *Handler) handleReady(msg *peer.ChaincodeMessage, errc chan error) error { + switch msg.Type { + case peer.ChaincodeMessage_RESPONSE, peer.ChaincodeMessage_ERROR: + if err := h.handleResponse(msg); err != nil { + return err + } + return nil + + case peer.ChaincodeMessage_INIT: + go h.handleStubInteraction(h.handleInit, msg, errc) + return nil + + case peer.ChaincodeMessage_TRANSACTION: + go h.handleStubInteraction(h.handleTransaction, msg, errc) + return nil + + default: + return fmt.Errorf("[%s] Chaincode h cannot handle message (%s) while in state: %s", msg.Txid, msg.Type, h.state) + } +} + +// handleEstablished handles messages received from the peer when the handler is in the "established" state. +func (h *Handler) handleEstablished(msg *peer.ChaincodeMessage) error { + if msg.Type != peer.ChaincodeMessage_READY { + return fmt.Errorf("[%s] Chaincode h cannot handle message (%s) while in state: %s", msg.Txid, msg.Type, h.state) + } + + h.state = ready + return nil +} + +// hanndleCreated handles messages received from the peer when the handler is in the "created" state. +func (h *Handler) handleCreated(msg *peer.ChaincodeMessage) error { + if msg.Type != peer.ChaincodeMessage_REGISTERED { + return fmt.Errorf("[%s] Chaincode h cannot handle message (%s) while in state: %s", msg.Txid, msg.Type, h.state) + } + + h.state = established + return nil +} + +// handleMessage message handles loop for shim side of chaincode/peer stream. +func (h *Handler) handleMessage(msg *peer.ChaincodeMessage, errc chan error) error { + if msg.Type == peer.ChaincodeMessage_KEEPALIVE { + h.serialSendAsync(msg, errc) + return nil + } + var err error + + switch h.state { + case ready: + err = h.handleReady(msg, errc) + case established: + err = h.handleEstablished(msg) + case created: + err = h.handleCreated(msg) + default: + panic(fmt.Sprintf("invalid handler state: %s", h.state)) + } + + if err != nil { + payload := []byte(err.Error()) + errorMsg := &peer.ChaincodeMessage{Type: peer.ChaincodeMessage_ERROR, Payload: payload, Txid: msg.Txid} + h.serialSend(errorMsg) //nolint:errcheck + return err + } + + return nil +} + +// marshalOrPanic attempts to marshal the provided protobbuf message but will panic +// when marshaling fails instead of returning an error. +func marshalOrPanic(msg proto.Message) []byte { + bytes, err := proto.Marshal(msg) + if err != nil { + panic(fmt.Sprintf("failed to marshal message: %s", err)) + } + return bytes +} diff --git a/v2/shim/handler_test.go b/v2/shim/handler_test.go new file mode 100644 index 0000000..f9d325c --- /dev/null +++ b/v2/shim/handler_test.go @@ -0,0 +1,299 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "fmt" + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal/mock" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + + "github.com/stretchr/testify/assert" +) + +//go:generate counterfeiter -o internal/mock/peer_chaincode_stream.go --fake-name PeerChaincodeStream . peerChaincodeStream + +//lint:ignore U1000 Required to avoid circular dependency with mock +type peerChaincodeStream interface{ PeerChaincodeStream } + +//go:generate counterfeiter -o internal/mock/client_stream.go --fake-name ClientStream . clientStream + +//lint:ignore U1000 Required to avoid circular dependency with mock +type clientStream interface{ ClientStream } + +type mockChaincode struct { + initCalled bool + invokeCalled bool +} + +func (mcc *mockChaincode) Init(stub ChaincodeStubInterface) *peer.Response { + mcc.initCalled = true + return Success(nil) +} + +func (mcc *mockChaincode) Invoke(stub ChaincodeStubInterface) *peer.Response { + mcc.invokeCalled = true + return Success(nil) +} + +func TestNewHandler_CreatedState(t *testing.T) { + t.Parallel() + + chatStream := &mock.PeerChaincodeStream{} + cc := &mockChaincode{} + + expected := &Handler{ + chatStream: chatStream, + cc: cc, + responseChannels: map[string]chan *peer.ChaincodeMessage{}, + state: created, + } + + handler := newChaincodeHandler(chatStream, cc) + if handler == nil { + t.Fatal("Handler should not be nil") + } + assert.Equal(t, expected, handler) +} + +func TestHandlerState(t *testing.T) { + t.Parallel() + + var tests = []struct { + name string + state state + msg *peer.ChaincodeMessage + expectedErr string + }{ + { + name: "created", + state: created, + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_REGISTERED, + }, + }, + { + name: "wrong message type in created state", + state: created, + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_READY, + }, + expectedErr: fmt.Sprintf("cannot handle message (%s)", peer.ChaincodeMessage_READY), + }, + { + name: "established", + state: established, + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_READY, + }, + }, + { + name: "wrong message type in established state", + state: established, + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_REGISTERED, + }, + expectedErr: fmt.Sprintf("cannot handle message (%s)", peer.ChaincodeMessage_REGISTERED), + }, + { + name: "wrong message type in ready state", + state: ready, + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_REGISTERED, + }, + expectedErr: fmt.Sprintf("cannot handle message (%s)", peer.ChaincodeMessage_REGISTERED), + }, + { + name: "keepalive", + state: established, + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_KEEPALIVE, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + handler := &Handler{ + chatStream: &mock.PeerChaincodeStream{}, + cc: &mockChaincode{}, + state: test.state, + } + err := handler.handleMessage(test.msg, nil) + if test.expectedErr != "" { + assert.Contains(t, err.Error(), test.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestHandleMessage(t *testing.T) { + t.Parallel() + + var tests = []struct { + name string + msg *peer.ChaincodeMessage + msgType peer.ChaincodeMessage_Type + expectedErr string + invokeCalled bool + initCalled bool + }{ + { + name: "INIT", + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_INIT, + }, + msgType: peer.ChaincodeMessage_COMPLETED, + initCalled: true, + invokeCalled: false, + }, + { + name: "INIT with bad payload", + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_INIT, + Payload: []byte{1}, + }, + msgType: peer.ChaincodeMessage_ERROR, + initCalled: false, + invokeCalled: false, + }, + { + name: "INVOKE", + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_TRANSACTION, + }, + msgType: peer.ChaincodeMessage_COMPLETED, + initCalled: false, + invokeCalled: true, + }, + { + name: "INVOKE with bad payload", + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_TRANSACTION, + Payload: []byte{1}, + }, + msgType: peer.ChaincodeMessage_ERROR, + initCalled: false, + invokeCalled: false, + }, + { + name: "RESPONSE with no responseChannel", + msg: &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_RESPONSE, + }, + expectedErr: "responseChannel does not exist", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + chatStream := &mock.PeerChaincodeStream{} + cc := &mockChaincode{} + + msgChan := make(chan *peer.ChaincodeMessage) + chatStream.SendStub = func(msg *peer.ChaincodeMessage) error { + go func() { + msgChan <- msg + }() + return nil + } + + // create handler in ready state + handler := &Handler{ + chatStream: chatStream, + cc: cc, + responseChannels: map[string]chan *peer.ChaincodeMessage{}, + state: ready, + } + + err := handler.handleMessage(test.msg, nil) + if test.expectedErr != "" { + assert.Contains(t, err.Error(), test.expectedErr) + } else { + if err != nil { + t.Fatalf("Unexpected error for '%s': %s", test.name, err) + } + resp := <-msgChan + assert.Equal(t, test.msgType, resp.GetType()) + assert.Equal(t, test.initCalled, cc.initCalled) + assert.Equal(t, test.invokeCalled, cc.invokeCalled) + } + }) + } +} + +func TestHandlePeerCalls(t *testing.T) { + payload := []byte("error") + h := &Handler{ + cc: &mockChaincode{}, + responseChannels: map[string]chan *peer.ChaincodeMessage{}, + state: ready, + } + chatStream := &mock.PeerChaincodeStream{} + chatStream.SendStub = func(msg *peer.ChaincodeMessage) error { + go func() { + err := h.handleResponse( + &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_ERROR, + ChannelId: msg.GetChannelId(), + Txid: msg.GetTxid(), + Payload: payload, + }, + ) + assert.NoError(t, err, "handleResponse") + }() + return nil + } + h.chatStream = chatStream + + _, err := h.handleQueryStateNext("id", "channel", "txid") + assert.EqualError(t, err, string(payload)) + + _, err = h.handleQueryStateClose("id", "channel", "txid") + assert.EqualError(t, err, string(payload)) + + // force error by removing responseChannels + h.responseChannels = nil + _, err = h.handleGetState("col", "key", "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending GET_STATE") + + _, err = h.handleGetPrivateDataHash("col", "key", "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending GET_PRIVATE_DATA_HASH") + + _, err = h.handleGetStateMetadata("col", "key", "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending GET_STATE_METADATA") + + err = h.handlePutState("col", "key", []byte{}, "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending PUT_STATE") + + err = h.handlePutStateMetadataEntry("col", "key", "mkey", []byte{}, "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending PUT_STATE_METADATA") + + err = h.handleDelState("col", "key", "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending DEL_STATE") + + _, err = h.handleGetStateByRange("col", "start", "end", []byte{}, "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending GET_STATE_BY_RANGE") + + _, err = h.handleQueryStateNext("id", "channel", "txid") + assert.Contains(t, err.Error(), "cannot create response channel") + + _, err = h.handleQueryStateClose("id", "channel", "txid") + assert.Contains(t, err.Error(), "cannot create response channel") + + _, err = h.handleGetQueryResult("col", "query", []byte{}, "channel", "txid") + assert.Contains(t, err.Error(), "[txid] error sending GET_QUERY_RESULT") + + _, err = h.handleGetHistoryForKey("key", "channel", "txid") + assert.Contains(t, err.Error(), "cannot create response channel") + +} diff --git a/v2/shim/interfaces.go b/v2/shim/interfaces.go new file mode 100644 index 0000000..7e6a643 --- /dev/null +++ b/v2/shim/interfaces.go @@ -0,0 +1,405 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "github.com/hyperledger/fabric-protos-go-apiv2/ledger/queryresult" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// Chaincode interface must be implemented by all chaincodes. The fabric runs +// the transactions by calling these functions as specified. +type Chaincode interface { + // Init is called during Instantiate transaction after the chaincode container + // has been established for the first time, allowing the chaincode to + // initialize its internal data + Init(stub ChaincodeStubInterface) *peer.Response + + // Invoke is called to update or query the ledger in a proposal transaction. + // Updated state variables are not committed to the ledger until the + // transaction is committed. + Invoke(stub ChaincodeStubInterface) *peer.Response +} + +// ChaincodeStubInterface is used by deployable chaincode apps to access and +// modify their ledgers +type ChaincodeStubInterface interface { + // GetArgs returns the arguments intended for the chaincode Init and Invoke + // as an array of byte arrays. + GetArgs() [][]byte + + // GetStringArgs returns the arguments intended for the chaincode Init and + // Invoke as a string array. Only use GetStringArgs if the client passes + // arguments intended to be used as strings. + GetStringArgs() []string + + // GetFunctionAndParameters returns the first argument as the function + // name and the rest of the arguments as parameters in a string array. + // Only use GetFunctionAndParameters if the client passes arguments intended + // to be used as strings. + GetFunctionAndParameters() (string, []string) + + // GetArgsSlice returns the arguments intended for the chaincode Init and + // Invoke as a byte array + GetArgsSlice() ([]byte, error) + + // GetTxID returns the tx_id of the transaction proposal, which is unique per + // transaction and per client. See + // https://godoc.org/github.com/hyperledger/fabric-protos-go-apiv2/common#ChannelHeader + // for further details. + GetTxID() string + + // GetChannelID returns the channel the proposal is sent to for chaincode to process. + // This would be the channel_id of the transaction proposal (see + // https://godoc.org/github.com/hyperledger/fabric-protos-go-apiv2/common#ChannelHeader ) + // except where the chaincode is calling another on a different channel. + GetChannelID() string + + // InvokeChaincode locally calls the specified chaincode `Invoke` using the + // same transaction context; that is, chaincode calling chaincode doesn't + // create a new transaction message. + // If the called chaincode is on the same channel, it simply adds the called + // chaincode read set and write set to the calling transaction. + // If the called chaincode is on a different channel, + // only the Response is returned to the calling chaincode; any PutState calls + // from the called chaincode will not have any effect on the ledger; that is, + // the called chaincode on a different channel will not have its read set + // and write set applied to the transaction. Only the calling chaincode's + // read set and write set will be applied to the transaction. Effectively + // the called chaincode on a different channel is a `Query`, which does not + // participate in state validation checks in subsequent commit phase. + // If `channel` is empty, the caller's channel is assumed. + InvokeChaincode(chaincodeName string, args [][]byte, channel string) *peer.Response + + // GetState returns the value of the specified `key` from the + // ledger. Note that GetState doesn't read data from the writeset, which + // has not been committed to the ledger. In other words, GetState doesn't + // consider data modified by PutState that has not been committed. + // If the key does not exist in the state database, (nil, nil) is returned. + GetState(key string) ([]byte, error) + + // PutState puts the specified `key` and `value` into the transaction's + // writeset as a data-write proposal. PutState doesn't effect the ledger + // until the transaction is validated and successfully committed. + // Simple keys must not be an empty string and must not start with a + // null character (0x00) in order to avoid range query collisions with + // composite keys, which internally get prefixed with 0x00 as composite + // key namespace. In addition, if using CouchDB, keys can only contain + // valid UTF-8 strings and cannot begin with an underscore ("_"). + PutState(key string, value []byte) error + + // DelState records the specified `key` to be deleted in the writeset of + // the transaction proposal. The `key` and its value will be deleted from + // the ledger when the transaction is validated and successfully committed. + DelState(key string) error + + // SetStateValidationParameter sets the key-level endorsement policy for `key`. + SetStateValidationParameter(key string, ep []byte) error + + // GetStateValidationParameter retrieves the key-level endorsement policy + // for `key`. Note that this will introduce a read dependency on `key` in + // the transaction's readset. + GetStateValidationParameter(key string) ([]byte, error) + + // GetStateByRange returns a range iterator over a set of keys in the + // ledger. The iterator can be used to iterate over all keys + // between the startKey (inclusive) and endKey (exclusive). + // However, if the number of keys between startKey and endKey is greater than the + // totalQueryLimit (defined in core.yaml), this iterator cannot be used + // to fetch all keys (results will be capped by the totalQueryLimit). + // The keys are returned by the iterator in lexical order. Note + // that startKey and endKey can be empty string, which implies unbounded range + // query on start or end. + // Call Close() on the returned StateQueryIteratorInterface object when done. + // The query is re-executed during validation phase to ensure result set + // has not changed since transaction endorsement (phantom reads detected). + GetStateByRange(startKey, endKey string) (StateQueryIteratorInterface, error) + + // GetStateByRangeWithPagination returns a range iterator over a set of keys in the + // ledger. The iterator can be used to fetch keys between the startKey (inclusive) + // and endKey (exclusive). + // When an empty string is passed as a value to the bookmark argument, the returned + // iterator can be used to fetch the first `pageSize` keys between the startKey + // (inclusive) and endKey (exclusive). + // When the bookmark is a non-emptry string, the iterator can be used to fetch + // the first `pageSize` keys between the bookmark (inclusive) and endKey (exclusive). + // Note that only the bookmark present in a prior page of query results (ResponseMetadata) + // can be used as a value to the bookmark argument. Otherwise, an empty string must + // be passed as bookmark. + // The keys are returned by the iterator in lexical order. Note + // that startKey and endKey can be empty string, which implies unbounded range + // query on start or end. + // Call Close() on the returned StateQueryIteratorInterface object when done. + // This call is only supported in a read only transaction. + GetStateByRangeWithPagination(startKey, endKey string, pageSize int32, + bookmark string) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) + + // GetStateByPartialCompositeKey queries the state in the ledger based on + // a given partial composite key. This function returns an iterator + // which can be used to iterate over all composite keys whose prefix matches + // the given partial composite key. However, if the number of matching composite + // keys is greater than the totalQueryLimit (defined in core.yaml), this iterator + // cannot be used to fetch all matching keys (results will be limited by the totalQueryLimit). + // The `objectType` and attributes are expected to have only valid utf8 strings and + // should not contain U+0000 (nil byte) and U+10FFFF (biggest and unallocated code point). + // See related functions SplitCompositeKey and CreateCompositeKey. + // Call Close() on the returned StateQueryIteratorInterface object when done. + // The query is re-executed during validation phase to ensure result set + // has not changed since transaction endorsement (phantom reads detected). This function should be used only for + // a partial composite key. For a full composite key, an iter with empty response + // would be returned. + GetStateByPartialCompositeKey(objectType string, keys []string) (StateQueryIteratorInterface, error) + + // GetStateByPartialCompositeKeyWithPagination queries the state in the ledger based on + // a given partial composite key. This function returns an iterator + // which can be used to iterate over the composite keys whose + // prefix matches the given partial composite key. + // When an empty string is passed as a value to the bookmark argument, the returned + // iterator can be used to fetch the first `pageSize` composite keys whose prefix + // matches the given partial composite key. + // When the bookmark is a non-emptry string, the iterator can be used to fetch + // the first `pageSize` keys between the bookmark (inclusive) and the last matching + // composite key. + // Note that only the bookmark present in a prior page of query result (ResponseMetadata) + // can be used as a value to the bookmark argument. Otherwise, an empty string must + // be passed as bookmark. + // The `objectType` and attributes are expected to have only valid utf8 strings + // and should not contain U+0000 (nil byte) and U+10FFFF (biggest and unallocated + // code point). See related functions SplitCompositeKey and CreateCompositeKey. + // Call Close() on the returned StateQueryIteratorInterface object when done. + // This call is only supported in a read only transaction. This function should be used only for + // a partial composite key. For a full composite key, an iter with empty response + // would be returned. + GetStateByPartialCompositeKeyWithPagination(objectType string, keys []string, + pageSize int32, bookmark string) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) + + // CreateCompositeKey combines the given `attributes` to form a composite + // key. The objectType and attributes are expected to have only valid utf8 + // strings and should not contain U+0000 (nil byte) and U+10FFFF + // (biggest and unallocated code point). + // The resulting composite key can be used as the key in PutState(). + CreateCompositeKey(objectType string, attributes []string) (string, error) + + // SplitCompositeKey splits the specified key into attributes on which the + // composite key was formed. Composite keys found during range queries + // or partial composite key queries can therefore be split into their + // composite parts. + SplitCompositeKey(compositeKey string) (string, []string, error) + + // GetQueryResult performs a "rich" query against a state database. It is + // only supported for state databases that support rich query, + // e.g.CouchDB. The query string is in the native syntax + // of the underlying state database. An iterator is returned + // which can be used to iterate over all keys in the query result set. + // However, if the number of keys in the query result set is greater than the + // totalQueryLimit (defined in core.yaml), this iterator cannot be used + // to fetch all keys in the query result set (results will be limited by + // the totalQueryLimit). + // The query is NOT re-executed during validation phase, phantom reads are + // not detected. That is, other committed transactions may have added, + // updated, or removed keys that impact the result set, and this would not + // be detected at validation/commit time. Applications susceptible to this + // should therefore not use GetQueryResult as part of transactions that update + // ledger, and should limit use to read-only chaincode operations. + GetQueryResult(query string) (StateQueryIteratorInterface, error) + + // GetQueryResultWithPagination performs a "rich" query against a state database. + // It is only supported for state databases that support rich query, + // e.g., CouchDB. The query string is in the native syntax + // of the underlying state database. An iterator is returned + // which can be used to iterate over keys in the query result set. + // When an empty string is passed as a value to the bookmark argument, the returned + // iterator can be used to fetch the first `pageSize` of query results. + // When the bookmark is a non-emptry string, the iterator can be used to fetch + // the first `pageSize` keys between the bookmark and the last key in the query result. + // Note that only the bookmark present in a prior page of query results (ResponseMetadata) + // can be used as a value to the bookmark argument. Otherwise, an empty string + // must be passed as bookmark. + // This call is only supported in a read only transaction. + GetQueryResultWithPagination(query string, pageSize int32, + bookmark string) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) + + // GetHistoryForKey returns a history of key values across time. + // For each historic key update, the historic value and associated + // transaction id and timestamp are returned. The timestamp is the + // timestamp provided by the client in the proposal header. + // GetHistoryForKey requires peer configuration + // core.ledger.history.enableHistoryDatabase to be true. + // The query is NOT re-executed during validation phase, phantom reads are + // not detected. That is, other committed transactions may have updated + // the key concurrently, impacting the result set, and this would not be + // detected at validation/commit time. Applications susceptible to this + // should therefore not use GetHistoryForKey as part of transactions that + // update ledger, and should limit use to read-only chaincode operations. + // Starting in Fabric v2.0, the GetHistoryForKey chaincode API + // will return results from newest to oldest in terms of ordered transaction + // height (block height and transaction height within block). + // This will allow applications to efficiently iterate through the top results + // to understand recent changes to a key. + GetHistoryForKey(key string) (HistoryQueryIteratorInterface, error) + + // GetPrivateData returns the value of the specified `key` from the specified + // `collection`. Note that GetPrivateData doesn't read data from the + // private writeset, which has not been committed to the `collection`. In + // other words, GetPrivateData doesn't consider data modified by PutPrivateData + // that has not been committed. + GetPrivateData(collection, key string) ([]byte, error) + + // GetPrivateDataHash returns the hash of the value of the specified `key` from the specified + // `collection` + GetPrivateDataHash(collection, key string) ([]byte, error) + + // PutPrivateData puts the specified `key` and `value` into the transaction's + // private writeset. Note that only hash of the private writeset goes into the + // transaction proposal response (which is sent to the client who issued the + // transaction) and the actual private writeset gets temporarily stored in a + // transient store. PutPrivateData doesn't effect the `collection` until the + // transaction is validated and successfully committed. Simple keys must not + // be an empty string and must not start with a null character (0x00) in order + // to avoid range query collisions with composite keys, which internally get + // prefixed with 0x00 as composite key namespace. In addition, if using + // CouchDB, keys can only contain valid UTF-8 strings and cannot begin with an + // an underscore ("_"). + PutPrivateData(collection string, key string, value []byte) error + + // DelPrivateData records the specified `key` to be deleted in the private writeset + // of the transaction. Note that only hash of the private writeset goes into the + // transaction proposal response (which is sent to the client who issued the + // transaction) and the actual private writeset gets temporarily stored in a + // transient store. The `key` and its value will be deleted from the collection + // when the transaction is validated and successfully committed. + DelPrivateData(collection, key string) error + + // PurgePrivateData records the specified `key` to be purged in the private writeset + // of the transaction. Note that only hash of the private writeset goes into the + // transaction proposal response (which is sent to the client who issued the + // transaction) and the actual private writeset gets temporarily stored in a + // transient store. The `key` and its value will be deleted from the collection + // when the transaction is validated and successfully committed, and will + // subsequently be completely removed from the private data store (that maintains + // the historical versions of private writesets) as a background operation. + PurgePrivateData(collection, key string) error + + // SetPrivateDataValidationParameter sets the key-level endorsement policy + // for the private data specified by `key`. + SetPrivateDataValidationParameter(collection, key string, ep []byte) error + + // GetPrivateDataValidationParameter retrieves the key-level endorsement + // policy for the private data specified by `key`. Note that this introduces + // a read dependency on `key` in the transaction's readset. + GetPrivateDataValidationParameter(collection, key string) ([]byte, error) + + // GetPrivateDataByRange returns a range iterator over a set of keys in a + // given private collection. The iterator can be used to iterate over all keys + // between the startKey (inclusive) and endKey (exclusive). + // The keys are returned by the iterator in lexical order. Note + // that startKey and endKey can be empty string, which implies unbounded range + // query on start or end. + // Call Close() on the returned StateQueryIteratorInterface object when done. + // The query is re-executed during validation phase to ensure result set + // has not changed since transaction endorsement (phantom reads detected). + GetPrivateDataByRange(collection, startKey, endKey string) (StateQueryIteratorInterface, error) + + // GetPrivateDataByPartialCompositeKey queries the state in a given private + // collection based on a given partial composite key. This function returns + // an iterator which can be used to iterate over all composite keys whose prefix + // matches the given partial composite key. The `objectType` and attributes are + // expected to have only valid utf8 strings and should not contain + // U+0000 (nil byte) and U+10FFFF (biggest and unallocated code point). + // See related functions SplitCompositeKey and CreateCompositeKey. + // Call Close() on the returned StateQueryIteratorInterface object when done. + // The query is re-executed during validation phase to ensure result set + // has not changed since transaction endorsement (phantom reads detected). This function should be used only for + //a partial composite key. For a full composite key, an iter with empty response + //would be returned. + GetPrivateDataByPartialCompositeKey(collection, objectType string, keys []string) (StateQueryIteratorInterface, error) + + // GetPrivateDataQueryResult performs a "rich" query against a given private + // collection. It is only supported for state databases that support rich query, + // e.g.CouchDB. The query string is in the native syntax + // of the underlying state database. An iterator is returned + // which can be used to iterate (next) over the query result set. + // The query is NOT re-executed during validation phase, phantom reads are + // not detected. That is, other committed transactions may have added, + // updated, or removed keys that impact the result set, and this would not + // be detected at validation/commit time. Applications susceptible to this + // should therefore not use GetPrivateDataQueryResult as part of transactions that update + // ledger, and should limit use to read-only chaincode operations. + GetPrivateDataQueryResult(collection, query string) (StateQueryIteratorInterface, error) + + // GetCreator returns `SignatureHeader.Creator` (e.g. an identity) + // of the `SignedProposal`. This is the identity of the agent (or user) + // submitting the transaction. + GetCreator() ([]byte, error) + + // GetTransient returns the `ChaincodeProposalPayload.Transient` field. + // It is a map that contains data (e.g. cryptographic material) + // that might be used to implement some form of application-level + // confidentiality. The contents of this field, as prescribed by + // `ChaincodeProposalPayload`, are supposed to always + // be omitted from the transaction and excluded from the ledger. + GetTransient() (map[string][]byte, error) + + // GetBinding returns the transaction binding, which is used to enforce a + // link between application data (like those stored in the transient field + // above) to the proposal itself. This is useful to avoid possible replay + // attacks. + GetBinding() ([]byte, error) + + // GetDecorations returns additional data (if applicable) about the proposal + // that originated from the peer. This data is set by the decorators of the + // peer, which append or mutate the chaincode input passed to the chaincode. + GetDecorations() map[string][]byte + + // GetSignedProposal returns the SignedProposal object, which contains all + // data elements part of a transaction proposal. + GetSignedProposal() (*peer.SignedProposal, error) + + // GetTxTimestamp returns the timestamp when the transaction was created. This + // is taken from the transaction ChannelHeader, therefore it will indicate the + // client's timestamp and will have the same value across all endorsers. + GetTxTimestamp() (*timestamppb.Timestamp, error) + + // SetEvent allows the chaincode to set an event on the response to the + // proposal to be included as part of a transaction. The event will be + // available within the transaction in the committed block regardless of the + // validity of the transaction. + // Only a single event can be included in a transaction, and must originate + // from the outer-most invoked chaincode in chaincode-to-chaincode scenarios. + // The marshaled ChaincodeEvent will be available in the transaction's ChaincodeAction.events field. + SetEvent(name string, payload []byte) error +} + +// CommonIteratorInterface allows a chaincode to check whether any more result +// to be fetched from an iterator and close it when done. +type CommonIteratorInterface interface { + // HasNext returns true if the range query iterator contains additional keys + // and values. + HasNext() bool + + // Close closes the iterator. This should be called when done + // reading from the iterator to free up resources. + Close() error +} + +// StateQueryIteratorInterface allows a chaincode to iterate over a set of +// key/value pairs returned by range and execute query. +type StateQueryIteratorInterface interface { + // Inherit HasNext() and Close() + CommonIteratorInterface + + // Next returns the next key and value in the range and execute query iterator. + Next() (*queryresult.KV, error) +} + +// HistoryQueryIteratorInterface allows a chaincode to iterate over a set of +// key/value pairs returned by a history query. +type HistoryQueryIteratorInterface interface { + // Inherit HasNext() and Close() + CommonIteratorInterface + + // Next returns the next key and value in the history query iterator. + Next() (*queryresult.KeyModification, error) +} diff --git a/v2/shim/internal/client.go b/v2/shim/internal/client.go new file mode 100644 index 0000000..4516375 --- /dev/null +++ b/v2/shim/internal/client.go @@ -0,0 +1,52 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "context" + "crypto/tls" + "time" + + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" +) + +const ( + dialTimeout = 10 * time.Second + maxRecvMessageSize = 100 * 1024 * 1024 // 100 MiB + maxSendMessageSize = 100 * 1024 * 1024 // 100 MiB +) + +// NewClientConn ... +func NewClientConn( + address string, + tlsConf *tls.Config, + kaOpts keepalive.ClientParameters, +) (*grpc.ClientConn, error) { + + dialOpts := []grpc.DialOption{ + grpc.WithKeepaliveParams(kaOpts), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(maxRecvMessageSize), + grpc.MaxCallSendMsgSize(maxSendMessageSize), + ), + } + + if tlsConf != nil { + creds := credentials.NewTLS(tlsConf) + dialOpts = append(dialOpts, grpc.WithTransportCredentials(creds)) + } else { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + return grpc.NewClient(address, dialOpts...) +} + +// NewRegisterClient ... +func NewRegisterClient(conn *grpc.ClientConn) (peer.ChaincodeSupport_RegisterClient, error) { + return peer.NewChaincodeSupportClient(conn).Register(context.Background()) +} diff --git a/v2/shim/internal/client_test.go b/v2/shim/internal/client_test.go new file mode 100644 index 0000000..ff888c5 --- /dev/null +++ b/v2/shim/internal/client_test.go @@ -0,0 +1,134 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +type testServer struct { + receivedMessages chan<- *peer.ChaincodeMessage + sendMessages <-chan *peer.ChaincodeMessage + waitTime time.Duration +} + +func (t *testServer) Register(registerServer peer.ChaincodeSupport_RegisterServer) error { + for { + recv, err := registerServer.Recv() + if err != nil { + return err + } + + select { + case t.receivedMessages <- recv: + case <-time.After(t.waitTime): + return errors.New("failed to capture received message") + } + + select { + case msg, ok := <-t.sendMessages: + if !ok { + return nil + } + if err := registerServer.Send(msg); err != nil { + return err + } + case <-time.After(t.waitTime): + return errors.New("no messages available on send channel") + } + } +} + +func TestMessageSizes(t *testing.T) { + const waitTime = 10 * time.Second + + lis, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err, "listen failed") + defer lis.Close() + + sendMessages := make(chan *peer.ChaincodeMessage, 1) + receivedMessages := make(chan *peer.ChaincodeMessage, 1) + testServer := &testServer{ + receivedMessages: receivedMessages, + sendMessages: sendMessages, + waitTime: waitTime, + } + + server := grpc.NewServer( + grpc.MaxSendMsgSize(2*maxSendMessageSize), + grpc.MaxRecvMsgSize(2*maxRecvMessageSize), + ) + peer.RegisterChaincodeSupportServer(server, testServer) + + serveCompleteCh := make(chan error, 1) + go func() { serveCompleteCh <- server.Serve(lis) }() + + client, err := NewClientConn(lis.Addr().String(), nil, keepalive.ClientParameters{}) + assert.NoError(t, err, "failed to create client connection") + + regClient, err := NewRegisterClient(client) + assert.NoError(t, err, "failed to create register client") + + t.Run("acceptable messaages", func(t *testing.T) { + acceptableMessage := &peer.ChaincodeMessage{ + Payload: make([]byte, maxSendMessageSize-100), + } + sendMessages <- acceptableMessage + err = regClient.Send(acceptableMessage) + assert.NoError(t, err, "sending messge below size threshold failed") + + select { + case m := <-receivedMessages: + assert.Len(t, m.Payload, maxSendMessageSize-100) + case <-time.After(waitTime): + t.Fatalf("acceptable message was not received by server") + } + + msg, err := regClient.Recv() + assert.NoError(t, err, "failed to receive message") + assert.Len(t, msg.Payload, maxSendMessageSize-100) + }) + + t.Run("response message is too large", func(t *testing.T) { + sendMessages <- &peer.ChaincodeMessage{ + Payload: make([]byte, maxSendMessageSize+1), + } + err = regClient.Send(&peer.ChaincodeMessage{}) + assert.NoError(t, err, "sending messge below size threshold should succeed") + + select { + case m := <-receivedMessages: + assert.Len(t, m.Payload, 0) + case <-time.After(waitTime): + t.Fatalf("acceptable message was not received by server") + } + + _, err := regClient.Recv() + assert.Error(t, err, "receiving a message that is too large should fail") + }) + + t.Run("sent message is too large", func(t *testing.T) { + tooBig := &peer.ChaincodeMessage{ + Payload: make([]byte, maxSendMessageSize+1), + } + err = regClient.Send(tooBig) + assert.Error(t, err, "sending messge above size threshold should fail") + }) + + err = lis.Close() + assert.NoError(t, err, "close failed") + select { + case <-serveCompleteCh: + case <-time.After(waitTime): + t.Fatal("server shutdown timeout") + } +} diff --git a/v2/shim/internal/config.go b/v2/shim/internal/config.go new file mode 100644 index 0000000..e7c2c4b --- /dev/null +++ b/v2/shim/internal/config.go @@ -0,0 +1,150 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "os" + "strconv" + "time" + + "google.golang.org/grpc/keepalive" +) + +// Config contains chaincode's configuration +type Config struct { + ChaincodeName string + TLS *tls.Config + KaOpts keepalive.ClientParameters +} + +// LoadConfig loads the chaincode configuration +func LoadConfig() (Config, error) { + var err error + tlsEnabled, err := strconv.ParseBool(os.Getenv("CORE_PEER_TLS_ENABLED")) + if err != nil { + return Config{}, errors.New("'CORE_PEER_TLS_ENABLED' must be set to 'true' or 'false'") + } + + conf := Config{ + ChaincodeName: os.Getenv("CORE_CHAINCODE_ID_NAME"), + // hardcode to match chaincode server + KaOpts: keepalive.ClientParameters{ + Time: 1 * time.Minute, + Timeout: 20 * time.Second, + PermitWithoutStream: true, + }, + } + + if !tlsEnabled { + return conf, nil + } + + var key []byte + path, set := os.LookupEnv("CORE_TLS_CLIENT_KEY_FILE") + if set { + key, err = os.ReadFile(path) + if err != nil { + return Config{}, fmt.Errorf("failed to read private key file: %s", err) + } + } else { + data, err := os.ReadFile(os.Getenv("CORE_TLS_CLIENT_KEY_PATH")) + if err != nil { + return Config{}, fmt.Errorf("failed to read private key file: %s", err) + } + key, err = base64.StdEncoding.DecodeString(string(data)) + if err != nil { + return Config{}, fmt.Errorf("failed to decode private key file: %s", err) + } + } + + var cert []byte + path, set = os.LookupEnv("CORE_TLS_CLIENT_CERT_FILE") + if set { + cert, err = os.ReadFile(path) + if err != nil { + return Config{}, fmt.Errorf("failed to read public key file: %s", err) + } + } else { + data, err := os.ReadFile(os.Getenv("CORE_TLS_CLIENT_CERT_PATH")) + if err != nil { + return Config{}, fmt.Errorf("failed to read public key file: %s", err) + } + cert, err = base64.StdEncoding.DecodeString(string(data)) + if err != nil { + return Config{}, fmt.Errorf("failed to decode public key file: %s", err) + } + } + + root, err := os.ReadFile(os.Getenv("CORE_PEER_TLS_ROOTCERT_FILE")) + if err != nil { + return Config{}, fmt.Errorf("failed to read root cert file: %s", err) + } + + tlscfg, err := LoadTLSConfig(false, key, cert, root) + if err != nil { + return Config{}, err + } + + conf.TLS = tlscfg + + return conf, nil +} + +// LoadTLSConfig loads the TLS configuration for the chaincode +func LoadTLSConfig(isserver bool, key, cert, root []byte) (*tls.Config, error) { + if key == nil { + return nil, fmt.Errorf("key not provided") + } + + if cert == nil { + return nil, fmt.Errorf("cert not provided") + } + + if !isserver && root == nil { + return nil, fmt.Errorf("root cert not provided") + } + + cccert, err := tls.X509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("failed to parse client key pair: %s", err) + } + + var rootCertPool *x509.CertPool + if root != nil { + rootCertPool = x509.NewCertPool() + if ok := rootCertPool.AppendCertsFromPEM(root); !ok { + return nil, errors.New("failed to load root cert file") + } + } + + tlscfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cccert}, + } + + //follow Peer's server default config properties + if isserver { + tlscfg.ClientCAs = rootCertPool + tlscfg.SessionTicketsDisabled = true + tlscfg.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + } + if rootCertPool != nil { + tlscfg.ClientAuth = tls.RequireAndVerifyClientCert + } + } else { + tlscfg.RootCAs = rootCertPool + } + + return tlscfg, nil +} diff --git a/v2/shim/internal/config_test.go b/v2/shim/internal/config_test.go new file mode 100644 index 0000000..63586fd --- /dev/null +++ b/v2/shim/internal/config_test.go @@ -0,0 +1,731 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "os" + "testing" + "time" + + . "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" +) + +// TLS tuples for client and server were created +// using cryptogen tool. Of course, any standard tool such as openssl +// could have been used as well +var keyPEM = `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgKg8jpiNIB5LXLull +IRoYMsQximSiU7XvGCYLslx4GauhRANCAARBGdslxalpg0dxk9GwVhi+Qw9oKZPE +n1hWPFmusDKtNbDLsHd9k1lU+SWnJKYlg7hmaUvxC1lR2M6KmvAwSUfN +-----END PRIVATE KEY----- +` +var certPEM = `-----BEGIN CERTIFICATE----- +MIICaTCCAhCgAwIBAgIQS46wcUDY2nJ2gQ/7fp/ptzAKBggqhkjOPQQDAjB2MQsw +CQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZy +YW5jaXNjbzEZMBcGA1UEChMQb3JnMS5leGFtcGxlLmNvbTEfMB0GA1UEAxMWdGxz +Y2Eub3JnMS5leGFtcGxlLmNvbTAeFw0xOTEyMTIwMTA1NTBaFw0yOTEyMDkwMTA1 +NTBaMFoxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQH +Ew1TYW4gRnJhbmNpc2NvMR4wHAYDVQQDExVteWNjLm9yZzEuZXhhbXBsZS5jb20w +WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARBGdslxalpg0dxk9GwVhi+Qw9oKZPE +n1hWPFmusDKtNbDLsHd9k1lU+SWnJKYlg7hmaUvxC1lR2M6KmvAwSUfNo4GbMIGY +MA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw +DAYDVR0TAQH/BAIwADArBgNVHSMEJDAigCBxQqUF6hEsSgXTc47WT4U58SOdgX8n +8RlMuxFg0wRtjjAsBgNVHREEJTAjghVteWNjLm9yZzEuZXhhbXBsZS5jb22CBG15 +Y2OHBH8AAAEwCgYIKoZIzj0EAwIDRwAwRAIgWgxAuGibD+Da/qCLBryJMDGlyIrx +HV+tI33lEy1B9qoCIEJD4xipI2WYp1sHmK2nxYPcoTb9WLFdNZ6twKZyw9c8 +-----END CERTIFICATE----- +` +var rootPEM = `-----BEGIN CERTIFICATE----- +MIICSTCCAe+gAwIBAgIQWpamEC5/D2N5JKS8FEpgTzAKBggqhkjOPQQDAjB2MQsw +CQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZy +YW5jaXNjbzEZMBcGA1UEChMQb3JnMS5leGFtcGxlLmNvbTEfMB0GA1UEAxMWdGxz +Y2Eub3JnMS5leGFtcGxlLmNvbTAeFw0xOTEyMTIwMTA1NTBaFw0yOTEyMDkwMTA1 +NTBaMHYxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQH +Ew1TYW4gRnJhbmNpc2NvMRkwFwYDVQQKExBvcmcxLmV4YW1wbGUuY29tMR8wHQYD +VQQDExZ0bHNjYS5vcmcxLmV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0D +AQcDQgAE2eFjoZkB/ozmheZZ9P05kUXAQAG+j0oTmRr9vX2qJa+tyrbS/i4UKrXo +82dqcDmmL16l2ukBXt7/aBre5WbVEaNfMF0wDgYDVR0PAQH/BAQDAgGmMA8GA1Ud +JQQIMAYGBFUdJQAwDwYDVR0TAQH/BAUwAwEB/zApBgNVHQ4EIgQgcUKlBeoRLEoF +03OO1k+FOfEjnYF/J/EZTLsRYNMEbY4wCgYIKoZIzj0EAwIDSAAwRQIhANmPRnJi +p7amrl9rF5xWtW0rR+y9uSCi6cy/T8bJl1JTAiATHlHcuNhHFeGb+Vl512FC3sGM +bHHlP/A/QkbGqJL4HQ== +-----END CERTIFICATE----- +` + +// #nosec G101 +var clientKeyPEM = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINVHep4/z6iPa151Ipp4MmCb1l/VKkY3vuMfUQf3LhQboAoGCCqGSM49 +AwEHoUQDQgAEcE6hZ7muszSi5wXIVKPdIuLYPTIxQxj+jekPRfFnJF/RJKM0Nj3T +Bk9spwCHwu1t3REyobjaZcFQk0y32Pje5A== +-----END EC PRIVATE KEY----- +` + +var clientCertPEM = `-----BEGIN CERTIFICATE----- +MIICAzCCAaqgAwIBAgIQe/ZUgn+/dH6FGrx+dr/PfjAKBggqhkjOPQQDAjBYMQsw +CQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZy +YW5jaXNjbzENMAsGA1UEChMET3JnMTENMAsGA1UEAxMET3JnMTAeFw0xODA4MjEw +ODI1MzNaFw0yODA4MTgwODI1MzNaMGgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpD +YWxpZm9ybmlhMRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRUwEwYDVQQKEwxPcmcx +LWNsaWVudDExFTATBgNVBAMTDE9yZzEtY2xpZW50MTBZMBMGByqGSM49AgEGCCqG +SM49AwEHA0IABHBOoWe5rrM0oucFyFSj3SLi2D0yMUMY/o3pD0XxZyRf0SSjNDY9 +0wZPbKcAh8Ltbd0RMqG42mXBUJNMt9j43uSjRjBEMA4GA1UdDwEB/wQEAwIFoDAT +BgNVHSUEDDAKBggrBgEFBQcDAjAMBgNVHRMBAf8EAjAAMA8GA1UdIwQIMAaABAEC +AwQwCgYIKoZIzj0EAwIDRwAwRAIgaK/prRkZS6zctxwBUl2QApUrH7pMmab30Nn9 +ER8f3m0CICBZ9XoxKXEFFcSRpfiA2/vzoOPg76lRXcCklxzGSJYu +-----END CERTIFICATE----- +` + +var clientRootPEM = `-----BEGIN CERTIFICATE----- +MIIB8TCCAZegAwIBAgIQUigdJy6IudO7sVOXsKVrtzAKBggqhkjOPQQDAjBYMQsw +CQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZy +YW5jaXNjbzENMAsGA1UEChMET3JnMTENMAsGA1UEAxMET3JnMTAeFw0xODA4MjEw +ODI1MzNaFw0yODA4MTgwODI1MzNaMFgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpD +YWxpZm9ybmlhMRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMQ0wCwYDVQQKEwRPcmcx +MQ0wCwYDVQQDEwRPcmcxMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEVOI+oAAB +Pl+iRsCcGq81WbXap2L1r432T5gbzUNKYRvVsyFFYmdO8ql8uDi4UxSY64eaeRFT +uxdcsTG7M5K2yaNDMEEwDgYDVR0PAQH/BAQDAgGmMA8GA1UdJQQIMAYGBFUdJQAw +DwYDVR0TAQH/BAUwAwEB/zANBgNVHQ4EBgQEAQIDBDAKBggqhkjOPQQDAgNIADBF +AiEA6U7IRGf+S7e9U2+jSI2eFiBsVEBIi35LgYoKqjELj5oCIAD7DfVMaMHzzjiQ +XIlJQdS/9afDi32qZWZfe3kAUAs0 +-----END CERTIFICATE----- +` + +func TestLoadBase64EncodedConfig(t *testing.T) { + // setup key/cert files + testDir, err := os.MkdirTemp("", "shiminternal") + if err != nil { + t.Fatalf("Failed to test directory: %s", err) + } + defer os.RemoveAll(testDir) + + keyFile, err := os.CreateTemp(testDir, "testKey") + if err != nil { + t.Fatalf("Failed to create key file: %s", err) + } + b64Key := base64.StdEncoding.EncodeToString([]byte(keyPEM)) + if _, err := keyFile.WriteString(b64Key); err != nil { + t.Fatalf("Failed to write to key file: %s", err) + } + + certFile, err := os.CreateTemp(testDir, "testCert") + if err != nil { + t.Fatalf("Failed to create cert file: %s", err) + } + b64Cert := base64.StdEncoding.EncodeToString([]byte(certPEM)) + if _, err := certFile.WriteString(b64Cert); err != nil { + t.Fatalf("Failed to write to cert file: %s", err) + } + + rootFile, err := os.CreateTemp(testDir, "testRoot") + if err != nil { + t.Fatalf("Failed to create root file: %s", err) + } + if _, err := rootFile.WriteString(rootPEM); err != nil { + t.Fatalf("Failed to write to root file: %s", err) + } + + notb64File, err := os.CreateTemp(testDir, "testNotb64") + if err != nil { + t.Fatalf("Failed to create notb64 file: %s", err) + } + if _, err := notb64File.WriteString("#####"); err != nil { + t.Fatalf("Failed to write to notb64 file: %s", err) + } + + notPEMFile, err := os.CreateTemp(testDir, "testNotPEM") + if err != nil { + t.Fatalf("Failed to create notPEM file: %s", err) + } + b64 := base64.StdEncoding.EncodeToString([]byte("not pem")) + if _, err := notPEMFile.WriteString(b64); err != nil { + t.Fatalf("Failed to write to notPEM file: %s", err) + } + + defer cleanupEnv() + + // expected TLS config + rootPool := x509.NewCertPool() + rootPool.AppendCertsFromPEM([]byte(rootPEM)) + clientCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + t.Fatalf("Failed to load client cert pair: %s", err) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{clientCert}, + RootCAs: rootPool, + } + + kaOpts := keepalive.ClientParameters{ + Time: 1 * time.Minute, + Timeout: 20 * time.Second, + PermitWithoutStream: true, + } + + var tests = []struct { + name string + env map[string]string + expected Config + errMsg string + }{ + { + name: "TLS disabled", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "false", + }, + expected: Config{ + ChaincodeName: "testCC", + KaOpts: kaOpts, + }, + }, + { + name: "TLS Enabled", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + expected: Config{ + ChaincodeName: "testCC", + TLS: tlsConfig, + KaOpts: kaOpts, + }, + }, + { + name: "Bad TLS_ENABLED", + env: map[string]string{ + "CORE_PEER_TLS_ENABLED": "nottruthy", + }, + errMsg: "'CORE_PEER_TLS_ENABLED' must be set to 'true' or 'false'", + }, + { + name: "Missing key file", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": "missingkey", + }, + errMsg: "failed to read private key file", + }, + { + name: "Bad key file", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": notb64File.Name(), + }, + errMsg: "failed to decode private key file", + }, + { + name: "Missing cert file", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": "missingkey", + }, + errMsg: "failed to read public key file", + }, + { + name: "Bad cert file", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": notb64File.Name(), + }, + errMsg: "failed to decode public key file", + }, + { + name: "Missing root file", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": "missingkey", + }, + errMsg: "failed to read root cert file", + }, + { + name: "Bad root file", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": notb64File.Name(), + }, + errMsg: "failed to load root cert file", + }, + { + name: "Key not PEM", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": notPEMFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + errMsg: "failed to parse client key pair", + }, + { + name: "Cert not PEM", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": notPEMFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + errMsg: "failed to parse client key pair", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + for k, v := range test.env { + os.Setenv(k, v) + } + conf, err := LoadConfig() + if test.errMsg == "" { + assert.EqualValues(t, test.expected.ChaincodeName, conf.ChaincodeName) + assert.Equal(t, test.expected.KaOpts, conf.KaOpts) + if test.expected.TLS != nil { + tlsConfigEquals(t, test.expected.TLS, conf.TLS) + } + } else { + assert.Contains(t, err.Error(), test.errMsg) + } + }) + } + + tlsServerConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{clientCert}, + ClientCAs: rootPool, + SessionTicketsDisabled: true, + CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // #nosec G402 + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + }, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + tlsServerNonMutualConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{clientCert}, + RootCAs: nil, + SessionTicketsDisabled: true, + CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // #nosec G402 + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + }, + ClientAuth: tls.NoClientCert, + } + + // additional tests to differentiate client vs server + var tlsTests = []struct { + name string + issrv bool + key []byte + cert []byte + rootCert []byte + expected *tls.Config + errMsg string + }{ + { + name: "Server TLS", + issrv: true, + key: []byte(keyPEM), + cert: []byte(certPEM), + rootCert: []byte(rootPEM), + expected: tlsServerConfig, + }, + { + name: "Server non-mutual TLS", + issrv: true, + key: []byte(keyPEM), + cert: []byte(certPEM), + rootCert: nil, + expected: tlsServerNonMutualConfig, + }, + { + name: "Server key unspecified", + issrv: true, + key: nil, + cert: []byte(certPEM), + rootCert: []byte(rootPEM), + errMsg: "key not provided", + }, + { + name: "Server cert unspecified", + issrv: true, + key: []byte(keyPEM), + cert: nil, + rootCert: []byte(rootPEM), + errMsg: "cert not provided", + }, + { + name: "Client TLS root CA unspecified", + issrv: false, + key: []byte(keyPEM), + cert: []byte(certPEM), + rootCert: nil, + errMsg: "root cert not provided", + }, + } + + for _, test := range tlsTests { + t.Run(test.name, func(t *testing.T) { + tlsCfg, err := LoadTLSConfig(test.issrv, test.key, test.cert, test.rootCert) + if test.errMsg == "" { + tlsConfigEquals(t, test.expected, tlsCfg) + } else { + assert.Contains(t, err.Error(), test.errMsg) + } + }) + } +} + +func tlsConfigEquals(t *testing.T, cfg1 *tls.Config, cfg2 *tls.Config) { + assert.EqualValues(t, cfg1.MinVersion, cfg2.MinVersion) + assert.EqualValues(t, cfg1.ClientAuth, cfg2.ClientAuth) +} + +func TestLoadPEMEncodedConfig(t *testing.T) { + // setup key/cert files + testDir, err := os.MkdirTemp("", "shiminternal") + if err != nil { + t.Fatalf("Failed to test directory: %s", err) + } + defer os.RemoveAll(testDir) + + keyFile, err := os.CreateTemp(testDir, "testKey") + if err != nil { + t.Fatalf("Failed to create key file: %s", err) + } + if _, err := keyFile.WriteString(keyPEM); err != nil { + t.Fatalf("Failed to write to key file: %s", err) + } + + certFile, err := os.CreateTemp(testDir, "testCert") + if err != nil { + t.Fatalf("Failed to create cert file: %s", err) + } + if _, err := certFile.WriteString(certPEM); err != nil { + t.Fatalf("Failed to write to cert file: %s", err) + } + + rootFile, err := os.CreateTemp(testDir, "testRoot") + if err != nil { + t.Fatalf("Failed to create root file: %s", err) + } + if _, err := rootFile.WriteString(rootPEM); err != nil { + t.Fatalf("Failed to write to root file: %s", err) + } + + keyFile64, err := os.CreateTemp(testDir, "testKey64") + if err != nil { + t.Fatalf("Failed to create key file: %s", err) + } + b64Key := base64.StdEncoding.EncodeToString([]byte(keyPEM)) + if _, err := keyFile64.WriteString(b64Key); err != nil { + t.Fatalf("Failed to write to key file: %s", err) + } + + certFile64, err := os.CreateTemp(testDir, "testCert64") + if err != nil { + t.Fatalf("Failed to create cert file: %s", err) + } + b64Cert := base64.StdEncoding.EncodeToString([]byte(certPEM)) + if _, err := certFile64.WriteString(b64Cert); err != nil { + t.Fatalf("Failed to write to cert file: %s", err) + } + + defer cleanupEnv() + + // expected TLS config + rootPool := x509.NewCertPool() + rootPool.AppendCertsFromPEM([]byte(rootPEM)) + clientCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + t.Fatalf("Failed to load client cert pair: %s", err) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{clientCert}, + RootCAs: rootPool, + } + + kaOpts := keepalive.ClientParameters{ + Time: 1 * time.Minute, + Timeout: 20 * time.Second, + PermitWithoutStream: true, + } + + var tests = []struct { + name string + env map[string]string + expected Config + errMsg string + }{ + { + name: "TLS Enabled with PEM-encoded variables", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_FILE": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_FILE": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + expected: Config{ + ChaincodeName: "testCC", + TLS: tlsConfig, + KaOpts: kaOpts, + }, + }, + { + name: "Client cert uses base64 encoding", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_FILE": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_PATH": certFile64.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + expected: Config{ + ChaincodeName: "testCC", + TLS: tlsConfig, + KaOpts: kaOpts, + }, + }, + { + name: "Client key uses base64 encoding", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_PATH": keyFile64.Name(), + "CORE_TLS_CLIENT_CERT_FILE": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + expected: Config{ + ChaincodeName: "testCC", + TLS: tlsConfig, + KaOpts: kaOpts, + }, + }, + { + name: "Client cert uses base64 encoding with PEM variable", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_FILE": keyFile.Name(), + "CORE_TLS_CLIENT_CERT_FILE": certFile64.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + errMsg: "failed to parse client key pair", + }, + { + name: "Client key uses base64 encoding with PEM variable", + env: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "testCC", + "CORE_PEER_TLS_ENABLED": "true", + "CORE_TLS_CLIENT_KEY_FILE": keyFile64.Name(), + "CORE_TLS_CLIENT_CERT_FILE": certFile.Name(), + "CORE_PEER_TLS_ROOTCERT_FILE": rootFile.Name(), + }, + errMsg: "failed to parse client key pair", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + for k, v := range test.env { + os.Setenv(k, v) + } + conf, err := LoadConfig() + if test.errMsg == "" { + assert.EqualValues(t, test.expected.ChaincodeName, conf.ChaincodeName) + assert.Equal(t, test.expected.KaOpts, conf.KaOpts) + if test.expected.TLS != nil { + tlsConfigEquals(t, test.expected.TLS, conf.TLS) + } + } else { + assert.Contains(t, err.Error(), test.errMsg) + } + }) + } +} + +func newTLSConnection(t *testing.T, address string, crt, key, rootCert []byte) *grpc.ClientConn { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + tlsConfig.RootCAs = x509.NewCertPool() + tlsConfig.RootCAs.AppendCertsFromPEM(rootCert) + if crt != nil && key != nil { + cert, err := tls.X509KeyPair(crt, key) + assert.NoError(t, err) + assert.NotNil(t, cert) + + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + } + + var dialOpts []grpc.DialOption + dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + + kap := keepalive.ClientParameters{ + Time: time.Duration(1) * time.Minute, + Timeout: time.Duration(20) * time.Second, + PermitWithoutStream: true, + } + + dialOpts = append(dialOpts, grpc.WithKeepaliveParams(kap)) + + conn, err := grpc.NewClient(address, dialOpts...) + assert.NoError(t, err) + assert.NotNil(t, conn) + + return conn +} + +func TestTLSClientWithChaincodeServer(t *testing.T) { + rootPool := x509.NewCertPool() + ok := rootPool.AppendCertsFromPEM([]byte(clientRootPEM)) + if !ok { + t.Fatal("failed to create test root cert pool") + } + + cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + t.Fatalf("Failed to load client cert pair: %s", err) + } + + tlsServerConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + ClientCAs: rootPool, + SessionTicketsDisabled: true, + CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // #nosec G402 + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + }, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + // given server is good and expects valid TLS connection, test good and invalid scenarios + var tlsTests = []struct { + name string + issrv bool + clientKey []byte + clientCert []byte + clientRootCert []byte + expected *tls.Config + errMsg string + success bool + address string + }{ + { + name: "Good TLS", + issrv: true, + clientKey: []byte(clientKeyPEM), + clientCert: []byte(clientCertPEM), + clientRootCert: []byte(rootPEM), + success: true, + address: "127.0.0.1:0", + }, + { + name: "Bad server RootCA", + issrv: true, + clientKey: []byte(clientKeyPEM), + clientCert: []byte(clientCertPEM), + clientRootCert: []byte(clientRootPEM), + success: false, + errMsg: "transport: authentication handshake failed: tls: failed to verify certificate: x509: certificate signed by unknown authority", + address: "127.0.0.1:0", + }, + { + name: "Bad client cert", + issrv: true, + clientKey: []byte(keyPEM), + clientCert: []byte(certPEM), + clientRootCert: []byte(rootPEM), + success: false, + errMsg: "rpc error", + address: "127.0.0.1:0", + }, + { + name: "No client cert", + issrv: true, + clientRootCert: []byte(rootPEM), + success: false, + errMsg: "rpc error", + address: "127.0.0.1:0", + }, + } + + for _, test := range tlsTests { + t.Run(test.name, func(t *testing.T) { + srv, err := NewServer(test.address, tlsServerConfig, nil) + if err != nil { + t.Fatalf("error creating server for test: %v", err) + } + defer srv.Stop() + go func() { + err = srv.Start() + assert.NoError(t, err, "srv.Start") + }() + + conn := newTLSConnection(t, srv.Listener.Addr().String(), test.clientCert, test.clientKey, test.clientRootCert) + assert.NotNil(t, conn) + + ccclient := peer.NewChaincodeClient(conn) + assert.NotNil(t, ccclient) + + stream, err := ccclient.Connect(context.Background()) + if test.success { + assert.NoError(t, err) + assert.NotNil(t, stream) + } else { + assert.Error(t, err) + assert.Regexp(t, test.errMsg, err.Error()) + } + }) + } +} + +func cleanupEnv() { + os.Unsetenv("CORE_PEER_TLS_ENABLED") + os.Unsetenv("CORE_TLS_CLIENT_KEY_PATH") + os.Unsetenv("CORE_TLS_CLIENT_CERT_PATH") + os.Unsetenv("CORE_PEER_TLS_ROOTCERT_FILE") + os.Unsetenv("CORE_CHAINCODE_ID_NAME") +} diff --git a/v2/shim/internal/mock/client_stream.go b/v2/shim/internal/mock/client_stream.go new file mode 100644 index 0000000..201cd63 --- /dev/null +++ b/v2/shim/internal/mock/client_stream.go @@ -0,0 +1,244 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package mock + +import ( + "sync" + + "github.com/hyperledger/fabric-protos-go-apiv2/peer" +) + +type ClientStream struct { + CloseSendStub func() error + closeSendMutex sync.RWMutex + closeSendArgsForCall []struct { + } + closeSendReturns struct { + result1 error + } + closeSendReturnsOnCall map[int]struct { + result1 error + } + RecvStub func() (*peer.ChaincodeMessage, error) + recvMutex sync.RWMutex + recvArgsForCall []struct { + } + recvReturns struct { + result1 *peer.ChaincodeMessage + result2 error + } + recvReturnsOnCall map[int]struct { + result1 *peer.ChaincodeMessage + result2 error + } + SendStub func(*peer.ChaincodeMessage) error + sendMutex sync.RWMutex + sendArgsForCall []struct { + arg1 *peer.ChaincodeMessage + } + sendReturns struct { + result1 error + } + sendReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *ClientStream) CloseSend() error { + fake.closeSendMutex.Lock() + ret, specificReturn := fake.closeSendReturnsOnCall[len(fake.closeSendArgsForCall)] + fake.closeSendArgsForCall = append(fake.closeSendArgsForCall, struct { + }{}) + stub := fake.CloseSendStub + fakeReturns := fake.closeSendReturns + fake.recordInvocation("CloseSend", []interface{}{}) + fake.closeSendMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *ClientStream) CloseSendCallCount() int { + fake.closeSendMutex.RLock() + defer fake.closeSendMutex.RUnlock() + return len(fake.closeSendArgsForCall) +} + +func (fake *ClientStream) CloseSendCalls(stub func() error) { + fake.closeSendMutex.Lock() + defer fake.closeSendMutex.Unlock() + fake.CloseSendStub = stub +} + +func (fake *ClientStream) CloseSendReturns(result1 error) { + fake.closeSendMutex.Lock() + defer fake.closeSendMutex.Unlock() + fake.CloseSendStub = nil + fake.closeSendReturns = struct { + result1 error + }{result1} +} + +func (fake *ClientStream) CloseSendReturnsOnCall(i int, result1 error) { + fake.closeSendMutex.Lock() + defer fake.closeSendMutex.Unlock() + fake.CloseSendStub = nil + if fake.closeSendReturnsOnCall == nil { + fake.closeSendReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeSendReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *ClientStream) Recv() (*peer.ChaincodeMessage, error) { + fake.recvMutex.Lock() + ret, specificReturn := fake.recvReturnsOnCall[len(fake.recvArgsForCall)] + fake.recvArgsForCall = append(fake.recvArgsForCall, struct { + }{}) + stub := fake.RecvStub + fakeReturns := fake.recvReturns + fake.recordInvocation("Recv", []interface{}{}) + fake.recvMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *ClientStream) RecvCallCount() int { + fake.recvMutex.RLock() + defer fake.recvMutex.RUnlock() + return len(fake.recvArgsForCall) +} + +func (fake *ClientStream) RecvCalls(stub func() (*peer.ChaincodeMessage, error)) { + fake.recvMutex.Lock() + defer fake.recvMutex.Unlock() + fake.RecvStub = stub +} + +func (fake *ClientStream) RecvReturns(result1 *peer.ChaincodeMessage, result2 error) { + fake.recvMutex.Lock() + defer fake.recvMutex.Unlock() + fake.RecvStub = nil + fake.recvReturns = struct { + result1 *peer.ChaincodeMessage + result2 error + }{result1, result2} +} + +func (fake *ClientStream) RecvReturnsOnCall(i int, result1 *peer.ChaincodeMessage, result2 error) { + fake.recvMutex.Lock() + defer fake.recvMutex.Unlock() + fake.RecvStub = nil + if fake.recvReturnsOnCall == nil { + fake.recvReturnsOnCall = make(map[int]struct { + result1 *peer.ChaincodeMessage + result2 error + }) + } + fake.recvReturnsOnCall[i] = struct { + result1 *peer.ChaincodeMessage + result2 error + }{result1, result2} +} + +func (fake *ClientStream) Send(arg1 *peer.ChaincodeMessage) error { + fake.sendMutex.Lock() + ret, specificReturn := fake.sendReturnsOnCall[len(fake.sendArgsForCall)] + fake.sendArgsForCall = append(fake.sendArgsForCall, struct { + arg1 *peer.ChaincodeMessage + }{arg1}) + stub := fake.SendStub + fakeReturns := fake.sendReturns + fake.recordInvocation("Send", []interface{}{arg1}) + fake.sendMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *ClientStream) SendCallCount() int { + fake.sendMutex.RLock() + defer fake.sendMutex.RUnlock() + return len(fake.sendArgsForCall) +} + +func (fake *ClientStream) SendCalls(stub func(*peer.ChaincodeMessage) error) { + fake.sendMutex.Lock() + defer fake.sendMutex.Unlock() + fake.SendStub = stub +} + +func (fake *ClientStream) SendArgsForCall(i int) *peer.ChaincodeMessage { + fake.sendMutex.RLock() + defer fake.sendMutex.RUnlock() + argsForCall := fake.sendArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *ClientStream) SendReturns(result1 error) { + fake.sendMutex.Lock() + defer fake.sendMutex.Unlock() + fake.SendStub = nil + fake.sendReturns = struct { + result1 error + }{result1} +} + +func (fake *ClientStream) SendReturnsOnCall(i int, result1 error) { + fake.sendMutex.Lock() + defer fake.sendMutex.Unlock() + fake.SendStub = nil + if fake.sendReturnsOnCall == nil { + fake.sendReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *ClientStream) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.closeSendMutex.RLock() + defer fake.closeSendMutex.RUnlock() + fake.recvMutex.RLock() + defer fake.recvMutex.RUnlock() + fake.sendMutex.RLock() + defer fake.sendMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *ClientStream) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} diff --git a/v2/shim/internal/mock/peer_chaincode_stream.go b/v2/shim/internal/mock/peer_chaincode_stream.go new file mode 100644 index 0000000..d5a7f01 --- /dev/null +++ b/v2/shim/internal/mock/peer_chaincode_stream.go @@ -0,0 +1,179 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package mock + +import ( + "sync" + + "github.com/hyperledger/fabric-protos-go-apiv2/peer" +) + +type PeerChaincodeStream struct { + RecvStub func() (*peer.ChaincodeMessage, error) + recvMutex sync.RWMutex + recvArgsForCall []struct { + } + recvReturns struct { + result1 *peer.ChaincodeMessage + result2 error + } + recvReturnsOnCall map[int]struct { + result1 *peer.ChaincodeMessage + result2 error + } + SendStub func(*peer.ChaincodeMessage) error + sendMutex sync.RWMutex + sendArgsForCall []struct { + arg1 *peer.ChaincodeMessage + } + sendReturns struct { + result1 error + } + sendReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *PeerChaincodeStream) Recv() (*peer.ChaincodeMessage, error) { + fake.recvMutex.Lock() + ret, specificReturn := fake.recvReturnsOnCall[len(fake.recvArgsForCall)] + fake.recvArgsForCall = append(fake.recvArgsForCall, struct { + }{}) + stub := fake.RecvStub + fakeReturns := fake.recvReturns + fake.recordInvocation("Recv", []interface{}{}) + fake.recvMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *PeerChaincodeStream) RecvCallCount() int { + fake.recvMutex.RLock() + defer fake.recvMutex.RUnlock() + return len(fake.recvArgsForCall) +} + +func (fake *PeerChaincodeStream) RecvCalls(stub func() (*peer.ChaincodeMessage, error)) { + fake.recvMutex.Lock() + defer fake.recvMutex.Unlock() + fake.RecvStub = stub +} + +func (fake *PeerChaincodeStream) RecvReturns(result1 *peer.ChaincodeMessage, result2 error) { + fake.recvMutex.Lock() + defer fake.recvMutex.Unlock() + fake.RecvStub = nil + fake.recvReturns = struct { + result1 *peer.ChaincodeMessage + result2 error + }{result1, result2} +} + +func (fake *PeerChaincodeStream) RecvReturnsOnCall(i int, result1 *peer.ChaincodeMessage, result2 error) { + fake.recvMutex.Lock() + defer fake.recvMutex.Unlock() + fake.RecvStub = nil + if fake.recvReturnsOnCall == nil { + fake.recvReturnsOnCall = make(map[int]struct { + result1 *peer.ChaincodeMessage + result2 error + }) + } + fake.recvReturnsOnCall[i] = struct { + result1 *peer.ChaincodeMessage + result2 error + }{result1, result2} +} + +func (fake *PeerChaincodeStream) Send(arg1 *peer.ChaincodeMessage) error { + fake.sendMutex.Lock() + ret, specificReturn := fake.sendReturnsOnCall[len(fake.sendArgsForCall)] + fake.sendArgsForCall = append(fake.sendArgsForCall, struct { + arg1 *peer.ChaincodeMessage + }{arg1}) + stub := fake.SendStub + fakeReturns := fake.sendReturns + fake.recordInvocation("Send", []interface{}{arg1}) + fake.sendMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *PeerChaincodeStream) SendCallCount() int { + fake.sendMutex.RLock() + defer fake.sendMutex.RUnlock() + return len(fake.sendArgsForCall) +} + +func (fake *PeerChaincodeStream) SendCalls(stub func(*peer.ChaincodeMessage) error) { + fake.sendMutex.Lock() + defer fake.sendMutex.Unlock() + fake.SendStub = stub +} + +func (fake *PeerChaincodeStream) SendArgsForCall(i int) *peer.ChaincodeMessage { + fake.sendMutex.RLock() + defer fake.sendMutex.RUnlock() + argsForCall := fake.sendArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *PeerChaincodeStream) SendReturns(result1 error) { + fake.sendMutex.Lock() + defer fake.sendMutex.Unlock() + fake.SendStub = nil + fake.sendReturns = struct { + result1 error + }{result1} +} + +func (fake *PeerChaincodeStream) SendReturnsOnCall(i int, result1 error) { + fake.sendMutex.Lock() + defer fake.sendMutex.Unlock() + fake.SendStub = nil + if fake.sendReturnsOnCall == nil { + fake.sendReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *PeerChaincodeStream) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.recvMutex.RLock() + defer fake.recvMutex.RUnlock() + fake.sendMutex.RLock() + defer fake.sendMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *PeerChaincodeStream) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} diff --git a/v2/shim/internal/server.go b/v2/shim/internal/server.go new file mode 100644 index 0000000..89a31fc --- /dev/null +++ b/v2/shim/internal/server.go @@ -0,0 +1,106 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "crypto/tls" + "errors" + "net" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" +) + +const ( + serverInterval = time.Duration(2) * time.Hour // 2 hours - gRPC default + serverTimeout = time.Duration(20) * time.Second // 20 sec - gRPC default + serverMinInterval = time.Duration(1) * time.Minute + connectionTimeout = 5 * time.Second +) + +// Server abstracts grpc service properties +type Server struct { + Listener net.Listener + Server *grpc.Server +} + +// Start the server +func (s *Server) Start() error { + if s.Listener == nil { + return errors.New("nil listener") + } + + if s.Server == nil { + return errors.New("nil server") + } + + return s.Server.Serve(s.Listener) +} + +// Stop the server +func (s *Server) Stop() { + if s.Server != nil { + s.Server.Stop() + } +} + +// NewServer creates a new implementation of a GRPC Server given a +// listen address +func NewServer( + address string, + tlsConf *tls.Config, + srvKaOpts *keepalive.ServerParameters, +) (*Server, error) { + if address == "" { + return nil, errors.New("server listen address not provided") + } + + //create our listener + listener, err := net.Listen("tcp", address) + if err != nil { + return nil, err + } + + //set up server options for keepalive and TLS + var serverOpts []grpc.ServerOption + + if srvKaOpts != nil { + serverOpts = append(serverOpts, grpc.KeepaliveParams(*srvKaOpts)) + } else { + serverKeepAliveParameters := keepalive.ServerParameters{ + Time: 1 * time.Minute, + Timeout: 20 * time.Second, + } + serverOpts = append(serverOpts, grpc.KeepaliveParams(serverKeepAliveParameters)) + } + + if tlsConf != nil { + serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConf))) + } + + // Default properties follow - let's start simple and stick with defaults for now. + // These match Fabric peer side properties. We can expose these as user properties + // if needed + + // set max send and recv msg sizes + serverOpts = append(serverOpts, grpc.MaxSendMsgSize(maxSendMessageSize)) + serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(maxRecvMessageSize)) + + //set enforcement policy + kep := keepalive.EnforcementPolicy{ + MinTime: serverMinInterval, + // allow keepalive w/o rpc + PermitWithoutStream: true, + } + serverOpts = append(serverOpts, grpc.KeepaliveEnforcementPolicy(kep)) + + //set default connection timeout + serverOpts = append(serverOpts, grpc.ConnectionTimeout(connectionTimeout)) + + server := grpc.NewServer(serverOpts...) + + return &Server{Listener: listener, Server: server}, nil +} diff --git a/v2/shim/internal/server_test.go b/v2/shim/internal/server_test.go new file mode 100644 index 0000000..03d0e85 --- /dev/null +++ b/v2/shim/internal/server_test.go @@ -0,0 +1,59 @@ +/* +Copyright State Street Corp. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package internal_test + +import ( + "net" + "testing" + "time" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/keepalive" +) + +func TestBadServer(t *testing.T) { + srv := &internal.Server{} + err := srv.Start() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "nil listener") + + l, err := net.Listen("tcp", ":0") // #nosec G102 + assert.NotNil(t, l) + assert.Nil(t, err) + srv = &internal.Server{Listener: l} + err = srv.Start() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "nil server") +} + +func TestServerAddressNotProvided(t *testing.T) { + kaOpts := &keepalive.ServerParameters{ + Time: 1 * time.Minute, + Timeout: 20 * time.Second, + } + srv, err := internal.NewServer("", nil, kaOpts) + assert.Nil(t, srv) + assert.NotNil(t, err, "server listen address not provided") +} + +func TestBadServerAddress(t *testing.T) { + kaOpts := &keepalive.ServerParameters{ + Time: 1 * time.Minute, + Timeout: 20 * time.Second, + } + srv, err := internal.NewServer("__badhost__:0", nil, kaOpts) + assert.Nil(t, srv) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "listen tcp: lookup __badhost__") + + srv, err = internal.NewServer("host", nil, kaOpts) + assert.Nil(t, srv) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "listen tcp: address host: missing port in address") +} diff --git a/v2/shim/response.go b/v2/shim/response.go new file mode 100644 index 0000000..f56b4f6 --- /dev/null +++ b/v2/shim/response.go @@ -0,0 +1,36 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "github.com/hyperledger/fabric-protos-go-apiv2/peer" +) + +const ( + // OK constant - status code less than 400, endorser will endorse it. + // OK means init or invoke successfully. + OK = 200 + + // ERRORTHRESHOLD constant - status code greater than or equal to 400 will be considered an error and rejected by endorser. + ERRORTHRESHOLD = 400 + + // ERROR constant - default error value + ERROR = 500 +) + +// Success ... +func Success(payload []byte) *peer.Response { + return &peer.Response{ + Status: OK, + Payload: payload, + } +} + +// Error ... +func Error(msg string) *peer.Response { + return &peer.Response{ + Status: ERROR, + Message: msg, + } +} diff --git a/v2/shim/shim.go b/v2/shim/shim.go new file mode 100644 index 0000000..f75839d --- /dev/null +++ b/v2/shim/shim.go @@ -0,0 +1,153 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package shim provides APIs for the chaincode to access its state +// variables, transaction context and call other chaincodes. +package shim + +import ( + "errors" + "flag" + "fmt" + "io" + "os" + "unicode/utf8" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/protobuf/proto" +) + +const ( + minUnicodeRuneValue = 0 //U+0000 + maxUnicodeRuneValue = utf8.MaxRune //U+10FFFF - maximum (and unallocated) code point + compositeKeyNamespace = "\x00" + emptyKeySubstitute = "\x01" +) + +// peer as server +var peerAddress = flag.String("peer.address", "", "peer address") + +// this separates the chaincode stream interface establishment +// so we can replace it with a mock peer stream +type peerStreamGetter func(name string) (ClientStream, error) + +// UTs to setup mock peer stream getter +var streamGetter peerStreamGetter + +// the non-mock user CC stream establishment func +func userChaincodeStreamGetter(name string) (ClientStream, error) { + if *peerAddress == "" { + return nil, errors.New("flag 'peer.address' must be set") + } + + conf, err := internal.LoadConfig() + if err != nil { + return nil, err + } + + conn, err := internal.NewClientConn(*peerAddress, conf.TLS, conf.KaOpts) + if err != nil { + return nil, err + } + + return internal.NewRegisterClient(conn) +} + +// Start chaincodes +func Start(cc Chaincode) error { + flag.Parse() + chaincodename := os.Getenv("CORE_CHAINCODE_ID_NAME") + if chaincodename == "" { + return errors.New("'CORE_CHAINCODE_ID_NAME' must be set") + } + + //mock stream not set up ... get real stream + if streamGetter == nil { + streamGetter = userChaincodeStreamGetter + } + + stream, err := streamGetter(chaincodename) + if err != nil { + return err + } + + err = chaincodeAsClientChat(chaincodename, stream, cc) + + return err +} + +// StartInProc is an entry point for system chaincodes bootstrap. It is not an +// API for chaincodes. +func StartInProc(chaincodename string, stream ClientStream, cc Chaincode) error { + return chaincodeAsClientChat(chaincodename, stream, cc) +} + +// this is the chat stream resulting from the chaincode-as-client model where the chaincode initiates connection +func chaincodeAsClientChat(chaincodename string, stream ClientStream, cc Chaincode) error { + defer stream.CloseSend() //nolint:Errcheck + return chatWithPeer(chaincodename, stream, cc) +} + +// chat stream for peer-chaincode interactions post connection +func chatWithPeer(chaincodename string, stream PeerChaincodeStream, cc Chaincode) error { + // Create the shim handler responsible for all control logic + handler := newChaincodeHandler(stream, cc) + + // Send the ChaincodeID during register. + chaincodeID := &peer.ChaincodeID{Name: chaincodename} + payload, err := proto.Marshal(chaincodeID) + if err != nil { + return fmt.Errorf("error marshalling chaincodeID during chaincode registration: %s", err) + } + + // Register on the stream + if err = handler.serialSend(&peer.ChaincodeMessage{Type: peer.ChaincodeMessage_REGISTER, Payload: payload}); err != nil { + return fmt.Errorf("error sending chaincode REGISTER: %s", err) + + } + + // holds return values from gRPC Recv below + type recvMsg struct { + msg *peer.ChaincodeMessage + err error + } + msgAvail := make(chan *recvMsg, 1) + errc := make(chan error) + + receiveMessage := func() { + in, err := stream.Recv() + msgAvail <- &recvMsg{in, err} + } + + go receiveMessage() + for { + select { + case rmsg := <-msgAvail: + switch { + case rmsg.err == io.EOF: + return errors.New("received EOF, ending chaincode stream") + case rmsg.err != nil: + err := fmt.Errorf("receive failed: %s", rmsg.err) + return err + case rmsg.msg == nil: + err := errors.New("received nil message, ending chaincode stream") + return err + default: + err := handler.handleMessage(rmsg.msg, errc) + if err != nil { + err = fmt.Errorf("error handling message: %s", err) + return err + } + + go receiveMessage() + } + + case sendErr := <-errc: + if sendErr != nil { + err := fmt.Errorf("error sending: %s", sendErr) + return err + } + } + } +} diff --git a/v2/shim/shim_test.go b/v2/shim/shim_test.go new file mode 100644 index 0000000..d64c1bc --- /dev/null +++ b/v2/shim/shim_test.go @@ -0,0 +1,204 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "errors" + "io" + "os" + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal/mock" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + + "github.com/stretchr/testify/assert" +) + +// MockQueryIteratorInterface allows a chaincode to iterate over a set of +// key/value pairs returned by range query. +// TODO: Once the execute query and history query are implemented in MockStub, +// we need to update this interface +type MockQueryIteratorInterface interface { + StateQueryIteratorInterface +} + +func TestStart(t *testing.T) { + + var tests = []struct { + name string + envVars map[string]string + peerAddress string + chaincodeAddress string + streamGetter func(name string) (ClientStream, error) + cc Chaincode + expectedErr string + }{ + { + name: "Missing Chaincode ID", + expectedErr: "'CORE_CHAINCODE_ID_NAME' must be set", + }, + { + name: "Missing Peer Address", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + }, + expectedErr: "flag 'peer.address' must be set", + }, + { + name: "TLS Not Set", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + }, + peerAddress: "127.0.0.1:12345", + expectedErr: "'CORE_PEER_TLS_ENABLED' must be set to 'true' or 'false'", + }, + { + name: "Connection Error", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + "CORE_PEER_TLS_ENABLED": "false", + }, + peerAddress: "127.0.0.1:12345", + expectedErr: `rpc error: code = Unavailable desc = connection error: desc = "transport: Error while dialing: dial tcp 127.0.0.1:12345: connect: connection refused"`, + }, + { + name: "Chat - Nil Message", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + "CORE_PEER_TLS_ENABLED": "false", + }, + peerAddress: "127.0.0.1:12345", + streamGetter: func(name string) (ClientStream, error) { + stream := &mock.ClientStream{} + return stream, nil + }, + expectedErr: "received nil message, ending chaincode stream", + }, + { + name: "Chat - EOF", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + "CORE_PEER_TLS_ENABLED": "false", + }, + peerAddress: "127.0.0.1:12345", + streamGetter: func(name string) (ClientStream, error) { + stream := &mock.ClientStream{} + stream.RecvReturns(nil, io.EOF) + return stream, nil + }, + expectedErr: "received EOF, ending chaincode stream", + }, + { + name: "Chat - Recv Error", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + "CORE_PEER_TLS_ENABLED": "false", + }, + peerAddress: "127.0.0.1:12345", + streamGetter: func(name string) (ClientStream, error) { + stream := &mock.ClientStream{} + stream.RecvReturns(nil, errors.New("recvError")) + return stream, nil + }, + expectedErr: "receive failed: recvError", + }, + { + name: "Chat - Not Ready", + envVars: map[string]string{ + "CORE_CHAINCODE_ID_NAME": "cc", + "CORE_PEER_TLS_ENABLED": "false", + }, + peerAddress: "127.0.0.1:12345", + streamGetter: func(name string) (ClientStream, error) { + stream := &mock.ClientStream{} + stream.RecvReturnsOnCall( + 0, + &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_READY, + Txid: "txid", + }, + nil, + ) + return stream, nil + }, + expectedErr: "error handling message: [txid] Chaincode h cannot handle message (READY) while in state: created", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + for k, v := range test.envVars { + os.Setenv(k, v) + defer os.Unsetenv(k) + } + peerAddress = &test.peerAddress + streamGetter = test.streamGetter + err := Start(test.cc) + assert.EqualError(t, err, test.expectedErr) + }) + } + +} + +func TestChaincodeServerStart(t *testing.T) { + + var tests = []struct { + name string + ccsrv ChaincodeServer + streamGetter func(name string) (ClientStream, error) + expectedErr string + containsErr string + }{ + { + name: "Missing Chaincode ID", + ccsrv: ChaincodeServer{}, + expectedErr: "ccid must be specified", + }, + { + name: "Missing Peer Address", + ccsrv: ChaincodeServer{CCID: "cc"}, + expectedErr: "address must be specified", + }, + { + name: "Missing Peer Address and Chaincode Address", + ccsrv: ChaincodeServer{CCID: "cc", Address: "127.0.0.1:12345"}, + expectedErr: "chaincode must be specified", + }, + { + name: "Badly formed chaincode server address", + ccsrv: ChaincodeServer{CCID: "cc", Address: "127.0.0.1", CC: &mockChaincode{}, TLSProps: TLSProperties{Disabled: true}}, + expectedErr: "listen tcp: address 127.0.0.1: missing port in address", + }, + { + name: "Bad host in chaincode server address", + ccsrv: ChaincodeServer{CCID: "cc", Address: "__badhost__:12345", CC: &mockChaincode{}, TLSProps: TLSProperties{Disabled: true}}, + containsErr: "listen tcp: lookup __badhost__", + }, + // Basic TLS tests, path tests + { + name: "TLS enabled but key path not provided", + ccsrv: ChaincodeServer{CCID: "cc", Address: "host:12345", CC: &mockChaincode{}, TLSProps: TLSProperties{Disabled: false}}, + containsErr: "key not provided", + }, + { + name: "TLS enabled but cert path not provided", + ccsrv: ChaincodeServer{CCID: "cc", Address: "host:12345", CC: &mockChaincode{}, TLSProps: TLSProperties{Disabled: false, Key: []byte("key")}}, + containsErr: "cert not provided", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + err := test.ccsrv.Start() + if test.expectedErr != "" { + assert.EqualError(t, err, test.expectedErr) + } else if test.containsErr != "" { + assert.Contains(t, err.Error(), test.containsErr) + } + }) + } + +} diff --git a/v2/shim/stub.go b/v2/shim/stub.go new file mode 100644 index 0000000..4e3fbe7 --- /dev/null +++ b/v2/shim/stub.go @@ -0,0 +1,759 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "os" + "unicode/utf8" + + "github.com/hyperledger/fabric-protos-go-apiv2/common" + "github.com/hyperledger/fabric-protos-go-apiv2/ledger/queryresult" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// ChaincodeStub is an object passed to chaincode for shim side handling of +// APIs. +type ChaincodeStub struct { + TxID string + ChannelID string + chaincodeEvent *peer.ChaincodeEvent + args [][]byte + handler *Handler + signedProposal *peer.SignedProposal + proposal *peer.Proposal + validationParameterMetakey string + + // Additional fields extracted from the signedProposal + creator []byte + transient map[string][]byte + binding []byte + + decorations map[string][]byte +} + +// ChaincodeInvocation functionality + +func newChaincodeStub(handler *Handler, channelID, txid string, input *peer.ChaincodeInput, signedProposal *peer.SignedProposal) (*ChaincodeStub, error) { + stub := &ChaincodeStub{ + TxID: txid, + ChannelID: channelID, + args: input.Args, + handler: handler, + signedProposal: signedProposal, + decorations: input.Decorations, + validationParameterMetakey: peer.MetaDataKeys_VALIDATION_PARAMETER.String(), + } + + // TODO: sanity check: verify that every call to init with a nil + // signedProposal is a legitimate one, meaning it is an internal call + // to system chaincodes. + if signedProposal != nil { + var err error + + stub.proposal = &peer.Proposal{} + err = proto.Unmarshal(signedProposal.ProposalBytes, stub.proposal) + if err != nil { + + return nil, fmt.Errorf("failed to extract Proposal from SignedProposal: %s", err) + } + + // check for header + if len(stub.proposal.GetHeader()) == 0 { + return nil, errors.New("failed to extract Proposal fields: proposal header is nil") + } + + // Extract creator, transient, binding... + hdr := &common.Header{} + if err := proto.Unmarshal(stub.proposal.GetHeader(), hdr); err != nil { + return nil, fmt.Errorf("failed to extract proposal header: %s", err) + } + + // extract and validate channel header + chdr := &common.ChannelHeader{} + if err := proto.Unmarshal(hdr.ChannelHeader, chdr); err != nil { + return nil, fmt.Errorf("failed to extract channel header: %s", err) + } + validTypes := map[common.HeaderType]bool{ + common.HeaderType_ENDORSER_TRANSACTION: true, + common.HeaderType_CONFIG: true, + } + if !validTypes[common.HeaderType(chdr.GetType())] { + return nil, fmt.Errorf( + "invalid channel header type. Expected %s or %s, received %s", + common.HeaderType_ENDORSER_TRANSACTION, + common.HeaderType_CONFIG, + common.HeaderType(chdr.GetType()), + ) + } + + // extract creator from signature header + shdr := &common.SignatureHeader{} + if err := proto.Unmarshal(hdr.GetSignatureHeader(), shdr); err != nil { + return nil, fmt.Errorf("failed to extract signature header: %s", err) + } + stub.creator = shdr.GetCreator() + + // extract trasient data from proposal payload + payload := &peer.ChaincodeProposalPayload{} + if err := proto.Unmarshal(stub.proposal.GetPayload(), payload); err != nil { + return nil, fmt.Errorf("failed to extract proposal payload: %s", err) + } + stub.transient = payload.GetTransientMap() + + // compute the proposal binding from the nonce, creator and epoch + epoch := make([]byte, 8) + binary.LittleEndian.PutUint64(epoch, chdr.GetEpoch()) + digest := sha256.Sum256(append(append(shdr.GetNonce(), stub.creator...), epoch...)) + stub.binding = digest[:] + + } + + return stub, nil +} + +// GetTxID returns the transaction ID for the proposal +func (s *ChaincodeStub) GetTxID() string { + return s.TxID +} + +// GetChannelID returns the channel for the proposal +func (s *ChaincodeStub) GetChannelID() string { + return s.ChannelID +} + +// GetDecorations ... +func (s *ChaincodeStub) GetDecorations() map[string][]byte { + return s.decorations +} + +// GetMSPID returns the local mspid of the peer by checking the CORE_PEER_LOCALMSPID +// env var and returns an error if the env var is not set +func GetMSPID() (string, error) { + mspid := os.Getenv("CORE_PEER_LOCALMSPID") + + if mspid == "" { + return "", errors.New("'CORE_PEER_LOCALMSPID' is not set") + } + + return mspid, nil +} + +// ------------- Call Chaincode functions --------------- + +// InvokeChaincode documentation can be found in interfaces.go +func (s *ChaincodeStub) InvokeChaincode(chaincodeName string, args [][]byte, channel string) *peer.Response { + // Internally we handle chaincode name as a composite name + if channel != "" { + chaincodeName = chaincodeName + "/" + channel + } + return s.handler.handleInvokeChaincode(chaincodeName, args, s.ChannelID, s.TxID) +} + +// --------- State functions ---------- + +// GetState documentation can be found in interfaces.go +func (s *ChaincodeStub) GetState(key string) ([]byte, error) { + // Access public data by setting the collection to empty string + collection := "" + return s.handler.handleGetState(collection, key, s.ChannelID, s.TxID) +} + +// SetStateValidationParameter documentation can be found in interfaces.go +func (s *ChaincodeStub) SetStateValidationParameter(key string, ep []byte) error { + return s.handler.handlePutStateMetadataEntry("", key, s.validationParameterMetakey, ep, s.ChannelID, s.TxID) +} + +// GetStateValidationParameter documentation can be found in interfaces.go +func (s *ChaincodeStub) GetStateValidationParameter(key string) ([]byte, error) { + md, err := s.handler.handleGetStateMetadata("", key, s.ChannelID, s.TxID) + if err != nil { + return nil, err + } + if ep, ok := md[s.validationParameterMetakey]; ok { + return ep, nil + } + return nil, nil +} + +// PutState documentation can be found in interfaces.go +func (s *ChaincodeStub) PutState(key string, value []byte) error { + if key == "" { + return errors.New("key must not be an empty string") + } + // Access public data by setting the collection to empty string + collection := "" + return s.handler.handlePutState(collection, key, value, s.ChannelID, s.TxID) +} + +func (s *ChaincodeStub) createStateQueryIterator(response *peer.QueryResponse) *StateQueryIterator { + return &StateQueryIterator{ + CommonIterator: &CommonIterator{ + handler: s.handler, + channelID: s.ChannelID, + txid: s.TxID, + response: response, + currentLoc: 0, + }, + } +} + +// GetQueryResult documentation can be found in interfaces.go +func (s *ChaincodeStub) GetQueryResult(query string) (StateQueryIteratorInterface, error) { + // Access public data by setting the collection to empty string + collection := "" + // ignore QueryResponseMetadata as it is not applicable for a rich query without pagination + iterator, _, err := s.handleGetQueryResult(collection, query, nil) + + return iterator, err +} + +// DelState documentation can be found in interfaces.go +func (s *ChaincodeStub) DelState(key string) error { + // Access public data by setting the collection to empty string + collection := "" + return s.handler.handleDelState(collection, key, s.ChannelID, s.TxID) +} + +// --------- private state functions --------- + +// GetPrivateData documentation can be found in interfaces.go +func (s *ChaincodeStub) GetPrivateData(collection string, key string) ([]byte, error) { + if collection == "" { + return nil, fmt.Errorf("collection must not be an empty string") + } + return s.handler.handleGetState(collection, key, s.ChannelID, s.TxID) +} + +// GetPrivateDataHash documentation can be found in interfaces.go +func (s *ChaincodeStub) GetPrivateDataHash(collection string, key string) ([]byte, error) { + if collection == "" { + return nil, fmt.Errorf("collection must not be an empty string") + } + return s.handler.handleGetPrivateDataHash(collection, key, s.ChannelID, s.TxID) +} + +// PutPrivateData documentation can be found in interfaces.go +func (s *ChaincodeStub) PutPrivateData(collection string, key string, value []byte) error { + if collection == "" { + return fmt.Errorf("collection must not be an empty string") + } + if key == "" { + return fmt.Errorf("key must not be an empty string") + } + return s.handler.handlePutState(collection, key, value, s.ChannelID, s.TxID) +} + +// DelPrivateData documentation can be found in interfaces.go +func (s *ChaincodeStub) DelPrivateData(collection string, key string) error { + if collection == "" { + return fmt.Errorf("collection must not be an empty string") + } + return s.handler.handleDelState(collection, key, s.ChannelID, s.TxID) +} + +// PurgePrivateData documentation can be found in interfaces.go +func (s *ChaincodeStub) PurgePrivateData(collection string, key string) error { + if collection == "" { + return fmt.Errorf("collection must not be an empty string") + } + return s.handler.handlePurgeState(collection, key, s.ChannelID, s.TxID) +} + +// GetPrivateDataByRange documentation can be found in interfaces.go +func (s *ChaincodeStub) GetPrivateDataByRange(collection, startKey, endKey string) (StateQueryIteratorInterface, error) { + if collection == "" { + return nil, fmt.Errorf("collection must not be an empty string") + } + if startKey == "" { + startKey = emptyKeySubstitute + } + if err := validateSimpleKeys(startKey, endKey); err != nil { + return nil, err + } + // ignore QueryResponseMetadata as it is not applicable for a range query without pagination + iterator, _, err := s.handleGetStateByRange(collection, startKey, endKey, nil) + + return iterator, err +} + +func (s *ChaincodeStub) createRangeKeysForPartialCompositeKey(objectType string, attributes []string) (string, string, error) { + partialCompositeKey, err := s.CreateCompositeKey(objectType, attributes) + if err != nil { + return "", "", err + } + startKey := partialCompositeKey + endKey := partialCompositeKey + string(maxUnicodeRuneValue) + + return startKey, endKey, nil +} + +// GetPrivateDataByPartialCompositeKey documentation can be found in interfaces.go +func (s *ChaincodeStub) GetPrivateDataByPartialCompositeKey(collection, objectType string, attributes []string) (StateQueryIteratorInterface, error) { + if collection == "" { + return nil, fmt.Errorf("collection must not be an empty string") + } + + startKey, endKey, err := s.createRangeKeysForPartialCompositeKey(objectType, attributes) + if err != nil { + return nil, err + } + // ignore QueryResponseMetadata as it is not applicable for a partial composite key query without pagination + iterator, _, err := s.handleGetStateByRange(collection, startKey, endKey, nil) + + return iterator, err +} + +// GetPrivateDataQueryResult documentation can be found in interfaces.go +func (s *ChaincodeStub) GetPrivateDataQueryResult(collection, query string) (StateQueryIteratorInterface, error) { + if collection == "" { + return nil, fmt.Errorf("collection must not be an empty string") + } + // ignore QueryResponseMetadata as it is not applicable for a range query without pagination + iterator, _, err := s.handleGetQueryResult(collection, query, nil) + + return iterator, err +} + +// GetPrivateDataValidationParameter documentation can be found in interfaces.go +func (s *ChaincodeStub) GetPrivateDataValidationParameter(collection, key string) ([]byte, error) { + md, err := s.handler.handleGetStateMetadata(collection, key, s.ChannelID, s.TxID) + if err != nil { + return nil, err + } + if ep, ok := md[s.validationParameterMetakey]; ok { + return ep, nil + } + return nil, nil +} + +// SetPrivateDataValidationParameter documentation can be found in interfaces.go +func (s *ChaincodeStub) SetPrivateDataValidationParameter(collection, key string, ep []byte) error { + return s.handler.handlePutStateMetadataEntry(collection, key, s.validationParameterMetakey, ep, s.ChannelID, s.TxID) +} + +// CommonIterator documentation can be found in interfaces.go +type CommonIterator struct { + handler *Handler + channelID string + txid string + response *peer.QueryResponse + currentLoc int +} + +// StateQueryIterator documentation can be found in interfaces.go +type StateQueryIterator struct { + *CommonIterator +} + +// HistoryQueryIterator documentation can be found in interfaces.go +type HistoryQueryIterator struct { + *CommonIterator +} + +// General interface for supporting different types of query results. +// Actual types differ for different queries +type queryResult interface{} + +type resultType uint8 + +// TODO: Document constants +/* + Constants ... +*/ +const ( + StateQueryResult resultType = iota + 1 + HistoryQueryResult +) + +func createQueryResponseMetadata(metadataBytes []byte) (*peer.QueryResponseMetadata, error) { + metadata := &peer.QueryResponseMetadata{} + err := proto.Unmarshal(metadataBytes, metadata) + if err != nil { + return nil, err + } + + return metadata, nil +} + +func (s *ChaincodeStub) handleGetStateByRange(collection, startKey, endKey string, + metadata []byte) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + + response, err := s.handler.handleGetStateByRange(collection, startKey, endKey, metadata, s.ChannelID, s.TxID) + if err != nil { + return nil, nil, err + } + + iterator := s.createStateQueryIterator(response) + responseMetadata, err := createQueryResponseMetadata(response.Metadata) + if err != nil { + return nil, nil, err + } + + return iterator, responseMetadata, nil +} + +func (s *ChaincodeStub) handleGetQueryResult(collection, query string, + metadata []byte) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + + response, err := s.handler.handleGetQueryResult(collection, query, metadata, s.ChannelID, s.TxID) + if err != nil { + return nil, nil, err + } + + iterator := s.createStateQueryIterator(response) + responseMetadata, err := createQueryResponseMetadata(response.Metadata) + if err != nil { + return nil, nil, err + } + + return iterator, responseMetadata, nil +} + +// GetStateByRange documentation can be found in interfaces.go +func (s *ChaincodeStub) GetStateByRange(startKey, endKey string) (StateQueryIteratorInterface, error) { + if startKey == "" { + startKey = emptyKeySubstitute + } + if err := validateSimpleKeys(startKey, endKey); err != nil { + return nil, err + } + collection := "" + + // ignore QueryResponseMetadata as it is not applicable for a range query without pagination + iterator, _, err := s.handleGetStateByRange(collection, startKey, endKey, nil) + + return iterator, err +} + +// GetHistoryForKey documentation can be found in interfaces.go +func (s *ChaincodeStub) GetHistoryForKey(key string) (HistoryQueryIteratorInterface, error) { + response, err := s.handler.handleGetHistoryForKey(key, s.ChannelID, s.TxID) + if err != nil { + return nil, err + } + return &HistoryQueryIterator{CommonIterator: &CommonIterator{s.handler, s.ChannelID, s.TxID, response, 0}}, nil +} + +// CreateCompositeKey documentation can be found in interfaces.go +func (s *ChaincodeStub) CreateCompositeKey(objectType string, attributes []string) (string, error) { + return CreateCompositeKey(objectType, attributes) +} + +// SplitCompositeKey documentation can be found in interfaces.go +func (s *ChaincodeStub) SplitCompositeKey(compositeKey string) (string, []string, error) { + return splitCompositeKey(compositeKey) +} + +// CreateCompositeKey ... +func CreateCompositeKey(objectType string, attributes []string) (string, error) { + if err := validateCompositeKeyAttribute(objectType); err != nil { + return "", err + } + ck := compositeKeyNamespace + objectType + string(rune(minUnicodeRuneValue)) + for _, att := range attributes { + if err := validateCompositeKeyAttribute(att); err != nil { + return "", err + } + ck += att + string(rune(minUnicodeRuneValue)) + } + return ck, nil +} + +func splitCompositeKey(compositeKey string) (string, []string, error) { + componentIndex := 1 + components := []string{} + for i := 1; i < len(compositeKey); i++ { + if compositeKey[i] == minUnicodeRuneValue { + components = append(components, compositeKey[componentIndex:i]) + componentIndex = i + 1 + } + } + return components[0], components[1:], nil +} + +func validateCompositeKeyAttribute(str string) error { + if !utf8.ValidString(str) { + return fmt.Errorf("not a valid utf8 string: [%x]", str) + } + for index, runeValue := range str { + if runeValue == minUnicodeRuneValue || runeValue == maxUnicodeRuneValue { + return fmt.Errorf(`input contains unicode %#U starting at position [%d]. %#U and %#U are not allowed in the input attribute of a composite key`, + runeValue, index, minUnicodeRuneValue, maxUnicodeRuneValue) + } + } + return nil +} + +// To ensure that simple keys do not go into composite key namespace, +// we validate simplekey to check whether the key starts with 0x00 (which +// is the namespace for compositeKey). This helps in avoding simple/composite +// key collisions. +func validateSimpleKeys(simpleKeys ...string) error { + for _, key := range simpleKeys { + if len(key) > 0 && key[0] == compositeKeyNamespace[0] { + return fmt.Errorf(`first character of the key [%s] contains a null character which is not allowed`, key) + } + } + return nil +} + +// GetStateByPartialCompositeKey documentation can be found in interfaces.go +func (s *ChaincodeStub) GetStateByPartialCompositeKey(objectType string, attributes []string) (StateQueryIteratorInterface, error) { + collection := "" + startKey, endKey, err := s.createRangeKeysForPartialCompositeKey(objectType, attributes) + if err != nil { + return nil, err + } + // ignore QueryResponseMetadata as it is not applicable for a partial composite key query without pagination + iterator, _, err := s.handleGetStateByRange(collection, startKey, endKey, nil) + + return iterator, err +} + +func createQueryMetadata(pageSize int32, bookmark string) ([]byte, error) { + // Construct the QueryMetadata with a page size and a bookmark needed for pagination + metadata := &peer.QueryMetadata{PageSize: pageSize, Bookmark: bookmark} + metadataBytes, err := proto.Marshal(metadata) + if err != nil { + return nil, err + } + return metadataBytes, nil +} + +// GetStateByRangeWithPagination ... +func (s *ChaincodeStub) GetStateByRangeWithPagination(startKey, endKey string, pageSize int32, + bookmark string) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + + if startKey == "" { + startKey = emptyKeySubstitute + } + if err := validateSimpleKeys(startKey, endKey); err != nil { + return nil, nil, err + } + + collection := "" + + metadata, err := createQueryMetadata(pageSize, bookmark) + if err != nil { + return nil, nil, err + } + + return s.handleGetStateByRange(collection, startKey, endKey, metadata) +} + +// GetStateByPartialCompositeKeyWithPagination ... +func (s *ChaincodeStub) GetStateByPartialCompositeKeyWithPagination(objectType string, keys []string, + pageSize int32, bookmark string) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + + collection := "" + + metadata, err := createQueryMetadata(pageSize, bookmark) + if err != nil { + return nil, nil, err + } + + startKey, endKey, err := s.createRangeKeysForPartialCompositeKey(objectType, keys) + if err != nil { + return nil, nil, err + } + return s.handleGetStateByRange(collection, startKey, endKey, metadata) +} + +// GetQueryResultWithPagination ... +func (s *ChaincodeStub) GetQueryResultWithPagination(query string, pageSize int32, + bookmark string) (StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + // Access public data by setting the collection to empty string + collection := "" + + metadata, err := createQueryMetadata(pageSize, bookmark) + if err != nil { + return nil, nil, err + } + return s.handleGetQueryResult(collection, query, metadata) +} + +// Next ... +func (iter *StateQueryIterator) Next() (*queryresult.KV, error) { + result, err := iter.nextResult(StateQueryResult) + if err != nil { + return nil, err + } + return result.(*queryresult.KV), err +} + +// Next ... +func (iter *HistoryQueryIterator) Next() (*queryresult.KeyModification, error) { + result, err := iter.nextResult(HistoryQueryResult) + if err != nil { + return nil, err + } + return result.(*queryresult.KeyModification), err +} + +// HasNext documentation can be found in interfaces.go +func (iter *CommonIterator) HasNext() bool { + if iter.currentLoc < len(iter.response.Results) || iter.response.HasMore { + return true + } + return false +} + +// getResultsFromBytes deserializes QueryResult and return either a KV struct +// or KeyModification depending on the result type (i.e., state (range/execute) +// query, history query). Note that queryResult is an empty golang +// interface that can hold values of any type. +func (iter *CommonIterator) getResultFromBytes(queryResultBytes *peer.QueryResultBytes, + rType resultType) (queryResult, error) { + + if rType == StateQueryResult { + stateQueryResult := &queryresult.KV{} + if err := proto.Unmarshal(queryResultBytes.ResultBytes, stateQueryResult); err != nil { + return nil, fmt.Errorf("error unmarshaling result from bytes: %s", err) + } + return stateQueryResult, nil + + } else if rType == HistoryQueryResult { + historyQueryResult := &queryresult.KeyModification{} + if err := proto.Unmarshal(queryResultBytes.ResultBytes, historyQueryResult); err != nil { + return nil, err + } + return historyQueryResult, nil + } + return nil, errors.New("wrong result type") +} + +func (iter *CommonIterator) fetchNextQueryResult() error { + response, err := iter.handler.handleQueryStateNext(iter.response.Id, iter.channelID, iter.txid) + if err != nil { + return err + } + iter.currentLoc = 0 + iter.response = response + return nil +} + +// nextResult returns the next QueryResult (i.e., either a KV struct or KeyModification) +// from the state or history query iterator. Note that queryResult is an +// empty golang interface that can hold values of any type. +func (iter *CommonIterator) nextResult(rType resultType) (queryResult, error) { + if iter.currentLoc < len(iter.response.Results) { + // On valid access of an element from cached results + queryResult, err := iter.getResultFromBytes(iter.response.Results[iter.currentLoc], rType) + if err != nil { + return nil, err + } + iter.currentLoc++ + + if iter.currentLoc == len(iter.response.Results) && iter.response.HasMore { + // On access of last item, pre-fetch to update HasMore flag + if err = iter.fetchNextQueryResult(); err != nil { + return nil, err + } + } + + return queryResult, err + } else if !iter.response.HasMore { + // On call to Next() without check of HasMore + return nil, errors.New("no such key") + } + + // should not fall through here + // case: no cached results but HasMore is true. + return nil, errors.New("invalid iterator state") +} + +// Close documentation can be found in interfaces.go +func (iter *CommonIterator) Close() error { + _, err := iter.handler.handleQueryStateClose(iter.response.Id, iter.channelID, iter.txid) + return err +} + +// GetArgs documentation can be found in interfaces.go +func (s *ChaincodeStub) GetArgs() [][]byte { + return s.args +} + +// GetStringArgs documentation can be found in interfaces.go +func (s *ChaincodeStub) GetStringArgs() []string { + args := s.GetArgs() + strargs := make([]string, 0, len(args)) + for _, barg := range args { + strargs = append(strargs, string(barg)) + } + return strargs +} + +// GetFunctionAndParameters documentation can be found in interfaces.go +func (s *ChaincodeStub) GetFunctionAndParameters() (function string, params []string) { + allargs := s.GetStringArgs() + function = "" + params = []string{} + if len(allargs) >= 1 { + function = allargs[0] + params = allargs[1:] + } + return +} + +// GetCreator documentation can be found in interfaces.go +func (s *ChaincodeStub) GetCreator() ([]byte, error) { + return s.creator, nil +} + +// GetTransient documentation can be found in interfaces.go +func (s *ChaincodeStub) GetTransient() (map[string][]byte, error) { + return s.transient, nil +} + +// GetBinding documentation can be found in interfaces.go +func (s *ChaincodeStub) GetBinding() ([]byte, error) { + return s.binding, nil +} + +// GetSignedProposal documentation can be found in interfaces.go +func (s *ChaincodeStub) GetSignedProposal() (*peer.SignedProposal, error) { + return s.signedProposal, nil +} + +// GetArgsSlice documentation can be found in interfaces.go +func (s *ChaincodeStub) GetArgsSlice() ([]byte, error) { + args := s.GetArgs() + res := []byte{} + for _, barg := range args { + res = append(res, barg...) + } + return res, nil +} + +// GetTxTimestamp documentation can be found in interfaces.go +func (s *ChaincodeStub) GetTxTimestamp() (*timestamppb.Timestamp, error) { + hdr := &common.Header{} + if err := proto.Unmarshal(s.proposal.Header, hdr); err != nil { + return nil, fmt.Errorf("error unmarshaling Header: %s", err) + } + + chdr := &common.ChannelHeader{} + if err := proto.Unmarshal(hdr.ChannelHeader, chdr); err != nil { + return nil, fmt.Errorf("error unmarshaling ChannelHeader: %s", err) + } + + return chdr.GetTimestamp(), nil +} + +// ------------- ChaincodeEvent API ---------------------- + +// SetEvent documentation can be found in interfaces.go +func (s *ChaincodeStub) SetEvent(name string, payload []byte) error { + if name == "" { + return errors.New("event name can not be empty string") + } + s.chaincodeEvent = &peer.ChaincodeEvent{EventName: name, Payload: payload} + return nil +} diff --git a/v2/shim/stub_test.go b/v2/shim/stub_test.go new file mode 100644 index 0000000..5f41d6e --- /dev/null +++ b/v2/shim/stub_test.go @@ -0,0 +1,619 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shim + +import ( + "crypto/sha256" + "encoding/binary" + "os" + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim/internal/mock" + "github.com/hyperledger/fabric-protos-go-apiv2/common" + "github.com/hyperledger/fabric-protos-go-apiv2/ledger/queryresult" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/protobuf/proto" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func toChaincodeArgs(args ...string) [][]byte { + ccArgs := make([][]byte, len(args)) + for i, a := range args { + ccArgs[i] = []byte(a) + } + return ccArgs +} + +// requireProtoEqual ensures an expected protobuf message matches an actual message +func requireProtoEqual(t *testing.T, expected proto.Message, actual proto.Message) { + require.True(t, proto.Equal(expected, actual), "Expected %v, got %v", expected, actual) +} + +func TestNewChaincodeStub(t *testing.T) { + expectedArgs := toChaincodeArgs("function", "arg1", "arg2") + expectedDecorations := map[string][]byte{"decoration-key": []byte("decoration-value")} + expectedCreator := []byte("signature-header-creator") + expectedTransient := map[string][]byte{"key": []byte("value")} + expectedEpoch := uint64(999) + + validSignedProposal := &peer.SignedProposal{ + ProposalBytes: marshalOrPanic(&peer.Proposal{ + Header: marshalOrPanic(&common.Header{ + ChannelHeader: marshalOrPanic(&common.ChannelHeader{ + Type: int32(common.HeaderType_ENDORSER_TRANSACTION), + Epoch: expectedEpoch, + }), + SignatureHeader: marshalOrPanic(&common.SignatureHeader{ + Creator: expectedCreator, + }), + }), + Payload: marshalOrPanic(&peer.ChaincodeProposalPayload{ + Input: []byte("chaincode-proposal-input"), + TransientMap: expectedTransient, + }), + }), + } + + tests := []struct { + signedProposal *peer.SignedProposal + expectedErr string + }{ + {signedProposal: nil}, + {signedProposal: proto.Clone(validSignedProposal).(*peer.SignedProposal)}, + { + signedProposal: &peer.SignedProposal{ProposalBytes: []byte("garbage")}, + expectedErr: "failed to extract Proposal from SignedProposal", + }, + { + signedProposal: &peer.SignedProposal{}, + expectedErr: "failed to extract Proposal fields: proposal header is nil", + }, + { + signedProposal: &peer.SignedProposal{}, + expectedErr: "failed to extract Proposal fields: proposal header is nil", + }, + { + signedProposal: &peer.SignedProposal{ + ProposalBytes: marshalOrPanic(&peer.Proposal{ + Header: marshalOrPanic(&common.Header{ + ChannelHeader: marshalOrPanic(&common.ChannelHeader{ + Type: int32(common.HeaderType_CONFIG_UPDATE), + Epoch: expectedEpoch, + }), + }), + }), + }, + expectedErr: "invalid channel header type. Expected ENDORSER_TRANSACTION or CONFIG, received CONFIG_UPDATE", + }, + } + + for _, tt := range tests { + stub, err := newChaincodeStub( + &Handler{}, + "channel-id", + "transaction-id", + &peer.ChaincodeInput{Args: expectedArgs[:], Decorations: expectedDecorations}, + tt.signedProposal, + ) + if tt.expectedErr != "" { + assert.Error(t, err) + assert.ErrorContains(t, err, tt.expectedErr) + continue + } + assert.NoError(t, err) + assert.NotNil(t, stub) + + assert.Equal(t, &Handler{}, stub.handler, "expected empty handler") + assert.Equal(t, "channel-id", stub.ChannelID) + assert.Equal(t, "transaction-id", stub.TxID) + assert.Equal(t, expectedArgs, stub.args) + assert.Equal(t, expectedDecorations, stub.decorations) + assert.Equal(t, "VALIDATION_PARAMETER", stub.validationParameterMetakey) + if tt.signedProposal == nil { + assert.Nil(t, stub.proposal, "expected nil proposal") + assert.Nil(t, stub.creator, "expected nil creator") + assert.Nil(t, stub.transient, "expected nil transient") + assert.Nil(t, stub.binding, "expected nil binding") + continue + } + + prop := &peer.Proposal{} + err = proto.Unmarshal(tt.signedProposal.ProposalBytes, prop) + assert.NoError(t, err) + assert.Equal(t, prop, stub.proposal) + + assert.Equal(t, expectedCreator, stub.creator) + assert.Equal(t, expectedTransient, stub.transient) + + epoch := make([]byte, 8) + binary.LittleEndian.PutUint64(epoch, expectedEpoch) + shdr := &common.SignatureHeader{} + digest := sha256.Sum256(append(append(shdr.GetNonce(), expectedCreator...), epoch...)) + assert.Equal(t, digest[:], stub.binding) + } +} + +func TestChaincodeStubSetEvent(t *testing.T) { + stub := &ChaincodeStub{} + err := stub.SetEvent("", []byte("event payload")) + assert.EqualError(t, err, "event name can not be empty string") + assert.Nil(t, stub.chaincodeEvent) + + stub = &ChaincodeStub{} + err = stub.SetEvent("name", []byte("payload")) + assert.NoError(t, err) + assert.Equal(t, &peer.ChaincodeEvent{EventName: "name", Payload: []byte("payload")}, stub.chaincodeEvent) +} + +func TestChaincodeStubAccessors(t *testing.T) { + stub := &ChaincodeStub{TxID: "transaction-id"} + assert.Equal(t, "transaction-id", stub.GetTxID()) + + stub = &ChaincodeStub{ChannelID: "channel-id"} + assert.Equal(t, "channel-id", stub.GetChannelID()) + + stub = &ChaincodeStub{decorations: map[string][]byte{"key": []byte("value")}} + assert.Equal(t, map[string][]byte{"key": []byte("value")}, stub.GetDecorations()) + + stub = &ChaincodeStub{args: [][]byte{[]byte("function"), []byte("arg1"), []byte("arg2")}} + assert.Equal(t, [][]byte{[]byte("function"), []byte("arg1"), []byte("arg2")}, stub.GetArgs()) + assert.Equal(t, []string{"function", "arg1", "arg2"}, stub.GetStringArgs()) + + f, a := stub.GetFunctionAndParameters() + assert.Equal(t, "function", f) + assert.Equal(t, []string{"arg1", "arg2"}, a) + + as, err := stub.GetArgsSlice() + assert.NoError(t, err) + assert.Equal(t, []byte("functionarg1arg2"), as) + + stub = &ChaincodeStub{} + f, a = stub.GetFunctionAndParameters() + assert.Equal(t, "", f) + assert.Empty(t, a) + + stub = &ChaincodeStub{creator: []byte("creator")} + creator, err := stub.GetCreator() + assert.NoError(t, err) + assert.Equal(t, []byte("creator"), creator) + + stub = &ChaincodeStub{transient: map[string][]byte{"key": []byte("value")}} + transient, err := stub.GetTransient() + assert.NoError(t, err) + assert.Equal(t, map[string][]byte{"key": []byte("value")}, transient) + + stub = &ChaincodeStub{binding: []byte("binding")} + binding, err := stub.GetBinding() + assert.NoError(t, err) + assert.Equal(t, []byte("binding"), binding) + + stub = &ChaincodeStub{signedProposal: &peer.SignedProposal{ProposalBytes: []byte("proposal-bytes")}} + sp, err := stub.GetSignedProposal() + assert.NoError(t, err) + assert.Equal(t, &peer.SignedProposal{ProposalBytes: []byte("proposal-bytes")}, sp) +} + +func TestChaincodeStubGetTxTimestamp(t *testing.T) { + now := timestamppb.Now() + tests := []struct { + proposal *peer.Proposal + ts *timestamppb.Timestamp + expectedErr string + }{ + { + ts: now, + proposal: &peer.Proposal{ + Header: marshalOrPanic(&common.Header{ + ChannelHeader: marshalOrPanic(&common.ChannelHeader{ + Timestamp: now, + }), + }), + }, + }, + { + proposal: &peer.Proposal{ + Header: marshalOrPanic(&common.Header{ + ChannelHeader: []byte("garbage-channel-header"), + }), + }, + expectedErr: "error unmarshaling ChannelHeader", + }, + { + proposal: &peer.Proposal{Header: []byte("garbage-header")}, + expectedErr: "error unmarshaling Header", + }, + } + + for _, tt := range tests { + stub := &ChaincodeStub{proposal: tt.proposal} + ts, err := stub.GetTxTimestamp() + if tt.expectedErr != "" { + assert.ErrorContains(t, err, tt.expectedErr) + continue + } + + assert.NoError(t, err) + assert.True(t, proto.Equal(ts, tt.ts)) + } +} + +func TestGetMSPID(t *testing.T) { + _, err := GetMSPID() + assert.EqualError(t, err, "'CORE_PEER_LOCALMSPID' is not set") + + os.Setenv("CORE_PEER_LOCALMSPID", "mspid") + + mspid, err := GetMSPID() + assert.NoError(t, err) + assert.Equal(t, "mspid", mspid) + + os.Unsetenv("CORE_PEER_LOCALMSPID") +} + +func TestChaincodeStubHandlers(t *testing.T) { + var tests = []struct { + name string + resType peer.ChaincodeMessage_Type + payload []byte + testFunc func(*ChaincodeStub, *Handler, *testing.T, []byte) + }{ + { + name: "Simple Response", + resType: peer.ChaincodeMessage_RESPONSE, + payload: []byte("myvalue"), + testFunc: func(s *ChaincodeStub, h *Handler, t *testing.T, payload []byte) { + resp, err := s.GetState("key") + if err != nil { + t.Fatalf("Unexpected error for GetState: %s", err) + } + assert.Equal(t, payload, resp) + + resp, err = s.GetPrivateData("col", "key") + if err != nil { + t.Fatalf("Unexpected error for GetState: %s", err) + } + assert.Equal(t, payload, resp) + _, err = s.GetPrivateData("", "key") + assert.EqualError(t, err, "collection must not be an empty string") + + resp, err = s.GetPrivateDataHash("col", "key") + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataHash: %s", err) + } + assert.Equal(t, payload, resp) + _, err = s.GetPrivateDataHash("", "key") + assert.EqualError(t, err, "collection must not be an empty string") + + err = s.PutState("key", payload) + assert.NoError(t, err) + + err = s.PutPrivateData("col", "key", payload) + assert.NoError(t, err) + err = s.PutPrivateData("", "key", payload) + assert.EqualError(t, err, "collection must not be an empty string") + err = s.PutPrivateData("col", "", payload) + assert.EqualError(t, err, "key must not be an empty string") + + err = s.SetStateValidationParameter("key", payload) + assert.NoError(t, err) + + err = s.SetPrivateDataValidationParameter("col", "key", payload) + assert.NoError(t, err) + + err = s.DelState("key") + assert.NoError(t, err) + + err = s.DelPrivateData("col", "key") + assert.NoError(t, err) + err = s.DelPrivateData("", "key") + assert.EqualError(t, err, "collection must not be an empty string") + + err = s.PurgePrivateData("col", "key") + assert.NoError(t, err) + err = s.PurgePrivateData("", "key") + assert.EqualError(t, err, "collection must not be an empty string") + + }, + }, + { + name: "ValidationParameter", + resType: peer.ChaincodeMessage_RESPONSE, + payload: marshalOrPanic( + &peer.StateMetadataResult{ + Entries: []*peer.StateMetadata{ + { + Metakey: "mkey", + Value: []byte("metavalue"), + }, + }, + }, + ), + testFunc: func(s *ChaincodeStub, h *Handler, t *testing.T, payload []byte) { + resp, err := s.GetStateValidationParameter("key") + if err != nil { + t.Fatalf("Unexpected error for GetStateValidationParameter: %s", err) + } + assert.Equal(t, []byte("metavalue"), resp) + + resp, err = s.GetPrivateDataValidationParameter("col", "key") + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataValidationParameter: %s", err) + } + assert.Equal(t, []byte("metavalue"), resp) + }, + }, + { + name: "InvokeChaincode", + resType: peer.ChaincodeMessage_RESPONSE, + payload: marshalOrPanic( + &peer.ChaincodeMessage{ + Type: peer.ChaincodeMessage_COMPLETED, + Payload: marshalOrPanic( + &peer.Response{ + Status: OK, + Payload: []byte("invokechaincode"), + }, + ), + }, + ), + testFunc: func(s *ChaincodeStub, h *Handler, t *testing.T, payload []byte) { + resp := s.InvokeChaincode("cc", [][]byte{}, "channel") + assert.Equal(t, resp.Payload, []byte("invokechaincode")) + }, + }, + { + name: "QueryResponse", + resType: peer.ChaincodeMessage_RESPONSE, + payload: marshalOrPanic( + &peer.QueryResponse{ + Results: []*peer.QueryResultBytes{ + { + ResultBytes: marshalOrPanic( + &queryresult.KV{ + Key: "querykey", + Value: []byte("queryvalue"), + }, + ), + }, + }, + Metadata: marshalOrPanic( + &peer.QueryResponseMetadata{ + Bookmark: "book", + FetchedRecordsCount: 1, + }, + ), + HasMore: true, + }, + ), + testFunc: func(s *ChaincodeStub, h *Handler, t *testing.T, payload []byte) { + expectedResult := &queryresult.KV{ + Key: "querykey", + Value: []byte("queryvalue"), + } + + // stub stuff + sqi, err := s.GetQueryResult("query") + if err != nil { + t.Fatalf("Unexpected error for GetQueryResult: %s", err) + } + kv, err := sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetQueryResult: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + + sqi, err = s.GetPrivateDataQueryResult("col", "query") + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataQueryResult: %s", err) + } + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataQueryResult: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + + _, err = s.GetPrivateDataQueryResult("", "query") + assert.EqualError(t, err, "collection must not be an empty string") + + sqi, err = s.GetStateByRange("", "end") + if err != nil { + t.Fatalf("Unexpected error for GetStateByRange: %s", err) + } + // first result + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetStateByRange: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + // second result + assert.True(t, sqi.HasNext()) + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetStateByRange: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + err = sqi.Close() + assert.NoError(t, err) + + sqi, qrm, err := s.GetStateByRangeWithPagination("", "end", 1, "book") + assert.NoError(t, err) + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetStateByRangeWithPagination: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + assert.Equal(t, "book", qrm.GetBookmark()) + assert.Equal(t, int32(1), qrm.GetFetchedRecordsCount()) + + sqi, err = s.GetPrivateDataByRange("col", "", "end") + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataByRange: %s", err) + } + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataByRange: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + + _, err = s.GetPrivateDataByRange("", "", "end") + assert.EqualError(t, err, "collection must not be an empty string") + + sqi, err = s.GetStateByPartialCompositeKey("object", []string{"attr1", "attr2"}) + assert.NoError(t, err) + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetStateByPartialCompositeKey: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + + sqi, err = s.GetPrivateDataByPartialCompositeKey("col", "object", []string{"attr1", "attr2"}) + assert.NoError(t, err) + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataByPartialCompositeKey: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + + _, err = s.GetPrivateDataByPartialCompositeKey("", "object", []string{"attr1", "attr2"}) + assert.EqualError(t, err, "collection must not be an empty string") + + sqi, qrm, err = s.GetStateByPartialCompositeKeyWithPagination( + "object", + []string{"key1", "key2"}, + 1, + "book", + ) + assert.NoError(t, err) + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetStateByPartialCompositeKeyWithPagination: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + assert.Equal(t, "book", qrm.GetBookmark()) + assert.Equal(t, int32(1), qrm.GetFetchedRecordsCount()) + + sqi, qrm, err = s.GetQueryResultWithPagination("query", 1, "book") + assert.NoError(t, err) + kv, err = sqi.Next() + if err != nil { + t.Fatalf("Unexpected error forGetQueryResultWithPagination: %s", err) + } + requireProtoEqual(t, expectedResult, kv) + assert.Equal(t, "book", qrm.GetBookmark()) + assert.Equal(t, int32(1), qrm.GetFetchedRecordsCount()) + }, + }, + { + name: "GetHistoryForKey", + resType: peer.ChaincodeMessage_RESPONSE, + payload: marshalOrPanic( + &peer.QueryResponse{ + Results: []*peer.QueryResultBytes{ + { + ResultBytes: marshalOrPanic( + &queryresult.KeyModification{ + TxId: "txid", + Value: []byte("historyforkey"), + }, + ), + }, + }, + HasMore: false, + }, + ), + testFunc: func(s *ChaincodeStub, h *Handler, t *testing.T, payload []byte) { + expectedResult := &queryresult.KeyModification{ + TxId: "txid", + Value: []byte("historyforkey"), + } + hqi, err := s.GetHistoryForKey("key") + if err != nil { + t.Fatalf("Unexpected error for GetHistoryForKey: %s", err) + } + km, err := hqi.Next() + if err != nil { + t.Fatalf("Unexpected error for GetPrivateDataByRangee: %s", err) + } + requireProtoEqual(t, expectedResult, km) + assert.False(t, hqi.HasNext()) + }, + }, + { + name: "Error Conditions", + resType: peer.ChaincodeMessage_ERROR, + payload: []byte("error"), + testFunc: func(s *ChaincodeStub, h *Handler, t *testing.T, payload []byte) { + _, err := s.GetState("key") + assert.EqualError(t, err, string(payload)) + + _, err = s.GetPrivateDataHash("col", "key") + assert.EqualError(t, err, string(payload)) + + _, err = s.GetStateValidationParameter("key") + assert.EqualError(t, err, string(payload)) + + err = s.PutState("key", payload) + assert.EqualError(t, err, string(payload)) + + err = s.SetPrivateDataValidationParameter("col", "key", payload) + assert.EqualError(t, err, string(payload)) + + err = s.DelState("key") + assert.EqualError(t, err, string(payload)) + + _, err = s.GetStateByRange("start", "end") + assert.EqualError(t, err, string(payload)) + + _, err = s.GetQueryResult("query") + assert.EqualError(t, err, string(payload)) + + _, err = s.GetHistoryForKey("key") + assert.EqualError(t, err, string(payload)) + + resp := s.InvokeChaincode("cc", [][]byte{}, "channel") + assert.Equal(t, payload, resp.GetPayload()) + + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + handler := &Handler{ + cc: &mockChaincode{}, + responseChannels: map[string]chan *peer.ChaincodeMessage{}, + state: ready, + } + stub := &ChaincodeStub{ + ChannelID: "channel", + TxID: "txid", + handler: handler, + validationParameterMetakey: "mkey", + } + chatStream := &mock.PeerChaincodeStream{} + chatStream.SendStub = func(msg *peer.ChaincodeMessage) error { + go func() { + err := handler.handleResponse( + &peer.ChaincodeMessage{ + Type: test.resType, + ChannelId: msg.GetChannelId(), + Txid: msg.GetTxid(), + Payload: test.payload, + }, + ) + assert.NoError(t, err, "handleResponse") + }() + return nil + } + handler.chatStream = chatStream + test.testFunc(stub, handler, t, test.payload) + }) + } +} diff --git a/v2/shimtest/mock/chaincode.go b/v2/shimtest/mock/chaincode.go new file mode 100644 index 0000000..2ccf99f --- /dev/null +++ b/v2/shimtest/mock/chaincode.go @@ -0,0 +1,184 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package mock + +import ( + "sync" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" +) + +type Chaincode struct { + InitStub func(shim.ChaincodeStubInterface) *peer.Response + initMutex sync.RWMutex + initArgsForCall []struct { + arg1 shim.ChaincodeStubInterface + } + initReturns struct { + result1 *peer.Response + } + initReturnsOnCall map[int]struct { + result1 *peer.Response + } + InvokeStub func(shim.ChaincodeStubInterface) *peer.Response + invokeMutex sync.RWMutex + invokeArgsForCall []struct { + arg1 shim.ChaincodeStubInterface + } + invokeReturns struct { + result1 *peer.Response + } + invokeReturnsOnCall map[int]struct { + result1 *peer.Response + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *Chaincode) Init(arg1 shim.ChaincodeStubInterface) *peer.Response { + fake.initMutex.Lock() + ret, specificReturn := fake.initReturnsOnCall[len(fake.initArgsForCall)] + fake.initArgsForCall = append(fake.initArgsForCall, struct { + arg1 shim.ChaincodeStubInterface + }{arg1}) + stub := fake.InitStub + fakeReturns := fake.initReturns + fake.recordInvocation("Init", []interface{}{arg1}) + fake.initMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *Chaincode) InitCallCount() int { + fake.initMutex.RLock() + defer fake.initMutex.RUnlock() + return len(fake.initArgsForCall) +} + +func (fake *Chaincode) InitCalls(stub func(shim.ChaincodeStubInterface) *peer.Response) { + fake.initMutex.Lock() + defer fake.initMutex.Unlock() + fake.InitStub = stub +} + +func (fake *Chaincode) InitArgsForCall(i int) shim.ChaincodeStubInterface { + fake.initMutex.RLock() + defer fake.initMutex.RUnlock() + argsForCall := fake.initArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *Chaincode) InitReturns(result1 *peer.Response) { + fake.initMutex.Lock() + defer fake.initMutex.Unlock() + fake.InitStub = nil + fake.initReturns = struct { + result1 *peer.Response + }{result1} +} + +func (fake *Chaincode) InitReturnsOnCall(i int, result1 *peer.Response) { + fake.initMutex.Lock() + defer fake.initMutex.Unlock() + fake.InitStub = nil + if fake.initReturnsOnCall == nil { + fake.initReturnsOnCall = make(map[int]struct { + result1 *peer.Response + }) + } + fake.initReturnsOnCall[i] = struct { + result1 *peer.Response + }{result1} +} + +func (fake *Chaincode) Invoke(arg1 shim.ChaincodeStubInterface) *peer.Response { + fake.invokeMutex.Lock() + ret, specificReturn := fake.invokeReturnsOnCall[len(fake.invokeArgsForCall)] + fake.invokeArgsForCall = append(fake.invokeArgsForCall, struct { + arg1 shim.ChaincodeStubInterface + }{arg1}) + stub := fake.InvokeStub + fakeReturns := fake.invokeReturns + fake.recordInvocation("Invoke", []interface{}{arg1}) + fake.invokeMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *Chaincode) InvokeCallCount() int { + fake.invokeMutex.RLock() + defer fake.invokeMutex.RUnlock() + return len(fake.invokeArgsForCall) +} + +func (fake *Chaincode) InvokeCalls(stub func(shim.ChaincodeStubInterface) *peer.Response) { + fake.invokeMutex.Lock() + defer fake.invokeMutex.Unlock() + fake.InvokeStub = stub +} + +func (fake *Chaincode) InvokeArgsForCall(i int) shim.ChaincodeStubInterface { + fake.invokeMutex.RLock() + defer fake.invokeMutex.RUnlock() + argsForCall := fake.invokeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *Chaincode) InvokeReturns(result1 *peer.Response) { + fake.invokeMutex.Lock() + defer fake.invokeMutex.Unlock() + fake.InvokeStub = nil + fake.invokeReturns = struct { + result1 *peer.Response + }{result1} +} + +func (fake *Chaincode) InvokeReturnsOnCall(i int, result1 *peer.Response) { + fake.invokeMutex.Lock() + defer fake.invokeMutex.Unlock() + fake.InvokeStub = nil + if fake.invokeReturnsOnCall == nil { + fake.invokeReturnsOnCall = make(map[int]struct { + result1 *peer.Response + }) + } + fake.invokeReturnsOnCall[i] = struct { + result1 *peer.Response + }{result1} +} + +func (fake *Chaincode) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.initMutex.RLock() + defer fake.initMutex.RUnlock() + fake.invokeMutex.RLock() + defer fake.invokeMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *Chaincode) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} diff --git a/v2/shimtest/mockstub.go b/v2/shimtest/mockstub.go new file mode 100644 index 0000000..6c9260e --- /dev/null +++ b/v2/shimtest/mockstub.go @@ -0,0 +1,632 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package shimtest provides a mock of the ChaincodeStubInterface for +// unit testing chaincode. +// +// Deprecated: ShimTest will be removed in a future release. +// Future development should make use of the ChaincodeStub Interface +// for generating mocks +package shimtest + +import ( + "container/list" + "errors" + "fmt" + "strings" + "unicode/utf8" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim" + "github.com/hyperledger/fabric-protos-go-apiv2/ledger/queryresult" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + minUnicodeRuneValue = 0 //U+0000 + compositeKeyNamespace = "\x00" +) + +// MockStub is an implementation of ChaincodeStubInterface for unit testing chaincode. +// Use this instead of ChaincodeStub in your chaincode's unit test calls to Init or Invoke. +type MockStub struct { + // arguments the stub was called with + args [][]byte + + // transientMap + TransientMap map[string][]byte + // A pointer back to the chaincode that will invoke this, set by constructor. + // If a peer calls this stub, the chaincode will be invoked from here. + cc shim.Chaincode + + // A nice name that can be used for logging + Name string + + // State keeps name value pairs + State map[string][]byte + + // Keys stores the list of mapped values in lexical order + Keys *list.List + + // registered list of other MockStub chaincodes that can be called from this MockStub + Invokables map[string]*MockStub + + // stores a transaction uuid while being Invoked / Deployed + // TODO if a chaincode uses recursion this may need to be a stack of TxIDs or possibly a reference counting map + TxID string + + TxTimestamp *timestamppb.Timestamp + + // mocked signedProposal + signedProposal *peer.SignedProposal + + // stores a channel ID of the proposal + ChannelID string + + PvtState map[string]map[string][]byte + + // stores per-key endorsement policy, first map index is the collection, second map index is the key + EndorsementPolicies map[string]map[string][]byte + + // channel to store ChaincodeEvents + ChaincodeEventsChannel chan *peer.ChaincodeEvent + + Creator []byte + + Decorations map[string][]byte +} + +// GetTxID ... +func (stub *MockStub) GetTxID() string { + return stub.TxID +} + +// GetChannelID ... +func (stub *MockStub) GetChannelID() string { + return stub.ChannelID +} + +// GetArgs ... +func (stub *MockStub) GetArgs() [][]byte { + return stub.args +} + +// GetStringArgs ... +func (stub *MockStub) GetStringArgs() []string { + args := stub.GetArgs() + strargs := make([]string, 0, len(args)) + for _, barg := range args { + strargs = append(strargs, string(barg)) + } + return strargs +} + +// GetFunctionAndParameters ... +func (stub *MockStub) GetFunctionAndParameters() (function string, params []string) { + allargs := stub.GetStringArgs() + function = "" + params = []string{} + if len(allargs) >= 1 { + function = allargs[0] + params = allargs[1:] + } + return +} + +// MockTransactionStart Used to indicate to a chaincode that it is part of a transaction. +// This is important when chaincodes invoke each other. +// MockStub doesn't support concurrent transactions at present. +func (stub *MockStub) MockTransactionStart(txid string) { + stub.TxID = txid + stub.setSignedProposal(&peer.SignedProposal{}) + stub.setTxTimestamp(timestamppb.Now()) +} + +// MockTransactionEnd End a mocked transaction, clearing the UUID. +func (stub *MockStub) MockTransactionEnd(uuid string) { + stub.signedProposal = nil + stub.TxID = "" +} + +// MockPeerChaincode Register another MockStub chaincode with this MockStub. +// invokableChaincodeName is the name of a chaincode. +// otherStub is a MockStub of the chaincode, already initialized. +// channel is the name of a channel on which another MockStub is called. +func (stub *MockStub) MockPeerChaincode(invokableChaincodeName string, otherStub *MockStub, channel string) { + // Internally we use chaincode name as a composite name + if channel != "" { + invokableChaincodeName = invokableChaincodeName + "/" + channel + } + stub.Invokables[invokableChaincodeName] = otherStub +} + +// MockInit Initialise this chaincode, also starts and ends a transaction. +func (stub *MockStub) MockInit(uuid string, args [][]byte) *peer.Response { + stub.args = args + stub.MockTransactionStart(uuid) + res := stub.cc.Init(stub) + stub.MockTransactionEnd(uuid) + return res +} + +// MockInvoke Invoke this chaincode, also starts and ends a transaction. +func (stub *MockStub) MockInvoke(uuid string, args [][]byte) *peer.Response { + stub.args = args + stub.MockTransactionStart(uuid) + res := stub.cc.Invoke(stub) + stub.MockTransactionEnd(uuid) + return res +} + +// GetDecorations ... +func (stub *MockStub) GetDecorations() map[string][]byte { + return stub.Decorations +} + +// MockInvokeWithSignedProposal Invoke this chaincode, also starts and ends a transaction. +func (stub *MockStub) MockInvokeWithSignedProposal(uuid string, args [][]byte, sp *peer.SignedProposal) *peer.Response { + stub.args = args + stub.MockTransactionStart(uuid) + stub.signedProposal = sp + res := stub.cc.Invoke(stub) + stub.MockTransactionEnd(uuid) + return res +} + +// GetPrivateData ... +func (stub *MockStub) GetPrivateData(collection string, key string) ([]byte, error) { + m, in := stub.PvtState[collection] + + if !in { + return nil, nil + } + + return m[key], nil +} + +// GetPrivateDataHash ... +func (stub *MockStub) GetPrivateDataHash(collection, key string) ([]byte, error) { + return nil, errors.New("not Implemented") +} + +// PutPrivateData ... +func (stub *MockStub) PutPrivateData(collection string, key string, value []byte) error { + m, in := stub.PvtState[collection] + if !in { + stub.PvtState[collection] = make(map[string][]byte) + m = stub.PvtState[collection] + } + + m[key] = value + + return nil +} + +// DelPrivateData ... +func (stub *MockStub) DelPrivateData(collection string, key string) error { + return errors.New("not Implemented") +} + +// PurgePrivateData ... +func (stub *MockStub) PurgePrivateData(collection string, key string) error { + return errors.New("not Implemented") +} + +// GetPrivateDataByRange ... +func (stub *MockStub) GetPrivateDataByRange(collection, startKey, endKey string) (shim.StateQueryIteratorInterface, error) { + return nil, errors.New("not Implemented") +} + +// GetPrivateDataByPartialCompositeKey ... +func (stub *MockStub) GetPrivateDataByPartialCompositeKey(collection, objectType string, attributes []string) (shim.StateQueryIteratorInterface, error) { + return nil, errors.New("not Implemented") +} + +// GetPrivateDataQueryResult ... +func (stub *MockStub) GetPrivateDataQueryResult(collection, query string) (shim.StateQueryIteratorInterface, error) { + // Not implemented since the mock engine does not have a query engine. + // However, a very simple query engine that supports string matching + // could be implemented to test that the framework supports queries + return nil, errors.New("not Implemented") +} + +// GetState retrieves the value for a given key from the ledger +func (stub *MockStub) GetState(key string) ([]byte, error) { + value := stub.State[key] + return value, nil +} + +// PutState writes the specified `value` and `key` into the ledger. +func (stub *MockStub) PutState(key string, value []byte) error { + if stub.TxID == "" { + err := errors.New("cannot PutState without a transactions - call stub.MockTransactionStart()?") + return err + } + + // If the value is nil or empty, delete the key + if len(value) == 0 { + return stub.DelState(key) + } + stub.State[key] = value + + // insert key into ordered list of keys + for elem := stub.Keys.Front(); elem != nil; elem = elem.Next() { + elemValue := elem.Value.(string) + comp := strings.Compare(key, elemValue) + if comp < 0 { + // key < elem, insert it before elem + stub.Keys.InsertBefore(key, elem) + break + } else if comp == 0 { + // keys exists, no need to change + break + } else { // comp > 0 + // key > elem, keep looking unless this is the end of the list + if elem.Next() == nil { + stub.Keys.PushBack(key) + break + } + } + } + + // special case for empty Keys list + if stub.Keys.Len() == 0 { + stub.Keys.PushFront(key) + } + + return nil +} + +// DelState removes the specified `key` and its value from the ledger. +func (stub *MockStub) DelState(key string) error { + delete(stub.State, key) + + for elem := stub.Keys.Front(); elem != nil; elem = elem.Next() { + if strings.Compare(key, elem.Value.(string)) == 0 { + stub.Keys.Remove(elem) + } + } + + return nil +} + +// GetStateByRange ... +func (stub *MockStub) GetStateByRange(startKey, endKey string) (shim.StateQueryIteratorInterface, error) { + if err := validateSimpleKeys(startKey, endKey); err != nil { + return nil, err + } + return NewMockStateRangeQueryIterator(stub, startKey, endKey), nil +} + +// To ensure that simple keys do not go into composite key namespace, +// we validate simplekey to check whether the key starts with 0x00 (which +// is the namespace for compositeKey). This helps in avoding simple/composite +// key collisions. +func validateSimpleKeys(simpleKeys ...string) error { + for _, key := range simpleKeys { + if len(key) > 0 && key[0] == compositeKeyNamespace[0] { + return fmt.Errorf(`first character of the key [%s] contains a null character which is not allowed`, key) + } + } + return nil +} + +// GetQueryResult function can be invoked by a chaincode to perform a +// rich query against state database. Only supported by state database implementations +// that support rich query. The query string is in the syntax of the underlying +// state database. An iterator is returned which can be used to iterate (next) over +// the query result set +func (stub *MockStub) GetQueryResult(query string) (shim.StateQueryIteratorInterface, error) { + // Not implemented since the mock engine does not have a query engine. + // However, a very simple query engine that supports string matching + // could be implemented to test that the framework supports queries + return nil, errors.New("not implemented") +} + +// GetHistoryForKey function can be invoked by a chaincode to return a history of +// key values across time. GetHistoryForKey is intended to be used for read-only queries. +func (stub *MockStub) GetHistoryForKey(key string) (shim.HistoryQueryIteratorInterface, error) { + return nil, errors.New("not implemented") +} + +// GetStateByPartialCompositeKey function can be invoked by a chaincode to query the +// state based on a given partial composite key. This function returns an +// iterator which can be used to iterate over all composite keys whose prefix +// matches the given partial composite key. This function should be used only for +// a partial composite key. For a full composite key, an iter with empty response +// would be returned. +func (stub *MockStub) GetStateByPartialCompositeKey(objectType string, attributes []string) (shim.StateQueryIteratorInterface, error) { + partialCompositeKey, err := stub.CreateCompositeKey(objectType, attributes) + if err != nil { + return nil, err + } + return NewMockStateRangeQueryIterator(stub, partialCompositeKey, partialCompositeKey+string(utf8.MaxRune)), nil +} + +// CreateCompositeKey combines the list of attributes +// to form a composite key. +func (stub *MockStub) CreateCompositeKey(objectType string, attributes []string) (string, error) { + return shim.CreateCompositeKey(objectType, attributes) +} + +// SplitCompositeKey splits the composite key into attributes +// on which the composite key was formed. +func (stub *MockStub) SplitCompositeKey(compositeKey string) (string, []string, error) { + return splitCompositeKey(compositeKey) +} + +func splitCompositeKey(compositeKey string) (string, []string, error) { + componentIndex := 1 + components := []string{} + for i := 1; i < len(compositeKey); i++ { + if compositeKey[i] == minUnicodeRuneValue { + components = append(components, compositeKey[componentIndex:i]) + componentIndex = i + 1 + } + } + return components[0], components[1:], nil +} + +// GetStateByRangeWithPagination ... +func (stub *MockStub) GetStateByRangeWithPagination(startKey, endKey string, pageSize int32, + bookmark string) (shim.StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + return nil, nil, nil +} + +// GetStateByPartialCompositeKeyWithPagination ... +func (stub *MockStub) GetStateByPartialCompositeKeyWithPagination(objectType string, keys []string, + pageSize int32, bookmark string) (shim.StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + return nil, nil, nil +} + +// GetQueryResultWithPagination ... +func (stub *MockStub) GetQueryResultWithPagination(query string, pageSize int32, + bookmark string) (shim.StateQueryIteratorInterface, *peer.QueryResponseMetadata, error) { + return nil, nil, nil +} + +// InvokeChaincode locally calls the specified chaincode `Invoke`. +// E.g. stub1.InvokeChaincode("othercc", funcArgs, channel) +// Before calling this make sure to create another MockStub stub2, call shim.NewMockStub("othercc", Chaincode) +// and register it with stub1 by calling stub1.MockPeerChaincode("othercc", stub2, channel) +func (stub *MockStub) InvokeChaincode(chaincodeName string, args [][]byte, channel string) *peer.Response { + // Internally we use chaincode name as a composite name + if channel != "" { + chaincodeName = chaincodeName + "/" + channel + } + // TODO "args" here should possibly be a serialized peer.ChaincodeInput + otherStub := stub.Invokables[chaincodeName] + // function, strings := getFuncArgs(args) + res := otherStub.MockInvoke(stub.TxID, args) + return res +} + +// GetCreator ... +func (stub *MockStub) GetCreator() ([]byte, error) { + return stub.Creator, nil +} + +// SetTransient set TransientMap to mockStub +func (stub *MockStub) SetTransient(tMap map[string][]byte) error { + if stub.signedProposal == nil { + return fmt.Errorf("signedProposal is not initialized") + } + payloadByte, err := proto.Marshal(&peer.ChaincodeProposalPayload{ + TransientMap: tMap, + }) + if err != nil { + return err + } + proposalByte, err := proto.Marshal(&peer.Proposal{ + Payload: payloadByte, + }) + if err != nil { + return err + } + stub.signedProposal.ProposalBytes = proposalByte + stub.TransientMap = tMap + return nil +} + +// GetTransient ... +func (stub *MockStub) GetTransient() (map[string][]byte, error) { + return stub.TransientMap, nil +} + +// GetBinding Not implemented ... +func (stub *MockStub) GetBinding() ([]byte, error) { + return nil, nil +} + +// GetSignedProposal Not implemented ... +func (stub *MockStub) GetSignedProposal() (*peer.SignedProposal, error) { + return stub.signedProposal, nil +} + +func (stub *MockStub) setSignedProposal(sp *peer.SignedProposal) { + stub.signedProposal = sp +} + +// GetArgsSlice Not implemented ... +func (stub *MockStub) GetArgsSlice() ([]byte, error) { + return nil, nil +} + +func (stub *MockStub) setTxTimestamp(time *timestamppb.Timestamp) { + stub.TxTimestamp = time +} + +// GetTxTimestamp ... +func (stub *MockStub) GetTxTimestamp() (*timestamppb.Timestamp, error) { + if stub.TxTimestamp == nil { + return nil, errors.New("TxTimestamp not set") + } + return stub.TxTimestamp, nil +} + +// SetEvent ... +func (stub *MockStub) SetEvent(name string, payload []byte) error { + stub.ChaincodeEventsChannel <- &peer.ChaincodeEvent{EventName: name, Payload: payload} + return nil +} + +// SetStateValidationParameter ... +func (stub *MockStub) SetStateValidationParameter(key string, ep []byte) error { + return stub.SetPrivateDataValidationParameter("", key, ep) +} + +// GetStateValidationParameter ... +func (stub *MockStub) GetStateValidationParameter(key string) ([]byte, error) { + return stub.GetPrivateDataValidationParameter("", key) +} + +// SetPrivateDataValidationParameter ... +func (stub *MockStub) SetPrivateDataValidationParameter(collection, key string, ep []byte) error { + m, in := stub.EndorsementPolicies[collection] + if !in { + stub.EndorsementPolicies[collection] = make(map[string][]byte) + m = stub.EndorsementPolicies[collection] + } + + m[key] = ep + return nil +} + +// GetPrivateDataValidationParameter ... +func (stub *MockStub) GetPrivateDataValidationParameter(collection, key string) ([]byte, error) { + m, in := stub.EndorsementPolicies[collection] + + if !in { + return nil, nil + } + + return m[key], nil +} + +// NewMockStub Constructor to initialise the internal State map +func NewMockStub(name string, cc shim.Chaincode) *MockStub { + s := new(MockStub) + s.Name = name + s.cc = cc + s.State = make(map[string][]byte) + s.PvtState = make(map[string]map[string][]byte) + s.EndorsementPolicies = make(map[string]map[string][]byte) + s.Invokables = make(map[string]*MockStub) + s.Keys = list.New() + s.ChaincodeEventsChannel = make(chan *peer.ChaincodeEvent, 100) //define large capacity for non-blocking setEvent calls. + s.Decorations = make(map[string][]byte) + + return s +} + +/***************************** + Range Query Iterator +*****************************/ + +// MockStateRangeQueryIterator ... +type MockStateRangeQueryIterator struct { + Closed bool + Stub *MockStub + StartKey string + EndKey string + Current *list.Element +} + +// HasNext returns true if the range query iterator contains additional keys +// and values. +func (iter *MockStateRangeQueryIterator) HasNext() bool { + if iter.Closed { + // previously called Close() + return false + } + + if iter.Current == nil { + return false + } + + current := iter.Current + for current != nil { + // if this is an open-ended query for all keys, return true + if iter.StartKey == "" && iter.EndKey == "" { + return true + } + comp1 := strings.Compare(current.Value.(string), iter.StartKey) + comp2 := strings.Compare(current.Value.(string), iter.EndKey) + if comp1 >= 0 { + return comp2 < 0 + } + current = current.Next() + } + return false +} + +// Next returns the next key and value in the range query iterator. +func (iter *MockStateRangeQueryIterator) Next() (*queryresult.KV, error) { + if iter.Closed { + err := errors.New("MockStateRangeQueryIterator.Next() called after Close()") + return nil, err + } + + if !iter.HasNext() { + err := errors.New("MockStateRangeQueryIterator.Next() called when it does not HaveNext()") + return nil, err + } + + for iter.Current != nil { + comp1 := strings.Compare(iter.Current.Value.(string), iter.StartKey) + comp2 := strings.Compare(iter.Current.Value.(string), iter.EndKey) + // compare to start and end keys. or, if this is an open-ended query for + // all keys, it should always return the key and value + if (comp1 >= 0 && comp2 < 0) || (iter.StartKey == "" && iter.EndKey == "") { + key := iter.Current.Value.(string) + value, err := iter.Stub.GetState(key) + iter.Current = iter.Current.Next() + return &queryresult.KV{Key: key, Value: value}, err + } + iter.Current = iter.Current.Next() + } + err := errors.New("MockStateRangeQueryIterator.Next() went past end of range") + return nil, err +} + +// Close closes the range query iterator. This should be called when done +// reading from the iterator to free up resources. +func (iter *MockStateRangeQueryIterator) Close() error { + if iter.Closed { + err := errors.New("MockStateRangeQueryIterator.Close() called after Close()") + return err + } + + iter.Closed = true + return nil +} + +// NewMockStateRangeQueryIterator ... +func NewMockStateRangeQueryIterator(stub *MockStub, startKey string, endKey string) *MockStateRangeQueryIterator { + iter := new(MockStateRangeQueryIterator) + iter.Closed = false + iter.Stub = stub + iter.StartKey = startKey + iter.EndKey = endKey + iter.Current = stub.Keys.Front() + return iter +} + +func getBytes(function string, args []string) [][]byte { + bytes := make([][]byte, 0, len(args)+1) + bytes = append(bytes, []byte(function)) + for _, s := range args { + bytes = append(bytes, []byte(s)) + } + return bytes +} + +func getFuncArgs(bytes [][]byte) (string, []string) { + function := string(bytes[0]) + args := make([]string, len(bytes)-1) + for i := 1; i < len(bytes); i++ { + args[i-1] = string(bytes[i]) + } + return function, args +} diff --git a/v2/shimtest/mockstub_test.go b/v2/shimtest/mockstub_test.go new file mode 100644 index 0000000..bf965c9 --- /dev/null +++ b/v2/shimtest/mockstub_test.go @@ -0,0 +1,340 @@ +// Copyright the Hyperledger Fabric contributors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package shimtest + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/hyperledger/fabric-chaincode-go/v2/shim" + "github.com/hyperledger/fabric-chaincode-go/v2/shimtest/mock" + "github.com/hyperledger/fabric-protos-go-apiv2/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//go:generate counterfeiter -o mock/chaincode.go --fake-name Chaincode . chaincode +//lint:ignore U1000 Required to avoid circular dependency with mock +type chaincode interface { + shim.Chaincode +} + +func stubPutState(t *testing.T, stub *MockStub, key string, value []byte) { + err := stub.PutState(key, value) + require.NoErrorf(t, err, "PutState(%s, %s)", key, value) +} + +func stubDelState(t *testing.T, stub *MockStub, key string) { + err := stub.DelState(key) + require.NoErrorf(t, err, "DelState(%s)", key) +} + +func stubGetStateByRange(t *testing.T, stub *MockStub, start string, end string) shim.StateQueryIteratorInterface { + result, err := stub.GetStateByRange(start, end) + require.NoErrorf(t, err, "GetStateByRange(%s, %s)", start, end) + return result +} + +func stubGetCreator(t *testing.T, stub *MockStub) []byte { + result, err := stub.GetCreator() + require.NoError(t, err, "GetCreator") + return result +} + +func stubGetTransient(t *testing.T, stub *MockStub) map[string][]byte { + result, err := stub.GetTransient() + require.NoError(t, err, "GetTransient") + return result +} + +func stubGetBinding(t *testing.T, stub *MockStub) []byte { + result, err := stub.GetBinding() + require.NoError(t, err, "GetBinding") + return result +} + +func stubGetSignedProposal(t *testing.T, stub *MockStub) *peer.SignedProposal { + result, err := stub.GetSignedProposal() + require.NoError(t, err, "GetSignedProposal") + return result +} + +func stubGetArgsSlice(t *testing.T, stub *MockStub) []byte { + result, err := stub.GetArgsSlice() + require.NoError(t, err, "GetArgsSlice") + return result +} + +func stubSetEvent(t *testing.T, stub *MockStub, key string, value []byte) { + err := stub.SetEvent(key, value) + require.NoErrorf(t, err, "SetEvent(%s, %s)", key, value) +} + +func TestMockStateRangeQueryIterator(t *testing.T) { + stub := NewMockStub("rangeTest", nil) + stub.MockTransactionStart("init") + stubPutState(t, stub, "1", []byte{61}) + stubPutState(t, stub, "0", []byte{62}) + stubPutState(t, stub, "5", []byte{65}) + stubPutState(t, stub, "3", []byte{63}) + stubPutState(t, stub, "4", []byte{64}) + stubPutState(t, stub, "6", []byte{66}) + stub.MockTransactionEnd("init") + + expectKeys := []string{"3", "4"} + expectValues := [][]byte{{63}, {64}} + + rqi := NewMockStateRangeQueryIterator(stub, "2", "5") + + fmt.Println("Running loop") + for i := 0; i < 2; i++ { + response, err := rqi.Next() + fmt.Println("Loop", i, "got", response.Key, response.Value, err) + if expectKeys[i] != response.Key { + fmt.Println("Expected key", expectKeys[i], "got", response.Key) + t.FailNow() + } + if expectValues[i][0] != response.Value[0] { + fmt.Println("Expected value", expectValues[i], "got", response.Value) + } + } +} + +// TestMockStateRangeQueryIterator_openEnded tests running an open-ended query +// for all keys on the MockStateRangeQueryIterator +func TestMockStateRangeQueryIterator_openEnded(t *testing.T) { + stub := NewMockStub("rangeTest", nil) + stub.MockTransactionStart("init") + stubPutState(t, stub, "1", []byte{61}) + stubPutState(t, stub, "0", []byte{62}) + stubPutState(t, stub, "5", []byte{65}) + stubPutState(t, stub, "3", []byte{63}) + stubPutState(t, stub, "4", []byte{64}) + stubPutState(t, stub, "6", []byte{66}) + stub.MockTransactionEnd("init") + + rqi := NewMockStateRangeQueryIterator(stub, "", "") + + count := 0 + for rqi.HasNext() { + _, err := rqi.Next() + require.NoError(t, err) + count++ + } + + if count != rqi.Stub.Keys.Len() { + t.FailNow() + } +} + +type Marble struct { + ObjectType string `json:"docType"` //docType is used to distinguish the various types of objects in state database + Name string `json:"name"` //the fieldtags are needed to keep case from bouncing around + Color string `json:"color"` + Size int `json:"size"` + Owner string `json:"owner"` +} + +// JSONBytesEqual compares the JSON in two byte slices. +func jsonBytesEqual(expected []byte, actual []byte) bool { + var infExpected, infActual interface{} + if err := json.Unmarshal(expected, &infExpected); err != nil { + return false + } + if err := json.Unmarshal(actual, &infActual); err != nil { + return false + } + return reflect.DeepEqual(infActual, infExpected) +} + +func TestGetStateByPartialCompositeKey(t *testing.T) { + stub := NewMockStub("GetStateByPartialCompositeKeyTest", nil) + stub.MockTransactionStart("init") + + marble1 := &Marble{"marble", "set-1", "red", 5, "tom"} + // Convert marble1 to JSON with Color and Name as composite key + compositeKey1, _ := stub.CreateCompositeKey(marble1.ObjectType, []string{marble1.Name, marble1.Color}) + marbleJSONBytes1, _ := json.Marshal(marble1) + // Add marble1 JSON to state + stubPutState(t, stub, compositeKey1, marbleJSONBytes1) + + marble2 := &Marble{"marble", "set-1", "blue", 5, "jerry"} + compositeKey2, _ := stub.CreateCompositeKey(marble2.ObjectType, []string{marble2.Name, marble2.Color}) + marbleJSONBytes2, _ := json.Marshal(marble2) + stubPutState(t, stub, compositeKey2, marbleJSONBytes2) + + marble3 := &Marble{"marble", "set-2", "red", 5, "tom-jerry"} + compositeKey3, _ := stub.CreateCompositeKey(marble3.ObjectType, []string{marble3.Name, marble3.Color}) + marbleJSONBytes3, _ := json.Marshal(marble3) + stubPutState(t, stub, compositeKey3, marbleJSONBytes3) + + stub.MockTransactionEnd("init") + // should return in sorted order of attributes + expectKeys := []string{compositeKey2, compositeKey1} + expectKeysAttributes := [][]string{{"set-1", "blue"}, {"set-1", "red"}} + expectValues := [][]byte{marbleJSONBytes2, marbleJSONBytes1} + + rqi, _ := stub.GetStateByPartialCompositeKey("marble", []string{"set-1"}) + fmt.Println("Running loop") + for i := 0; i < 2; i++ { + response, err := rqi.Next() + fmt.Println("Loop", i, "got", response.Key, response.Value, err) + if expectKeys[i] != response.Key { + fmt.Println("Expected key", expectKeys[i], "got", response.Key) + t.FailNow() + } + objectType, attributes, _ := stub.SplitCompositeKey(response.Key) + if objectType != "marble" { + fmt.Println("Expected objectType", "marble", "got", objectType) + t.FailNow() + } + fmt.Println(attributes) + for index, attr := range attributes { + if expectKeysAttributes[i][index] != attr { + fmt.Println("Expected keys attribute", expectKeysAttributes[index][i], "got", attr) + t.FailNow() + } + } + if jsonBytesEqual(expectValues[i], response.Value) != true { + fmt.Println("Expected value", expectValues[i], "got", response.Value) + t.FailNow() + } + } +} + +func TestGetStateByPartialCompositeKeyCollision(t *testing.T) { + stub := NewMockStub("GetStateByPartialCompositeKeyCollisionTest", nil) + stub.MockTransactionStart("init") + + vehicle1Bytes := []byte("vehicle1") + compositeKeyVehicle1, _ := stub.CreateCompositeKey("Vehicle", []string{"VIN_1234"}) + stubPutState(t, stub, compositeKeyVehicle1, vehicle1Bytes) + + vehicleListing1Bytes := []byte("vehicleListing1") + compositeKeyVehicleListing1, _ := stub.CreateCompositeKey("VehicleListing", []string{"LIST_1234"}) + stubPutState(t, stub, compositeKeyVehicleListing1, vehicleListing1Bytes) + + stub.MockTransactionEnd("init") + + // Only the single "Vehicle" object should be returned, not the "VehicleListing" object + rqi, _ := stub.GetStateByPartialCompositeKey("Vehicle", []string{}) + i := 0 + fmt.Println("Running loop") + for rqi.HasNext() { + i++ + response, err := rqi.Next() + fmt.Println("Loop", i, "got", response.Key, response.Value, err) + } + // Only the single "Vehicle" object should be returned, not the "VehicleListing" object + if i != 1 { + fmt.Println("Expected 1, got", i) + t.FailNow() + } +} + +func TestGetTxTimestamp(t *testing.T) { + stub := NewMockStub("GetTxTimestamp", nil) + stub.MockTransactionStart("init") + + timestamp, err := stub.GetTxTimestamp() + if timestamp == nil || err != nil { + t.FailNow() + } + + stub.MockTransactionEnd("init") +} + +// TestPutEmptyState confirms that setting a key value to empty or nil in the mock state deletes the key +// instead of storing an empty key. +func TestPutEmptyState(t *testing.T) { + stub := NewMockStub("FAB-12545", nil) + + // Put an empty and nil state value + stub.MockTransactionStart("1") + stubPutState(t, stub, "empty", []byte{}) + stubPutState(t, stub, "nil", nil) + stub.MockTransactionEnd("1") + + // Confirm both are nil + stub.MockTransactionStart("2") + val, err := stub.GetState("empty") + assert.NoError(t, err) + assert.Nil(t, val) + val, err = stub.GetState("nil") + assert.NoError(t, err) + assert.Nil(t, val) + // Add a value to both empty and nil + stubPutState(t, stub, "empty", []byte{0}) + stubPutState(t, stub, "nil", []byte{0}) + stub.MockTransactionEnd("2") + + // Confirm the value is in both + stub.MockTransactionStart("3") + val, err = stub.GetState("empty") + assert.NoError(t, err) + assert.Equal(t, val, []byte{0}) + val, err = stub.GetState("nil") + assert.NoError(t, err) + assert.Equal(t, val, []byte{0}) + stub.MockTransactionEnd("3") + + // Set both back to empty / nil + stub.MockTransactionStart("4") + stubPutState(t, stub, "empty", []byte{}) + stubPutState(t, stub, "nil", nil) + stub.MockTransactionEnd("4") + + // Confirm both are nil + stub.MockTransactionStart("5") + val, err = stub.GetState("empty") + assert.NoError(t, err) + assert.Nil(t, val) + val, err = stub.GetState("nil") + assert.NoError(t, err) + assert.Nil(t, val) + stub.MockTransactionEnd("5") + +} + +// TestMockMock clearly cheating for coverage... but not. Mock should +// be tucked away under common/mocks package which is not +// included for coverage. Moving mockstub to another package +// will cause upheaval in other code best dealt with separately +// For now, call all the methods to get mock covered in this +// package +func TestMockMock(t *testing.T) { + stub := NewMockStub("MOCKMOCK", &mock.Chaincode{}) + stub.args = [][]byte{[]byte("a"), []byte("b")} + stub.MockInit("id", nil) + stub.GetArgs() + stub.GetStringArgs() + stub.GetFunctionAndParameters() + stub.GetTxID() + stub.GetChannelID() + stub.MockInvoke("id", nil) + stub.MockInvokeWithSignedProposal("id", nil, nil) + stubDelState(t, stub, "dummy") + stubGetStateByRange(t, stub, "start", "end") + _, err := stub.GetQueryResult("q") + require.Error(t, err, "GetQueryResult not implemented") + + stub2 := NewMockStub("othercc", &mock.Chaincode{}) + stub2.MockPeerChaincode("othercc", stub2, "mychan") + stub2.InvokeChaincode("othercc", nil, "mychan") + stubGetCreator(t, stub2) + stubGetTransient(t, stub2) + stubGetBinding(t, stub2) + stubGetSignedProposal(t, stub2) + stubGetArgsSlice(t, stub2) + stubSetEvent(t, stub2, "e", nil) + _, err = stub2.GetHistoryForKey("k") + require.Error(t, err, "GetHistoryForKey not implemented") + iter := &MockStateRangeQueryIterator{} + iter.HasNext() + iter.Close() + getBytes("f", []string{"a", "b"}) + getFuncArgs([][]byte{[]byte("a")}) +}