Skip to content

Commit

Permalink
feat: 优化数据库删除
Browse files Browse the repository at this point in the history
  • Loading branch information
devhaozi committed Nov 26, 2024
1 parent 5d5633b commit c2ae9dc
Show file tree
Hide file tree
Showing 14 changed files with 292 additions and 97 deletions.
21 changes: 11 additions & 10 deletions internal/biz/database_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ const (
)

type DatabaseUser struct {
ID uint `gorm:"primaryKey" json:"id"`
ServerID uint `gorm:"not null" json:"server_id"`
Username string `gorm:"not null" json:"username"`
Password string `gorm:"not null" json:"password"`
Host string `gorm:"not null" json:"host"` // 仅 mysql
Status DatabaseUserStatus `gorm:"-:all" json:"status"` // 仅显示
Privileges map[string][]string `gorm:"-:all" json:"privileges"` // 仅显示
Remark string `gorm:"not null" json:"remark"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID uint `gorm:"primaryKey" json:"id"`
ServerID uint `gorm:"not null" json:"server_id"`
Username string `gorm:"not null" json:"username"`
Password string `gorm:"not null" json:"password"`
Host string `gorm:"not null" json:"host"` // 仅 mysql
Status DatabaseUserStatus `gorm:"-:all" json:"status"` // 仅显示
Privileges []string `gorm:"-:all" json:"privileges"` // 仅显示
Remark string `gorm:"not null" json:"remark"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`

Server *DatabaseServer `gorm:"foreignKey:ServerID;references:ID" json:"server"`
}
Expand Down Expand Up @@ -59,5 +59,6 @@ type DatabaseUserRepo interface {
Update(req *request.DatabaseUserUpdate) error
UpdateRemark(req *request.DatabaseUserUpdateRemark) error
Delete(id uint) error
DeleteByNames(serverID uint, names []string) error
DeleteByServerID(serverID uint) error
}
14 changes: 12 additions & 2 deletions internal/data/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
return err
}
if req.CreateUser {
if err = mysql.UserCreate(req.Username, req.Password, req.Host); err != nil {
if err = NewDatabaseUserRepo().Create(&request.DatabaseUserCreate{
ServerID: req.ServerID,
Username: req.Username,
Password: req.Password,
Host: req.Host,
}); err != nil {
return err
}
}
Expand All @@ -96,7 +101,12 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
return err
}
if req.CreateUser {
if err = postgres.UserCreate(req.Username, req.Password); err != nil {
if err = NewDatabaseUserRepo().Create(&request.DatabaseUserCreate{
ServerID: req.ServerID,
Username: req.Username,
Password: req.Password,
Host: req.Host,
}); err != nil {
return err
}
}
Expand Down
58 changes: 51 additions & 7 deletions internal/data/database_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
return err
}

user := new(biz.DatabaseUser)
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
Expand All @@ -70,6 +71,11 @@ func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
return err
}
}
user = &biz.DatabaseUser{
ServerID: req.ServerID,
Username: req.Username,
Host: req.Host,
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
Expand All @@ -83,12 +89,10 @@ func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
return err
}
}
}

user := &biz.DatabaseUser{
ServerID: req.ServerID,
Username: req.Username,
Host: req.Host,
user = &biz.DatabaseUser{
ServerID: req.ServerID,
Username: req.Username,
}
}

if err = app.Orm.FirstOrInit(user, user).Error; err != nil {
Expand Down Expand Up @@ -191,6 +195,46 @@ func (r databaseUserRepo) Delete(id uint) error {
return app.Orm.Where("id = ?", id).Delete(&biz.DatabaseUser{}).Error
}

func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
server, err := NewDatabaseServerRepo().Get(serverID)
if err != nil {
return err
}

switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return err
}
users := make([]*biz.DatabaseUser, 0)
if err = app.Orm.Where("server_id = ? AND username IN ?", serverID, names).Find(&users).Error; err != nil {
return err
}
for name := range slices.Values(names) {
host := "localhost"
for u := range slices.Values(users) {
if u.Username == name {
host = u.Host
break
}
}
_ = mysql.UserDrop(name, host)
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
for name := range slices.Values(names) {
_ = postgres.UserDrop(name)
}
}

return app.Orm.Where("server_id = ? AND username IN ?", serverID, names).Delete(&biz.DatabaseUser{}).Error
}

// DeleteByServerID 删除指定服务器的所有用户,只是删除面板记录,不会实际删除
func (r databaseUserRepo) DeleteByServerID(serverID uint) error {
return app.Orm.Where("server_id = ?", serverID).Delete(&biz.DatabaseUser{}).Error
}
Expand Down Expand Up @@ -225,6 +269,6 @@ func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) {
}
// 初始化,防止 nil
if user.Privileges == nil {
user.Privileges = make(map[string][]string)
user.Privileges = make([]string, 0)
}
}
20 changes: 8 additions & 12 deletions internal/data/website.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/TheTNB/panel/internal/http/request"
"github.com/TheTNB/panel/pkg/acme"
"github.com/TheTNB/panel/pkg/cert"
"github.com/TheTNB/panel/pkg/db"
"github.com/TheTNB/panel/pkg/io"
"github.com/TheTNB/panel/pkg/nginx"
"github.com/TheTNB/panel/pkg/punycode"
Expand Down Expand Up @@ -507,17 +506,14 @@ func (r *websiteRepo) Delete(req *request.WebsiteDelete) error {
_ = io.Remove(website.Path)
}
if req.DB {
rootPassword, err := NewSettingRepo().Get(biz.SettingKeyMySQLRootPassword)
if err != nil {
return err
}
if mysql, err := db.NewMySQL("root", rootPassword, "/tmp/mysql.sock", "unix"); err == nil {
_ = mysql.UserDrop(website.Name, "localhost")
_ = mysql.DatabaseDrop(website.Name)
}
if postgres, err := db.NewPostgres("postgres", "", "127.0.0.1", 5432); err == nil {
_ = postgres.UserDrop(website.Name)
_ = postgres.DatabaseDrop(website.Name)
repo := NewDatabaseServerRepo()
if mysql, err := repo.GetByName("local_mysql"); err == nil {
_ = NewDatabaseUserRepo().DeleteByNames(mysql.ID, []string{website.Name})
_ = NewDatabaseRepo().Delete(mysql.ID, website.Name)
}
if postgres, err := repo.GetByName("local_postgresql"); err == nil {
_ = NewDatabaseUserRepo().DeleteByNames(postgres.ID, []string{website.Name})
_ = NewDatabaseRepo().Delete(postgres.ID, website.Name)
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/route/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func Http(r chi.Router) {
server := service.NewDatabaseServerService()
r.Get("/", server.List)
r.Post("/", server.Create)
r.Get("/{id}", server.Get)
r.Put("/{id}", server.Update)
r.Put("/{id}/remark", server.UpdateRemark)
r.Delete("/{id}", server.Delete)
Expand All @@ -84,6 +85,7 @@ func Http(r chi.Router) {
user := service.NewDatabaseUserService()
r.Get("/", user.List)
r.Post("/", user.Create)
r.Get("/{id}", user.Get)
r.Put("/{id}", user.Update)
r.Put("/{id}/remark", user.UpdateRemark)
r.Delete("/{id}", user.Delete)
Expand Down
16 changes: 16 additions & 0 deletions internal/service/database_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ func (s *DatabaseServer) Create(w http.ResponseWriter, r *http.Request) {
Success(w, nil)
}

func (s *DatabaseServer) Get(w http.ResponseWriter, r *http.Request) {
req, err := Bind[request.ID](r)
if err != nil {
Error(w, http.StatusUnprocessableEntity, "%v", err)
return
}

server, err := s.databaseServerRepo.Get(req.ID)
if err != nil {
Error(w, http.StatusInternalServerError, "%v", err)
return
}

Success(w, server)
}

func (s *DatabaseServer) Update(w http.ResponseWriter, r *http.Request) {
req, err := Bind[request.DatabaseServerUpdate](r)
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions internal/service/database_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ func (s *DatabaseUser) Create(w http.ResponseWriter, r *http.Request) {
Success(w, nil)
}

func (s *DatabaseUser) Get(w http.ResponseWriter, r *http.Request) {
req, err := Bind[request.ID](r)
if err != nil {
Error(w, http.StatusUnprocessableEntity, "%v", err)
return
}

user, err := s.databaseUserRepo.Get(req.ID)
if err != nil {
Error(w, http.StatusInternalServerError, "%v", err)
return
}

Success(w, user)
}

func (s *DatabaseUser) Update(w http.ResponseWriter, r *http.Request) {
req, err := Bind[request.DatabaseUserUpdate](r)
if err != nil {
Expand Down
49 changes: 20 additions & 29 deletions pkg/db/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package db
import (
"database/sql"
"fmt"
"strings"

_ "github.com/go-sql-driver/mysql"
"regexp"
"slices"

"github.com/TheTNB/panel/pkg/types"
)
Expand Down Expand Up @@ -122,52 +122,43 @@ func (r *MySQL) PrivilegesGrant(user, database, host string) error {
return err
}

func (r *MySQL) UserPrivileges(user, host string) (map[string][]string, error) {
func (r *MySQL) PrivilegesRevoke(user, database, host string) error {
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host))
r.flushPrivileges()
return err
}

func (r *MySQL) UserPrivileges(user, host string) ([]string, error) {
rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host))
if err != nil {
return nil, err
}
defer rows.Close()

privileges := make(map[string][]string)
re := regexp.MustCompile(`GRANT\s+ALL PRIVILEGES\s+ON\s+[\x60'"]?([^\s\x60'"]+)[\x60'"]?\.\*\s+TO\s+`)
var databases []string
for rows.Next() {
var grant string
if err = rows.Scan(&grant); err != nil {
return nil, err
}
if !strings.HasPrefix(grant, "GRANT ") {
continue
}

parts := strings.Split(grant, " ON ")
if len(parts) < 2 {
continue
}

privList := strings.TrimPrefix(parts[0], "GRANT ")
privs := strings.Split(privList, ", ")

dbPart := strings.Split(parts[1], " TO")[0]
// *.* 表示全局权限
if dbPart == "*.*" {
dbPart = "*"
// 使用正则表达式匹配
matches := re.FindStringSubmatch(grant)
if len(matches) == 2 {
dbName := matches[1]
if dbName != "*" {
databases = append(databases, dbName)
}
}

dbPart = strings.Trim(dbPart, "`")
privileges[dbPart] = append(privileges[dbPart], privs...)
}

