package service import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "strings" "time" "gitea.com/bitwsd/document_ai/config" "gitea.com/bitwsd/document_ai/internal/model/formula" "gitea.com/bitwsd/document_ai/internal/storage/cache" "gitea.com/bitwsd/document_ai/internal/storage/dao" "gitea.com/bitwsd/document_ai/pkg/log" "gitea.com/bitwsd/document_ai/pkg/common" "gitea.com/bitwsd/document_ai/pkg/constant" "gitea.com/bitwsd/document_ai/pkg/httpclient" "gitea.com/bitwsd/document_ai/pkg/oss" "gitea.com/bitwsd/document_ai/pkg/requestid" "gitea.com/bitwsd/document_ai/pkg/utils" "gorm.io/gorm" ) type RecognitionService struct { db *gorm.DB queueLimit chan struct{} stopChan chan struct{} httpClient *httpclient.Client } func NewRecognitionService() *RecognitionService { s := &RecognitionService{ db: dao.DB, queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition), stopChan: make(chan struct{}), httpClient: httpclient.NewClient(nil), // 使用默认配置 } // 服务启动时就开始处理队列 utils.SafeGo(func() { lock, err := cache.GetDistributedLock(context.Background()) if err != nil { log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err) return } if !lock { log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败") return } s.processFormulaQueue(context.Background()) }) return s } func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) { count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx)) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err) return nil, common.NewError(common.CodeSystemError, "系统错误", err) } if count >= constant.VLMFormulaCount { return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil) } taskDao := dao.NewRecognitionTaskDao() task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err) return nil, common.NewError(common.CodeDBError, "获取任务失败", err) } if task == nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo) return nil, common.NewError(common.CodeNotFound, "任务不存在", err) } if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo) return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err) } err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing}) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err) return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err) } utils.SafeGo(func() { s.processVLFormula(context.Background(), task.ID) _, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx)) if err != nil { log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err) } }) return task, nil } func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) { sess := dao.DB.WithContext(ctx) taskDao := dao.NewRecognitionTaskDao() task := &dao.RecognitionTask{ UserID: req.UserID, TaskUUID: utils.NewUUID(), TaskType: dao.TaskType(req.TaskType), Status: dao.TaskStatusPending, FileURL: req.FileURL, FileName: req.FileName, FileHash: req.FileHash, IP: common.GetIPFromContext(ctx), } if err := taskDao.Create(sess, task); err != nil { log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err) return nil, common.NewError(common.CodeDBError, "创建任务失败", err) } err := s.handleFormulaRecognition(ctx, task.ID) if err != nil { log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err) return nil, common.NewError(common.CodeSystemError, "处理任务失败", err) } return task, nil } func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) { taskDao := dao.NewRecognitionTaskDao() resultDao := dao.NewRecognitionResultDao() sess := dao.DB.WithContext(ctx) count, err := cache.GetFormulaTaskCount(ctx) if err != nil { log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err) } if count > int64(config.GlobalConfig.Limit.FormulaRecognition) { return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil } task, err := taskDao.GetByTaskNo(sess, taskNo) if err != nil { if err == gorm.ErrRecordNotFound { log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo) return nil, common.NewError(common.CodeNotFound, "任务不存在", err) } log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo) return nil, common.NewError(common.CodeDBError, "查询任务失败", err) } if task.Status != dao.TaskStatusCompleted { return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil } taskRet, err := resultDao.GetByTaskID(sess, task.ID) if err != nil { if err == gorm.ErrRecordNotFound { log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo) return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err) } log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo) return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err) } // 构建 Markdown 格式 markdown := taskRet.Markdown if markdown == "" { markdown = fmt.Sprintf("$$%s$$", taskRet.Latex) } return &formula.GetFormulaTaskResponse{ TaskNo: taskNo, Latex: taskRet.Latex, Markdown: markdown, MathML: taskRet.MathML, Status: int(task.Status), }, nil } func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error { // 简化为只负责将任务加入队列 _, err := cache.PushFormulaTask(ctx, taskID) if err != nil { log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err) return err } log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID) return nil } // Stop 用于优雅关闭服务 func (s *RecognitionService) Stop() { close(s.stopChan) } func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) { task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err) return } if task == nil { log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID) return } ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID) log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID) // 处理具体任务 if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLQwen3VL32BInstruct); err != nil { log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err) return } log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID) } // MathpixRequest Mathpix API /v3/text 完整请求结构 type MathpixRequest struct { // 图片源:URL 或 base64 编码 Src string `json:"src"` // 元数据键值对 Metadata map[string]interface{} `json:"metadata"` // 标签列表,用于标识结果 Tags []string `json:"tags"` // 异步请求标志 Async bool `json:"async"` // 回调配置 Callback *MathpixCallback `json:"callback"` // 输出格式列表:text, data, html, latex_styled Formats []string `json:"formats"` // 数据选项 DataOptions *MathpixDataOptions `json:"data_options,omitempty"` // 返回检测到的字母表 IncludeDetectedAlphabets *bool `json:"include_detected_alphabets,omitempty"` // 允许的字母表 AlphabetsAllowed *MathpixAlphabetsAllowed `json:"alphabets_allowed,omitempty"` // 指定图片区域 Region *MathpixRegion `json:"region,omitempty"` // 蓝色HSV过滤模式 EnableBlueHsvFilter bool `json:"enable_blue_hsv_filter"` // 置信度阈值 ConfidenceThreshold float64 `json:"confidence_threshold"` // 符号级别置信度阈值,默认0.75 ConfidenceRateThreshold float64 `json:"confidence_rate_threshold"` // 包含公式标签 IncludeEquationTags bool `json:"include_equation_tags"` // 返回逐行信息 IncludeLineData bool `json:"include_line_data"` // 返回逐词信息 IncludeWordData bool `json:"include_word_data"` // 化学结构OCR IncludeSmiles bool `json:"include_smiles"` // InChI数据 IncludeInchi bool `json:"include_inchi"` // 几何图形数据 IncludeGeometryData bool `json:"include_geometry_data"` // 图表文本提取 IncludeDiagramText bool `json:"include_diagram_text"` // 页面信息,默认true IncludePageInfo *bool `json:"include_page_info,omitempty"` // 自动旋转置信度阈值,默认0.99 AutoRotateConfidenceThreshold float64 `json:"auto_rotate_confidence_threshold"` // 移除多余空格,默认true RmSpaces *bool `json:"rm_spaces,omitempty"` // 移除字体命令,默认false RmFonts bool `json:"rm_fonts"` // 使用aligned/gathered/cases代替array,默认false IdiomaticEqnArrays bool `json:"idiomatic_eqn_arrays"` // 移除不必要的大括号,默认false IdiomaticBraces bool `json:"idiomatic_braces"` // 数字始终为数学模式,默认false NumbersDefaultToMath bool `json:"numbers_default_to_math"` // 数学字体始终为数学模式,默认false MathFontsDefaultToMath bool `json:"math_fonts_default_to_math"` // 行内数学分隔符,默认 ["\\(", "\\)"] MathInlineDelimiters []string `json:"math_inline_delimiters"` // 行间数学分隔符,默认 ["\\[", "\\]"] MathDisplayDelimiters []string `json:"math_display_delimiters"` // 高级表格处理,默认false EnableTablesFallback bool `json:"enable_tables_fallback"` // 全角标点,null表示自动判断 FullwidthPunctuation *bool `json:"fullwidth_punctuation,omitempty"` } // MathpixCallback 回调配置 type MathpixCallback struct { URL string `json:"url"` Headers map[string]string `json:"headers"` } // MathpixDataOptions 数据选项 type MathpixDataOptions struct { IncludeAsciimath bool `json:"include_asciimath"` IncludeMathml bool `json:"include_mathml"` IncludeLatex bool `json:"include_latex"` IncludeTsv bool `json:"include_tsv"` } // MathpixAlphabetsAllowed 允许的字母表 type MathpixAlphabetsAllowed struct { En bool `json:"en"` Hi bool `json:"hi"` Zh bool `json:"zh"` Ja bool `json:"ja"` Ko bool `json:"ko"` Ru bool `json:"ru"` Th bool `json:"th"` Vi bool `json:"vi"` } // MathpixRegion 图片区域 type MathpixRegion struct { TopLeftX int `json:"top_left_x"` TopLeftY int `json:"top_left_y"` Width int `json:"width"` Height int `json:"height"` } // MathpixResponse Mathpix API /v3/text 完整响应结构 type MathpixResponse struct { // 请求ID,用于调试 RequestID string `json:"request_id"` // Mathpix Markdown 格式文本 Text string `json:"text"` // 带样式的LaTeX(仅单个公式图片时返回) LatexStyled string `json:"latex_styled"` // 置信度 [0,1] Confidence float64 `json:"confidence"` // 置信度比率 [0,1] ConfidenceRate float64 `json:"confidence_rate"` // 行数据 LineData []map[string]interface{} `json:"line_data"` // 词数据 WordData []map[string]interface{} `json:"word_data"` // 数据对象列表 Data []MathpixDataItem `json:"data"` // HTML输出 HTML string `json:"html"` // 检测到的字母表 DetectedAlphabets []map[string]interface{} `json:"detected_alphabets"` // 是否打印内容 IsPrinted bool `json:"is_printed"` // 是否手写内容 IsHandwritten bool `json:"is_handwritten"` // 自动旋转置信度 AutoRotateConfidence float64 `json:"auto_rotate_confidence"` // 几何数据 GeometryData []map[string]interface{} `json:"geometry_data"` // 自动旋转角度 {0, 90, -90, 180} AutoRotateDegrees int `json:"auto_rotate_degrees"` // 图片宽度 ImageWidth int `json:"image_width"` // 图片高度 ImageHeight int `json:"image_height"` // 错误信息 Error string `json:"error"` // 错误详情 ErrorInfo *MathpixErrorInfo `json:"error_info"` // API版本 Version string `json:"version"` } // MathpixDataItem 数据项 type MathpixDataItem struct { Type string `json:"type"` Value string `json:"value"` } // MathpixErrorInfo 错误详情 type MathpixErrorInfo struct { ID string `json:"id"` Message string `json:"message"` } // BaiduOCRRequest 百度 OCR 版面分析请求结构 type BaiduOCRRequest struct { // 文件内容 base64 编码 File string `json:"file"` // 文件类型: 0=PDF, 1=图片 FileType int `json:"fileType"` // 是否启用文档方向分类 UseDocOrientationClassify bool `json:"useDocOrientationClassify"` // 是否启用文档扭曲矫正 UseDocUnwarping bool `json:"useDocUnwarping"` // 是否启用图表识别 UseChartRecognition bool `json:"useChartRecognition"` } // BaiduOCRResponse 百度 OCR 版面分析响应结构 type BaiduOCRResponse struct { ErrorCode int `json:"errorCode"` ErrorMsg string `json:"errorMsg"` Result *BaiduOCRResult `json:"result"` } // BaiduOCRResult 百度 OCR 响应结果 type BaiduOCRResult struct { LayoutParsingResults []BaiduLayoutParsingResult `json:"layoutParsingResults"` } // BaiduLayoutParsingResult 单页版面解析结果 type BaiduLayoutParsingResult struct { Markdown BaiduMarkdownResult `json:"markdown"` OutputImages map[string]string `json:"outputImages"` } // BaiduMarkdownResult markdown 结果 type BaiduMarkdownResult struct { Text string `json:"text"` Images map[string]string `json:"images"` } // GetMathML 从响应中获取MathML func (r *MathpixResponse) GetMathML() string { for _, item := range r.Data { if item.Type == "mathml" { return item.Value } } return "" } // GetAsciiMath 从响应中获取AsciiMath func (r *MathpixResponse) GetAsciiMath() string { for _, item := range r.Data { if item.Type == "asciimath" { return item.Value } } return "" } func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) { // 为整个任务处理添加超时控制 ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() tx := dao.DB.Begin() var ( taskDao = dao.NewRecognitionTaskDao() resultDao = dao.NewRecognitionResultDao() ) isSuccess := false defer func() { if !isSuccess { tx.Rollback() status := dao.TaskStatusFailed remark := "任务处理失败" if ctx.Err() == context.DeadlineExceeded { remark = "任务处理超时" } if err != nil { remark = err.Error() } _ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context,避免已取消的context影响状态更新 map[string]interface{}{"id": taskID}, map[string]interface{}{ "status": status, "completed_at": time.Now(), "remark": remark, }) return } _ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()}) tx.Commit() }() err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing}) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err) } // 下载图片文件 reader, err := oss.DownloadFile(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err) return err } defer reader.Close() // 读取图片数据 imageData, err := io.ReadAll(reader) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err) return err } // 将图片转为base64编码 base64Image := base64.StdEncoding.EncodeToString(imageData) // 创建JSON请求 requestData := map[string]string{ "image_base64": base64Image, } jsonData, err := json.Marshal(requestData) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err) return err } // 设置Content-Type头为application/json headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)} // 发送请求到新的 OCR 接口 resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/doc_process/v1/image/ocr", bytes.NewReader(jsonData), headers) if err != nil { if ctx.Err() == context.DeadlineExceeded { log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时") return fmt.Errorf("request timeout") } log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err) return err } defer resp.Body.Close() log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功") body := &bytes.Buffer{} if _, err = body.ReadFrom(resp.Body); err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err) return err } log.Info(ctx, "func", "processFormulaTask", "msg", "响应内容", "body", body.String()) // 解析 JSON 响应 var ocrResp formula.ImageOCRResponse if err := json.Unmarshal(body.Bytes(), &ocrResp); err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err) return err } err = resultDao.Create(tx, dao.RecognitionResult{ TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: ocrResp.Latex, Markdown: ocrResp.Markdown, MathML: ocrResp.MathML, }) if err != nil { log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err) return err } isSuccess = true return nil } func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string, model string) error { isSuccess := false defer func() { if !isSuccess { err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err) } return } err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err) } }() reader, err := oss.DownloadFile(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err) return err } defer reader.Close() imageData, err := io.ReadAll(reader) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err) return err } prompt := `Please perform OCR on the image and output only LaTeX code.` base64Image := base64.StdEncoding.EncodeToString(imageData) requestBody := formula.VLFormulaRequest{ Model: model, Stream: false, MaxTokens: 512, Temperature: 0.1, TopP: 0.1, TopK: 50, FrequencyPenalty: 0.2, N: 1, Messages: []formula.Message{ { Role: "user", Content: []formula.Content{ { Type: "text", Text: prompt, }, { Type: "image_url", ImageURL: formula.Image{ Detail: "auto", URL: "data:image/jpeg;base64," + base64Image, }, }, }, }, }, } // 将请求体转换为JSON jsonData, err := json.Marshal(requestBody) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err) return err } headers := map[string]string{ "Content-Type": "application/json", "Authorization": utils.SiliconFlowToken, } resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err) return err } defer resp.Body.Close() // 解析响应 var response formula.VLFormulaResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err) return err } // 提取LaTeX代码 var latex string if len(response.Choices) > 0 { if response.Choices[0].Message.Content != "" { latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "") latex = strings.ReplaceAll(latex, "```latex", "") latex = strings.ReplaceAll(latex, "```", "") // 规范化LaTeX代码,移除不必要的空格 latex = strings.ReplaceAll(latex, " = ", "=") latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[") latex = strings.TrimPrefix(latex, "\\[") latex = strings.TrimSuffix(latex, "\\]") } } resultDao := dao.NewRecognitionResultDao() result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err) return err } if result == nil { formulaRes := &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex} err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err) return err } } else { result.Latex = latex err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err) return err } } isSuccess = true return nil } func (s *RecognitionService) processFormulaQueue(ctx context.Context) { for { select { case <-s.stopChan: return default: s.processOneTask(ctx) } } } func (s *RecognitionService) processOneTask(ctx context.Context) { // 限制队列数量 s.queueLimit <- struct{}{} defer func() { <-s.queueLimit }() taskID, err := cache.PopFormulaTask(ctx) if err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err) return } task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err) return } if task == nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID) return } ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID) // 使用 gls 设置 request_id,确保在整个任务处理过程中可用 requestid.SetRequestID(task.TaskUUID, func() { log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID) err = s.processFormulaTask(ctx, taskID, task.FileURL) if err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err) return } log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID) }) } // processMathpixTask 使用 Mathpix API 处理公式识别任务(用于增强识别) func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int64, fileURL string) error { isSuccess := false logDao := dao.NewRecognitionLogDao() defer func() { if !isSuccess { err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed}) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err) } return } err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted}) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err) } }() // 下载图片 imageUrl, err := oss.GetDownloadURL(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "获取图片URL失败", "error", err) return err } // 创建 Mathpix API 请求 mathpixReq := MathpixRequest{ Src: imageUrl, Formats: []string{ "text", "latex_styled", "data", "html", }, DataOptions: &MathpixDataOptions{ IncludeMathml: true, IncludeAsciimath: true, IncludeLatex: true, IncludeTsv: true, }, MathInlineDelimiters: []string{"$", "$"}, MathDisplayDelimiters: []string{"$$", "$$"}, RmSpaces: &[]bool{true}[0], } jsonData, err := json.Marshal(mathpixReq) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "JSON编码失败", "error", err) return err } headers := map[string]string{ "Content-Type": "application/json", "app_id": config.GlobalConfig.Mathpix.AppID, "app_key": config.GlobalConfig.Mathpix.AppKey, } endpoint := "https://api.mathpix.com/v3/text" startTime := time.Now() log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_Start", "start_time", startTime) resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 请求失败", "error", err) return err } defer resp.Body.Close() log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_End", "end_time", time.Now(), "duration", time.Since(startTime)) body := &bytes.Buffer{} if _, err = body.ReadFrom(resp.Body); err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "读取响应体失败", "error", err) return err } // 创建日志记录 recognitionLog := &dao.RecognitionLog{ TaskID: taskID, Provider: dao.ProviderMathpix, RequestBody: string(jsonData), ResponseBody: body.String(), } // 解析响应 var mathpixResp MathpixResponse if err := json.Unmarshal(body.Bytes(), &mathpixResp); err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "解析响应失败", "error", err) return err } // 检查错误 if mathpixResp.Error != "" { errMsg := mathpixResp.Error if mathpixResp.ErrorInfo != nil { errMsg = fmt.Sprintf("%s: %s", mathpixResp.ErrorInfo.ID, mathpixResp.ErrorInfo.Message) } log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 返回错误", "error", errMsg) return fmt.Errorf("mathpix error: %s", errMsg) } // 保存日志 err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "保存日志失败", "error", err) } // 更新或创建识别结果 resultDao := dao.NewRecognitionResultDao() result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "获取任务结果失败", "error", err) return err } log.Info(ctx, "func", "processMathpixTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime)) if result == nil { // 创建新结果 err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{ TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: mathpixResp.LatexStyled, Markdown: mathpixResp.Text, MathML: mathpixResp.GetMathML(), }) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "创建任务结果失败", "error", err) return err } } else { // 更新现有结果 err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{ "latex": mathpixResp.LatexStyled, "markdown": mathpixResp.Text, "mathml": mathpixResp.GetMathML(), }) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务结果失败", "error", err) return err } } isSuccess = true return nil } func (s *RecognitionService) processBaiduOCRTask(ctx context.Context, taskID int64, fileURL string) error { isSuccess := false logDao := dao.NewRecognitionLogDao() defer func() { if !isSuccess { err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed}) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err) } return } err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted}) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err) } }() // 从 OSS 下载文件 reader, err := oss.DownloadFile(ctx, fileURL) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "从OSS下载文件失败", "error", err) return err } defer reader.Close() // 读取文件内容 fileBytes, err := io.ReadAll(reader) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取文件内容失败", "error", err) return err } // Base64 编码 fileData := base64.StdEncoding.EncodeToString(fileBytes) // 根据文件扩展名确定 fileType: 0=PDF, 1=图片 fileType := 1 // 默认为图片 lowerFileURL := strings.ToLower(fileURL) if strings.HasSuffix(lowerFileURL, ".pdf") { fileType = 0 } // 创建百度 OCR API 请求 baiduReq := BaiduOCRRequest{ File: fileData, FileType: fileType, UseDocOrientationClassify: false, UseDocUnwarping: false, UseChartRecognition: false, } jsonData, err := json.Marshal(baiduReq) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "JSON编码失败", "error", err) return err } headers := map[string]string{ "Content-Type": "application/json", "Authorization": fmt.Sprintf("token %s", config.GlobalConfig.BaiduOCR.Token), } endpoint := "https://j5veh2l2r6ubk6cb.aistudio-app.com/layout-parsing" startTime := time.Now() log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_Start", "start_time", startTime) resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 请求失败", "error", err) return err } defer resp.Body.Close() log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_End", "end_time", time.Now(), "duration", time.Since(startTime)) body := &bytes.Buffer{} if _, err = body.ReadFrom(resp.Body); err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取响应体失败", "error", err) return err } // 创建日志记录(不记录请求体中的 base64 数据以节省存储) requestLogData := map[string]interface{}{ "fileType": fileType, "useDocOrientationClassify": false, "useDocUnwarping": false, "useChartRecognition": false, "fileSize": len(fileBytes), } requestLogBytes, _ := json.Marshal(requestLogData) recognitionLog := &dao.RecognitionLog{ TaskID: taskID, Provider: dao.ProviderBaiduOCR, RequestBody: string(requestLogBytes), ResponseBody: body.String(), } // 解析响应 var baiduResp BaiduOCRResponse if err := json.Unmarshal(body.Bytes(), &baiduResp); err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "解析响应失败", "error", err) return err } // 检查错误 if baiduResp.ErrorCode != 0 { errMsg := fmt.Sprintf("errorCode: %d, errorMsg: %s", baiduResp.ErrorCode, baiduResp.ErrorMsg) log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 返回错误", "error", errMsg) return fmt.Errorf("baidu ocr error: %s", errMsg) } // 保存日志 err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "保存日志失败", "error", err) } // 合并所有页面的 markdown 结果 var markdownTexts []string if baiduResp.Result != nil && len(baiduResp.Result.LayoutParsingResults) > 0 { for _, res := range baiduResp.Result.LayoutParsingResults { if res.Markdown.Text != "" { markdownTexts = append(markdownTexts, res.Markdown.Text) } } } markdownResult := strings.Join(markdownTexts, "\n\n---\n\n") latex, mml, e := s.HandleConvert(ctx, markdownResult) if e != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "转换失败", "error", err) } // 更新或创建识别结果 resultDao := dao.NewRecognitionResultDao() result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "获取任务结果失败", "error", err) return err } log.Info(ctx, "func", "processBaiduOCRTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime)) if result == nil { // 创建新结果 err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{ TaskID: taskID, TaskType: dao.TaskTypeFormula, Markdown: markdownResult, Latex: latex, MathML: mml, }) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "创建任务结果失败", "error", err) return err } } else { // 更新现有结果 err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{ "markdown": markdownResult, "latex": latex, "mathml": mml, }) if err != nil { log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务结果失败", "error", err) return err } } isSuccess = true return nil } func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error { task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "获取任务失败", "error", err) return err } if task == nil { log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "任务不存在", "task_id", taskID) return err } return s.processMathpixTask(ctx, taskID, task.FileURL) } // ConvertResponse Python 接口返回结构 type ConvertResponse struct { Latex string `json:"latex"` MathML string `json:"mathml"` Error string `json:"error,omitempty"` } func (s *RecognitionService) HandleConvert(ctx context.Context, markdown string) (latex string, mml string, err error) { url := "https://cloud.texpixel.com:10443/doc_converter/v1/convert" // 构建 multipart form body := &bytes.Buffer{} writer := multipart.NewWriter(body) _ = writer.WriteField("markdown_input", markdown) writer.Close() // 使用正确的 Content-Type(包含 boundary) headers := map[string]string{ "Content-Type": writer.FormDataContentType(), } resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, url, body, headers) if err != nil { return "", "", err } defer resp.Body.Close() // 读取响应体 respBody, err := io.ReadAll(resp.Body) if err != nil { return "", "", err } // 检查 HTTP 状态码 if resp.StatusCode != http.StatusOK { return "", "", fmt.Errorf("convert failed: status %d, body: %s", resp.StatusCode, string(respBody)) } // 解析 JSON 响应 var convertResp ConvertResponse if err := json.Unmarshal(respBody, &convertResp); err != nil { return "", "", fmt.Errorf("unmarshal response failed: %v, body: %s", err, string(respBody)) } // 检查业务错误 if convertResp.Error != "" { return "", "", fmt.Errorf("convert error: %s", convertResp.Error) } return convertResp.Latex, convertResp.MathML, nil }