feat: add list api
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/internal/service"
|
"gitea.com/bitwsd/document_ai/internal/service"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -37,11 +38,14 @@ func NewFormulaEndpoint() *FormulaEndpoint {
|
|||||||
// @Router /v1/formula/recognition [post]
|
// @Router /v1/formula/recognition [post]
|
||||||
func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) {
|
func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) {
|
||||||
var req formula.CreateFormulaRecognitionRequest
|
var req formula.CreateFormulaRecognitionRequest
|
||||||
|
uid := ctx.GetInt64(constant.ContextUserID)
|
||||||
if err := ctx.BindJSON(&req); err != nil {
|
if err := ctx.BindJSON(&req); err != nil {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.UserID = uid
|
||||||
|
|
||||||
if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) {
|
if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type"))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
package formula
|
package formula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupRouter(engine *gin.RouterGroup) {
|
func SetupRouter(engine *gin.RouterGroup) {
|
||||||
endpoint := NewFormulaEndpoint()
|
endpoint := NewFormulaEndpoint()
|
||||||
engine.POST("/formula/recognition", endpoint.CreateTask)
|
formulaRouter := engine.Group("/formula", common.GetAuthMiddleware())
|
||||||
engine.POST("/formula/ai_enhance", endpoint.AIEnhanceRecognition)
|
{
|
||||||
engine.GET("/formula/recognition/:task_no", endpoint.GetTaskStatus)
|
formulaRouter.POST("/recognition", endpoint.CreateTask)
|
||||||
|
formulaRouter.POST("/ai_enhance", endpoint.AIEnhanceRecognition)
|
||||||
|
formulaRouter.GET("/recognition/:task_no", endpoint.GetTaskStatus)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/config"
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/oss"
|
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -38,7 +37,6 @@ func GetPostObjectSignature(ctx *gin.Context) {
|
|||||||
// @Failure 200 {object} common.Response "Error response"
|
// @Failure 200 {object} common.Response "Error response"
|
||||||
// @Router /signature_url [get]
|
// @Router /signature_url [get]
|
||||||
func GetSignatureURL(ctx *gin.Context) {
|
func GetSignatureURL(ctx *gin.Context) {
|
||||||
userID := ctx.GetInt64(constant.ContextUserID)
|
|
||||||
type Req struct {
|
type Req struct {
|
||||||
FileHash string `json:"file_hash" binding:"required"`
|
FileHash string `json:"file_hash" binding:"required"`
|
||||||
FileName string `json:"file_name" binding:"required"`
|
FileName string `json:"file_name" binding:"required"`
|
||||||
@@ -51,7 +49,7 @@ func GetSignatureURL(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
taskDao := dao.NewRecognitionTaskDao()
|
||||||
sess := dao.DB.WithContext(ctx)
|
sess := dao.DB.WithContext(ctx)
|
||||||
task, err := taskDao.GetTaskByFileURL(sess, userID, req.FileHash)
|
task, err := taskDao.GetTaskByFileURL(sess, req.FileHash)
|
||||||
if err != nil && err != gorm.ErrRecordNotFound {
|
if err != nil && err != gorm.ErrRecordNotFound {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task"))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package task
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/internal/model/task"
|
"gitea.com/bitwsd/document_ai/internal/model/task"
|
||||||
"gitea.com/bitwsd/document_ai/internal/service"
|
"gitea.com/bitwsd/document_ai/internal/service"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +43,8 @@ func (h *TaskEndpoint) GetTaskList(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.UserID = common.GetUserIDFromContext(c)
|
||||||
|
|
||||||
if req.Page <= 0 {
|
if req.Page <= 0 {
|
||||||
req.Page = 1
|
req.Page = 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package task
|
package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupRouter(engine *gin.RouterGroup) {
|
func SetupRouter(engine *gin.RouterGroup) {
|
||||||
endpoint := NewTaskEndpoint()
|
endpoint := NewTaskEndpoint()
|
||||||
engine.POST("/task/evaluate", endpoint.EvaluateTask)
|
engine.POST("/task/evaluate", endpoint.EvaluateTask)
|
||||||
engine.GET("/task/list", endpoint.GetTaskList)
|
engine.GET("/task/list", common.MustAuthMiddleware(), endpoint.GetTaskList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,10 +9,9 @@ func SetupRouter(router *gin.RouterGroup) {
|
|||||||
userEndpoint := NewUserEndpoint()
|
userEndpoint := NewUserEndpoint()
|
||||||
userRouter := router.Group("/user")
|
userRouter := router.Group("/user")
|
||||||
{
|
{
|
||||||
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode)
|
userRouter.POST("/sms", userEndpoint.SendVerificationCode)
|
||||||
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode)
|
userRouter.POST("/register", userEndpoint.RegisterByEmail)
|
||||||
userRouter.POST("/register/email", userEndpoint.RegisterByEmail)
|
userRouter.POST("/login", userEndpoint.LoginByEmail)
|
||||||
userRouter.POST("/login/email", userEndpoint.LoginByEmail)
|
userRouter.GET("/info", common.MustAuthMiddleware(), userEndpoint.GetUserInfo)
|
||||||
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ server:
|
|||||||
database:
|
database:
|
||||||
driver: mysql
|
driver: mysql
|
||||||
host: mysql
|
host: mysql
|
||||||
port: 3306 # 容器内部端口,不是宿主机映射的 3006
|
port: 3306
|
||||||
username: root
|
username: root
|
||||||
password: texpixel#pwd123!
|
password: texpixel#pwd123!
|
||||||
dbname: doc_ai
|
dbname: doc_ai
|
||||||
@@ -13,7 +13,7 @@ database:
|
|||||||
max_open: 100
|
max_open: 100
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
addr: redis:6379 # 容器内部端口,不是宿主机映射的 6079
|
addr: redis:6379
|
||||||
password: yoge@123321!
|
password: yoge@123321!
|
||||||
db: 0
|
db: 0
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ limit:
|
|||||||
|
|
||||||
log:
|
log:
|
||||||
appName: document_ai
|
appName: document_ai
|
||||||
level: info # debug, info, warn, error
|
level: info
|
||||||
format: console # json, console
|
format: console # json, console
|
||||||
outputPath: ./logs/app.log # 日志文件路径
|
outputPath: ./logs/app.log # 日志文件路径
|
||||||
maxSize: 2 # 单个日志文件最大尺寸,单位MB
|
maxSize: 2 # 单个日志文件最大尺寸,单位MB
|
||||||
@@ -41,6 +41,6 @@ aliyun:
|
|||||||
oss:
|
oss:
|
||||||
endpoint: oss-cn-beijing.aliyuncs.com
|
endpoint: oss-cn-beijing.aliyuncs.com
|
||||||
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
||||||
access_key_id: LTAI5tKogxeiBb4gJGWEePWN
|
access_key_id: LTAI5t8qXhow6NCdYDtu1saF
|
||||||
access_key_secret: l4oCxtt5iLSQ1DAs40guTzKUfrxXwq
|
access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
|
||||||
bucket_name: bitwsd-doc-ai
|
bucket_name: texpixel-doc
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type CreateFormulaRecognitionRequest struct {
|
|||||||
FileHash string `json:"file_hash" binding:"required"` // file hash
|
FileHash string `json:"file_hash" binding:"required"` // file hash
|
||||||
FileName string `json:"file_name" binding:"required"` // file name
|
FileName string `json:"file_name" binding:"required"` // file name
|
||||||
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
|
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
|
||||||
|
UserID int64 `json:"user_id"` // user id
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetRecognitionStatusRequest struct {
|
type GetRecognitionStatusRequest struct {
|
||||||
|
|||||||
@@ -11,26 +11,26 @@ type TaskListRequest struct {
|
|||||||
TaskType string `json:"task_type" form:"task_type" binding:"required"`
|
TaskType string `json:"task_type" form:"task_type" binding:"required"`
|
||||||
Page int `json:"page" form:"page"`
|
Page int `json:"page" form:"page"`
|
||||||
PageSize int `json:"page_size" form:"page_size"`
|
PageSize int `json:"page_size" form:"page_size"`
|
||||||
}
|
UserID int64 `json:"-"`
|
||||||
|
|
||||||
type PdfInfo struct {
|
|
||||||
PageCount int `json:"page_count"`
|
|
||||||
PageWidth int `json:"page_width"`
|
|
||||||
PageHeight int `json:"page_height"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskListDTO struct {
|
type TaskListDTO struct {
|
||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
FileName string `json:"file_name"`
|
FileName string `json:"file_name"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
TaskType string `json:"task_type"`
|
TaskType string `json:"task_type"`
|
||||||
CreatedAt string `json:"created_at"`
|
CreatedAt string `json:"created_at"`
|
||||||
PdfInfo PdfInfo `json:"pdf_info"`
|
Latex string `json:"latex"`
|
||||||
|
Markdown string `json:"markdown"`
|
||||||
|
MathML string `json:"mathml"`
|
||||||
|
MathMLMW string `json:"mathml_mw"`
|
||||||
|
ImageBlob string `json:"image_blob"`
|
||||||
|
DocxURL string `json:"docx_url"`
|
||||||
|
PDFURL string `json:"pdf_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskListResponse struct {
|
type TaskListResponse struct {
|
||||||
TaskList []*TaskListDTO `json:"task_list"`
|
TaskList []*TaskListDTO `json:"task_list"`
|
||||||
HasMore bool `json:"has_more"`
|
Total int64 `json:"total"`
|
||||||
NextPage int `json:"next_page"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *for
|
|||||||
sess := dao.DB.WithContext(ctx)
|
sess := dao.DB.WithContext(ctx)
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
taskDao := dao.NewRecognitionTaskDao()
|
||||||
task := &dao.RecognitionTask{
|
task := &dao.RecognitionTask{
|
||||||
|
UserID: req.UserID,
|
||||||
TaskUUID: utils.NewUUID(),
|
TaskUUID: utils.NewUUID(),
|
||||||
TaskType: dao.TaskType(req.TaskType),
|
TaskType: dao.TaskType(req.TaskType),
|
||||||
Status: dao.TaskStatusPending,
|
Status: dao.TaskStatusPending,
|
||||||
@@ -166,7 +167,8 @@ func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string)
|
|||||||
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
||||||
}
|
}
|
||||||
latex := taskRet.NewContentCodec().GetContent().(string)
|
latex := taskRet.NewContentCodec().GetContent().(string)
|
||||||
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
|
markdown := fmt.Sprintf("$$%s$$", latex)
|
||||||
|
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Markdown: markdown, Status: int(task.Status)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
||||||
@@ -281,7 +283,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
|||||||
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
||||||
|
|
||||||
// 发送请求时会使用带超时的context
|
// 发送请求时会使用带超时的context
|
||||||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "http://cloud.texpixel.com:1080/formula/predict", bytes.NewReader(jsonData), headers)
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/formula/predict", bytes.NewReader(jsonData), headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
||||||
|
|||||||
@@ -3,25 +3,30 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/internal/model/task"
|
"gitea.com/bitwsd/document_ai/internal/model/task"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gorm.io/gorm"
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TaskService struct {
|
type TaskService struct {
|
||||||
db *gorm.DB
|
recognitionTaskDao *dao.RecognitionTaskDao
|
||||||
|
evaluateTaskDao *dao.EvaluateTaskDao
|
||||||
|
recognitionResultDao *dao.RecognitionResultDao
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTaskService() *TaskService {
|
func NewTaskService() *TaskService {
|
||||||
return &TaskService{dao.DB}
|
return &TaskService{
|
||||||
|
recognitionTaskDao: dao.NewRecognitionTaskDao(),
|
||||||
|
evaluateTaskDao: dao.NewEvaluateTaskDao(),
|
||||||
|
recognitionResultDao: dao.NewRecognitionResultDao(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
|
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
task, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
||||||
task, err := taskDao.GetByTaskNo(svc.db.WithContext(ctx), req.TaskNo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
|
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
|
||||||
return err
|
return err
|
||||||
@@ -36,14 +41,13 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
|
|||||||
return errors.New("task not finished")
|
return errors.New("task not finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluateTaskDao := dao.NewEvaluateTaskDao()
|
|
||||||
evaluateTask := &dao.EvaluateTask{
|
evaluateTask := &dao.EvaluateTask{
|
||||||
TaskID: task.ID,
|
TaskID: task.ID,
|
||||||
Satisfied: req.Satisfied,
|
Satisfied: req.Satisfied,
|
||||||
Feedback: req.Feedback,
|
Feedback: req.Feedback,
|
||||||
Comment: strings.Join(req.Suggestion, ","),
|
Comment: strings.Join(req.Suggestion, ","),
|
||||||
}
|
}
|
||||||
err = evaluateTaskDao.Create(svc.db.WithContext(ctx), evaluateTask)
|
err = svc.evaluateTaskDao.Create(dao.DB.WithContext(ctx), evaluateTask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
|
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
|
||||||
return err
|
return err
|
||||||
@@ -53,19 +57,43 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
|
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
tasks, total, err := svc.recognitionTaskDao.GetTaskList(dao.DB.WithContext(ctx), req.UserID, dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
||||||
tasks, err := taskDao.GetTaskList(svc.db.WithContext(ctx), dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
|
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
taskIDs := make([]int64, 0, len(tasks))
|
||||||
|
for _, item := range tasks {
|
||||||
|
taskIDs = append(taskIDs, item.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
recognitionResults, err := svc.recognitionResultDao.GetByTaskIDs(dao.DB.WithContext(ctx), taskIDs)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "GetTaskList", "msg", "get recognition results failed", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
recognitionResultMap := make(map[int64]*dao.RecognitionResult)
|
||||||
|
for _, item := range recognitionResults {
|
||||||
|
recognitionResultMap[item.TaskID] = item
|
||||||
|
}
|
||||||
|
|
||||||
resp := &task.TaskListResponse{
|
resp := &task.TaskListResponse{
|
||||||
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
|
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
|
||||||
HasMore: false,
|
Total: total,
|
||||||
NextPage: 0,
|
|
||||||
}
|
}
|
||||||
for _, item := range tasks {
|
for _, item := range tasks {
|
||||||
|
var latex string
|
||||||
|
var markdown string
|
||||||
|
recognitionResult := recognitionResultMap[item.ID]
|
||||||
|
if recognitionResult != nil {
|
||||||
|
latex = recognitionResult.NewContentCodec().GetContent().(string)
|
||||||
|
markdown = fmt.Sprintf("$$%s$$", latex)
|
||||||
|
}
|
||||||
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
|
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
|
||||||
|
Latex: latex,
|
||||||
|
Markdown: markdown,
|
||||||
TaskID: item.TaskUUID,
|
TaskID: item.TaskUUID,
|
||||||
FileName: item.FileName,
|
FileName: item.FileName,
|
||||||
Status: item.Status.String(),
|
Status: item.Status.String(),
|
||||||
|
|||||||
@@ -84,6 +84,11 @@ func (dao *RecognitionResultDao) GetByTaskID(tx *gorm.DB, taskID int64) (result
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dao *RecognitionResultDao) GetByTaskIDs(tx *gorm.DB, taskIDs []int64) (results []*RecognitionResult, err error) {
|
||||||
|
err = tx.Where("task_id IN (?)", taskIDs).Find(&results).Error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
|
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
|
||||||
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
|
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -69,9 +69,9 @@ func (dao *RecognitionTaskDao) GetByTaskNo(tx *gorm.DB, taskUUID string) (task *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, userID int64, fileHash string) (task *RecognitionTask, err error) {
|
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, fileHash string) (task *RecognitionTask, err error) {
|
||||||
task = &RecognitionTask{}
|
task = &RecognitionTask{}
|
||||||
err = tx.Model(RecognitionTask{}).Where("user_id = ? AND file_hash = ?", userID, fileHash).First(task).Error
|
err = tx.Model(RecognitionTask{}).Where("file_hash = ?", fileHash).Last(task).Error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,8 +87,13 @@ func (dao *RecognitionTaskDao) GetTaskByID(tx *gorm.DB, id int64) (task *Recogni
|
|||||||
return task, nil
|
return task, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, err error) {
|
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, userID int64, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, total int64, err error) {
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
err = tx.Model(RecognitionTask{}).Where("task_type = ?", taskType).Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
|
query := tx.Model(RecognitionTask{}).Where("user_id = ? AND task_type = ?", userID, taskType)
|
||||||
return
|
err = query.Count(&total).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
err = query.Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
|
||||||
|
return tasks, total, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,19 @@ func AuthMiddleware(ctx *gin.Context) {
|
|||||||
ctx.Set(constant.ContextUserID, claims.UserId)
|
ctx.Set(constant.ContextUserID, claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MustAuthMiddleware() gin.HandlerFunc {
|
||||||
|
return func(ctx *gin.Context) {
|
||||||
|
token := ctx.GetHeader("Authorization")
|
||||||
|
if token != "" {
|
||||||
|
token = strings.TrimPrefix(token, "Bearer ")
|
||||||
|
claims, err := jwt.ParseToken(token)
|
||||||
|
if err == nil {
|
||||||
|
ctx.Set(constant.ContextUserID, claims.UserId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetAuthMiddleware() gin.HandlerFunc {
|
func GetAuthMiddleware() gin.HandlerFunc {
|
||||||
return func(ctx *gin.Context) {
|
return func(ctx *gin.Context) {
|
||||||
token := ctx.GetHeader("Authorization")
|
token := ctx.GetHeader("Authorization")
|
||||||
|
|||||||
@@ -19,9 +19,9 @@ type Config struct {
|
|||||||
func DefaultConfig() Config {
|
func DefaultConfig() Config {
|
||||||
return Config{
|
return Config{
|
||||||
AllowOrigins: []string{"*"},
|
AllowOrigins: []string{"*"},
|
||||||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
|
||||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept"},
|
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With"},
|
||||||
ExposeHeaders: []string{"Content-Length"},
|
ExposeHeaders: []string{"Content-Length", "Content-Type"},
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
MaxAge: 86400, // 24 hours
|
MaxAge: 86400, // 24 hours
|
||||||
}
|
}
|
||||||
@@ -30,16 +30,30 @@ func DefaultConfig() Config {
|
|||||||
func Cors(config Config) gin.HandlerFunc {
|
func Cors(config Config) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
origin := c.Request.Header.Get("Origin")
|
origin := c.Request.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否允许该来源
|
// 检查是否允许该来源
|
||||||
allowOrigin := "*"
|
allowOrigin := ""
|
||||||
for _, o := range config.AllowOrigins {
|
for _, o := range config.AllowOrigins {
|
||||||
|
if o == "*" {
|
||||||
|
// 通配符时,回显实际 origin(兼容 credentials)
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
if o == origin {
|
if o == origin {
|
||||||
allowOrigin = origin
|
allowOrigin = origin
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if allowOrigin == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
||||||
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
|
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
|
||||||
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))
|
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))
|
||||||
|
|||||||
Reference in New Issue
Block a user