From 52b979cc701a7264a784eaaeea526a9305f97416 Mon Sep 17 00:00:00 2001 From: Tor Colvin Date: Tue, 12 Dec 2023 17:27:29 -0500 Subject: [PATCH] [3.1.3 backport] CBG-3660 check Origin header on websocket (#6612) backports CBG-3652 check Origin header on websocket (#6610) --- db/active_replicator.go | 3 +- db/active_replicator_test.go | 2 +- db/blip.go | 14 ++- go.mod | 2 +- go.sum | 4 +- rest/blip_client_test.go | 4 + rest/blip_sync.go | 26 ++++- rest/blip_sync_test.go | 72 ++++++++++++++ rest/config.go | 5 + rest/config_test.go | 17 ++++ rest/cors_test.go | 179 ++++++++++++++++++++++++++++++----- rest/utilities_testing.go | 25 ++++- 12 files changed, 313 insertions(+), 40 deletions(-) create mode 100644 rest/blip_sync_test.go diff --git a/db/active_replicator.go b/db/active_replicator.go index f728c087fd..dbf55314f5 100644 --- a/db/active_replicator.go +++ b/db/active_replicator.go @@ -207,7 +207,8 @@ func (ar *ActiveReplicator) GetStatus(ctx context.Context) *ReplicationStatus { func connect(arc *activeReplicatorCommon, idSuffix string) (blipSender *blip.Sender, bsc *BlipSyncContext, err error) { arc.replicationStats.NumConnectAttempts.Add(1) - blipContext, err := NewSGBlipContext(arc.ctx, arc.config.ID+idSuffix) + var originPatterns []string // no origin headers for ISGR + blipContext, err := NewSGBlipContext(arc.ctx, arc.config.ID+idSuffix, originPatterns) if err != nil { return nil, nil, err } diff --git a/db/active_replicator_test.go b/db/active_replicator_test.go index bb029b48c5..93e5eb4808 100644 --- a/db/active_replicator_test.go +++ b/db/active_replicator_test.go @@ -65,7 +65,7 @@ func TestBlipSyncErrorUserinfo(t *testing.T) { srvURL.Path = "/db1" t.Logf("srvURL: %v", srvURL.String()) - blipContext, err := NewSGBlipContext(base.TestCtx(t), t.Name()) + blipContext, err := NewSGBlipContext(base.TestCtx(t), t.Name(), nil) require.NoError(t, err) _, err = blipSync(*srvURL, blipContext, false) diff --git a/db/blip.go b/db/blip.go index 39a927e387..058b95e7d3 100644 --- a/db/blip.go +++ b/db/blip.go @@ -43,18 +43,22 @@ var ( ) // NewSGBlipContext returns a go-blip context with the given ID, initialized for use in Sync Gateway. -func NewSGBlipContext(ctx context.Context, id string) (bc *blip.Context, err error) { +func NewSGBlipContext(ctx context.Context, id string, origin []string) (bc *blip.Context, err error) { // V3 is first here as it is the preferred communication method // In the host case this means SGW can accept both V3 and V2 clients // In the client case this means we prefer V3 but can fallback to V2 - return NewSGBlipContextWithProtocols(ctx, id, BlipCBMobileReplicationV3, BlipCBMobileReplicationV2) + return NewSGBlipContextWithProtocols(ctx, id, origin, []string{BlipCBMobileReplicationV3, BlipCBMobileReplicationV2}) } -func NewSGBlipContextWithProtocols(ctx context.Context, id string, protocol ...string) (bc *blip.Context, err error) { +func NewSGBlipContextWithProtocols(ctx context.Context, id string, origin []string, protocols []string) (bc *blip.Context, err error) { + opts := blip.ContextOptions{ + Origin: origin, + ProtocolIds: protocols, + } if id == "" { - bc, err = blip.NewContext(protocol...) + bc, err = blip.NewContext(opts) } else { - bc, err = blip.NewContextCustomID(id, protocol...) + bc, err = blip.NewContextCustomID(id, opts) } bc.LogMessages = base.LogDebugEnabled(base.KeyWebSocket) diff --git a/go.mod b/go.mod index 33b84bc04c..2cef29990f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/coreos/go-oidc v2.2.1+incompatible github.com/couchbase/cbgt v1.3.9 github.com/couchbase/clog v0.1.0 - github.com/couchbase/go-blip v0.0.0-20221021161139-215cbac22bd7 + github.com/couchbase/go-blip v0.0.0-20231212221113-a6ee87e0c16f github.com/couchbase/go-couchbase v0.1.1 github.com/couchbase/gocb/v2 v2.6.5 github.com/couchbase/gocbcore/v10 v10.2.10 diff --git a/go.sum b/go.sum index a57889c334..255fbcc621 100644 --- a/go.sum +++ b/go.sum @@ -72,8 +72,8 @@ github.com/couchbase/cbgt v1.3.9 h1:MAT3FwD1ctekxuFe0yau0H1BCTvgLXvh1ipbZ3nZhBE= github.com/couchbase/cbgt v1.3.9/go.mod h1:MImhtmvk0qjJit5HbmA34tnYThZoNtvgjL7jJH/kCAE= github.com/couchbase/clog v0.1.0 h1:4Kh/YHkhRjMCbdQuvRVsm39XZh4FtL1d8fAwJsHrEPY= github.com/couchbase/clog v0.1.0/go.mod h1:7tzUpEOsE+fgU81yfcjy5N1H6XtbVC8SgOz/3mCjmd4= -github.com/couchbase/go-blip v0.0.0-20221021161139-215cbac22bd7 h1:/GTlMVovmGKrFAl5e7u9CXuhjTlR5a4911Ujou18Q4Q= -github.com/couchbase/go-blip v0.0.0-20221021161139-215cbac22bd7/go.mod h1:nSpldGTqAhTOaDDL0Li2dSE0smqbISKagT7fIqYIRec= +github.com/couchbase/go-blip v0.0.0-20231212221113-a6ee87e0c16f h1:vmZxVtUFv5TfEXXrnTAqVikt6m++/gGja891m1ig6PU= +github.com/couchbase/go-blip v0.0.0-20231212221113-a6ee87e0c16f/go.mod h1:nSpldGTqAhTOaDDL0Li2dSE0smqbISKagT7fIqYIRec= github.com/couchbase/go-couchbase v0.1.1 h1:ClFXELcKj/ojyoTYbsY34QUrrYCBi/1G749sXSCkdhk= github.com/couchbase/go-couchbase v0.1.1/go.mod h1:+/bddYDxXsf9qt0xpDUtRR47A2GjaXmGGAqQ/k3GJ8A= github.com/couchbase/gocb/v2 v2.6.5 h1:xaZu29o8UJEV1ZQ3n2s9jcRCUHz/JsQ6+y6JBnVsy5A= diff --git a/rest/blip_client_test.go b/rest/blip_client_test.go index edc8474af5..bc572fc519 100644 --- a/rest/blip_client_test.go +++ b/rest/blip_client_test.go @@ -39,6 +39,9 @@ type BlipTesterClientOpts struct { // a deltaSrc rev ID for which to reject a delta rejectDeltasForSrcRev string + + // optional Origin header + origin *string } // BlipTesterClient is a fully fledged client to emulate CBL behaviour on both push and pull replications through methods on this type. @@ -539,6 +542,7 @@ func newBlipTesterReplication(tb testing.TB, id string, btc *BlipTesterClient, s connectingUserChannelGrants: btc.Channels, blipProtocols: btc.SupportedBLIPProtocols, skipCollectionsInitialization: skipCollectionsInitialization, + origin: btc.origin, }, btc.rt) if err != nil { return nil, err diff --git a/rest/blip_sync.go b/rest/blip_sync.go index 0671a8b7cb..7a50299557 100644 --- a/rest/blip_sync.go +++ b/rest/blip_sync.go @@ -13,6 +13,7 @@ package rest import ( "fmt" "net/http" + "net/url" "github.com/couchbase/sync_gateway/db" @@ -36,8 +37,11 @@ func (h *handler) handleBLIPSync() error { blip.CompressionLevel = *c } + // error is checked at the time of database load, and ignored at this time + originPatterns, _ := hostOnlyCORS(h.db.CORS.Origin) + // Create a BLIP context: - blipContext, err := db.NewSGBlipContext(h.ctx(), "") + blipContext, err := db.NewSGBlipContext(h.ctx(), "", originPatterns) if err != nil { return err } @@ -71,3 +75,23 @@ func (h *handler) handleBLIPSync() error { return nil } + +// hostOnlyCORS returns the host portion of the origin URL, suitable for passing to websocket library. +func hostOnlyCORS(originPatterns []string) ([]string, error) { + var origins []string + var multiError *base.MultiError + for _, origin := range originPatterns { + // this is a special pattern for allowing all origins + if origin == "*" { + origins = append(origins, origin) + continue + } + u, err := url.Parse(origin) + if err != nil { + multiError = multiError.Append(fmt.Errorf("%s is not a valid pattern for CORS config", err)) + continue + } + origins = append(origins, u.Host) + } + return origins, multiError.ErrorOrNil() +} diff --git a/rest/blip_sync_test.go b/rest/blip_sync_test.go new file mode 100644 index 0000000000..d85c4cb46b --- /dev/null +++ b/rest/blip_sync_test.go @@ -0,0 +1,72 @@ +// Copyright 2023-Present Couchbase, Inc. +// +// Use of this software is governed by the Business Source License included +// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified +// in that file, in accordance with the Business Source License, use of this +// software will be governed by the Apache License, Version 2.0, included in +// the file licenses/APL2.txt. + +package rest + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHostOnlyCORS(t *testing.T) { + const unparseableURL = "1http:///example.com" + testsCases := []struct { + input []string + output []string + hasError bool + }{ + { + input: []string{"http://example.com"}, + output: []string{"example.com"}, + }, + { + input: []string{"https://example.com", "http://example.com"}, + output: []string{"example.com", "example.com"}, + }, + { + input: []string{"*", "http://example.com"}, + output: []string{"*", "example.com"}, + }, + { + input: []string{"wss://example.com"}, + output: []string{"example.com"}, + }, + { + input: []string{"http://example.com:12345"}, + output: []string{"example.com:12345"}, + }, + { + input: []string{unparseableURL}, + output: nil, + hasError: true, + }, + { + input: []string{"*", unparseableURL}, + output: []string{"*"}, + hasError: true, + }, + { + input: []string{"*", unparseableURL, "http://example.com"}, + output: []string{"*", "example.com"}, + hasError: true, + }, + } + for _, test := range testsCases { + t.Run(fmt.Sprintf("%v->%v", test.input, test.output), func(t *testing.T) { + output, err := hostOnlyCORS(test.input) + if test.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.output, output) + }) + } +} diff --git a/rest/config.go b/rest/config.go index af58da0ed3..f92183eaf6 100644 --- a/rest/config.go +++ b/rest/config.go @@ -961,6 +961,11 @@ func (dbConfig *DbConfig) validateVersion(ctx context.Context, isEnterpriseEditi } } + if dbConfig.CORS != nil { + // these values will likely to be ignored by the CORS handler unless browser sends abornmal Origin headers + _, err := hostOnlyCORS(dbConfig.CORS.Origin) + base.WarnfCtx(ctx, "The cors.origin contains values that may be ignored: %s", err) + } return multiError.ErrorOrNil() } diff --git a/rest/config_test.go b/rest/config_test.go index 4980e0b697..1fa6544ddb 100644 --- a/rest/config_test.go +++ b/rest/config_test.go @@ -2713,3 +2713,20 @@ func TestDatabaseConfigDropScopes(t *testing.T) { require.Contains(t, resp.Body.String(), "cannot change scope") } + +func TestBadCORSValuesConfig(t *testing.T) { + if base.UnitTestUrlIsWalrus() { + t.Skip("test only works with CBS/rosmar") + } + rt := NewRestTester(t, &RestTesterConfig{PersistentConfig: true}) + defer rt.Close() + + // expect database to be created with bad CORS values, but do log a warning + dbConfig := rt.NewDbConfig() + dbConfig.CORS = &auth.CORSConfig{ + Origin: []string{"http://example.com", "1http://example.com"}, + } + base.AssertLogContains(t, "cors.origin contains values", func() { + rt.CreateDatabase("db", dbConfig) + }) +} diff --git a/rest/cors_test.go b/rest/cors_test.go index e4345db574..3f072dffdc 100644 --- a/rest/cors_test.go +++ b/rest/cors_test.go @@ -15,10 +15,13 @@ import ( "testing" "github.com/couchbase/sync_gateway/auth" + "github.com/couchbase/sync_gateway/base" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +const accessControlAllowOrigin = "Access-Control-Allow-Origin" + func TestCORSDynamicSet(t *testing.T) { rt := NewRestTester(t, &RestTesterConfig{ PersistentConfig: true, @@ -42,7 +45,7 @@ func TestCORSDynamicSet(t *testing.T) { for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/{{.keyspace}}/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusBadRequest) require.Contains(t, response.Body.String(), invalidDatabaseName) @@ -53,7 +56,7 @@ func TestCORSDynamicSet(t *testing.T) { // successful request for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendUserRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusOK) } else { @@ -62,7 +65,7 @@ func TestCORSDynamicSet(t *testing.T) { } for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusUnauthorized) require.Contains(t, response.Body.String(), ErrLoginRequired.Message) @@ -72,7 +75,7 @@ func TestCORSDynamicSet(t *testing.T) { } for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendUserRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusOK) } else { @@ -91,12 +94,15 @@ func TestCORSDynamicSet(t *testing.T) { for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/{{.keyspace}}/", "", reqHeaders) if method == http.MethodGet { - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusBadRequest) - require.Contains(t, response.Body.String(), invalidDatabaseName) + if base.TestsUseNamedCollections() { + RequireStatus(t, response, http.StatusBadRequest) + require.Contains(t, response.Body.String(), invalidDatabaseName) + } else { // CBG-2978, should not be different from GSI/collections + RequireStatus(t, response, http.StatusUnauthorized) + } } else { // information leak: the options request knows about the database and knows it doesn't match - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get(accessControlAllowOrigin)) RequireStatus(t, response, http.StatusNoContent) } } @@ -104,7 +110,7 @@ func TestCORSDynamicSet(t *testing.T) { // successful request - mismatched headers for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendUserRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusOK) } else { @@ -117,16 +123,16 @@ func TestCORSDynamicSet(t *testing.T) { if method == http.MethodGet { RequireStatus(t, response, http.StatusUnauthorized) require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) } else { RequireStatus(t, response, http.StatusNoContent) // information leak: the options request knows about the database and knows it doesn't match - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get(accessControlAllowOrigin)) } } for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusUnauthorized) require.Contains(t, response.Body.String(), ErrLoginRequired.Message) @@ -136,7 +142,7 @@ func TestCORSDynamicSet(t *testing.T) { } for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendUserRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusOK) } else { @@ -151,7 +157,7 @@ func TestCORSDynamicSet(t *testing.T) { for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendUserRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.org", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusOK) } else { @@ -162,18 +168,18 @@ func TestCORSDynamicSet(t *testing.T) { for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders) if method == http.MethodGet { - require.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "*", response.Header().Get(accessControlAllowOrigin)) RequireStatus(t, response, http.StatusUnauthorized) require.Contains(t, response.Body.String(), ErrLoginRequired.Message) } else { // information leak: the options request knows about the database and knows it doesn't match - require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.org", response.Header().Get(accessControlAllowOrigin)) RequireStatus(t, response, http.StatusNoContent) } } for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) - require.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "*", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusUnauthorized) require.Contains(t, response.Body.String(), ErrLoginRequired.Message) @@ -184,7 +190,7 @@ func TestCORSDynamicSet(t *testing.T) { } for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendUserRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.org", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodGet { RequireStatus(t, response, http.StatusOK) } else { @@ -204,7 +210,7 @@ func TestCORSNoMux(t *testing.T) { // this method doesn't exist for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/_notanendpoint/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) RequireStatus(t, response, http.StatusNotFound) require.Contains(t, response.Body.String(), "unknown URL") } @@ -212,14 +218,14 @@ func TestCORSNoMux(t *testing.T) { // admin port shouldn't populate CORS for _, method := range []string{http.MethodGet, http.MethodOptions} { response := rt.SendAdminRequestWithHeaders(method, "/_notanendpoint/", "", reqHeaders) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get(accessControlAllowOrigin)) RequireStatus(t, response, http.StatusNotFound) require.Contains(t, response.Body.String(), "unknown URL") } // this method doesn't exist for _, method := range []string{http.MethodDelete, http.MethodOptions} { response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "http://example.com", response.Header().Get(accessControlAllowOrigin)) if method == http.MethodDelete { RequireStatus(t, response, http.StatusMethodNotAllowed) } else { @@ -235,7 +241,7 @@ func TestCORSNoMux(t *testing.T) { RequireStatus(t, response, http.StatusMethodNotAllowed) } - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get(accessControlAllowOrigin)) require.Equal(t, "", response.Header().Get("Access-Control-Max-Age")) require.Equal(t, "", response.Header().Get("Access-Control-Allow-Methods")) } @@ -268,9 +274,9 @@ func TestCORSUserNoAccess(t *testing.T) { response := rt.SendRequestWithHeaders(method, endpoint, "", reqHeaders) if method == http.MethodOptions && endpoint == "/{{.db}}/" { // information leak: the options request knows about the database and knows it doesn't match - assert.Equal(t, "http://couchbase.com", response.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "http://couchbase.com", response.Header().Get(accessControlAllowOrigin)) } else { - assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "*", response.Header().Get(accessControlAllowOrigin)) } if method == http.MethodGet { @@ -369,7 +375,7 @@ func TestCORSOriginPerDatabase(t *testing.T) { } else { require.Equal(t, http.StatusNoContent, response.Code) } - require.Equal(t, test.headerResponse, response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, test.headerResponse, response.Header().Get(accessControlAllowOrigin)) if method == http.MethodOptions { if strings.Contains(test.endpoint, "{{.db}}") { require.Equal(t, strconv.Itoa(perDBMaxAge), response.Header().Get("Access-Control-Max-Age")) @@ -405,3 +411,126 @@ func TestCORSValidation(t *testing.T) { RequireStatus(t, resp, http.StatusCreated) } + +func TestCORSBlipSync(t *testing.T) { + rtConfig := &RestTesterConfig{ + PersistentConfig: true, + } + + rt := NewRestTester(t, rtConfig) + defer rt.Close() + + dbConfig := rt.NewDbConfig() + dbConfig.CORS = &auth.CORSConfig{ + Origin: []string{"http://example.com"}, + } + + rt.CreateDatabase("corsdb", dbConfig) + require.NoError(t, rt.SetAdminParty(true)) + testCases := []struct { + name string + origin *string + errorMessage string + }{ + { + name: "CORS matching origin", + origin: base.StringPtr("http://example.com"), + }, + { + name: "CORS non-matching origin", + origin: base.StringPtr("http://example2.com"), + errorMessage: "expected handshake response", + }, + { + name: "CORS empty", + origin: base.StringPtr(""), + }, + } + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + + spec := getDefaultBlipTesterSpec() + spec.origin = test.origin + _, err := createBlipTesterWithSpec(t, spec, rt) + if test.errorMessage == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, "expected handshake response") + } + }) + } + requireBlipHandshakeEmptyCORS(rt) + requireBlipHandshakeMatchingHost(rt) +} + +func TestCORSBlipSyncStar(t *testing.T) { + rtConfig := &RestTesterConfig{ + PersistentConfig: true, + } + + rt := NewRestTester(t, rtConfig) + defer rt.Close() + + dbConfig := rt.NewDbConfig() + dbConfig.CORS = &auth.CORSConfig{ + Origin: []string{"*"}, + } + rt.CreateDatabase("corsdb", dbConfig) + require.NoError(t, rt.SetAdminParty(true)) + urls := []string{"http://example.com", "http://example2.com", "https://example.com"} + for _, url := range urls { + t.Run(url, func(t *testing.T) { + spec := getDefaultBlipTesterSpec() + spec.origin = &url + _, err := createBlipTesterWithSpec(t, spec, rt) + require.NoError(t, err) + }) + } + requireBlipHandshakeEmptyCORS(rt) + requireBlipHandshakeMatchingHost(rt) +} + +// TestCORSBlipNoConfig has no CORS config set on the database, and should fail any CORS checks. +func TestCORSBlipNoConfig(t *testing.T) { + rtConfig := &RestTesterConfig{ + PersistentConfig: true, + } + + rt := NewRestTester(t, rtConfig) + defer rt.Close() + + dbConfig := rt.NewDbConfig() + dbConfig.CORS = &auth.CORSConfig{ + Origin: []string{""}, + } + + rt.CreateDatabase("corsdb", dbConfig) + require.NoError(t, rt.SetAdminParty(true)) + + urls := []string{"http://example.com", "http://example2.com", "https://example.com"} + for _, url := range urls { + t.Run(url, func(t *testing.T) { + spec := getDefaultBlipTesterSpec() + spec.origin = &url + _, err := createBlipTesterWithSpec(t, spec, rt) + require.Error(t, err) + }) + } + requireBlipHandshakeEmptyCORS(rt) + requireBlipHandshakeMatchingHost(rt) +} + +// requireBlipHandshakeEmptyCORS creates a new blip tester with no Origin header +func requireBlipHandshakeEmptyCORS(rt *RestTester) { + spec := getDefaultBlipTesterSpec() + _, err := createBlipTesterWithSpec(rt.TB, spec, rt) + require.NoError(rt.TB, err) +} + +// requireBlipHandshakeMatchingHost creates a new blip tester with an Origin header that matches the host name of the test +func requireBlipHandshakeMatchingHost(rt *RestTester) { + spec := getDefaultBlipTesterSpec() + spec.useHostOrigin = true + _, err := createBlipTesterWithSpec(rt.TB, spec, rt) + require.NoError(rt.TB, err) +} diff --git a/rest/utilities_testing.go b/rest/utilities_testing.go index 400d557d00..221759629d 100644 --- a/rest/utilities_testing.go +++ b/rest/utilities_testing.go @@ -544,7 +544,7 @@ func (rt *RestTester) WaitForPendingChanges() error { func (rt *RestTester) SetAdminParty(partyTime bool) error { ctx := rt.Context() - a := rt.ServerContext().Database(ctx, rt.DatabaseConfig.Name).Authenticator(ctx) + a := rt.GetDatabase().Authenticator(ctx) guest, err := a.GetUser("") if err != nil { return err @@ -1245,6 +1245,12 @@ type BlipTesterSpec struct { // If set, use custom sync function for all collections. syncFn string + + // Represents Origin header values to be used in the blip handshake. + origin *string + + // If true, pass Allow-Header-Origin: to the hostname in the blip handshake. + useHostOrigin bool } // State associated with a BlipTester @@ -1407,8 +1413,12 @@ func createBlipTesterWithSpec(tb testing.TB, spec BlipTesterSpec, rt *RestTester protocols = []string{db.BlipCBMobileReplicationV3} } + origin, err := hostOnlyCORS(bt.restTester.GetDatabase().CORS.Origin) + if err != nil { + return nil, err + } // Make BLIP/Websocket connection - bt.blipContext, err = db.NewSGBlipContextWithProtocols(base.TestCtx(tb), "", protocols...) + bt.blipContext, err = db.NewSGBlipContextWithProtocols(base.TestCtx(tb), "", origin, protocols) if err != nil { return nil, err } @@ -1426,10 +1436,17 @@ func createBlipTesterWithSpec(tb testing.TB, spec BlipTesterSpec, rt *RestTester URL: u.String(), } + config.HTTPHeader = make(http.Header) if len(spec.connectingUsername) > 0 { - config.HTTPHeader = http.Header{ - "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte(spec.connectingUsername+":"+spec.connectingPassword))}, + config.HTTPHeader.Add("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(spec.connectingUsername+":"+spec.connectingPassword))) + } + if spec.origin != nil { + if spec.useHostOrigin { + require.Fail(tb, "setting both origin and useHostOrigin is not supported") } + config.HTTPHeader.Add("Origin", *spec.origin) + } else if spec.useHostOrigin { + config.HTTPHeader.Add("Origin", "https://"+u.Host) } bt.sender, err = bt.blipContext.DialConfig(&config)