package service import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "strings" "time" "gitea.com/bitwsd/document_ai/config" "gitea.com/bitwsd/document_ai/internal/model/formula" "gitea.com/bitwsd/document_ai/internal/storage/cache" "gitea.com/bitwsd/document_ai/internal/storage/dao" "gitea.com/bitwsd/document_ai/pkg/log" "gitea.com/bitwsd/document_ai/pkg/common" "gitea.com/bitwsd/document_ai/pkg/constant" "gitea.com/bitwsd/document_ai/pkg/httpclient" "gitea.com/bitwsd/document_ai/pkg/oss" "gitea.com/bitwsd/document_ai/pkg/utils" "gorm.io/gorm" ) type RecognitionService struct { db *gorm.DB queueLimit chan struct{} stopChan chan struct{} httpClient *httpclient.Client } func NewRecognitionService() *RecognitionService { s := &RecognitionService{ db: dao.DB, queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition), stopChan: make(chan struct{}), httpClient: httpclient.NewClient(nil), // 使用默认配置 } // 服务启动时就开始处理队列 utils.SafeGo(func() { lock, err := cache.GetDistributedLock(context.Background()) if err != nil { log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err) return } if !lock { log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败") return } s.processFormulaQueue(context.Background()) }) return s } func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) { count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx)) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err) return nil, common.NewError(common.CodeSystemError, "系统错误", err) } if count >= constant.VLMFormulaCount { return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil) } taskDao := dao.NewRecognitionTaskDao() task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err) return nil, common.NewError(common.CodeDBError, "获取任务失败", err) } if task == nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo) return nil, common.NewError(common.CodeNotFound, "任务不存在", err) } if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo) return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err) } err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing}) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err) return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err) } utils.SafeGo(func() { s.processVLFormula(context.Background(), task.ID) _, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx)) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err) } }) return task, nil } func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) { sess := dao.DB.WithContext(ctx) taskDao := dao.NewRecognitionTaskDao() task := &dao.RecognitionTask{ TaskUUID: utils.NewUUID(), TaskType: dao.TaskType(req.TaskType), Status: dao.TaskStatusPending, FileURL: req.FileURL, FileName: req.FileName, FileHash: req.FileHash, IP: common.GetIPFromContext(ctx), } if err := taskDao.Create(sess, task); err != nil { log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err) return nil, common.NewError(common.CodeDBError, "创建任务失败", err) } err := s.handleFormulaRecognition(ctx, task.ID) if err != nil { log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err) return nil, common.NewError(common.CodeSystemError, "处理任务失败", err) } return task, nil } func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) { taskDao := dao.NewRecognitionTaskDao() resultDao := dao.NewRecognitionResultDao() sess := dao.DB.WithContext(ctx) count, err := cache.GetFormulaTaskCount(ctx) if err != nil { log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err) } if count > int64(config.GlobalConfig.Limit.FormulaRecognition) { return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil } task, err := taskDao.GetByTaskNo(sess, taskNo) if err != nil { if err == gorm.ErrRecordNotFound { log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo) return nil, common.NewError(common.CodeNotFound, "任务不存在", err) } log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo) return nil, common.NewError(common.CodeDBError, "查询任务失败", err) } if task.Status != dao.TaskStatusCompleted { return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil } taskRet, err := resultDao.GetByTaskID(sess, task.ID) if err != nil { if err == gorm.ErrRecordNotFound { log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo) return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err) } log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo) return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err) } latex := taskRet.NewContentCodec().GetContent().(string) return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil } func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error { // 简化为只负责将任务加入队列 _, err := cache.PushFormulaTask(ctx, taskID) if err != nil { log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err) return err } log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID) return nil } // Stop 用于优雅关闭服务 func (s *RecognitionService) Stop() { close(s.stopChan) } func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) { task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err) return } if task == nil { log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID) return } ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID) log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID) // 处理具体任务 if err := s.processVLFormulaTask(ctx, taskID, task.FileURL); err != nil { log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err) return } log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID) } func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) { // 为整个任务处理添加超时控制 ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() tx := dao.DB.Begin() var ( taskDao = dao.NewRecognitionTaskDao() resultDao = dao.NewRecognitionResultDao() ) isSuccess := false defer func() { if !isSuccess { tx.Rollback() status := dao.TaskStatusFailed remark := "任务处理失败" if ctx.Err() == context.DeadlineExceeded { remark = "任务处理超时" } if err != nil { remark = err.Error() } _ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context,避免已取消的context影响状态更新 map[string]interface{}{"id": taskID}, map[string]interface{}{ "status": status, "completed_at": time.Now(), "remark": remark, }) return } _ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()}) tx.Commit() }() err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing}) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err) } // 下载图片文件 reader, err := oss.DownloadFile(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err) return err } defer reader.Close() // 读取图片数据 imageData, err := io.ReadAll(reader) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err) return err } downloadURL, err := oss.GetDownloadURL(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "获取下载URL失败", "error", err) return err } // 将图片转为base64编码 base64Image := base64.StdEncoding.EncodeToString(imageData) // 创建JSON请求 requestData := map[string]string{ "image_base64": base64Image, "img_url": downloadURL, } jsonData, err := json.Marshal(requestData) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err) return err } // 设置Content-Type头为application/json headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)} // 发送请求时会使用带超时的context resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, s.getURL(ctx), bytes.NewReader(jsonData), headers) if err != nil { if ctx.Err() == context.DeadlineExceeded { log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时") return fmt.Errorf("request timeout") } log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err) return err } defer resp.Body.Close() log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功", "resp", resp.Body) body := &bytes.Buffer{} if _, err = body.ReadFrom(resp.Body); err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err) return err } katex := utils.ToKatex(body.String()) content := &dao.FormulaRecognitionContent{Latex: katex} b, _ := json.Marshal(content) // Save recognition result result := &dao.RecognitionResult{ TaskID: taskID, TaskType: dao.TaskTypeFormula, Content: b, } if err := resultDao.Create(tx, *result); err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err) return err } isSuccess = true return nil } func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string) error { isSuccess := false defer func() { if !isSuccess { err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err) } return } err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err) } }() reader, err := oss.DownloadFile(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err) return err } defer reader.Close() imageData, err := io.ReadAll(reader) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err) return err } prompt := ` Please perform OCR on the image and output only LaTeX code. Important instructions: * "The image contains mathematical formulas, no plain text." * "Preserve all layout, symbols, subscripts, summations, parentheses, etc., exactly as shown." * "Use \[ ... \] or align environments to represent multiline math expressions." * "Use adaptive symbols such as \left and \right where applicable." * "Do not include any extra commentary, template answers, or unrelated equations." * "Only output valid LaTeX code based on the actual content of the image, and not change the original mathematical expression." * "The output result must be can render by better-react-mathjax." ` base64Image := base64.StdEncoding.EncodeToString(imageData) requestBody := formula.VLFormulaRequest{ Model: "Qwen/Qwen2.5-VL-32B-Instruct", Stream: false, MaxTokens: 512, Temperature: 0.1, TopP: 0.1, TopK: 50, FrequencyPenalty: 0.2, N: 1, Messages: []formula.Message{ { Role: "user", Content: []formula.Content{ { Type: "text", Text: prompt, }, { Type: "image_url", ImageURL: formula.Image{ Detail: "auto", URL: "data:image/jpeg;base64," + base64Image, }, }, }, }, }, } // 将请求体转换为JSON jsonData, err := json.Marshal(requestBody) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err) return err } headers := map[string]string{ "Content-Type": "application/json", "Authorization": utils.SiliconFlowToken, } resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err) return err } defer resp.Body.Close() // 解析响应 var response formula.VLFormulaResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err) return err } // 提取LaTeX代码 var latex string if len(response.Choices) > 0 { if response.Choices[0].Message.Content != "" { latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "") latex = strings.ReplaceAll(latex, "```latex", "") latex = strings.ReplaceAll(latex, "```", "") // 规范化LaTeX代码,移除不必要的空格 latex = strings.ReplaceAll(latex, " = ", "=") latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[") latex = strings.TrimPrefix(latex, "\\[") latex = strings.TrimSuffix(latex, "\\]") } } resultDao := dao.NewRecognitionResultDao() var formulaRes *dao.FormulaRecognitionContent result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err) return err } if result == nil { formulaRes = &dao.FormulaRecognitionContent{EnhanceLatex: latex} b, err := formulaRes.Encode() if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err) return err } err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Content: b}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err) return err } } else { formulaRes = result.NewContentCodec().(*dao.FormulaRecognitionContent) err = formulaRes.Decode() if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "解码任务结果失败", "error", err) return err } formulaRes.EnhanceLatex = latex b, err := formulaRes.Encode() if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err) return err } err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"content": b}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err) return err } } isSuccess = true return nil } func (s *RecognitionService) processFormulaQueue(ctx context.Context) { for { select { case <-s.stopChan: return default: s.processOneTask(ctx) } } } func (s *RecognitionService) processOneTask(ctx context.Context) { // 限制队列数量 s.queueLimit <- struct{}{} defer func() { <-s.queueLimit }() taskID, err := cache.PopFormulaTask(ctx) if err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err) return } task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err) return } if task == nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID) return } ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID) log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID) // 处理具体任务 if err := s.processFormulaTask(ctx, taskID, task.FileURL); err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err) return } log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID) } func (s *RecognitionService) getURL(ctx context.Context) string { return "http://cloud.srcstar.com:8045/formula/predict" }