diff --git a/main.go b/main.go index e7cace9..4c1cfce 100644 --- a/main.go +++ b/main.go @@ -112,10 +112,8 @@ func init() { panic(err) } - _state.OnGatewayPortChange(func(s string) error { - config.Set(common.ConfigKeyGatewayPort, _state.GetGatewayPort()) - config.Set(common.ConfigKeyRuntimePath, _state.GetRuntimePath()) - + _state.OnGatewayPortChange(func(port string) error { + config.Set(common.ConfigKeyGatewayPort, port) return config.WriteConfig() }) } diff --git a/route/management_route_test.go b/route/management_route_test.go index df9c944..0698e52 100644 --- a/route/management_route_test.go +++ b/route/management_route_test.go @@ -3,6 +3,7 @@ package route import ( "bytes" "encoding/json" + "errors" "net/http" "net/http/httptest" "os" @@ -134,3 +135,75 @@ func TestChangePort(t *testing.T) { assert.NilError(t, err) assert.Equal(t, expectedPort, result.Data) } + +func TestChangePortNegative(t *testing.T) { + defer setup(t)(t) + + expectedPort := "123" + + // set + request := &model.ChangePortRequest{ + Port: expectedPort, + } + + body, err := json.Marshal(request) + assert.NilError(t, err) + + req, _ := http.NewRequest(http.MethodPut, "/v1/gateway/port", bytes.NewReader(body)) + req.RemoteAddr = "127.0.0.1:0" + + w := httptest.NewRecorder() + _router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, expectedPort, "123") + + // get + req, _ = http.NewRequest(http.MethodGet, "/v1/gateway/port", nil) + + w = httptest.NewRecorder() + _router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var result *model.Result + decoder := json.NewDecoder(w.Body) + + err = decoder.Decode(&result) + assert.NilError(t, err) + assert.Equal(t, expectedPort, result.Data) + + // emulate error + _state.OnGatewayPortChange(func(_ string) error { + return errors.New("error") + }) + + // set + request.Port = "456" + + body, err = json.Marshal(request) + assert.NilError(t, err) + + req, _ = http.NewRequest(http.MethodPut, "/v1/gateway/port", bytes.NewReader(body)) + req.RemoteAddr = "127.0.0.1:0" + + w = httptest.NewRecorder() + _router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, expectedPort, "123") + + // get + req, _ = http.NewRequest(http.MethodGet, "/v1/gateway/port", nil) + + w = httptest.NewRecorder() + _router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + decoder = json.NewDecoder(w.Body) + + err = decoder.Decode(&result) + assert.NilError(t, err) + assert.Equal(t, expectedPort, result.Data) +} diff --git a/service/state.go b/service/state.go index 7f09dce..78b7bb9 100644 --- a/service/state.go +++ b/service/state.go @@ -18,9 +18,13 @@ func NewState() *State { } } -func (c *State) SetGatewayPort(port string) error { - c.gatewayPort = port - return c.notifyOnGatewayPortChange() +func (c *State) SetGatewayPort(port string) (err error) { + defer func() { + if err == nil { + c.gatewayPort = port + } + }() + return c.notifyOnGatewayPortChange(port) } func (c *State) GetGatewayPort() string { @@ -32,9 +36,9 @@ func (c *State) OnGatewayPortChange(f func(string) error) { c.onGatewayPortChange = append(c.onGatewayPortChange, f) } -func (c *State) notifyOnGatewayPortChange() error { +func (c *State) notifyOnGatewayPortChange(port string) error { for i := len(c.onGatewayPortChange) - 1; i >= 0; i-- { - if err := c.onGatewayPortChange[i](c.gatewayPort); err != nil { + if err := c.onGatewayPortChange[i](port); err != nil { return err } }