From f3bc3b752c0cd9d981921bb828a9af59f1c92f49 Mon Sep 17 00:00:00 2001 From: Gradey Cullins Date: Wed, 15 Nov 2023 19:44:09 -0700 Subject: [PATCH] Add golangci lint config - fix lint warnings - fix tests - fix other --- .golangci.yaml | 3 ++ main.go | 2 +- src/filter.go | 12 +++--- src/handlers.go | 22 +++++++---- src/images.go | 4 +- src/images_test.go | 8 +++- src/remove_test.go | 2 +- src/safe_search.go | 2 +- src/server.go | 4 +- src/server_test.go | 92 ++++++++++++++++++++++++---------------------- 10 files changed, 88 insertions(+), 63 deletions(-) create mode 100644 .golangci.yaml diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..4601e46 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,3 @@ +linters-settings: + govet: + check-shadowing: true \ No newline at end of file diff --git a/main.go b/main.go index 4d0d82b..29e5e2b 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,6 @@ import ( ) func main() { - godotenv.Load() + _ = godotenv.Load() src.InitServer() } diff --git a/src/filter.go b/src/filter.go index a0596a9..41a98aa 100644 --- a/src/filter.go +++ b/src/filter.go @@ -17,7 +17,7 @@ func getCachedSSAs(ctx appContext, uris []string) ([]*ImageAnnotation, []string, return nil, nil, err } - uncachedURIs := make([]string, 0) + uncachedURIs := make([]string, 0, len(uris)) for _, uri := range uris { found := false @@ -59,7 +59,7 @@ func filterImages(ctx appContext, uris []string, licenseID string) ([]*ImageAnno uris = uris[:remainingUsage] } if remainingUsage <= 0 { // return early if trial license is expired - license, err := ctx.licenseStore.ExpireTrial(license) + license, err = ctx.licenseStore.ExpireTrial(license) if err != nil { return res, fmt.Errorf("failed to mark trial license as expired: %s", err.Error()) } else { @@ -78,21 +78,21 @@ func filterImages(ctx appContext, uris []string, licenseID string) ([]*ImageAnno if err = ctx.licenseStore.UpdateLicense(license); err != nil { ctx.logger.Error().Msgf("failed to update license request count: %s", err) } - if err := IncrementSubscriptionMeter(ctx.config.StripeKey, license, int64(len(annotateImageResponses))); err != nil { + if err = IncrementSubscriptionMeter(ctx.config.StripeKey, license, int64(len(annotateImageResponses))); err != nil { ctx.logger.Error().Msgf("failed to update stripe subscription usage: %s", err.Error()) } } var buildSSARes = func(annotations []*pb.AnnotateImageResponse) []*ImageAnnotation { - var res []*ImageAnnotation + var annoRes []*ImageAnnotation for i, annotation := range annotations { if annotation == nil { continue } uri := uris[i] - res = append(res, annotationToSafeSearchResponseRes(uri, annotation)) + annoRes = append(annoRes, annotationToSafeSearchResponseRes(uri, annotation)) } - return res + return annoRes } safeSearchAnnotationsRes := buildSSARes(annotateImageResponses) diff --git a/src/handlers.go b/src/handlers.go index 482a1ec..e80096d 100644 --- a/src/handlers.go +++ b/src/handlers.go @@ -16,15 +16,19 @@ import ( "github.com/stripe/stripe-go/v74/webhook" ) -func health(w http.ResponseWriter, req *http.Request) { +func handleHealth(ctx appContext, w http.ResponseWriter, req *http.Request) (int, error) { w.WriteHeader(200) - w.Write([]byte("All Good ☮️")) + if _, err := w.Write([]byte("All Good ☮️")); err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err) + } + + return http.StatusOK, nil } const MAX_IMAGES_PER_REQUEST = 16 func removeDuplicates(logger zerolog.Logger, vals []string) []string { - res := make([]string, 0) + res := make([]string, 0, len(vals)) strMap := make(map[string]bool, 0) for _, v := range vals { @@ -81,7 +85,9 @@ func handleBatchFilter(ctx appContext, w http.ResponseWriter, req *http.Request) } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) + if err := json.NewEncoder(w).Encode(res); err != nil { + return http.StatusInternalServerError, err + } return http.StatusOK, nil } @@ -123,7 +129,7 @@ func handleWebhook(ctx appContext, w http.ResponseWriter, req *http.Request) (in if license != nil { ctx.logger.Debug().Msg("existing license found, ensuring IsValid is true") license.IsValid = true - if err := ctx.licenseStore.UpdateLicense(license); err != nil { + if err = ctx.licenseStore.UpdateLicense(license); err != nil { return http.StatusInternalServerError, errors.New("") } // TODO: email person to remind them their subscription is renewed. @@ -152,7 +158,7 @@ func handleWebhook(ctx appContext, w http.ResponseWriter, req *http.Request) (in metadata := map[string]string{ "license": licenseID, } - if _, err := customer.Update(session.Customer.ID, &stripe.CustomerParams{ + if _, err = customer.Update(session.Customer.ID, &stripe.CustomerParams{ Params: stripe.Params{Metadata: metadata}, }); err != nil { return http.StatusInternalServerError, fmt.Errorf("error adding license to customer metadata: %v", err) @@ -219,7 +225,9 @@ func handleGetLicense(ctx appContext, w http.ResponseWriter, req *http.Request) // return // } - json.NewEncoder(w).Encode(license) + if err := json.NewEncoder(w).Encode(license); err != nil { + return http.StatusInternalServerError, err + } return http.StatusOK, nil } diff --git a/src/images.go b/src/images.go index d2354a6..a2118ee 100644 --- a/src/images.go +++ b/src/images.go @@ -42,7 +42,9 @@ func FindAnnotationsByURI(conn pg.DB, uris []string) ([]ImageAnnotation, error) return nil, fmt.Errorf("imgURIList cannot be empty") } - conn.Model(&annotations).Where("uri IN (?)", pg.In(uris)).Select() + if err := conn.Model(&annotations).Where("uri IN (?)", pg.In(uris)).Select(); err != nil { + return nil, err + } return annotations, nil } diff --git a/src/images_test.go b/src/images_test.go index 8ec6487..30b8f20 100644 --- a/src/images_test.go +++ b/src/images_test.go @@ -2,6 +2,7 @@ package src import ( "database/sql" + "fmt" "os" "testing" "time" @@ -18,6 +19,7 @@ func getTestCtx() (appContext, error) { return ctx, err } config.DBName = "purity_test" + config.DBHost = "localhost" conn, err := InitDB(config) if err != nil { return ctx, err @@ -31,10 +33,14 @@ func getTestCtx() (appContext, error) { } func TestMain(m *testing.M) { - godotenv.Load() + if err := godotenv.Load("../.env"); err != nil { + fmt.Println(err) + } + m.Run() } func TestImages(t *testing.T) { + // t.FailNow() ctx, err := getTestCtx() if err != nil { t.Fatal(err) diff --git a/src/remove_test.go b/src/remove_test.go index eded079..913b700 100644 --- a/src/remove_test.go +++ b/src/remove_test.go @@ -25,7 +25,7 @@ func TestStringSliceRemove(t *testing.T) { t.Fatalf("result array should be empty") } - if _, err := StringSliceRemove(arr, 100); err == nil { + if _, err = StringSliceRemove(arr, 100); err == nil { t.Fatal("an error should have been thrown") } diff --git a/src/safe_search.go b/src/safe_search.go index c1a4337..cc59476 100644 --- a/src/safe_search.go +++ b/src/safe_search.go @@ -17,7 +17,7 @@ func batchAnnotateURIs(uris []string) (*pb.BatchAnnotateImagesResponse, error) { } defer client.Close() - requests := make([]*pb.AnnotateImageRequest, 0) + requests := make([]*pb.AnnotateImageRequest, 0, len(uris)) for _, uri := range uris { requests = append(requests, &pb.AnnotateImageRequest{ Image: vision.NewImageFromURI(uri), diff --git a/src/server.go b/src/server.go index a315b45..c8a330f 100644 --- a/src/server.go +++ b/src/server.go @@ -44,7 +44,7 @@ func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Updated to pass ah.appContext as a parameter to our handler type. status, err := ah.H(ah.appContext, w, r) if err != nil { - log.Printf("HTTP %d: %q", status, err) + ah.appContext.logger.Printf("HTTP %d: %q", status, err) switch status { case http.StatusNotFound: http.NotFound(w, r) @@ -96,7 +96,7 @@ func InitServer() { r.Use(addCorsHeaders) r.Handle("/", http.FileServer(http.Dir("./"))).Methods("GET") - r.HandleFunc("/health", health).Methods("GET", "OPTIONS") + r.Handle("/health", &appHandler{ctx, handleHealth}).Methods("GET", "OPTIONS") r.Handle("/license/{id}", &appHandler{ctx, handleGetLicense}).Methods("GET", "OPTIONS") r.Handle("/webhook", &appHandler{ctx, handleWebhook}).Methods("POST") // r.HandleFunc("/trial-register", handleTrialRegister).Methods("POST", "OPTIONS") diff --git a/src/server_test.go b/src/server_test.go index 37aff9c..8b05b64 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -33,59 +33,62 @@ type FilterTest struct { Expect FilterTestExpect } -func testHealthNoBody(t *testing.T) { - req, err := http.NewRequest("GET", "/health", nil) - if err != nil { - t.Error("Failed to create test HTTP request") - } - - rr := httptest.NewRecorder() - handler := http.HandlerFunc(health) - - handler.ServeHTTP(rr, req) - - if rr.Code != 200 { - t.Errorf("Health endpoint expected response 200 but got %d", rr.Code) - } -} - type junkData struct { Name string Color int } -// The health endpoint given junk POST data should still simply return a 200 code. -func testHealthJunkBody(t *testing.T) { - someData := junkData{ - Name: "pil", - Color: 221, - } - b, err := json.Marshal(someData) - if err != nil { - t.Error("Failed to marshal request body struct") - } - r := bytes.NewReader(b) - req, err := http.NewRequest("POST", "/health", r) +func TestHealthEndpoint(t *testing.T) { + ctx, err := getTestCtx() if err != nil { - t.Error("Failed to create test HTTP request") + t.Error(err) } - rr := httptest.NewRecorder() - handler := http.HandlerFunc(health) + t.Run("returns 200 if there is no POST body", func(t *testing.T) { + req, err := http.NewRequest("GET", "/health", nil) + if err != nil { + t.Error("Failed to create test HTTP request") + } - handler.ServeHTTP(rr, req) + rr := httptest.NewRecorder() - if rr.Code != 200 { - t.Errorf("Health endpoint expected response 200 but got %d", rr.Code) - } -} + code, err := handleHealth(ctx, rr, req) + if err != nil { + t.Error(err) + } + if code != 200 { + t.Errorf("Health endpoint expected response 200 but got %d", rr.Code) + } + }) + + t.Run("returns 200 if there is a junk POST body", func(t *testing.T) { + someData := junkData{ + Name: "pil", + Color: 221, + } + b, err := json.Marshal(someData) + if err != nil { + t.Error("Failed to marshal request body struct") + } + r := bytes.NewReader(b) + req, err := http.NewRequest("POST", "/health", r) + if err != nil { + t.Error("Failed to create test HTTP request") + } + + rr := httptest.NewRecorder() -func TestHealthHandler(t *testing.T) { - testHealthNoBody(t) - testHealthJunkBody(t) + code, err := handleHealth(ctx, rr, req) + if err != nil { + t.Error(err) + } + if code != 200 { + t.Errorf("Health endpoint expected response 200 but got %d", rr.Code) + } + }) } -func TestServer(t *testing.T) { +func TestFilterEndpoint(t *testing.T) { ctx, err := getTestCtx() if err != nil { t.Fatal(err) @@ -93,8 +96,11 @@ func TestServer(t *testing.T) { ctx.logger = zerolog.Logger{} t.Cleanup(func() { - ctx.db.Model(&ImageAnnotation{}).Where("1=1").Delete() - _, err := ctx.db.Model(&License{}).Where("1=1").Delete() + _, err := ctx.db.Model(&ImageAnnotation{}).Where("1=1").Delete() + if err != nil { + fmt.Println("error: ", err) + } + _, err = ctx.db.Model(&License{}).Where("1=1").Delete() if err != nil { fmt.Println("error: ", err) } @@ -231,7 +237,7 @@ func TestServer(t *testing.T) { t.Errorf("expected status %d but got %d", test.Expect.Code, rec.Code) } var annotations []*ImageAnnotation - json.Unmarshal(rec.Body.Bytes(), &annotations) + _ = json.Unmarshal(rec.Body.Bytes(), &annotations) if len(annotations) != len(test.Expect.Res) { t.Errorf("expected %d annotation results but got %d", len(test.Expect.Res), len(annotations))