feat: add list api

This commit is contained in:
2025-12-18 12:39:50 +08:00
parent d06f2d9df1
commit 8a6da5b627
15 changed files with 133 additions and 57 deletions

View File

@@ -105,6 +105,7 @@ func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *for
sess := dao.DB.WithContext(ctx)
taskDao := dao.NewRecognitionTaskDao()
task := &dao.RecognitionTask{
UserID: req.UserID,
TaskUUID: utils.NewUUID(),
TaskType: dao.TaskType(req.TaskType),
Status: dao.TaskStatusPending,
@@ -166,7 +167,8 @@ func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string)
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
}
latex := taskRet.NewContentCodec().GetContent().(string)
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
markdown := fmt.Sprintf("$$%s$$", latex)
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Markdown: markdown, Status: int(task.Status)}, nil
}
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
@@ -281,7 +283,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
// 发送请求时会使用带超时的context
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "http://cloud.texpixel.com:1080/formula/predict", bytes.NewReader(jsonData), headers)
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/formula/predict", bytes.NewReader(jsonData), headers)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")

View File

@@ -3,25 +3,30 @@ package service
import (
"context"
"errors"
"fmt"
"strings"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/internal/model/task"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
"gorm.io/gorm"
"gitea.com/bitwsd/document_ai/pkg/log"
)
type TaskService struct {
db *gorm.DB
recognitionTaskDao *dao.RecognitionTaskDao
evaluateTaskDao *dao.EvaluateTaskDao
recognitionResultDao *dao.RecognitionResultDao
}
func NewTaskService() *TaskService {
return &TaskService{dao.DB}
return &TaskService{
recognitionTaskDao: dao.NewRecognitionTaskDao(),
evaluateTaskDao: dao.NewEvaluateTaskDao(),
recognitionResultDao: dao.NewRecognitionResultDao(),
}
}
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
taskDao := dao.NewRecognitionTaskDao()
task, err := taskDao.GetByTaskNo(svc.db.WithContext(ctx), req.TaskNo)
task, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
if err != nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
return err
@@ -36,14 +41,13 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
return errors.New("task not finished")
}
evaluateTaskDao := dao.NewEvaluateTaskDao()
evaluateTask := &dao.EvaluateTask{
TaskID: task.ID,
Satisfied: req.Satisfied,
Feedback: req.Feedback,
Comment: strings.Join(req.Suggestion, ","),
}
err = evaluateTaskDao.Create(svc.db.WithContext(ctx), evaluateTask)
err = svc.evaluateTaskDao.Create(dao.DB.WithContext(ctx), evaluateTask)
if err != nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
return err
@@ -53,19 +57,43 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
}
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
taskDao := dao.NewRecognitionTaskDao()
tasks, err := taskDao.GetTaskList(svc.db.WithContext(ctx), dao.TaskType(req.TaskType), req.Page, req.PageSize)
tasks, total, err := svc.recognitionTaskDao.GetTaskList(dao.DB.WithContext(ctx), req.UserID, dao.TaskType(req.TaskType), req.Page, req.PageSize)
if err != nil {
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
return nil, err
}
taskIDs := make([]int64, 0, len(tasks))
for _, item := range tasks {
taskIDs = append(taskIDs, item.ID)
}
recognitionResults, err := svc.recognitionResultDao.GetByTaskIDs(dao.DB.WithContext(ctx), taskIDs)
if err != nil {
log.Error(ctx, "func", "GetTaskList", "msg", "get recognition results failed", "error", err)
return nil, err
}
recognitionResultMap := make(map[int64]*dao.RecognitionResult)
for _, item := range recognitionResults {
recognitionResultMap[item.TaskID] = item
}
resp := &task.TaskListResponse{
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
HasMore: false,
NextPage: 0,
Total: total,
}
for _, item := range tasks {
var latex string
var markdown string
recognitionResult := recognitionResultMap[item.ID]
if recognitionResult != nil {
latex = recognitionResult.NewContentCodec().GetContent().(string)
markdown = fmt.Sprintf("$$%s$$", latex)
}
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
Latex: latex,
Markdown: markdown,
TaskID: item.TaskUUID,
FileName: item.FileName,
Status: item.Status.String(),