diff --git a/internal/service/user.go b/internal/service/user.go index baa15f3460..11ab9ce661 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -2,7 +2,7 @@ package service import ( "crypto/rsa" - "encoding/json" + "encoding/gob" "fmt" "net" "net/http" @@ -24,6 +24,7 @@ type UserService struct { } func NewUserService() *UserService { + gob.Register(rsa.PrivateKey{}) // 必须注册 rsa.PrivateKey 类型否则无法反序列化 session 中的 key return &UserService{ repo: data.NewUserRepo(), } @@ -41,13 +42,7 @@ func (s *UserService) GetKey(w http.ResponseWriter, r *http.Request) { Error(w, http.StatusInternalServerError, "%v", err) return } - - encoded, err := json.Marshal(key) - if err != nil { - Error(w, http.StatusInternalServerError, "%v", err) - return - } - sess.Put("key", encoded) + sess.Put("key", *key) pk, err := rsacrypto.PublicKeyToString(&key.PublicKey) if err != nil { @@ -71,22 +66,22 @@ func (s *UserService) Login(w http.ResponseWriter, r *http.Request) { return } - key := new(rsa.PrivateKey) - if err = json.Unmarshal(sess.Get("key").([]byte), key); err != nil { + key, ok := sess.Get("key").(rsa.PrivateKey) + if !ok { Error(w, http.StatusForbidden, "invalid key, please refresh the page") return } - decryptedUsername, _ := rsacrypto.DecryptData(key, req.Username) - decryptedPassword, _ := rsacrypto.DecryptData(key, req.Password) + decryptedUsername, _ := rsacrypto.DecryptData(&key, req.Username) + decryptedPassword, _ := rsacrypto.DecryptData(&key, req.Password) user, err := s.repo.CheckPassword(string(decryptedUsername), string(decryptedPassword)) if err != nil { Error(w, http.StatusForbidden, "%v", err) return } - // 安全登录模式下,将当前客户端与会话绑定 - // 安全登录模式只在未启用TLS时生效,因为TLS本身就是安全的 + // 安全登录下,将当前客户端与会话绑定 + // 安全登录只在未启用面板 HTTPS 时生效 ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) if err != nil { Error(w, http.StatusInternalServerError, "%v", err)