Skip to content

Commit

Permalink
Highlights removed when a user is removed
Browse files Browse the repository at this point in the history
  • Loading branch information
svera authored Nov 4, 2023
1 parent 75f9090 commit df7857d
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 41 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ bin
tmp
*.zip
*.exe
.air.toml
21 changes: 9 additions & 12 deletions internal/controller/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ import (
"github.com/svera/coreander/v3/internal/infrastructure"
"github.com/svera/coreander/v3/internal/model"
"golang.org/x/text/message"
"gorm.io/gorm"
)

type authRepository interface {
FindByEmail(email string) (model.User, error)
FindByRecoveryUuid(recoveryUuid string) (model.User, error)
Update(user model.User) error
FindByEmail(email string) (*model.User, error)
FindByRecoveryUuid(recoveryUuid string) (*model.User, error)
Update(user *model.User) error
}

type recoveryEmail interface {
Expand Down Expand Up @@ -91,7 +90,7 @@ func (a *Auth) Login(c *fiber.Ctx) error {
// Signs in a user and gives them a JWT.
func (a *Auth) SignIn(c *fiber.Ctx) error {
var (
user model.User
user *model.User
err error
)

Expand All @@ -101,7 +100,7 @@ func (a *Auth) SignIn(c *fiber.Ctx) error {
return fiber.ErrInternalServerError
}

if user.Password != model.Hash(c.FormValue("password")) {
if user == nil || user.Password != model.Hash(c.FormValue("password")) {
return c.Status(fiber.StatusUnauthorized).Render("auth/login", fiber.Map{
"Title": "Login",
"Error": "Wrong email or password",
Expand All @@ -111,9 +110,7 @@ func (a *Auth) SignIn(c *fiber.Ctx) error {
// Send back JWT as a cookie.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"userdata": model.User{
Model: gorm.Model{
ID: user.ID,
},
ID: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
Expand Down Expand Up @@ -250,13 +247,13 @@ func (a *Auth) UpdatePassword(c *fiber.Ctx) error {
return c.Redirect(fmt.Sprintf("/%s/login", c.Params("lang")))
}

func (a *Auth) validateRecoveryAccess(c *fiber.Ctx, recoveryUuid string) (model.User, error) {
func (a *Auth) validateRecoveryAccess(c *fiber.Ctx, recoveryUuid string) (*model.User, error) {
if _, ok := a.sender.(*infrastructure.NoEmail); ok {
return model.User{}, fiber.ErrNotFound
return &model.User{}, fiber.ErrNotFound
}

if recoveryUuid == "" {
return model.User{}, fiber.ErrBadRequest
return &model.User{}, fiber.ErrBadRequest
}
user, err := a.repository.FindByRecoveryUuid(recoveryUuid)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion internal/controller/highlights.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ func (h *Highlights) Highlights(c *fiber.Ctx) error {

user, err := h.usrRepository.FindByUuid(c.Params("uuid"))
if err != nil {
return fiber.ErrBadRequest
return fiber.ErrInternalServerError
}

if user == nil {
return fiber.ErrNotFound
}

highlights, err := h.hlRepository.Highlights(int(user.ID), page, model.ResultsPerPage)
Expand Down
18 changes: 9 additions & 9 deletions internal/controller/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
type usersRepository interface {
List(page int, resultsPerPage int) ([]model.User, error)
Total() int64
FindByUuid(uuid string) (model.User, error)
Create(user model.User) error
Update(user model.User) error
FindByEmail(email string) (model.User, error)
FindByUuid(uuid string) (*model.User, error)
Create(user *model.User) error
Update(user *model.User) error
FindByEmail(email string) (*model.User, error)
Admins() int64
Delete(uuid string) error
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func (u *Users) Create(c *fiber.Ctx) error {
user.WordsPerMinute, _ = strconv.ParseFloat(c.FormValue("words-per-minute"), 64)

errs := user.Validate(u.minPasswordLength)
if exist, _ := u.repository.FindByEmail(c.FormValue("email")); exist.Email != "" {
if exist, _ := u.repository.FindByEmail(c.FormValue("email")); exist != nil {
errs["email"] = "A user with this email address already exist"
}

Expand All @@ -120,7 +120,7 @@ func (u *Users) Create(c *fiber.Ctx) error {
}

user.Password = model.Hash(user.Password)
if err := u.repository.Create(user); err != nil {
if err := u.repository.Create(&user); err != nil {
return fiber.ErrInternalServerError
}

Expand Down Expand Up @@ -163,7 +163,7 @@ func (u *Users) Update(c *fiber.Ctx) error {
}

if c.FormValue("password-tab") == "true" {
return u.updatePassword(c, session, user)
return u.updatePassword(c, session, *user)
}

user.Name = c.FormValue("name")
Expand Down Expand Up @@ -225,7 +225,7 @@ func (u *Users) updatePassword(c *fiber.Ctx, session, user model.User) error {
}

user.Password = model.Hash(user.Password)
if err := u.repository.Update(user); err != nil {
if err := u.repository.Update(&user); err != nil {
return fiber.ErrInternalServerError
}

Expand All @@ -244,7 +244,7 @@ func (u *Users) updatePassword(c *fiber.Ctx, session, user model.User) error {
func (u *Users) Delete(c *fiber.Ctx) error {
session := jwtclaimsreader.SessionData(c)

if session.Role != model.RoleAdmin && session.Uuid != c.Params("uuid") {
if session.Role != model.RoleAdmin {
return fiber.ErrForbidden
}

Expand Down
17 changes: 17 additions & 0 deletions internal/infrastructure/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ func Connect(path string, wordsPerMinute float64) *gorm.DB {
log.Printf("Created database at %s\n", path)
}

// Use the following line to connect when the temporary code block below is removed
//db, err := gorm.Open(sqlite.Open(fmt.Sprintf("%s?_pragma=foreign_keys(1)", path)), &gorm.Config{})
db, err := gorm.Open(sqlite.Open(path), &gorm.Config{})
if err != nil {
log.Fatal(err)
Expand All @@ -27,6 +29,21 @@ func Connect(path string, wordsPerMinute float64) *gorm.DB {
if err := db.AutoMigrate(&model.User{}, &model.Highlight{}); err != nil {
log.Fatal(err)
}
// The next block is temporary, used to add constraints to an en existing highlights table
// Remove when the new format is established
if !db.Migrator().HasConstraint(&model.User{}, "Highlights") {
err := db.Migrator().CreateConstraint(&model.User{}, "Highlights")
if err != nil {
log.Fatal(err)
}
err = db.Migrator().CreateConstraint(&model.User{}, "fk_users_highlights")
if err != nil {
log.Fatal(err)
}
}
if res := db.Exec("PRAGMA foreign_keys(1)", nil); res.Error != nil {
log.Fatal(err)
}
addDefaultAdmin(db, wordsPerMinute)
return db
}
Expand Down
8 changes: 4 additions & 4 deletions internal/model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package model
import (
"net/mail"
"time"

"gorm.io/gorm"
)

const (
Expand All @@ -13,7 +11,9 @@ const (
)

type User struct {
gorm.Model
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
Uuid string `gorm:"uniqueIndex"`
Name string
Email string `gorm:"uniqueIndex"`
Expand All @@ -23,7 +23,7 @@ type User struct {
WordsPerMinute float64
RecoveryUUID string
RecoveryValidUntil time.Time
Highlights []Highlight
Highlights []Highlight `gorm:"constraint:OnDelete:CASCADE"`
}

// Validate checks all user's fields to ensure they are in the required format
Expand Down
26 changes: 12 additions & 14 deletions internal/model/user_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,31 @@ func (u *UserRepository) Total() int64 {
return totalRows
}

func (u *UserRepository) FindByUuid(uuid string) (User, error) {
func (u *UserRepository) FindByUuid(uuid string) (*User, error) {
return u.find("uuid", uuid)
}

func (u *UserRepository) Create(user User) error {
if result := u.DB.Create(&user); result.Error != nil {
func (u *UserRepository) Create(user *User) error {
if result := u.DB.Create(user); result.Error != nil {
log.Printf("error creating user: %s\n", result.Error)
return result.Error
}
return nil
}

func (u *UserRepository) Update(user User) error {
if result := u.DB.Save(&user); result.Error != nil {
func (u *UserRepository) Update(user *User) error {
if result := u.DB.Save(user); result.Error != nil {
log.Printf("error updating user: %s\n", result.Error)
return result.Error
}
return nil
}

func (u *UserRepository) FindByEmail(email string) (User, error) {
func (u *UserRepository) FindByEmail(email string) (*User, error) {
return u.find("email", email)
}

func (u *UserRepository) FindByRecoveryUuid(recoveryUuid string) (User, error) {
func (u *UserRepository) FindByRecoveryUuid(recoveryUuid string) (*User, error) {
return u.find("recovery_uuid", recoveryUuid)
}

Expand All @@ -78,15 +78,13 @@ func Hash(s string) string {
return string(h.Sum(nil))
}

func (u *UserRepository) find(field, value string) (User, error) {
func (u *UserRepository) find(field, value string) (*User, error) {
var (
err error
user User
)
result := u.DB.Limit(1).Where(fmt.Sprintf("%s = ?", field), value).Find(&user)
if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = result.Error
log.Printf("error retrieving user: %s\n", result.Error)
result := u.DB.Where(fmt.Sprintf("%s = ?", field), value).First(&user)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return user, err
return &user, result.Error
}
1 change: 1 addition & 0 deletions internal/webserver/embedded/translations/es.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@
"%s highlights": "Destacados de %s"
"Remove from highlights": "Quitar de destacados"
"%d highlighted documents": "%d documentos destacados"
"Method not allowed": "Método no permitido"
1 change: 1 addition & 0 deletions internal/webserver/embedded/translations/fr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@
"%s highlights": "Favoris du %s"
"Remove from highlights": "Retirer des favoris"
"%d highlighted documents": "%d documents favoris"
"Method not allowed": "Méthode Non Autorisée"
3 changes: 3 additions & 0 deletions internal/webserver/embedded/views/errors/405.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<div class="px-4 py-5 my-5 text-center">
<h2>{{t .Lang "Method not allowed"}}</h2>
</div>
69 changes: 69 additions & 0 deletions internal/webserver/highlights_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,25 @@ func TestHighlights(t *testing.T) {
adminUser := model.User{}
db.Where("email = ?", "admin@example.com").First(&adminUser)

regularUserData := url.Values{
"name": {"Test user"},
"email": {"test@example.com"},
"password": {"test"},
"confirm-password": {"test"},
"role": {fmt.Sprint(model.RoleRegular)},
"words-per-minute": {"250"},
}

adminCookie, err := login(app, "admin@example.com", "admin")
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}

response, err := addUser(regularUserData, adminCookie, app)
if response == nil {
t.Fatalf("Unexpected error: %v", err.Error())
}

t.Run("Try to highlight a document without an active session", func(t *testing.T) {
response, err := highlight(&http.Cookie{}, app, strings.NewReader(data.Encode()), fiber.MethodPost)
if err != nil {
Expand Down Expand Up @@ -56,6 +70,46 @@ func TestHighlights(t *testing.T) {

assertHighlights(app, t, adminCookie, adminUser.Uuid, 0)
})

t.Run("Deleting a user also remove his/her highlights", func(t *testing.T) {
regularUser := model.User{}
db.Where("email = ?", "test@example.com").First(&regularUser)

regularUserCookie, err := login(app, "test@example.com", "test")
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}

response, err := highlight(regularUserCookie, app, strings.NewReader(data.Encode()), fiber.MethodPost)
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}

mustReturnStatus(response, fiber.StatusOK, t)

assertHighlights(app, t, regularUserCookie, regularUser.Uuid, 1)

adminCookie, err = login(app, "admin@example.com", "admin")
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}

data = url.Values{
"uuid": {regularUser.Uuid},
}

_, err = deleteUser(data, adminCookie, app)
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}

var total int64
db.Table("highlights").Where("user_id = ?", regularUser.ID).Count(&total)
if total != 0 {
t.Errorf("Expected no highlights in DB for deleted user, got %d", total)
}
assertNoHighlights(app, t, adminCookie, regularUser.Uuid)
})
}

func highlight(cookie *http.Cookie, app *fiber.App, reader *strings.Reader, method string) (*http.Response, error) {
Expand Down Expand Up @@ -92,3 +146,18 @@ func assertHighlights(app *fiber.App, t *testing.T, cookie *http.Cookie, uuid st
t.Errorf("Expected %d results, got %d", expectedResults, actualResults)
}
}

func assertNoHighlights(app *fiber.App, t *testing.T, cookie *http.Cookie, uuid string) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/en/highlights/%s", uuid), nil)
req.AddCookie(cookie)
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}
response, err := app.Test(req)
if err != nil {
t.Fatalf("Unexpected error: %v", err.Error())
}
if expectedStatus := http.StatusNotFound; response.StatusCode != expectedStatus {
t.Errorf("Expected status %d, received %d", expectedStatus, response.StatusCode)
}
}
2 changes: 1 addition & 1 deletion internal/webserver/user_management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestUserManagement(t *testing.T) {
"email": {"test@example.com"},
"password": {"test"},
"confirm-password": {"test"},
"role": {"1"},
"role": {fmt.Sprint(model.RoleRegular)},
"words-per-minute": {"250"},
}

Expand Down

0 comments on commit df7857d

Please sign in to comment.