Skip to content

Commit

Permalink
add OPTIONS handler for CORS (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
noahdietz authored Jul 10, 2019
1 parent cd3ef8a commit 22f01c7
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
17 changes: 16 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"google.golang.org/grpc/status"
)

const fallbackPath = "/$rpc/{service:[.a-zA-Z0-9]+}/{method:[a-zA-Z]+}"

// FallbackServer is a grpc-fallback HTTP server.
type FallbackServer struct {
backend string
Expand Down Expand Up @@ -84,7 +86,10 @@ func (f *FallbackServer) preStart() {

// setup grpc-fallback complient router
r := mux.NewRouter()
r.HandleFunc("/$rpc/{service:[.a-zA-Z0-9]+}/{method:[a-zA-Z]+}", f.handler).Headers("Content-Type", "application/x-protobuf")
r.HandleFunc(fallbackPath, f.options).
Methods(http.MethodOptions)
r.HandleFunc(fallbackPath, f.handler).
Headers("Content-Type", "application/x-protobuf")
f.server.Handler = r
}

Expand Down Expand Up @@ -144,3 +149,13 @@ func (f *FallbackServer) dial() (connection, error) {

return grpc.Dial(f.backend, opts...)
}

// options is a handler for the OPTIONS call that precedes CORS-enabled calls.
func (f *FallbackServer) options(w http.ResponseWriter, r *http.Request) {
w.Header().Add("access-control-allow-credentials", "true")
w.Header().Add("access-control-allow-headers", "*")
w.Header().Add("access-control-allow-methods", http.MethodPost)
w.Header().Add("access-control-allow-origin", "*")
w.Header().Add("access-control-max-age", "3600")
w.WriteHeader(http.StatusOK)
}
59 changes: 59 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,62 @@ func TestFallbackServer_preStart(t *testing.T) {
})
}
}

func TestFallbackServer_options(t *testing.T) {
type fields struct {
backend string
server http.Server
cc connection
}
type args struct {
w http.ResponseWriter
r *http.Request
}

req, _ := http.NewRequest("OPTIONS", "/test", nil)
hdr := make(http.Header)
hdr.Add("access-control-allow-credentials", "true")
hdr.Add("access-control-allow-headers", "*")
hdr.Add("access-control-allow-methods", http.MethodPost)
hdr.Add("access-control-allow-origin", "*")
hdr.Add("access-control-max-age", "3600")

tests := []struct {
name string
fields fields
args args
wantHeader http.Header
}{
{
name: "basic",
args: args{
r: req,
w: &testRespWriter{},
},
fields: fields{
cc: &testConnection{},
},
wantHeader: hdr,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &FallbackServer{
backend: tt.fields.backend,
server: tt.fields.server,
cc: tt.fields.cc,
}
f.options(tt.args.w, tt.args.r)

resp := tt.args.w.(*testRespWriter)

if resp.code != http.StatusOK {
t.Errorf("handler() %s: got = %d, want = %d", tt.name, resp.code, http.StatusOK)
}

if !reflect.DeepEqual(resp.Header(), tt.wantHeader) {
t.Errorf("handler() %s: got = %s, want = %s", tt.name, resp.Header(), tt.wantHeader)
}
})
}
}

0 comments on commit 22f01c7

Please sign in to comment.