202 lines
6.3 KiB
Go
202 lines
6.3 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"gitea.com/texpixel/document_ai/internal/model/task"
|
|
"gitea.com/texpixel/document_ai/internal/storage/dao"
|
|
"gitea.com/texpixel/document_ai/pkg/log"
|
|
"gitea.com/texpixel/document_ai/pkg/oss"
|
|
)
|
|
|
|
type TaskService struct {
|
|
recognitionTaskDao *dao.RecognitionTaskDao
|
|
evaluateTaskDao *dao.EvaluateTaskDao
|
|
recognitionResultDao *dao.RecognitionResultDao
|
|
}
|
|
|
|
func NewTaskService() *TaskService {
|
|
return &TaskService{
|
|
recognitionTaskDao: dao.NewRecognitionTaskDao(),
|
|
evaluateTaskDao: dao.NewEvaluateTaskDao(),
|
|
recognitionResultDao: dao.NewRecognitionResultDao(),
|
|
}
|
|
}
|
|
|
|
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
|
|
task, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
|
|
return err
|
|
}
|
|
|
|
if task == nil {
|
|
log.Error(ctx, "func", "EvaluateTask", "msg", "task not found")
|
|
return errors.New("task not found")
|
|
}
|
|
if task.Status != dao.TaskStatusCompleted {
|
|
log.Error(ctx, "func", "EvaluateTask", "msg", "task not finished")
|
|
return errors.New("task not finished")
|
|
}
|
|
|
|
evaluateTask := &dao.EvaluateTask{
|
|
TaskID: task.ID,
|
|
Satisfied: req.Satisfied,
|
|
Feedback: req.Feedback,
|
|
Comment: strings.Join(req.Suggestion, ","),
|
|
}
|
|
err = svc.evaluateTaskDao.Create(dao.DB.WithContext(ctx), evaluateTask)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
|
|
tasks, total, err := svc.recognitionTaskDao.GetTaskList(dao.DB.WithContext(ctx), req.UserID, dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
taskIDs := make([]int64, 0, len(tasks))
|
|
for _, item := range tasks {
|
|
taskIDs = append(taskIDs, item.ID)
|
|
}
|
|
|
|
recognitionResults, err := svc.recognitionResultDao.GetByTaskIDs(dao.DB.WithContext(ctx), taskIDs)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "GetTaskList", "msg", "get recognition results failed", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
recognitionResultMap := make(map[int64]*dao.RecognitionResult)
|
|
for _, item := range recognitionResults {
|
|
recognitionResultMap[item.TaskID] = item
|
|
}
|
|
|
|
resp := &task.TaskListResponse{
|
|
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
|
|
Total: total,
|
|
}
|
|
for _, item := range tasks {
|
|
var latex string
|
|
var markdown string
|
|
var mathML string
|
|
recognitionResult := recognitionResultMap[item.ID]
|
|
if recognitionResult != nil {
|
|
latex = recognitionResult.Latex
|
|
markdown = recognitionResult.Markdown
|
|
mathML = recognitionResult.MathML
|
|
}
|
|
originURL, err := oss.GetDownloadURL(ctx, item.FileURL)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "GetTaskList", "msg", "get origin url failed", "error", err)
|
|
}
|
|
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
|
|
Latex: latex,
|
|
Markdown: markdown,
|
|
MathML: mathML,
|
|
TaskID: item.TaskUUID,
|
|
FileName: item.FileName,
|
|
Status: int(item.Status),
|
|
OriginURL: originURL,
|
|
TaskType: item.TaskType.String(),
|
|
CreatedAt: item.CreatedAt.Format("2006-01-02 15:04:05"),
|
|
})
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (svc *TaskService) ExportTask(ctx context.Context, req *task.ExportTaskRequest) ([]byte, string, error) {
|
|
recognitionTask, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "get task by task id failed", "error", err)
|
|
return nil, "", err
|
|
}
|
|
|
|
if recognitionTask == nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "task not found")
|
|
return nil, "", errors.New("task not found")
|
|
}
|
|
|
|
if recognitionTask.Status != dao.TaskStatusCompleted {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "task not finished")
|
|
return nil, "", errors.New("task not finished")
|
|
}
|
|
|
|
recognitionResult, err := svc.recognitionResultDao.GetByTaskID(dao.DB.WithContext(ctx), recognitionTask.ID)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "get recognition result by task id failed", "error", err)
|
|
return nil, "", err
|
|
}
|
|
|
|
if recognitionResult == nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "recognition result not found")
|
|
return nil, "", errors.New("recognition result not found")
|
|
}
|
|
|
|
markdown := recognitionResult.Markdown
|
|
if markdown == "" {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "markdown not found")
|
|
return nil, "", errors.New("markdown not found")
|
|
}
|
|
|
|
// 获取文件名(去掉扩展名)
|
|
filename := strings.TrimSuffix(recognitionTask.FileName, "."+strings.ToLower(strings.Split(recognitionTask.FileName, ".")[len(strings.Split(recognitionTask.FileName, "."))-1]))
|
|
if filename == "" {
|
|
filename = "texpixel"
|
|
}
|
|
|
|
// 构建 JSON 请求体
|
|
requestBody := map[string]string{
|
|
"markdown": markdown,
|
|
"filename": filename,
|
|
}
|
|
jsonData, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "json marshal failed", "error", err)
|
|
return nil, "", err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/doc_process/v1/convert/file", bytes.NewReader(jsonData))
|
|
if err != nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "create http request failed", "error", err)
|
|
return nil, "", err
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "http request failed", "error", err)
|
|
return nil, "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "http request failed", "status", resp.StatusCode)
|
|
return nil, "", fmt.Errorf("export service returned status: %d", resp.StatusCode)
|
|
}
|
|
|
|
fileData, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
log.Error(ctx, "func", "ExportTask", "msg", "read response body failed", "error", err)
|
|
return nil, "", err
|
|
}
|
|
|
|
// 新接口只返回 DOCX 格式
|
|
contentType := "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
|
|
return fileData, contentType, nil
|
|
}
|