Merge pull request 'feat: add user register' (#1) from feature/user_login into master

Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2025-12-17 20:45:41 +08:00
9 changed files with 216 additions and 42 deletions

View File

@@ -3,13 +3,13 @@ package user
import (
"net/http"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/config"
model "gitea.com/bitwsd/document_ai/internal/model/user"
"gitea.com/bitwsd/document_ai/internal/service"
"gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/jwt"
"gitea.com/bitwsd/document_ai/pkg/log"
"github.com/gin-gonic/gin"
)
@@ -55,12 +55,15 @@ func (h *UserEndpoint) LoginByPhoneCode(ctx *gin.Context) {
if config.GlobalConfig.Server.IsDebug() {
uid := 1
token, err := jwt.CreateToken(jwt.User{UserId: int64(uid)})
tokenResult, err := jwt.CreateToken(jwt.User{UserId: int64(uid)})
if err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeUnauthorized, common.CodeUnauthorizedMsg))
return
}
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{Token: token}))
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{
Token: tokenResult.Token,
ExpiresAt: tokenResult.ExpiresAt,
}))
return
}
@@ -70,13 +73,16 @@ func (h *UserEndpoint) LoginByPhoneCode(ctx *gin.Context) {
return
}
token, err := jwt.CreateToken(jwt.User{UserId: uid})
tokenResult, err := jwt.CreateToken(jwt.User{UserId: uid})
if err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeUnauthorized, common.CodeUnauthorizedMsg))
return
}
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{Token: token}))
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{
Token: tokenResult.Token,
ExpiresAt: tokenResult.ExpiresAt,
}))
}
func (h *UserEndpoint) GetUserInfo(ctx *gin.Context) {
@@ -103,3 +109,63 @@ func (h *UserEndpoint) GetUserInfo(ctx *gin.Context) {
Status: status,
}))
}
func (h *UserEndpoint) RegisterByEmail(ctx *gin.Context) {
req := model.EmailRegisterRequest{}
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, common.CodeParamErrorMsg))
return
}
uid, err := h.userService.RegisterByEmail(ctx, req.Email, req.Password)
if err != nil {
if bizErr, ok := err.(*common.BusinessError); ok {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, int(bizErr.Code), bizErr.Message))
return
}
log.Error(ctx, "func", "RegisterByEmail", "msg", "注册失败", "error", err)
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
return
}
tokenResult, err := jwt.CreateToken(jwt.User{UserId: uid})
if err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
return
}
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.EmailRegisterResponse{
Token: tokenResult.Token,
ExpiresAt: tokenResult.ExpiresAt,
}))
}
func (h *UserEndpoint) LoginByEmail(ctx *gin.Context) {
req := model.EmailLoginRequest{}
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, common.CodeParamErrorMsg))
return
}
uid, err := h.userService.LoginByEmail(ctx, req.Email, req.Password)
if err != nil {
if bizErr, ok := err.(*common.BusinessError); ok {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, int(bizErr.Code), bizErr.Message))
return
}
log.Error(ctx, "func", "LoginByEmail", "msg", "登录失败", "error", err)
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
return
}
tokenResult, err := jwt.CreateToken(jwt.User{UserId: uid})
if err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
return
}
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.EmailLoginResponse{
Token: tokenResult.Token,
ExpiresAt: tokenResult.ExpiresAt,
}))
}

View File

@@ -11,6 +11,8 @@ func SetupRouter(router *gin.RouterGroup) {
{
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode)
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode)
userRouter.POST("/register/email", userEndpoint.RegisterByEmail)
userRouter.POST("/login/email", userEndpoint.LoginByEmail)
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
}
}

View File

@@ -6,10 +6,16 @@ type CreateTaskResponse struct {
}
type GetFormulaTaskResponse struct {
TaskNo string `json:"task_no"`
Status int `json:"status"`
Count int `json:"count"`
Latex string `json:"latex"`
TaskNo string `json:"task_no"`
Status int `json:"status"`
Count int `json:"count"`
Latex string `json:"latex"`
Markdown string `json:"markdown"`
MathML string `json:"mathml"`
MathMLMW string `json:"mathml_mw"`
ImageBlob string `json:"image_blob"`
DocxURL string `json:"docx_url"`
PDFURL string `json:"pdf_url"`
}
// FormulaRecognitionResponse 公式识别服务返回的响应

View File

@@ -14,7 +14,8 @@ type PhoneLoginRequest struct {
}
type PhoneLoginResponse struct {
Token string `json:"token"`
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
}
type UserInfoResponse struct {
@@ -22,3 +23,23 @@ type UserInfoResponse struct {
Phone string `json:"phone"`
Status int `json:"status"` // 0: not login, 1: login
}
type EmailRegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
type EmailRegisterResponse struct {
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
}
type EmailLoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
type EmailLoginResponse struct {
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
}

View File

@@ -263,12 +263,6 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
return err
}
// downloadURL, err := oss.GetDownloadURL(ctx, fileURL)
// if err != nil {
// log.Error(ctx, "func", "processFormulaTask", "msg", "获取下载URL失败", "error", err)
// return err
// }
// 将图片转为base64编码
base64Image := base64.StdEncoding.EncodeToString(imageData)

View File

