From 84e9e4e3ff329e85bd9d9b42ecb2b8f96e04219c Mon Sep 17 00:00:00 2001 From: Richard Brodie Date: Fri, 5 Apr 2024 14:32:07 +0200 Subject: [PATCH] Fix inability to handle files with a trailing newline Support for json arrays in the stub files was added previously, but it only worked with array files that did not end with a newline or files only containing a single non-array stub definition. Since all files written on linux will typically end in a newline this PR adds a rudimentary input sanitisation step while reading the file which fixes this. Also tidy up some deprecations and add some more helpful error messages. --- gripmock.go | 8 ++------ stub/storage.go | 28 +++++++++++++++++----------- stub/stub.go | 38 +++++++++++++++++++++++++------------- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/gripmock.go b/gripmock.go index 2b0a763a..1129865b 100644 --- a/gripmock.go +++ b/gripmock.go @@ -70,20 +70,17 @@ func main() { imports: importDirs, }) - // build the server - //buildServer(output) - // and run run, runerr := runGrpcServer(output) - var term = make(chan os.Signal) + term := make(chan os.Signal) signal.Notify(term, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT) select { case err := <-runerr: log.Fatal(err) case <-term: fmt.Println("Stopping gRPC Server") - run.Process.Kill() + _ = run.Process.Kill() } } @@ -151,7 +148,6 @@ func generateProtoc(param protocParam) { if err != nil { log.Fatal("Fail on protoc ", err) } - } // append gopackage in proto files if doesn't have any diff --git a/stub/storage.go b/stub/storage.go index 4faf810d..fe54f473 100644 --- a/stub/storage.go +++ b/stub/storage.go @@ -1,10 +1,11 @@ package stub import ( + "bytes" "encoding/json" "fmt" - "io/ioutil" "log" + "os" "reflect" "regexp" "sync" @@ -60,11 +61,11 @@ func findStub(stub *findStubPayload) (*Output, error) { mx.Lock() defer mx.Unlock() if _, ok := stubStorage[stub.Service]; !ok { - return nil, fmt.Errorf("Can't find stub for Service: %s", stub.Service) + return nil, fmt.Errorf("can't find stub for Service: %s", stub.Service) } if _, ok := stubStorage[stub.Service][stub.Method]; !ok { - return nil, fmt.Errorf("Can't find stub for Service:%s and Method:%s", stub.Service, stub.Method) + return nil, fmt.Errorf("can't find stub for Service:%s and Method:%s", stub.Service, stub.Method) } stubs := stubStorage[stub.Service][stub.Method] @@ -171,8 +172,8 @@ func deepEqual(expect, actual interface{}) bool { } func regexMatch(expect, actual interface{}) bool { - var expectedStr, expectedStringOk = expect.(string) - var actualStr, actualStringOk = actual.(string) + expectedStr, expectedStringOk := expect.(string) + actualStr, actualStringOk := actual.(string) if expectedStringOk && actualStringOk { match, err := regexp.Match(expectedStr, []byte(actualStr)) @@ -198,9 +199,8 @@ func matches(expect, actual map[string]interface{}) bool { } func find(expect, actual interface{}, acc, exactMatch bool, f matchFunc) bool { - // circuit brake - if acc == false { + if !acc { return false } @@ -277,7 +277,7 @@ func readStubFromFile(path string) { } func (sm *stubMapping) readStubFromFile(path string) { - files, err := ioutil.ReadDir(path) + files, err := os.ReadDir(path) if err != nil { log.Printf("Can't read stub from %s. %v\n", path, err) return @@ -289,12 +289,14 @@ func (sm *stubMapping) readStubFromFile(path string) { continue } - byt, err := ioutil.ReadFile(path + "/" + file.Name()) + byt, err := os.ReadFile(path + "/" + file.Name()) if err != nil { log.Printf("Error when reading file %s. %v. skipping...", file.Name(), err) continue } + // most files have a trailing newline so trim that before checking + byt = bytes.TrimSuffix(byt, []byte("\n")) if byt[0] == '[' && byt[len(byt)-1] == ']' { var stubs []*Stub err = json.Unmarshal(byt, &stubs) @@ -303,7 +305,9 @@ func (sm *stubMapping) readStubFromFile(path string) { continue } for _, s := range stubs { - sm.storeStub(s) + if err = sm.storeStub(s); err != nil { + log.Printf("Error when storing Stub from %s. %v. skipping...", file.Name(), err) + } } continue } @@ -315,6 +319,8 @@ func (sm *stubMapping) readStubFromFile(path string) { continue } - sm.storeStub(stub) + if err = sm.storeStub(stub); err != nil { + log.Printf("Error when storing Stub from %s. %v. skipping...", file.Name(), err) + } } } diff --git a/stub/stub.go b/stub/stub.go index 5509e7ea..ab742a34 100644 --- a/stub/stub.go +++ b/stub/stub.go @@ -3,11 +3,13 @@ package stub import ( "encoding/json" "fmt" - "google.golang.org/grpc/codes" - "io/ioutil" + "io" "log" "net/http" - "strings" + + "golang.org/x/text/cases" + "golang.org/x/text/language" + "google.golang.org/grpc/codes" "github.com/go-chi/chi" ) @@ -44,7 +46,9 @@ func RunStubServer(opt Options) { func responseError(err error, w http.ResponseWriter) { w.WriteHeader(500) - w.Write([]byte(err.Error())) + if _, err = w.Write([]byte(err.Error())); err != nil { + log.Println("Error writing response: %w", err) + } } type Stub struct { @@ -67,7 +71,7 @@ type Output struct { } func addStub(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { responseError(err, w) return @@ -92,26 +96,30 @@ func addStub(w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte("Success add stub")) + if _, err = w.Write([]byte("Success add stub")); err != nil { + log.Println("Error writing response: %w", err) + } } func listStub(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(allStub()) + if err := json.NewEncoder(w).Encode(allStub()); err != nil { + log.Println("Error writing listStub response: %w", err) + } } func validateStub(stub *Stub) error { if stub.Service == "" { - return fmt.Errorf("Service name can't be empty") + return fmt.Errorf("service name can't be empty") } if stub.Method == "" { - return fmt.Errorf("Method name can't be emtpy") + return fmt.Errorf("method name can't be emtpy") } // due to golang implementation // method name must capital - stub.Method = strings.Title(stub.Method) + stub.Method = cases.Title(language.Und, cases.NoLower).String(stub.Method) switch { case stub.Input.Contains != nil: @@ -148,7 +156,7 @@ func handleFindStub(w http.ResponseWriter, r *http.Request) { // due to golang implementation // method name must capital - stub.Method = strings.Title(stub.Method) + stub.Method = cases.Title(language.Und, cases.NoLower).String(stub.Method) output, err := findStub(stub) if err != nil { @@ -158,10 +166,14 @@ func handleFindStub(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(output) + if err := json.NewEncoder(w).Encode(output); err != nil { + log.Println("Error writing handleFindStub response: %w", err) + } } func handleClearStub(w http.ResponseWriter, r *http.Request) { clearStorage() - w.Write([]byte("OK")) + if _, err := w.Write([]byte("OK")); err != nil { + log.Println("Error writing handleClearStub response: %w", err) + } }