Merge pull request 'feat: add user register' (#1) from feature/user_login into master
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
@@ -3,13 +3,13 @@ package user
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/config"
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
model "gitea.com/bitwsd/document_ai/internal/model/user"
|
model "gitea.com/bitwsd/document_ai/internal/model/user"
|
||||||
"gitea.com/bitwsd/document_ai/internal/service"
|
"gitea.com/bitwsd/document_ai/internal/service"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/jwt"
|
"gitea.com/bitwsd/document_ai/pkg/jwt"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,12 +55,15 @@ func (h *UserEndpoint) LoginByPhoneCode(ctx *gin.Context) {
|
|||||||
|
|
||||||
if config.GlobalConfig.Server.IsDebug() {
|
if config.GlobalConfig.Server.IsDebug() {
|
||||||
uid := 1
|
uid := 1
|
||||||
token, err := jwt.CreateToken(jwt.User{UserId: int64(uid)})
|
tokenResult, err := jwt.CreateToken(jwt.User{UserId: int64(uid)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeUnauthorized, common.CodeUnauthorizedMsg))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeUnauthorized, common.CodeUnauthorizedMsg))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{Token: token}))
|
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{
|
||||||
|
Token: tokenResult.Token,
|
||||||
|
ExpiresAt: tokenResult.ExpiresAt,
|
||||||
|
}))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,13 +73,16 @@ func (h *UserEndpoint) LoginByPhoneCode(ctx *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.CreateToken(jwt.User{UserId: uid})
|
tokenResult, err := jwt.CreateToken(jwt.User{UserId: uid})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeUnauthorized, common.CodeUnauthorizedMsg))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeUnauthorized, common.CodeUnauthorizedMsg))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{Token: token}))
|
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.PhoneLoginResponse{
|
||||||
|
Token: tokenResult.Token,
|
||||||
|
ExpiresAt: tokenResult.ExpiresAt,
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserEndpoint) GetUserInfo(ctx *gin.Context) {
|
func (h *UserEndpoint) GetUserInfo(ctx *gin.Context) {
|
||||||
@@ -103,3 +109,63 @@ func (h *UserEndpoint) GetUserInfo(ctx *gin.Context) {
|
|||||||
Status: status,
|
Status: status,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *UserEndpoint) RegisterByEmail(ctx *gin.Context) {
|
||||||
|
req := model.EmailRegisterRequest{}
|
||||||
|
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, common.CodeParamErrorMsg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uid, err := h.userService.RegisterByEmail(ctx, req.Email, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
if bizErr, ok := err.(*common.BusinessError); ok {
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, int(bizErr.Code), bizErr.Message))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error(ctx, "func", "RegisterByEmail", "msg", "注册失败", "error", err)
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResult, err := jwt.CreateToken(jwt.User{UserId: uid})
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.EmailRegisterResponse{
|
||||||
|
Token: tokenResult.Token,
|
||||||
|
ExpiresAt: tokenResult.ExpiresAt,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserEndpoint) LoginByEmail(ctx *gin.Context) {
|
||||||
|
req := model.EmailLoginRequest{}
|
||||||
|
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, common.CodeParamErrorMsg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uid, err := h.userService.LoginByEmail(ctx, req.Email, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
if bizErr, ok := err.(*common.BusinessError); ok {
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, int(bizErr.Code), bizErr.Message))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error(ctx, "func", "LoginByEmail", "msg", "登录失败", "error", err)
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResult, err := jwt.CreateToken(jwt.User{UserId: uid})
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, common.CodeSystemErrorMsg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, model.EmailLoginResponse{
|
||||||
|
Token: tokenResult.Token,
|
||||||
|
ExpiresAt: tokenResult.ExpiresAt,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ func SetupRouter(router *gin.RouterGroup) {
|
|||||||
{
|
{
|
||||||
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode)
|
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode)
|
||||||
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode)
|
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode)
|
||||||
|
userRouter.POST("/register/email", userEndpoint.RegisterByEmail)
|
||||||
|
userRouter.POST("/login/email", userEndpoint.LoginByEmail)
|
||||||
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
|
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,10 +6,16 @@ type CreateTaskResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetFormulaTaskResponse struct {
|
type GetFormulaTaskResponse struct {
|
||||||
TaskNo string `json:"task_no"`
|
TaskNo string `json:"task_no"`
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
Count int `json:"count"`
|
Count int `json:"count"`
|
||||||
Latex string `json:"latex"`
|
Latex string `json:"latex"`
|
||||||
|
Markdown string `json:"markdown"`
|
||||||
|
MathML string `json:"mathml"`
|
||||||
|
MathMLMW string `json:"mathml_mw"`
|
||||||
|
ImageBlob string `json:"image_blob"`
|
||||||
|
DocxURL string `json:"docx_url"`
|
||||||
|
PDFURL string `json:"pdf_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FormulaRecognitionResponse 公式识别服务返回的响应
|
// FormulaRecognitionResponse 公式识别服务返回的响应
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ type PhoneLoginRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PhoneLoginResponse struct {
|
type PhoneLoginResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserInfoResponse struct {
|
type UserInfoResponse struct {
|
||||||
@@ -22,3 +23,23 @@ type UserInfoResponse struct {
|
|||||||
Phone string `json:"phone"`
|
Phone string `json:"phone"`
|
||||||
Status int `json:"status"` // 0: not login, 1: login
|
Status int `json:"status"` // 0: not login, 1: login
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmailRegisterRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmailRegisterResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmailLoginRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmailLoginResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -263,12 +263,6 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// downloadURL, err := oss.GetDownloadURL(ctx, fileURL)
|
|
||||||
// if err != nil {
|
|
||||||
// log.Error(ctx, "func", "processFormulaTask", "msg", "获取下载URL失败", "error", err)
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// 将图片转为base64编码
|
// 将图片转为base64编码
|
||||||
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/sms"
|
"gitea.com/bitwsd/document_ai/pkg/sms"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
@@ -107,3 +109,53 @@ func (svc *UserService) GetUserInfo(ctx context.Context, uid int64) (*dao.User,
|
|||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (svc *UserService) RegisterByEmail(ctx context.Context, email, password string) (uid int64, err error) {
|
||||||
|
existingUser, err := svc.userDao.GetByEmail(dao.DB.WithContext(ctx), email)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "RegisterByEmail", "msg", "get user by email error", "error", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if existingUser != nil {
|
||||||
|
log.Warn(ctx, "func", "RegisterByEmail", "msg", "email already registered", "email", email)
|
||||||
|
return 0, common.ErrEmailExists
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "RegisterByEmail", "msg", "hash password error", "error", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &dao.User{
|
||||||
|
Email: email,
|
||||||
|
Password: string(hashedPassword),
|
||||||
|
}
|
||||||
|
err = svc.userDao.Create(dao.DB.WithContext(ctx), user)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "RegisterByEmail", "msg", "create user error", "error", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (svc *UserService) LoginByEmail(ctx context.Context, email, password string) (uid int64, err error) {
|
||||||
|
user, err := svc.userDao.GetByEmail(dao.DB.WithContext(ctx), email)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "LoginByEmail", "msg", "get user by email error", "error", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
log.Warn(ctx, "func", "LoginByEmail", "msg", "user not found", "email", email)
|
||||||
|
return 0, common.ErrEmailNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(ctx, "func", "LoginByEmail", "msg", "password mismatch", "email", email)
|
||||||
|
return 0, common.ErrPasswordMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
return user.ID, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ type User struct {
|
|||||||
BaseModel
|
BaseModel
|
||||||
Username string `gorm:"column:username" json:"username"`
|
Username string `gorm:"column:username" json:"username"`
|
||||||
Phone string `gorm:"column:phone" json:"phone"`
|
Phone string `gorm:"column:phone" json:"phone"`
|
||||||
|
Email string `gorm:"column:email" json:"email"`
|
||||||
Password string `gorm:"column:password" json:"password"`
|
Password string `gorm:"column:password" json:"password"`
|
||||||
WechatOpenID string `gorm:"column:wechat_open_id" json:"wechat_open_id"`
|
WechatOpenID string `gorm:"column:wechat_open_id" json:"wechat_open_id"`
|
||||||
WechatUnionID string `gorm:"column:wechat_union_id" json:"wechat_union_id"`
|
WechatUnionID string `gorm:"column:wechat_union_id" json:"wechat_union_id"`
|
||||||
@@ -51,3 +52,14 @@ func (dao *UserDao) GetByID(tx *gorm.DB, id int64) (*User, error) {
|
|||||||
}
|
}
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dao *UserDao) GetByEmail(tx *gorm.DB, email string) (*User, error) {
|
||||||
|
var user User
|
||||||
|
if err := tx.Where("email = ?", email).First(&user).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,31 +3,37 @@ package common
|
|||||||
type ErrorCode int
|
type ErrorCode int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CodeSuccess = 200
|
CodeSuccess = 200
|
||||||
CodeParamError = 400
|
CodeParamError = 400
|
||||||
CodeUnauthorized = 401
|
CodeUnauthorized = 401
|
||||||
CodeForbidden = 403
|
CodeForbidden = 403
|
||||||
CodeNotFound = 404
|
CodeNotFound = 404
|
||||||
CodeInvalidStatus = 405
|
CodeInvalidStatus = 405
|
||||||
CodeDBError = 500
|
CodeDBError = 500
|
||||||
CodeSystemError = 501
|
CodeSystemError = 501
|
||||||
CodeTaskNotComplete = 1001
|
CodeTaskNotComplete = 1001
|
||||||
CodeRecordRepeat = 1002
|
CodeRecordRepeat = 1002
|
||||||
CodeSmsCodeError = 1003
|
CodeSmsCodeError = 1003
|
||||||
|
CodeEmailExists = 1004
|
||||||
|
CodeEmailNotFound = 1005
|
||||||
|
CodePasswordMismatch = 1006
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CodeSuccessMsg = "success"
|
CodeSuccessMsg = "success"
|
||||||
CodeParamErrorMsg = "param error"
|
CodeParamErrorMsg = "param error"
|
||||||
CodeUnauthorizedMsg = "unauthorized"
|
CodeUnauthorizedMsg = "unauthorized"
|
||||||
CodeForbiddenMsg = "forbidden"
|
CodeForbiddenMsg = "forbidden"
|
||||||
CodeNotFoundMsg = "not found"
|
CodeNotFoundMsg = "not found"
|
||||||
CodeInvalidStatusMsg = "invalid status"
|
CodeInvalidStatusMsg = "invalid status"
|
||||||
CodeDBErrorMsg = "database error"
|
CodeDBErrorMsg = "database error"
|
||||||
CodeSystemErrorMsg = "system error"
|
CodeSystemErrorMsg = "system error"
|
||||||
CodeTaskNotCompleteMsg = "task not complete"
|
CodeTaskNotCompleteMsg = "task not complete"
|
||||||
CodeRecordRepeatMsg = "record repeat"
|
CodeRecordRepeatMsg = "record repeat"
|
||||||
CodeSmsCodeErrorMsg = "sms code error"
|
CodeSmsCodeErrorMsg = "sms code error"
|
||||||
|
CodeEmailExistsMsg = "email already registered"
|
||||||
|
CodeEmailNotFoundMsg = "email not found"
|
||||||
|
CodePasswordMismatchMsg = "password mismatch"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BusinessError struct {
|
type BusinessError struct {
|
||||||
@@ -47,3 +53,10 @@ func NewError(code ErrorCode, message string, err error) *BusinessError {
|
|||||||
Err: err,
|
Err: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 预定义业务错误
|
||||||
|
var (
|
||||||
|
ErrEmailExists = NewError(CodeEmailExists, CodeEmailExistsMsg, nil)
|
||||||
|
ErrEmailNotFound = NewError(CodeEmailNotFound, CodeEmailNotFoundMsg, nil)
|
||||||
|
ErrPasswordMismatch = NewError(CodePasswordMismatch, CodePasswordMismatchMsg, nil)
|
||||||
|
)
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ type CustomClaims struct {
|
|||||||
jwt.StandardClaims
|
jwt.StandardClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateToken(user User) (string, error) {
|
type TokenResult struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateToken(user User) (*TokenResult, error) {
|
||||||
expire := time.Now().Add(time.Duration(ValidTime) * time.Second)
|
expire := time.Now().Add(time.Duration(ValidTime) * time.Second)
|
||||||
claims := &CustomClaims{
|
claims := &CustomClaims{
|
||||||
User: user,
|
User: user,
|
||||||
@@ -32,10 +37,13 @@ func CreateToken(user User) (string, error) {
|
|||||||
|
|
||||||
t, err := token.SignedString(JwtKey)
|
t, err := token.SignedString(JwtKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return "Bearer " + t, nil
|
return &TokenResult{
|
||||||
|
Token: "Bearer " + t,
|
||||||
|
ExpiresAt: expire.Unix(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseToken(signToken string) (*CustomClaims, error) {
|
func ParseToken(signToken string) (*CustomClaims, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user