feat: add list api

This commit is contained in:
2025-12-18 12:39:50 +08:00
parent d06f2d9df1
commit 8a6da5b627
15 changed files with 133 additions and 57 deletions

View File

@@ -9,6 +9,7 @@ import (
"gitea.com/bitwsd/document_ai/internal/service" "gitea.com/bitwsd/document_ai/internal/service"
"gitea.com/bitwsd/document_ai/internal/storage/dao" "gitea.com/bitwsd/document_ai/internal/storage/dao"
"gitea.com/bitwsd/document_ai/pkg/common" "gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/utils" "gitea.com/bitwsd/document_ai/pkg/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -37,11 +38,14 @@ func NewFormulaEndpoint() *FormulaEndpoint {
// @Router /v1/formula/recognition [post] // @Router /v1/formula/recognition [post]
func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) { func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) {
var req formula.CreateFormulaRecognitionRequest var req formula.CreateFormulaRecognitionRequest
uid := ctx.GetInt64(constant.ContextUserID)
if err := ctx.BindJSON(&req); err != nil { if err := ctx.BindJSON(&req); err != nil {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters")) ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters"))
return return
} }
req.UserID = uid
if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) { if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type")) ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type"))
return return

View File

@@ -1,12 +1,16 @@
package formula package formula
import ( import (
"gitea.com/bitwsd/document_ai/pkg/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func SetupRouter(engine *gin.RouterGroup) { func SetupRouter(engine *gin.RouterGroup) {
endpoint := NewFormulaEndpoint() endpoint := NewFormulaEndpoint()
engine.POST("/formula/recognition", endpoint.CreateTask) formulaRouter := engine.Group("/formula", common.GetAuthMiddleware())
engine.POST("/formula/ai_enhance", endpoint.AIEnhanceRecognition) {
engine.GET("/formula/recognition/:task_no", endpoint.GetTaskStatus) formulaRouter.POST("/recognition", endpoint.CreateTask)
formulaRouter.POST("/ai_enhance", endpoint.AIEnhanceRecognition)
formulaRouter.GET("/recognition/:task_no", endpoint.GetTaskStatus)
}
} }

View File

