2025-12-10 18:33:37 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"bytes"
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"encoding/base64"
|
|
|
|
|
|
"encoding/json"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"io"
|
|
|
|
|
|
"net/http"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/config"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/internal/model/formula"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
2025-12-10 23:17:24 +08:00
|
|
|
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
2025-12-10 18:33:37 +08:00
|
|
|
|
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/pkg/constant"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/pkg/httpclient"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/pkg/oss"
|
|
|
|
|
|
"gitea.com/bitwsd/document_ai/pkg/utils"
|
|
|
|
|
|
"gorm.io/gorm"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type RecognitionService struct {
|
|
|
|
|
|
db *gorm.DB
|
|
|
|
|
|
queueLimit chan struct{}
|
|
|
|
|
|
stopChan chan struct{}
|
|
|
|
|
|
httpClient *httpclient.Client
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func NewRecognitionService() *RecognitionService {
|
|
|
|
|
|
s := &RecognitionService{
|
|
|
|
|
|
db: dao.DB,
|
|
|
|
|
|
queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition),
|
|
|
|
|
|
stopChan: make(chan struct{}),
|
|
|
|
|
|
httpClient: httpclient.NewClient(nil), // 使用默认配置
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 服务启动时就开始处理队列
|
|
|
|
|
|
utils.SafeGo(func() {
|
|
|
|
|
|
lock, err := cache.GetDistributedLock(context.Background())
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if !lock {
|
|
|
|
|
|
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败")
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
s.processFormulaQueue(context.Background())
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return s
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) {
|
|
|
|
|
|
count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err)
|
|
|
|
|
|
return nil, common.NewError(common.CodeSystemError, "系统错误", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if count >= constant.VLMFormulaCount {
|
|
|
|
|
|
return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
taskDao := dao.NewRecognitionTaskDao()
|
|
|
|
|
|
task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err)
|
|
|
|
|
|
return nil, common.NewError(common.CodeDBError, "获取任务失败", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if task == nil {
|
|
|
|
|
|
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo)
|
|
|
|
|
|
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending {
|
|
|
|
|
|
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo)
|
|
|
|
|
|
return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err)
|
|
|
|
|
|
return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
utils.SafeGo(func() {
|
|
|
|
|
|
s.processVLFormula(context.Background(), task.ID)
|
|
|
|
|
|
_, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return task, nil
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) {
|
|
|
|
|
|
sess := dao.DB.WithContext(ctx)
|
|
|
|
|
|
taskDao := dao.NewRecognitionTaskDao()
|
|
|
|
|
|
task := &dao.RecognitionTask{
|
|
|
|
|
|
TaskUUID: utils.NewUUID(),
|
|
|
|
|
|
TaskType: dao.TaskType(req.TaskType),
|
|
|
|
|
|
Status: dao.TaskStatusPending,
|
|
|
|
|
|
FileURL: req.FileURL,
|
|
|
|
|
|
FileName: req.FileName,
|
|
|
|
|
|
FileHash: req.FileHash,
|
|
|
|
|
|
IP: common.GetIPFromContext(ctx),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := taskDao.Create(sess, task); err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err)
|
|
|
|
|
|
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err := s.handleFormulaRecognition(ctx, task.ID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err)
|
|
|
|
|
|
return nil, common.NewError(common.CodeSystemError, "处理任务失败", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return task, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) {
|
|
|
|
|
|
taskDao := dao.NewRecognitionTaskDao()
|
|
|
|
|
|
resultDao := dao.NewRecognitionResultDao()
|
|
|
|
|
|
sess := dao.DB.WithContext(ctx)
|
|
|
|
|
|
|
|
|
|
|
|
count, err := cache.GetFormulaTaskCount(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if count > int64(config.GlobalConfig.Limit.FormulaRecognition) {
|
|
|
|
|
|
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
task, err := taskDao.GetByTaskNo(sess, taskNo)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
|
|
|
|
log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo)
|
|
|
|
|
|
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo)
|
|
|
|
|
|
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if task.Status != dao.TaskStatusCompleted {
|
|
|
|
|
|
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
taskRet, err := resultDao.GetByTaskID(sess, task.ID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
|
|
|
|
log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo)
|
|
|
|
|
|
return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
|
|
|
|
|
|
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
latex := taskRet.NewContentCodec().GetContent().(string)
|
|
|
|
|
|
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
|
|
|
|
|
// 简化为只负责将任务加入队列
|
|
|
|
|
|
_, err := cache.PushFormulaTask(ctx, taskID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Stop 用于优雅关闭服务
|
|
|
|
|
|
func (s *RecognitionService) Stop() {
|
|
|
|
|
|
close(s.stopChan)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) {
|
|
|
|
|
|
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if task == nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
|
|
|
|
|
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
|
|
|
|
|
|
|
|
|
|
|
// 处理具体任务
|
|
|
|
|
|
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL); err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
|
|
|
|
|
}
|
|
|
|
|
|
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
|
|
|
|
|
|
// 为整个任务处理添加超时控制
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
|
|
tx := dao.DB.Begin()
|
|
|
|
|
|
var (
|
|
|
|
|
|
taskDao = dao.NewRecognitionTaskDao()
|
|
|
|
|
|
resultDao = dao.NewRecognitionResultDao()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
isSuccess := false
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
if !isSuccess {
|
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
|
status := dao.TaskStatusFailed
|
|
|
|
|
|
remark := "任务处理失败"
|
|
|
|
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
|
|
|
|
remark = "任务处理超时"
|
|
|
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
remark = err.Error()
|
|
|
|
|
|
}
|
|
|
|
|
|
_ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context,避免已取消的context影响状态更新
|
|
|
|
|
|
map[string]interface{}{"id": taskID},
|
|
|
|
|
|
map[string]interface{}{
|
|
|
|
|
|
"status": status,
|
|
|
|
|
|
"completed_at": time.Now(),
|
|
|
|
|
|
"remark": remark,
|
|
|
|
|
|
})
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
_ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()})
|
|
|
|
|
|
tx.Commit()
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 下载图片文件
|
|
|
|
|
|
reader, err := oss.DownloadFile(ctx, fileURL)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
defer reader.Close()
|
|
|
|
|
|
|
|
|
|
|
|
// 读取图片数据
|
|
|
|
|
|
imageData, err := io.ReadAll(reader)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
downloadURL, err := oss.GetDownloadURL(ctx, fileURL)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "获取下载URL失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 将图片转为base64编码
|
|
|
|
|
|
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
|
|
|
|
|
|
|
|
|
|
|
// 创建JSON请求
|
|
|
|
|
|
requestData := map[string]string{
|
|
|
|
|
|
"image_base64": base64Image,
|
|
|
|
|
|
"img_url": downloadURL,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
jsonData, err := json.Marshal(requestData)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 设置Content-Type头为application/json
|
|
|
|
|
|
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
|
|
|
|
|
|
|
|
|
|
|
// 发送请求时会使用带超时的context
|
|
|
|
|
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, s.getURL(ctx), bytes.NewReader(jsonData), headers)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
|
|
|
|
|
return fmt.Errorf("request timeout")
|
|
|
|
|
|
}
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
|
|
|
|
log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功", "resp", resp.Body)
|
|
|
|
|
|
body := &bytes.Buffer{}
|
|
|
|
|
|
if _, err = body.ReadFrom(resp.Body); err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
katex := utils.ToKatex(body.String())
|
|
|
|
|
|
content := &dao.FormulaRecognitionContent{Latex: katex}
|
|
|
|
|
|
b, _ := json.Marshal(content)
|
|
|
|
|
|
// Save recognition result
|
|
|
|
|
|
result := &dao.RecognitionResult{
|
|
|
|
|
|
TaskID: taskID,
|
|
|
|
|
|
TaskType: dao.TaskTypeFormula,
|
|
|
|
|
|
Content: b,
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := resultDao.Create(tx, *result); err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
isSuccess = true
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string) error {
|
|
|
|
|
|
isSuccess := false
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
if !isSuccess {
|
|
|
|
|
|
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
reader, err := oss.DownloadFile(ctx, fileURL)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
defer reader.Close()
|
|
|
|
|
|
imageData, err := io.ReadAll(reader)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
prompt := `
|
|
|
|
|
|
Please perform OCR on the image and output only LaTeX code.
|
|
|
|
|
|
Important instructions:
|
|
|
|
|
|
|
|
|
|
|
|
* "The image contains mathematical formulas, no plain text."
|
|
|
|
|
|
|
|
|
|
|
|
* "Preserve all layout, symbols, subscripts, summations, parentheses, etc., exactly as shown."
|
|
|
|
|
|
|
|
|
|
|
|
* "Use \[ ... \] or align environments to represent multiline math expressions."
|
|
|
|
|
|
|
|
|
|
|
|
* "Use adaptive symbols such as \left and \right where applicable."
|
|
|
|
|
|
|
|
|
|
|
|
* "Do not include any extra commentary, template answers, or unrelated equations."
|
|
|
|
|
|
|
|
|
|
|
|
* "Only output valid LaTeX code based on the actual content of the image, and not change the original mathematical expression."
|
|
|
|
|
|
|
|
|
|
|
|
* "The output result must be can render by better-react-mathjax."
|
|
|
|
|
|
`
|
|
|
|
|
|
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
|
|
|
|
|
|
|
|
|
|
|
requestBody := formula.VLFormulaRequest{
|
|
|
|
|
|
Model: "Qwen/Qwen2.5-VL-32B-Instruct",
|
|
|
|
|
|
Stream: false,
|
|
|
|
|
|
MaxTokens: 512,
|
|
|
|
|
|
Temperature: 0.1,
|
|
|
|
|
|
TopP: 0.1,
|
|
|
|
|
|
TopK: 50,
|
|
|
|
|
|
FrequencyPenalty: 0.2,
|
|
|
|
|
|
N: 1,
|
|
|
|
|
|
Messages: []formula.Message{
|
|
|
|
|
|
{
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Content: []formula.Content{
|
|
|
|
|
|
{
|
|
|
|
|
|
Type: "text",
|
|
|
|
|
|
Text: prompt,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
Type: "image_url",
|
|
|
|
|
|
ImageURL: formula.Image{
|
|
|
|
|
|
Detail: "auto",
|
|
|
|
|
|
URL: "data:image/jpeg;base64," + base64Image,
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 将请求体转换为JSON
|
|
|
|
|
|
jsonData, err := json.Marshal(requestBody)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
headers := map[string]string{
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
"Authorization": utils.SiliconFlowToken,
|
|
|
|
|
|
}
|
|
|
|
|
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
|
var response formula.VLFormulaResponse
|
|
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 提取LaTeX代码
|
|
|
|
|
|
var latex string
|
|
|
|
|
|
if len(response.Choices) > 0 {
|
|
|
|
|
|
if response.Choices[0].Message.Content != "" {
|
|
|
|
|
|
latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "")
|
|
|
|
|
|
latex = strings.ReplaceAll(latex, "```latex", "")
|
|
|
|
|
|
latex = strings.ReplaceAll(latex, "```", "")
|
|
|
|
|
|
// 规范化LaTeX代码,移除不必要的空格
|
|
|
|
|
|
latex = strings.ReplaceAll(latex, " = ", "=")
|
|
|
|
|
|
latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[")
|
|
|
|
|
|
latex = strings.TrimPrefix(latex, "\\[")
|
|
|
|
|
|
latex = strings.TrimSuffix(latex, "\\]")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
resultDao := dao.NewRecognitionResultDao()
|
|
|
|
|
|
var formulaRes *dao.FormulaRecognitionContent
|
|
|
|
|
|
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
if result == nil {
|
|
|
|
|
|
formulaRes = &dao.FormulaRecognitionContent{EnhanceLatex: latex}
|
|
|
|
|
|
b, err := formulaRes.Encode()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Content: b})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
formulaRes = result.NewContentCodec().(*dao.FormulaRecognitionContent)
|
|
|
|
|
|
err = formulaRes.Decode()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解码任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
formulaRes.EnhanceLatex = latex
|
|
|
|
|
|
b, err := formulaRes.Encode()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"content": b})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
isSuccess = true
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) processFormulaQueue(ctx context.Context) {
|
|
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-s.stopChan:
|
|
|
|
|
|
return
|
|
|
|
|
|
default:
|
|
|
|
|
|
s.processOneTask(ctx)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) processOneTask(ctx context.Context) {
|
|
|
|
|
|
// 限制队列数量
|
|
|
|
|
|
s.queueLimit <- struct{}{}
|
|
|
|
|
|
defer func() { <-s.queueLimit }()
|
|
|
|
|
|
|
|
|
|
|
|
taskID, err := cache.PopFormulaTask(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if task == nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
|
|
|
|
|
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
|
|
|
|
|
|
|
|
|
|
|
// 处理具体任务
|
|
|
|
|
|
if err := s.processFormulaTask(ctx, taskID, task.FileURL); err != nil {
|
|
|
|
|
|
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *RecognitionService) getURL(ctx context.Context) string {
|
|
|
|
|
|
return "http://cloud.srcstar.com:8045/formula/predict"
|
|
|
|
|
|
}
|