diff --git a/controllers/url.go b/controllers/url.go index 0dd0962..3815732 100644 --- a/controllers/url.go +++ b/controllers/url.go @@ -31,15 +31,21 @@ func CreateTinyURL(ctx *gin.Context, db *bun.DB) { return } + userID, exists := ctx.Get("userID") + if !exists { + ctx.JSON(http.StatusUnauthorized, dtos.URLCreationResponse{ + Message: "User not authenticated", + }) + return + } + var existingOriginalURL models.Tinyurl if err := db.NewSelect().Model(&existingOriginalURL). - Where("original_url = ?", body.OriginalUrl). - Where("user_id = ?", body.UserID). - Where("is_deleted = ?", false). + Where("original_url = ? AND user_id = ? AND is_deleted = ?", body.OriginalUrl, userID, false). Scan(ctx); err == nil { ctx.JSON(http.StatusOK, dtos.URLCreationResponse{ - Message: "Shortened URL already exists", - ShortURL: existingOriginalURL.ShortUrl, + Message: "Shortened URL already exists", + ShortURL: existingOriginalURL.ShortUrl, CreatedAt: existingOriginalURL.CreatedAt, }) return @@ -76,8 +82,7 @@ func CreateTinyURL(ctx *gin.Context, db *bun.DB) { count, err := db.NewSelect(). Model(&models.Tinyurl{}). - Where("user_id = ?", body.UserID). - Where("is_deleted = ?", false). + Where("user_id = ? AND is_deleted = ?", userID, false). Count(ctx) if err != nil { @@ -94,16 +99,21 @@ func CreateTinyURL(ctx *gin.Context, db *bun.DB) { return } - body.CreatedAt = time.Now().UTC() + newTinyURL := models.Tinyurl{ + OriginalUrl: body.OriginalUrl, + ShortUrl: body.ShortUrl, + UserID: userID.(int64), + CreatedAt: time.Now().UTC(), + } - if _, err := db.NewInsert().Model(&body).Exec(ctx); err != nil { + if _, err := db.NewInsert().Model(&newTinyURL).Exec(ctx); err != nil { ctx.JSON(http.StatusInternalServerError, dtos.URLCreationResponse{ Message: "Failed to create tiny URL", }) return } - if err := utils.IncrementURLCount(body.UserID, db, ctx); err != nil { + if err := utils.IncrementURLCount(userID.(int64), db, ctx); err != nil { ctx.JSON(http.StatusInternalServerError, dtos.URLCreationResponse{ Message: "Failed to increment URL count: " + err.Error(), }) @@ -112,8 +122,7 @@ func CreateTinyURL(ctx *gin.Context, db *bun.DB) { updatedCount, err := db.NewSelect(). Model(&models.Tinyurl{}). - Where("user_id = ?", body.UserID). - Where("is_deleted = ?", false). + Where("user_id = ? AND is_deleted = ?", userID, false). Count(ctx) if err != nil { @@ -124,9 +133,9 @@ func CreateTinyURL(ctx *gin.Context, db *bun.DB) { } ctx.JSON(http.StatusOK, dtos.URLCreationResponse{ - Message: "Tiny URL created successfully", - ShortURL: body.ShortUrl, - URLCount: updatedCount, + Message: "Tiny URL created successfully", + ShortURL: newTinyURL.ShortUrl, + URLCount: updatedCount, }) } @@ -215,51 +224,55 @@ func GetAllURLs(ctx *gin.Context, db *bun.DB) { } func DeleteURL(ctx *gin.Context, db *bun.DB) { - id, _ := ctx.Params.Get("id") - - var body struct { - UserID int64 `json:"user_id"` - } - if err := ctx.BindJSON(&body); err != nil { - ctx.JSON(http.StatusBadRequest, dtos.UserURLsResponse{ - Message: "Invalid Request.", - }) - return - } - - _, err := db.NewUpdate().Model(&models.Tinyurl{}).Set("is_deleted=?", true).Set("deleted_at=?", time.Now().UTC()).Where("id = ?", id).Exec(ctx) - if err != nil { - ctx.JSON(http.StatusNotFound, dtos.UserURLsResponse{ - Message: "No URLs found", - }) - return - } - - - if err := utils.DecrementURLCount(body.UserID, db, ctx); err != nil { - ctx.JSON(http.StatusInternalServerError, dtos.URLCreationResponse{ - Message: "Failed to decrement URL count: " + err.Error(), - }) - return - } + id, _ := ctx.Params.Get("id") + + userID, exists := ctx.Get("userID") + if !exists { + ctx.JSON(http.StatusUnauthorized, dtos.UserURLsResponse{ + Message: "User not authenticated", + }) + return + } + + _, err := db.NewUpdate(). + Model(&models.Tinyurl{}). + Set("is_deleted=?", true). + Set("deleted_at=?", time.Now().UTC()). + Where("id = ?", id). + Where("user_id = ?", userID). + Exec(ctx) + + if err != nil { + ctx.JSON(http.StatusNotFound, dtos.UserURLsResponse{ + Message: "No URLs found", + }) + return + } + + if err := utils.DecrementURLCount(userID.(int64), db, ctx); err != nil { + ctx.JSON(http.StatusInternalServerError, dtos.URLCreationResponse{ + Message: "Failed to decrement URL count: " + err.Error(), + }) + return + } updatedCount, err := db.NewSelect(). - Model(&models.Tinyurl{}). - Where("user_id = ?", body.UserID). - Where("is_deleted = ?", false). - Count(ctx) - - if err != nil { - ctx.JSON(http.StatusInternalServerError, dtos.URLCreationResponse{ - Message: "Failed to fetch updated URL count", - }) - return - } - - ctx.JSON(http.StatusOK, dtos.URLDeleteResponse{ - Message: "URL deleted", - URLCount: updatedCount, - }) + Model(&models.Tinyurl{}). + Where("user_id = ?", userID). + Where("is_deleted = ?", false). + Count(ctx) + + if err != nil { + ctx.JSON(http.StatusInternalServerError, dtos.URLCreationResponse{ + Message: "Failed to fetch updated URL count", + }) + return + } + + ctx.JSON(http.StatusOK, dtos.URLDeleteResponse{ + Message: "URL deleted", + URLCount: updatedCount, + }) } func GetURLDetails(ctx *gin.Context, db *bun.DB) { diff --git a/middlewares/auth.go b/middlewares/auth.go index e632476..1926924 100644 --- a/middlewares/auth.go +++ b/middlewares/auth.go @@ -18,7 +18,7 @@ func AuthMiddleware() gin.HandlerFunc { token := tokenCookie.Value - email, err := utils.VerifyToken(token) + claims, err := utils.VerifyToken(token) if err != nil { ctx.JSON(http.StatusUnauthorized, gin.H{"message": "Unauthorized"}) @@ -26,7 +26,16 @@ func AuthMiddleware() gin.HandlerFunc { return } - ctx.Set("user", email) + userID, ok := claims["userID"].(float64) + + if !ok { + ctx.JSON(http.StatusUnauthorized, gin.H{"message": "Invalid UserID format"}) + ctx.Abort() + return + } + + ctx.Set("user", claims["email"]) + ctx.Set("userID", int64(userID)) ctx.Next() } -} +} \ No newline at end of file diff --git a/routes/url.go b/routes/url.go index 91be58c..aaeb872 100644 --- a/routes/url.go +++ b/routes/url.go @@ -29,7 +29,7 @@ func TinyURLRoutes(rg *gin.RouterGroup, db *bun.DB) { redirect.GET("/:shortURL", func(ctx *gin.Context) { controller.RedirectShortURL(ctx, db) }) - urls.DELETE("/:id", func(ctx *gin.Context) { + urls.DELETE("/:id", middleware.AuthMiddleware(), func(ctx *gin.Context) { controller.DeleteURL(ctx, db) - }) + }) } diff --git a/tests/integration/url_test.go b/tests/integration/url_test.go index acefb50..3b6eda8 100644 --- a/tests/integration/url_test.go +++ b/tests/integration/url_test.go @@ -15,6 +15,12 @@ import ( func (suite *AppTestSuite) TestCreateTinyURLSuccess() { // Setup the router and route for creating a tiny URL router := gin.Default() + + router.Use(func(ctx *gin.Context) { + ctx.Set("userID", int64(1)) + ctx.Next() + }) + router.POST("/v1/tinyurl", func(ctx *gin.Context) { controller.CreateTinyURL(ctx, suite.db) }) @@ -81,6 +87,12 @@ func (suite *AppTestSuite) TestCreateTinyURLEmptyOriginalURL() { // TestCreateTinyURLCustomShortURL tests the creation of a tiny URL with a custom short URL and expects a successful response. func (suite *AppTestSuite) TestCreateTinyURLCustomShortURL() { router := gin.Default() + + router.Use(func(ctx *gin.Context) { + ctx.Set("userID", int64(1)) + ctx.Next() + }) + router.POST("/v1/tinyurl", func(ctx *gin.Context) { controller.CreateTinyURL(ctx, suite.db) }) @@ -88,7 +100,6 @@ func (suite *AppTestSuite) TestCreateTinyURLCustomShortURL() { requestBody := map[string]interface{}{ "OriginalUrl": "https://example.com", "ShortUrl": "short", - "UserId": 1, } requestJSON, _ := json.Marshal(requestBody) req, _ := http.NewRequest("POST", "/v1/tinyurl", bytes.NewBuffer(requestJSON)) @@ -102,6 +113,12 @@ func (suite *AppTestSuite) TestCreateTinyURLCustomShortURL() { func (suite *AppTestSuite) TestCreateTinyURLCustomShortURLExists() { router := gin.Default() + + router.Use(func(ctx *gin.Context) { + ctx.Set("userID", int64(1)) + ctx.Next() + }) + router.POST("/v1/tinyurl", func(ctx *gin.Context) { controller.CreateTinyURL(ctx, suite.db) }) @@ -109,7 +126,6 @@ func (suite *AppTestSuite) TestCreateTinyURLCustomShortURLExists() { requestBody := map[string]interface{}{ "OriginalUrl": "https://rds.com", "ShortUrl": "37fff", - "UserId": 1, } requestJSON, _ := json.Marshal(requestBody) req, _ := http.NewRequest("POST", "/v1/tinyurl", bytes.NewBuffer(requestJSON)) @@ -123,6 +139,12 @@ func (suite *AppTestSuite) TestCreateTinyURLCustomShortURLExists() { func (suite *AppTestSuite) TestCreateTinyURLExistingOriginalURL() { router := gin.Default() + + router.Use(func(ctx *gin.Context) { + ctx.Set("userID", int64(1)) + ctx.Next() + }) + router.POST("/v1/tinyurl", func(ctx *gin.Context) { controller.CreateTinyURL(ctx, suite.db) }) @@ -131,7 +153,6 @@ func (suite *AppTestSuite) TestCreateTinyURLExistingOriginalURL() { requestBody := map[string]interface{}{ "OriginalUrl": existingOriginalURL, - "UserId": 1, } requestJSON, _ := json.Marshal(requestBody) req, _ := http.NewRequest("POST", "/v1/tinyurl", bytes.NewBuffer(requestJSON)) diff --git a/tests/unit/jwt_test.go b/tests/unit/jwt_test.go index 1992993..1ce8806 100644 --- a/tests/unit/jwt_test.go +++ b/tests/unit/jwt_test.go @@ -3,9 +3,12 @@ package unit import ( "os" "testing" + "time" + "github.com/Real-Dev-Squad/tiny-site-backend/config" "github.com/Real-Dev-Squad/tiny-site-backend/models" "github.com/Real-Dev-Squad/tiny-site-backend/utils" + "github.com/golang-jwt/jwt/v5" ) func TestMain(m *testing.M) { @@ -19,6 +22,7 @@ func TestMain(m *testing.M) { func TestGenerateJWT(t *testing.T) { dummyUser := &models.User{ Email: "test@gmail.com", + ID: 123, } token, err := utils.GenerateToken(dummyUser) @@ -36,32 +40,51 @@ func TestVerifyJWT(t *testing.T) { t.Run("ValidToken", func(t *testing.T) { dummyUser := &models.User{ Email: "test@gmail.com", + ID: 123, } validToken, generateTokenError := utils.GenerateToken(dummyUser) - if generateTokenError != nil { t.Fatalf("Error: %v", generateTokenError) } - email, validTokenError := utils.VerifyToken(validToken) - + claims, validTokenError := utils.VerifyToken(validToken) if validTokenError != nil { t.Fatalf("Error: %v", validTokenError) } - if email != dummyUser.Email { - t.Fatalf("Expected %v but got %v", dummyUser.Email, email) + if claims["email"] != dummyUser.Email { + t.Fatalf("Expected email %v but got %v", dummyUser.Email, claims["email"]) + } + + if claims["userID"] != float64(dummyUser.ID) { + t.Fatalf("Expected userID %v but got %v", dummyUser.ID, claims["userID"]) } }) t.Run("ExpiredToken", func(t *testing.T) { - expiredToken := "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAZ21haWwuY29tIiwiZXhwIjoiMjAyMy0xMC0wMVQxOTo1Njo0OS4zOTc5NzEyWiIsImlzcyI6Indpc2VlLWJhY2tlbmQifQ.h11JtaPg-ITKR8UXTyz_Q7pJU_3gYyXwIkqX7lI1UK2nVkvxQvkyN23-u3wj8fV5mNIvp-ePTOp-7odsPcGC_g" + expiredToken := jwt.NewWithClaims(jwt.SigningMethodHS512, jwt.MapClaims{ + "iss": config.JwtIssuer, + "exp": time.Now().Add(-time.Hour).Unix(), + "email": "test@gmail.com", + "userID": 123, + }) + + key := []byte(config.JwtSecret) + expiredTokenString, _ := expiredToken.SignedString(key) + + _, expiredTokenError := utils.VerifyToken(expiredTokenString) + if expiredTokenError != utils.ErrTokenExpired { + t.Fatalf("Expected error %v but got %v", utils.ErrTokenExpired, expiredTokenError) + } + }) - _, expiredTokenError := utils.VerifyToken(expiredToken) + t.Run("InvalidToken", func(t *testing.T) { + invalidToken := "invalid.token.here" - if expiredTokenError == nil { - t.Fatalf("Expected error but got nil") + _, invalidTokenError := utils.VerifyToken(invalidToken) + if invalidTokenError == nil { + t.Fatalf("Expected an error but got nil") } }) } @@ -72,21 +95,24 @@ func TestVerifyJWTForOneYear(t *testing.T) { dummyUser := &models.User{ Email: "test@gmail.com", + ID: 123, } validToken, generateTokenError := utils.GenerateToken(dummyUser) - if generateTokenError != nil { t.Fatalf("Error: %v", generateTokenError) } - email, validTokenError := utils.VerifyToken(validToken) - + claims, validTokenError := utils.VerifyToken(validToken) if validTokenError != nil { t.Fatalf("Error: %v", validTokenError) } - if email != dummyUser.Email { - t.Fatalf("Expected %v but got %v", dummyUser.Email, email) + if claims["email"] != dummyUser.Email { + t.Fatalf("Expected email %v but got %v", dummyUser.Email, claims["email"]) + } + + if claims["userID"] != float64(dummyUser.ID) { + t.Fatalf("Expected userID %v but got %v", dummyUser.ID, claims["userID"]) } } diff --git a/utils/jwt.go b/utils/jwt.go index 46cd732..aede20a 100644 --- a/utils/jwt.go +++ b/utils/jwt.go @@ -9,58 +9,53 @@ import ( "github.com/golang-jwt/jwt/v5" ) -/* - * GenerateToken generates a JWT token for the user - */ +var ( + ErrUnexpectedSigningMethod = errors.New("unexpected signing method") + ErrInvalidToken = errors.New("invalid token") + ErrTokenExpired = errors.New("token has expired") +) + func GenerateToken(user *models.User) (string, error) { - issuer := config.JwtIssuer key := []byte(config.JwtSecret) + expiryTime := time.Now().Add(time.Duration(config.JwtValidity) * time.Hour).UTC() + + claims := jwt.MapClaims{ + "iss": config.JwtIssuer, + "exp": expiryTime.Unix(), + "iat": time.Now().UTC().Unix(), + "email": user.Email, + "userID": user.ID, + } - tokenValidityInHours := config.JwtValidity - - tokenExpiryTime := time.Now().Add(time.Duration(tokenValidityInHours) * time.Hour).UTC().Format(time.RFC3339) - - t := jwt.NewWithClaims(jwt.SigningMethodHS512, jwt.MapClaims{ - "iss": issuer, - "exp": tokenExpiryTime, - "email": user.Email, - }) - - token, error := t.SignedString(key) - - return token, error + token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) + return token.SignedString(key) } -/* - * VerifyToken verifies the token and returns the email of the user - */ -func VerifyToken(tokenString string) (string, error) { - var claims jwt.MapClaims = nil +func VerifyToken(tokenString string) (jwt.MapClaims, error) { + key := []byte(config.JwtSecret) - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Parsing the token - if token.Method.Alg() != jwt.SigningMethodHS512.Alg() { - return nil, jwt.ErrSignatureInvalid - } + token, err := jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(config.JwtSecret), nil + //validatint the algo + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, ErrUnexpectedSigningMethod + } + return key, nil }) - if c, ok := token.Claims.(jwt.MapClaims); !ok && !token.Valid { - return "", err - } else { - claims = c - } - - expiryTime, err := time.Parse(time.RFC3339, claims["exp"].(string)) - if err != nil { - return "", err + if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) { + return nil, ErrTokenExpired + } + return nil, err } - if time.Now().UTC().After(expiryTime) { - return "", errors.New("token has expired") + // Validating the token and casting the claims :P + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + return claims, nil } - return claims["email"].(string), nil + return nil, ErrInvalidToken }