@@ -11,7 +11,6 @@ import (
"gitea.com/bitwsd/document_ai/config" "gitea.com/bitwsd/document_ai/config"
"gitea.com/bitwsd/document_ai/internal/storage/dao" "gitea.com/bitwsd/document_ai/internal/storage/dao"
"gitea.com/bitwsd/document_ai/pkg/common" "gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/oss" "gitea.com/bitwsd/document_ai/pkg/oss"
"gitea.com/bitwsd/document_ai/pkg/utils" "gitea.com/bitwsd/document_ai/pkg/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -38,7 +37,6 @@ func GetPostObjectSignature(ctx *gin.Context) {
// @Failure 200 {object} common.Response "Error response" // @Failure 200 {object} common.Response "Error response"
// @Router /signature_url [get] // @Router /signature_url [get]
func GetSignatureURL(ctx *gin.Context) { func GetSignatureURL(ctx *gin.Context) {
userID := ctx.GetInt64(constant.ContextUserID)
type Req struct { type Req struct {
FileHash string `json:"file_hash" binding:"required"` FileHash string `json:"file_hash" binding:"required"`
FileName string `json:"file_name" binding:"required"` FileName string `json:"file_name" binding:"required"`
@@ -51,7 +49,7 @@ func GetSignatureURL(ctx *gin.Context) {
} }
taskDao := dao.NewRecognitionTaskDao() taskDao := dao.NewRecognitionTaskDao()
sess := dao.DB.WithContext(ctx) sess := dao.DB.WithContext(ctx)
task, err := taskDao.GetTaskByFileURL(sess, userID, req.FileHash) task, err := taskDao.GetTaskByFileURL(sess, req.FileHash)
if err != nil && err != gorm.ErrRecordNotFound { if err != nil && err != gorm.ErrRecordNotFound {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task")) ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task"))
return return

View File

@@ -3,10 +3,10 @@ package task
import ( import (
"net/http" "net/http"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/internal/model/task" "gitea.com/bitwsd/document_ai/internal/model/task"
"gitea.com/bitwsd/document_ai/internal/service" "gitea.com/bitwsd/document_ai/internal/service"
"gitea.com/bitwsd/document_ai/pkg/common" "gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -43,6 +43,8 @@ func (h *TaskEndpoint) GetTaskList(c *gin.Context) {
return return
} }
req.UserID = common.GetUserIDFromContext(c)
if req.Page <= 0 { if req.Page <= 0 {
req.Page = 1 req.Page = 1
} }

View File

@@ -1,11 +1,12 @@
package task package task
import ( import (
"gitea.com/bitwsd/document_ai/pkg/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func SetupRouter(engine *gin.RouterGroup) { func SetupRouter(engine *gin.RouterGroup) {
endpoint := NewTaskEndpoint() endpoint := NewTaskEndpoint()
engine.POST("/task/evaluate", endpoint.EvaluateTask) engine.POST("/task/evaluate", endpoint.EvaluateTask)
engine.GET("/task/list", endpoint.GetTaskList) engine.GET("/task/list", common.MustAuthMiddleware(), endpoint.GetTaskList)
} }

View File

@@ -9,10 +9,9 @@ func SetupRouter(router *gin.RouterGroup) {
userEndpoint := NewUserEndpoint() userEndpoint := NewUserEndpoint()
userRouter := router.Group("/user") userRouter := router.Group("/user")
{ {
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode) userRouter.POST("/sms", userEndpoint.SendVerificationCode)
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode) userRouter.POST("/register", userEndpoint.RegisterByEmail)
userRouter.POST("/register/email", userEndpoint.RegisterByEmail) userRouter.POST("/login", userEndpoint.LoginByEmail)
userRouter.POST("/login/email", userEndpoint.LoginByEmail) userRouter.GET("/info", common.MustAuthMiddleware(), userEndpoint.GetUserInfo)
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
} }
} }

View File

@@ -5,7 +5,7 @@ server:
database: database:
driver: mysql driver: mysql
host: mysql host: mysql
port: 3306 # 容器内部端口,不是宿主机映射的 3006 port: 3306
username: root username: root
password: texpixel#pwd123! password: texpixel#pwd123!
dbname: doc_ai dbname: doc_ai
@@ -13,7 +13,7 @@ database:
max_open: 100 max_open: 100
redis: redis:
addr: redis:6379 # 容器内部端口,不是宿主机映射的 6079 addr: redis:6379
password: yoge@123321! password: yoge@123321!
db: 0 db: 0
@@ -22,7 +22,7 @@ limit:
log: log:
appName: document_ai appName: document_ai
level: info # debug, info, warn, error level: info
format: console # json, console format: console # json, console
outputPath: ./logs/app.log # 日志文件路径 outputPath: ./logs/app.log # 日志文件路径
maxSize: 2 # 单个日志文件最大尺寸单位MB maxSize: 2 # 单个日志文件最大尺寸单位MB
@@ -41,6 +41,6 @@ aliyun:
oss: oss:
endpoint: oss-cn-beijing.aliyuncs.com endpoint: oss-cn-beijing.aliyuncs.com
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
access_key_id: LTAI5tKogxeiBb4gJGWEePWN access_key_id: LTAI5t8qXhow6NCdYDtu1saF
access_key_secret: l4oCxtt5iLSQ1DAs40guTzKUfrxXwq access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
bucket_name: bitwsd-doc-ai bucket_name: texpixel-doc

View File

@@ -5,6 +5,7 @@ type CreateFormulaRecognitionRequest struct {
FileHash string `json:"file_hash" binding:"required"` // file hash FileHash string `json:"file_hash" binding:"required"` // file hash
FileName string `json:"file_name" binding:"required"` // file name FileName string `json:"file_name" binding:"required"` // file name
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
UserID int64 `json:"user_id"` // user id
} }
type GetRecognitionStatusRequest struct { type GetRecognitionStatusRequest struct {

View File

@@ -11,12 +11,7 @@ type TaskListRequest struct {
TaskType string `json:"task_type" form:"task_type" binding:"required"` TaskType string `json:"task_type" form:"task_type" binding:"required"`
Page int `json:"page" form:"page"` Page int `json:"page" form:"page"`
PageSize int `json:"page_size" form:"page_size"` PageSize int `json:"page_size" form:"page_size"`
} UserID int64 `json:"-"`
type PdfInfo struct {
PageCount int `json:"page_count"`
PageWidth int `json:"page_width"`
PageHeight int `json:"page_height"`
} }
type TaskListDTO struct { type TaskListDTO struct {
@@ -26,11 +21,16 @@ type TaskListDTO struct {
Path string `json:"path"` Path string `json:"path"`
TaskType string `json:"task_type"` TaskType string `json:"task_type"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
PdfInfo PdfInfo `json:"pdf_info"` Latex string `json:"latex"`
Markdown string `json:"markdown"`
MathML string `json:"mathml"`
MathMLMW string `json:"mathml_mw"`
ImageBlob string `json:"image_blob"`
DocxURL string `json:"docx_url"`
PDFURL string `json:"pdf_url"`
} }
type TaskListResponse struct { type TaskListResponse struct {
TaskList []*TaskListDTO `json:"task_list"` TaskList []*TaskListDTO `json:"task_list"`
HasMore bool `json:"has_more"` Total int64 `json:"total"`
NextPage int `json:"next_page"`
} }

View File

@@ -105,6 +105,7 @@ func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *for
sess := dao.DB.WithContext(ctx) sess := dao.DB.WithContext(ctx)
taskDao := dao.NewRecognitionTaskDao() taskDao := dao.NewRecognitionTaskDao()
task := &dao.RecognitionTask{ task := &dao.RecognitionTask{
UserID: req.UserID,
TaskUUID: utils.NewUUID(), TaskUUID: utils.NewUUID(),
TaskType: dao.TaskType(req.TaskType), TaskType: dao.TaskType(req.TaskType),
Status: dao.TaskStatusPending, Status: dao.TaskStatusPending,
@@ -166,7 +167,8 @@ func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string)
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err) return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
} }
latex := taskRet.NewContentCodec().GetContent().(string) latex := taskRet.NewContentCodec().GetContent().(string)
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil markdown := fmt.Sprintf("$$%s$$", latex)
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Markdown: markdown, Status: int(task.Status)}, nil
} }
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error { func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
@@ -281,7 +283,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)} headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
// 发送请求时会使用带超时的context // 发送请求时会使用带超时的context
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "http://cloud.texpixel.com:1080/formula/predict", bytes.NewReader(jsonData), headers) resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/formula/predict", bytes.NewReader(jsonData), headers)
if err != nil { if err != nil {
if ctx.Err() == context.DeadlineExceeded { if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时") log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")

View File

@@ -3,25 +3,30 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/internal/model/task" "gitea.com/bitwsd/document_ai/internal/model/task"
"gitea.com/bitwsd/document_ai/internal/storage/dao" "gitea.com/bitwsd/document_ai/internal/storage/dao"
"gorm.io/gorm" "gitea.com/bitwsd/document_ai/pkg/log"
) )
type TaskService struct { type TaskService struct {
db *gorm.DB recognitionTaskDao *dao.RecognitionTaskDao
evaluateTaskDao *dao.EvaluateTaskDao
recognitionResultDao *dao.RecognitionResultDao
} }
func NewTaskService() *TaskService { func NewTaskService() *TaskService {
return &TaskService{dao.DB} return &TaskService{
recognitionTaskDao: dao.NewRecognitionTaskDao(),
evaluateTaskDao: dao.NewEvaluateTaskDao(),
recognitionResultDao: dao.NewRecognitionResultDao(),
}
} }
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error { func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
taskDao := dao.NewRecognitionTaskDao() task, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
task, err := taskDao.GetByTaskNo(svc.db.WithContext(ctx), req.TaskNo)
if err != nil { if err != nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err) log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
return err return err
@@ -36,14 +41,13 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
return errors.New("task not finished") return errors.New("task not finished")
} }
evaluateTaskDao := dao.NewEvaluateTaskDao()
evaluateTask := &dao.EvaluateTask{ evaluateTask := &dao.EvaluateTask{
TaskID: task.ID, TaskID: task.ID,
Satisfied: req.Satisfied, Satisfied: req.Satisfied,
Feedback: req.Feedback, Feedback: req.Feedback,
Comment: strings.Join(req.Suggestion, ","), Comment: strings.Join(req.Suggestion, ","),
} }
err = evaluateTaskDao.Create(svc.db.WithContext(ctx), evaluateTask) err = svc.evaluateTaskDao.Create(dao.DB.WithContext(ctx), evaluateTask)
if err != nil { if err != nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err) log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
return err return err
@@ -53,19 +57,43 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
} }
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) { func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
taskDao := dao.NewRecognitionTaskDao() tasks, total, err := svc.recognitionTaskDao.GetTaskList(dao.DB.WithContext(ctx), req.UserID, dao.TaskType(req.TaskType), req.Page, req.PageSize)
tasks, err := taskDao.GetTaskList(svc.db.WithContext(ctx), dao.TaskType(req.TaskType), req.Page, req.PageSize)
if err != nil { if err != nil {
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err) log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
return nil, err return nil, err
} }
taskIDs := make([]int64, 0, len(tasks))
for _, item := range tasks {
taskIDs = append(taskIDs, item.ID)
}
recognitionResults, err := svc.recognitionResultDao.GetByTaskIDs(dao.DB.WithContext(ctx), taskIDs)
if err != nil {
log.Error(ctx, "func", "GetTaskList", "msg", "get recognition results failed", "error", err)
return nil, err
}
recognitionResultMap := make(map[int64]*dao.RecognitionResult)
for _, item := range recognitionResults {
recognitionResultMap[item.TaskID] = item
}
resp := &task.TaskListResponse{ resp := &task.TaskListResponse{
TaskList: make([]*task.TaskListDTO, 0, len(tasks)), TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
HasMore: false, Total: total,
NextPage: 0,
} }
for _, item := range tasks { for _, item := range tasks {
var latex string
var markdown string
recognitionResult := recognitionResultMap[item.ID]
if recognitionResult != nil {
latex = recognitionResult.NewContentCodec().GetContent().(string)
markdown = fmt.Sprintf("$$%s$$", latex)
}
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{ resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
Latex: latex,
Markdown: markdown,
TaskID: item.TaskUUID, TaskID: item.TaskUUID,
FileName: item.FileName, FileName: item.FileName,
Status: item.Status.String(), Status: item.Status.String(),

View File

@@ -84,6 +84,11 @@ func (dao *RecognitionResultDao) GetByTaskID(tx *gorm.DB, taskID int64) (result
return return
} }
func (dao *RecognitionResultDao) GetByTaskIDs(tx *gorm.DB, taskIDs []int64) (results []*RecognitionResult, err error) {
err = tx.Where("task_id IN (?)", taskIDs).Find(&results).Error
return
}
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error { func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
} }

View File

@@ -69,9 +69,9 @@ func (dao *RecognitionTaskDao) GetByTaskNo(tx *gorm.DB, taskUUID string) (task *
return return
} }
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, userID int64, fileHash string) (task *RecognitionTask, err error) { func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, fileHash string) (task *RecognitionTask, err error) {
task = &RecognitionTask{} task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("user_id = ? AND file_hash = ?", userID, fileHash).First(task).Error err = tx.Model(RecognitionTask{}).Where("file_hash = ?", fileHash).Last(task).Error
return return
} }
@@ -87,8 +87,13 @@ func (dao *RecognitionTaskDao) GetTaskByID(tx *gorm.DB, id int64) (task *Recogni
return task, nil return task, nil
} }
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, err error) { func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, userID int64, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, total int64, err error) {
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
err = tx.Model(RecognitionTask{}).Where("task_type = ?", taskType).Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error query := tx.Model(RecognitionTask{}).Where("user_id = ? AND task_type = ?", userID, taskType)
return err = query.Count(&total).Error
if err != nil {
return nil, 0, err
}
err = query.Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
return tasks, total, err
} }

View File

@@ -45,6 +45,19 @@ func AuthMiddleware(ctx *gin.Context) {
ctx.Set(constant.ContextUserID, claims.UserId) ctx.Set(constant.ContextUserID, claims.UserId)
} }
func MustAuthMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
token := ctx.GetHeader("Authorization")
if token != "" {
token = strings.TrimPrefix(token, "Bearer ")
claims, err := jwt.ParseToken(token)
if err == nil {
ctx.Set(constant.ContextUserID, claims.UserId)
}
}
}
}
func GetAuthMiddleware() gin.HandlerFunc { func GetAuthMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) { return func(ctx *gin.Context) {
token := ctx.GetHeader("Authorization") token := ctx.GetHeader("Authorization")

View File

@@ -19,9 +19,9 @@ type Config struct {
func DefaultConfig() Config { func DefaultConfig() Config {
return Config{ return Config{
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept"}, AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With"},
ExposeHeaders: []string{"Content-Length"}, ExposeHeaders: []string{"Content-Length", "Content-Type"},
AllowCredentials: true, AllowCredentials: true,
MaxAge: 86400, // 24 hours MaxAge: 86400, // 24 hours
} }
@@ -30,16 +30,30 @@ func DefaultConfig() Config {
func Cors(config Config) gin.HandlerFunc { func Cors(config Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin") origin := c.Request.Header.Get("Origin")
if origin == "" {
c.Next()
return
}
// 检查是否允许该来源 // 检查是否允许该来源
allowOrigin := "*" allowOrigin := ""
for _, o := range config.AllowOrigins { for _, o := range config.AllowOrigins {
if o == "*" {
// 通配符时,回显实际 origin兼容 credentials
allowOrigin = origin
break
}
if o == origin { if o == origin {
allowOrigin = origin allowOrigin = origin
break break
} }
} }
if allowOrigin == "" {
c.Next()
return
}
c.Header("Access-Control-Allow-Origin", allowOrigin) c.Header("Access-Control-Allow-Origin", allowOrigin)
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ",")) c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ",")) c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))