Files
doc_ai_backed/internal/service/recognition_service.go

532 lines
18 KiB
Go
Raw Normal View History

2025-12-10 18:33:37 +08:00
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"gitea.com/bitwsd/document_ai/config"
"gitea.com/bitwsd/document_ai/internal/model/formula"
"gitea.com/bitwsd/document_ai/internal/storage/cache"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
2025-12-10 23:17:24 +08:00
"gitea.com/bitwsd/document_ai/pkg/log"
2025-12-10 18:33:37 +08:00
"gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/httpclient"
"gitea.com/bitwsd/document_ai/pkg/oss"
"gitea.com/bitwsd/document_ai/pkg/utils"
"gorm.io/gorm"
)
type RecognitionService struct {
db *gorm.DB
queueLimit chan struct{}
stopChan chan struct{}
httpClient *httpclient.Client
}
func NewRecognitionService() *RecognitionService {
s := &RecognitionService{
db: dao.DB,
queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition),
stopChan: make(chan struct{}),
httpClient: httpclient.NewClient(nil), // 使用默认配置
}
// 服务启动时就开始处理队列
utils.SafeGo(func() {
lock, err := cache.GetDistributedLock(context.Background())
if err != nil {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err)
return
}
if !lock {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败")
return
}
s.processFormulaQueue(context.Background())
})
return s
}
func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) {
count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "系统错误", err)
}
if count >= constant.VLMFormulaCount {
return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil)
}
taskDao := dao.NewRecognitionTaskDao()
task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "获取任务失败", err)
}
if task == nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err)
}
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err)
return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err)
}
utils.SafeGo(func() {
s.processVLFormula(context.Background(), task.ID)
_, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err)
}
})
return task, nil
}
func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) {
sess := dao.DB.WithContext(ctx)
taskDao := dao.NewRecognitionTaskDao()
task := &dao.RecognitionTask{
TaskUUID: utils.NewUUID(),
TaskType: dao.TaskType(req.TaskType),
Status: dao.TaskStatusPending,
FileURL: req.FileURL,
FileName: req.FileName,
FileHash: req.FileHash,
IP: common.GetIPFromContext(ctx),
}
if err := taskDao.Create(sess, task); err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
}
err := s.handleFormulaRecognition(ctx, task.ID)
if err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "处理任务失败", err)
}
return task, nil
}
func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) {
taskDao := dao.NewRecognitionTaskDao()
resultDao := dao.NewRecognitionResultDao()
sess := dao.DB.WithContext(ctx)
count, err := cache.GetFormulaTaskCount(ctx)
if err != nil {
log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err)
}
if count > int64(config.GlobalConfig.Limit.FormulaRecognition) {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil
}
task, err := taskDao.GetByTaskNo(sess, taskNo)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
}
if task.Status != dao.TaskStatusCompleted {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil
}
taskRet, err := resultDao.GetByTaskID(sess, task.ID)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
}
latex := taskRet.NewContentCodec().GetContent().(string)
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
}
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
// 简化为只负责将任务加入队列
_, err := cache.PushFormulaTask(ctx, taskID)
if err != nil {
log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err)
return err
}
log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID)
return nil
}
// Stop 用于优雅关闭服务
func (s *RecognitionService) Stop() {
close(s.stopChan)
}
func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) {
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL); err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
// 为整个任务处理添加超时控制
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tx := dao.DB.Begin()
var (
taskDao = dao.NewRecognitionTaskDao()
resultDao = dao.NewRecognitionResultDao()
)
isSuccess := false
defer func() {
if !isSuccess {
tx.Rollback()
status := dao.TaskStatusFailed
remark := "任务处理失败"
if ctx.Err() == context.DeadlineExceeded {
remark = "任务处理超时"
}
if err != nil {
remark = err.Error()
}
_ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context避免已取消的context影响状态更新
map[string]interface{}{"id": taskID},
map[string]interface{}{
"status": status,
"completed_at": time.Now(),
"remark": remark,
})
return
}
_ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()})
tx.Commit()
}()
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err)
}
// 下载图片文件
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err)
return err
}
defer reader.Close()
// 读取图片数据
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
downloadURL, err := oss.GetDownloadURL(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "获取下载URL失败", "error", err)
return err
}
// 将图片转为base64编码
base64Image := base64.StdEncoding.EncodeToString(imageData)
// 创建JSON请求
requestData := map[string]string{
"image_base64": base64Image,
"img_url": downloadURL,
}
jsonData, err := json.Marshal(requestData)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
// 设置Content-Type头为application/json
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
// 发送请求时会使用带超时的context
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, s.getURL(ctx), bytes.NewReader(jsonData), headers)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
return fmt.Errorf("request timeout")
}
log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err)
return err
}
defer resp.Body.Close()
log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功", "resp", resp.Body)
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err)
return err
}
katex := utils.ToKatex(body.String())
content := &dao.FormulaRecognitionContent{Latex: katex}
b, _ := json.Marshal(content)
// Save recognition result
result := &dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypeFormula,
Content: b,
}
if err := resultDao.Create(tx, *result); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
return err
}
isSuccess = true
return nil
}
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string) error {
isSuccess := false
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
}()
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err)
return err
}
defer reader.Close()
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
prompt := `
Please perform OCR on the image and output only LaTeX code.
Important instructions:
* "The image contains mathematical formulas, no plain text."
* "Preserve all layout, symbols, subscripts, summations, parentheses, etc., exactly as shown."
* "Use \[ ... \] or align environments to represent multiline math expressions."
* "Use adaptive symbols such as \left and \right where applicable."
* "Do not include any extra commentary, template answers, or unrelated equations."
* "Only output valid LaTeX code based on the actual content of the image, and not change the original mathematical expression."
* "The output result must be can render by better-react-mathjax."
`
base64Image := base64.StdEncoding.EncodeToString(imageData)
requestBody := formula.VLFormulaRequest{
Model: "Qwen/Qwen2.5-VL-32B-Instruct",
Stream: false,
MaxTokens: 512,
Temperature: 0.1,
TopP: 0.1,
TopK: 50,
FrequencyPenalty: 0.2,
N: 1,
Messages: []formula.Message{
{
Role: "user",
Content: []formula.Content{
{
Type: "text",
Text: prompt,
},
{
Type: "image_url",
ImageURL: formula.Image{
Detail: "auto",
URL: "data:image/jpeg;base64," + base64Image,
},
},
},
},
},
}
// 将请求体转换为JSON
jsonData, err := json.Marshal(requestBody)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": utils.SiliconFlowToken,
}
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err)
return err
}
defer resp.Body.Close()
// 解析响应
var response formula.VLFormulaResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err)
return err
}
// 提取LaTeX代码
var latex string
if len(response.Choices) > 0 {
if response.Choices[0].Message.Content != "" {
latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "")
latex = strings.ReplaceAll(latex, "```latex", "")
latex = strings.ReplaceAll(latex, "```", "")
// 规范化LaTeX代码移除不必要的空格
latex = strings.ReplaceAll(latex, " = ", "=")
latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[")
latex = strings.TrimPrefix(latex, "\\[")
latex = strings.TrimSuffix(latex, "\\]")
}
}
resultDao := dao.NewRecognitionResultDao()
var formulaRes *dao.FormulaRecognitionContent
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
return err
}
if result == nil {
formulaRes = &dao.FormulaRecognitionContent{EnhanceLatex: latex}
b, err := formulaRes.Encode()
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
return err
}
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Content: b})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
formulaRes = result.NewContentCodec().(*dao.FormulaRecognitionContent)
err = formulaRes.Decode()
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解码任务结果失败", "error", err)
return err
}
formulaRes.EnhanceLatex = latex
b, err := formulaRes.Encode()
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
return err
}
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"content": b})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
func (s *RecognitionService) processFormulaQueue(ctx context.Context) {
for {
select {
case <-s.stopChan:
return
default:
s.processOneTask(ctx)
}
}
}
func (s *RecognitionService) processOneTask(ctx context.Context) {
// 限制队列数量
s.queueLimit <- struct{}{}
defer func() { <-s.queueLimit }()
taskID, err := cache.PopFormulaTask(ctx)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务
if err := s.processFormulaTask(ctx, taskID, task.FileURL); err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
func (s *RecognitionService) getURL(ctx context.Context) string {
return "http://cloud.srcstar.com:8045/formula/predict"
}