@@ -6,10 +6,12 @@ import (
"fmt"
"math/rand"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/internal/storage/cache"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
"gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/pkg/sms"
"golang.org/x/crypto/bcrypt"
)
type UserService struct {
@@ -107,3 +109,53 @@ func (svc *UserService) GetUserInfo(ctx context.Context, uid int64) (*dao.User,
return user, nil
}
func (svc *UserService) RegisterByEmail(ctx context.Context, email, password string) (uid int64, err error) {
existingUser, err := svc.userDao.GetByEmail(dao.DB.WithContext(ctx), email)
if err != nil {
log.Error(ctx, "func", "RegisterByEmail", "msg", "get user by email error", "error", err)
return 0, err
}
if existingUser != nil {
log.Warn(ctx, "func", "RegisterByEmail", "msg", "email already registered", "email", email)
return 0, common.ErrEmailExists
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Error(ctx, "func", "RegisterByEmail", "msg", "hash password error", "error", err)
return 0, err
}
user := &dao.User{
Email: email,
Password: string(hashedPassword),
}
err = svc.userDao.Create(dao.DB.WithContext(ctx), user)
if err != nil {
log.Error(ctx, "func", "RegisterByEmail", "msg", "create user error", "error", err)
return 0, err
}
return user.ID, nil
}
func (svc *UserService) LoginByEmail(ctx context.Context, email, password string) (uid int64, err error) {
user, err := svc.userDao.GetByEmail(dao.DB.WithContext(ctx), email)
if err != nil {
log.Error(ctx, "func", "LoginByEmail", "msg", "get user by email error", "error", err)
return 0, err
}
if user == nil {
log.Warn(ctx, "func", "LoginByEmail", "msg", "user not found", "email", email)
return 0, common.ErrEmailNotFound
}
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
log.Warn(ctx, "func", "LoginByEmail", "msg", "password mismatch", "email", email)
return 0, common.ErrPasswordMismatch
}
return user.ID, nil
}

View File

@@ -10,6 +10,7 @@ type User struct {
BaseModel
Username string `gorm:"column:username" json:"username"`
Phone string `gorm:"column:phone" json:"phone"`
Email string `gorm:"column:email" json:"email"`
Password string `gorm:"column:password" json:"password"`
WechatOpenID string `gorm:"column:wechat_open_id" json:"wechat_open_id"`
WechatUnionID string `gorm:"column:wechat_union_id" json:"wechat_union_id"`
@@ -51,3 +52,14 @@ func (dao *UserDao) GetByID(tx *gorm.DB, id int64) (*User, error) {
}
return &user, nil
}
func (dao *UserDao) GetByEmail(tx *gorm.DB, email string) (*User, error) {
var user User
if err := tx.Where("email = ?", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &user, nil
}

View File

@@ -3,31 +3,37 @@ package common
type ErrorCode int
const (
CodeSuccess = 200
CodeParamError = 400
CodeUnauthorized = 401
CodeForbidden = 403
CodeNotFound = 404
CodeInvalidStatus = 405
CodeDBError = 500
CodeSystemError = 501
CodeTaskNotComplete = 1001
CodeRecordRepeat = 1002
CodeSmsCodeError = 1003
CodeSuccess = 200
CodeParamError = 400
CodeUnauthorized = 401
CodeForbidden = 403
CodeNotFound = 404
CodeInvalidStatus = 405
CodeDBError = 500
CodeSystemError = 501
CodeTaskNotComplete = 1001
CodeRecordRepeat = 1002
CodeSmsCodeError = 1003
CodeEmailExists = 1004
CodeEmailNotFound = 1005
CodePasswordMismatch = 1006
)
const (
CodeSuccessMsg = "success"
CodeParamErrorMsg = "param error"
CodeUnauthorizedMsg = "unauthorized"
CodeForbiddenMsg = "forbidden"
CodeNotFoundMsg = "not found"
CodeInvalidStatusMsg = "invalid status"
CodeDBErrorMsg = "database error"
CodeSystemErrorMsg = "system error"
CodeTaskNotCompleteMsg = "task not complete"
CodeRecordRepeatMsg = "record repeat"
CodeSmsCodeErrorMsg = "sms code error"
CodeSuccessMsg = "success"
CodeParamErrorMsg = "param error"
CodeUnauthorizedMsg = "unauthorized"
CodeForbiddenMsg = "forbidden"
CodeNotFoundMsg = "not found"
CodeInvalidStatusMsg = "invalid status"
CodeDBErrorMsg = "database error"
CodeSystemErrorMsg = "system error"
CodeTaskNotCompleteMsg = "task not complete"
CodeRecordRepeatMsg = "record repeat"
CodeSmsCodeErrorMsg = "sms code error"
CodeEmailExistsMsg = "email already registered"
CodeEmailNotFoundMsg = "email not found"
CodePasswordMismatchMsg = "password mismatch"
)
type BusinessError struct {
@@ -47,3 +53,10 @@ func NewError(code ErrorCode, message string, err error) *BusinessError {
Err: err,
}
}
// 预定义业务错误
var (
ErrEmailExists = NewError(CodeEmailExists, CodeEmailExistsMsg, nil)
ErrEmailNotFound = NewError(CodeEmailNotFound, CodeEmailNotFoundMsg, nil)
ErrPasswordMismatch = NewError(CodePasswordMismatch, CodePasswordMismatchMsg, nil)
)

View File

@@ -18,7 +18,12 @@ type CustomClaims struct {
jwt.StandardClaims
}
func CreateToken(user User) (string, error) {
type TokenResult struct {
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
}
func CreateToken(user User) (*TokenResult, error) {
expire := time.Now().Add(time.Duration(ValidTime) * time.Second)
claims := &CustomClaims{
User: user,
@@ -32,10 +37,13 @@ func CreateToken(user User) (string, error) {
t, err := token.SignedString(JwtKey)
if err != nil {
return "", err
return nil, err
}
return "Bearer " + t, nil
return &TokenResult{
Token: "Bearer " + t,
ExpiresAt: expire.Unix(),
}, nil
}
func ParseToken(signToken string) (*CustomClaims, error) {