if err = rows.Err(); err != nil {
return nil, err
}

return privileges, nil
}

func (r *MySQL) PrivilegesRevoke(user, database, host string) error {
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host))
r.flushPrivileges()
return err
slices.Sort(databases)
return slices.Compact(databases), nil
}

func (r *MySQL) Users() ([]types.MySQLUser, error) {
Expand Down
31 changes: 15 additions & 16 deletions pkg/db/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package db
import (
"database/sql"
"fmt"
"slices"
"strings"

_ "github.com/lib/pq"
"slices"

"github.com/TheTNB/panel/pkg/systemctl"
"github.com/TheTNB/panel/pkg/types"
Expand Down Expand Up @@ -123,37 +121,38 @@ func (r *Postgres) UserPassword(user, password string) error {
return err
}

func (r *Postgres) UserPrivileges(user string) (map[string][]string, error) {
func (r *Postgres) UserPrivileges(user string) ([]string, error) {
query := `
SELECT
table_catalog as database_name,
string_agg(DISTINCT privilege_type, ',') as privileges
FROM information_schema.role_database_privileges
WHERE grantee = $1
GROUP BY table_catalog`
SELECT d.datname
FROM pg_catalog.pg_database d
JOIN pg_catalog.pg_roles r ON d.datdba = r.oid
WHERE r.rolname = $1
AND d.datistemplate = false
AND d.datname NOT IN ('template0', 'template1', 'postgres')
ORDER BY d.datname;
`

rows, err := r.Query(query, user)
if err != nil {
return nil, err
}
defer rows.Close()

privileges := make(map[string][]string)
var databases []string

for rows.Next() {
var db, privilegeStr string
if err = rows.Scan(&db, &privilegeStr); err != nil {
var dbName string
if err = rows.Scan(&dbName); err != nil {
return nil, err
}

privileges[db] = strings.Split(privilegeStr, ",")
databases = append(databases, dbName)
}

if err = rows.Err(); err != nil {
return nil, err
}

return privileges, nil
return databases, nil
}

func (r *Postgres) PrivilegesGrant(user, database string) error {
Expand Down
Loading

0 comments on commit c2ae9dc

Please sign in to comment.