Files
doc_ai_backed/internal/service/recognition_service.go
2025-12-20 22:48:02 +08:00

835 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"gitea.com/bitwsd/document_ai/config"
"gitea.com/bitwsd/document_ai/internal/model/formula"
"gitea.com/bitwsd/document_ai/internal/storage/cache"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
"gitea.com/bitwsd/document_ai/pkg/log"
"gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/httpclient"
"gitea.com/bitwsd/document_ai/pkg/oss"
"gitea.com/bitwsd/document_ai/pkg/utils"
"gorm.io/gorm"
)
type RecognitionService struct {
db *gorm.DB
queueLimit chan struct{}
stopChan chan struct{}
httpClient *httpclient.Client
}
func NewRecognitionService() *RecognitionService {
s := &RecognitionService{
db: dao.DB,
queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition),
stopChan: make(chan struct{}),
httpClient: httpclient.NewClient(nil), // 使用默认配置
}
// 服务启动时就开始处理队列
utils.SafeGo(func() {
lock, err := cache.GetDistributedLock(context.Background())
if err != nil {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err)
return
}
if !lock {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败")
return
}
s.processFormulaQueue(context.Background())
})
return s
}
func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) {
count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "系统错误", err)
}
if count >= constant.VLMFormulaCount {
return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil)
}
taskDao := dao.NewRecognitionTaskDao()
task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "获取任务失败", err)
}
if task == nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err)
}
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err)
return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err)
}
utils.SafeGo(func() {
s.processVLFormula(context.Background(), task.ID)
_, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err)
}
})
return task, nil
}
func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) {
sess := dao.DB.WithContext(ctx)
taskDao := dao.NewRecognitionTaskDao()
task := &dao.RecognitionTask{
UserID: req.UserID,
TaskUUID: utils.NewUUID(),
TaskType: dao.TaskType(req.TaskType),
Status: dao.TaskStatusPending,
FileURL: req.FileURL,
FileName: req.FileName,
FileHash: req.FileHash,
IP: common.GetIPFromContext(ctx),
}
if err := taskDao.Create(sess, task); err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
}
err := s.handleFormulaRecognition(ctx, task.ID)
if err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "处理任务失败", err)
}
return task, nil
}
func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) {
taskDao := dao.NewRecognitionTaskDao()
resultDao := dao.NewRecognitionResultDao()
sess := dao.DB.WithContext(ctx)
count, err := cache.GetFormulaTaskCount(ctx)
if err != nil {
log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err)
}
if count > int64(config.GlobalConfig.Limit.FormulaRecognition) {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil
}
task, err := taskDao.GetByTaskNo(sess, taskNo)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
}
if task.Status != dao.TaskStatusCompleted {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil
}
taskRet, err := resultDao.GetByTaskID(sess, task.ID)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
}
// 构建 Markdown 格式
markdown := taskRet.Markdown
if markdown == "" {
markdown = fmt.Sprintf("$$%s$$", taskRet.Latex)
}
return &formula.GetFormulaTaskResponse{
TaskNo: taskNo,
Latex: taskRet.Latex,
Markdown: markdown,
MathML: taskRet.MathML,
Status: int(task.Status),
}, nil
}
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
// 简化为只负责将任务加入队列
_, err := cache.PushFormulaTask(ctx, taskID)
if err != nil {
log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err)
return err
}
log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID)
return nil
}
// Stop 用于优雅关闭服务
func (s *RecognitionService) Stop() {
close(s.stopChan)
}
func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) {
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLQwen3VL32BInstruct); err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
// MathpixRequest Mathpix API /v3/text 完整请求结构
type MathpixRequest struct {
// 图片源URL 或 base64 编码
Src string `json:"src"`
// 元数据键值对
Metadata map[string]interface{} `json:"metadata"`
// 标签列表,用于标识结果
Tags []string `json:"tags"`
// 异步请求标志
Async bool `json:"async"`
// 回调配置
Callback *MathpixCallback `json:"callback"`
// 输出格式列表text, data, html, latex_styled
Formats []string `json:"formats"`
// 数据选项
DataOptions *MathpixDataOptions `json:"data_options,omitempty"`
// 返回检测到的字母表
IncludeDetectedAlphabets *bool `json:"include_detected_alphabets,omitempty"`
// 允许的字母表
AlphabetsAllowed *MathpixAlphabetsAllowed `json:"alphabets_allowed,omitempty"`
// 指定图片区域
Region *MathpixRegion `json:"region,omitempty"`
// 蓝色HSV过滤模式
EnableBlueHsvFilter bool `json:"enable_blue_hsv_filter"`
// 置信度阈值
ConfidenceThreshold float64 `json:"confidence_threshold"`
// 符号级别置信度阈值默认0.75
ConfidenceRateThreshold float64 `json:"confidence_rate_threshold"`
// 包含公式标签
IncludeEquationTags bool `json:"include_equation_tags"`
// 返回逐行信息
IncludeLineData bool `json:"include_line_data"`
// 返回逐词信息
IncludeWordData bool `json:"include_word_data"`
// 化学结构OCR
IncludeSmiles bool `json:"include_smiles"`
// InChI数据
IncludeInchi bool `json:"include_inchi"`
// 几何图形数据
IncludeGeometryData bool `json:"include_geometry_data"`
// 图表文本提取
IncludeDiagramText bool `json:"include_diagram_text"`
// 页面信息默认true
IncludePageInfo *bool `json:"include_page_info,omitempty"`
// 自动旋转置信度阈值默认0.99
AutoRotateConfidenceThreshold float64 `json:"auto_rotate_confidence_threshold"`
// 移除多余空格默认true
RmSpaces *bool `json:"rm_spaces,omitempty"`
// 移除字体命令默认false
RmFonts bool `json:"rm_fonts"`
// 使用aligned/gathered/cases代替array默认false
IdiomaticEqnArrays bool `json:"idiomatic_eqn_arrays"`
// 移除不必要的大括号默认false
IdiomaticBraces bool `json:"idiomatic_braces"`
// 数字始终为数学模式默认false
NumbersDefaultToMath bool `json:"numbers_default_to_math"`
// 数学字体始终为数学模式默认false
MathFontsDefaultToMath bool `json:"math_fonts_default_to_math"`
// 行内数学分隔符,默认 ["\\(", "\\)"]
MathInlineDelimiters []string `json:"math_inline_delimiters"`
// 行间数学分隔符,默认 ["\\[", "\\]"]
MathDisplayDelimiters []string `json:"math_display_delimiters"`
// 高级表格处理默认false
EnableTablesFallback bool `json:"enable_tables_fallback"`
// 全角标点null表示自动判断
FullwidthPunctuation *bool `json:"fullwidth_punctuation,omitempty"`
}
// MathpixCallback 回调配置
type MathpixCallback struct {
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
// MathpixDataOptions 数据选项
type MathpixDataOptions struct {
IncludeAsciimath bool `json:"include_asciimath"`
IncludeMathml bool `json:"include_mathml"`
IncludeLatex bool `json:"include_latex"`
IncludeTsv bool `json:"include_tsv"`
}
// MathpixAlphabetsAllowed 允许的字母表
type MathpixAlphabetsAllowed struct {
En bool `json:"en"`
Hi bool `json:"hi"`
Zh bool `json:"zh"`
Ja bool `json:"ja"`
Ko bool `json:"ko"`
Ru bool `json:"ru"`
Th bool `json:"th"`
Vi bool `json:"vi"`
}
// MathpixRegion 图片区域
type MathpixRegion struct {
TopLeftX int `json:"top_left_x"`
TopLeftY int `json:"top_left_y"`
Width int `json:"width"`
Height int `json:"height"`
}
// MathpixResponse Mathpix API /v3/text 完整响应结构
type MathpixResponse struct {
// 请求ID用于调试
RequestID string `json:"request_id"`
// Mathpix Markdown 格式文本
Text string `json:"text"`
// 带样式的LaTeX仅单个公式图片时返回
LatexStyled string `json:"latex_styled"`
// 置信度 [0,1]
Confidence float64 `json:"confidence"`
// 置信度比率 [0,1]
ConfidenceRate float64 `json:"confidence_rate"`
// 行数据
LineData []map[string]interface{} `json:"line_data"`
// 词数据
WordData []map[string]interface{} `json:"word_data"`
// 数据对象列表
Data []MathpixDataItem `json:"data"`
// HTML输出
HTML string `json:"html"`
// 检测到的字母表
DetectedAlphabets []map[string]interface{} `json:"detected_alphabets"`
// 是否打印内容
IsPrinted bool `json:"is_printed"`
// 是否手写内容
IsHandwritten bool `json:"is_handwritten"`
// 自动旋转置信度
AutoRotateConfidence float64 `json:"auto_rotate_confidence"`
// 几何数据
GeometryData []map[string]interface{} `json:"geometry_data"`
// 自动旋转角度 {0, 90, -90, 180}
AutoRotateDegrees int `json:"auto_rotate_degrees"`
// 图片宽度
ImageWidth int `json:"image_width"`
// 图片高度
ImageHeight int `json:"image_height"`
// 错误信息
Error string `json:"error"`
// 错误详情
ErrorInfo *MathpixErrorInfo `json:"error_info"`
// API版本
Version string `json:"version"`
}
// MathpixDataItem 数据项
type MathpixDataItem struct {
Type string `json:"type"`
Value string `json:"value"`
}
// MathpixErrorInfo 错误详情
type MathpixErrorInfo struct {
ID string `json:"id"`
Message string `json:"message"`
}
// GetMathML 从响应中获取MathML
func (r *MathpixResponse) GetMathML() string {
for _, item := range r.Data {
if item.Type == "mathml" {
return item.Value
}
}
return ""
}
// GetAsciiMath 从响应中获取AsciiMath
func (r *MathpixResponse) GetAsciiMath() string {
for _, item := range r.Data {
if item.Type == "asciimath" {
return item.Value
}
}
return ""
}
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
// 为整个任务处理添加超时控制
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tx := dao.DB.Begin()
var (
taskDao = dao.NewRecognitionTaskDao()
resultDao = dao.NewRecognitionResultDao()
)
isSuccess := false
defer func() {
if !isSuccess {
tx.Rollback()
status := dao.TaskStatusFailed
remark := "任务处理失败"
if ctx.Err() == context.DeadlineExceeded {
remark = "任务处理超时"
}
if err != nil {
remark = err.Error()
}
_ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context避免已取消的context影响状态更新
map[string]interface{}{"id": taskID},
map[string]interface{}{
"status": status,
"completed_at": time.Now(),
"remark": remark,
})
return
}
_ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()})
tx.Commit()
}()
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err)
}
// 下载图片文件
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err)
return err
}
defer reader.Close()
// 读取图片数据
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
// 将图片转为base64编码
base64Image := base64.StdEncoding.EncodeToString(imageData)
// 创建JSON请求
requestData := map[string]string{
"image_base64": base64Image,
}
jsonData, err := json.Marshal(requestData)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
// 设置Content-Type头为application/json
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
// 发送请求时会使用带超时的context
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/vlm/formula/predict", bytes.NewReader(jsonData), headers)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
return fmt.Errorf("request timeout")
}
log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err)
return err
}
defer resp.Body.Close()
log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功")
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err)
return err
}
log.Info(ctx, "func", "processFormulaTask", "msg", "响应内容", "body", body.String())
// 解析 JSON 响应
var formulaResp formula.FormulaRecognitionResponse
if err := json.Unmarshal(body.Bytes(), &formulaResp); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err)
return err
}
err = resultDao.Create(tx, dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: formulaResp.Result})
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
return err
}
isSuccess = true
return nil
}
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string, model string) error {
isSuccess := false
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
}()
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err)
return err
}
defer reader.Close()
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
prompt := `Please perform OCR on the image and output only LaTeX code.`
base64Image := base64.StdEncoding.EncodeToString(imageData)
requestBody := formula.VLFormulaRequest{
Model: model,
Stream: false,
MaxTokens: 512,
Temperature: 0.1,
TopP: 0.1,
TopK: 50,
FrequencyPenalty: 0.2,
N: 1,
Messages: []formula.Message{
{
Role: "user",
Content: []formula.Content{
{
Type: "text",
Text: prompt,
},
{
Type: "image_url",
ImageURL: formula.Image{
Detail: "auto",
URL: "data:image/jpeg;base64," + base64Image,
},
},
},
},
},
}
// 将请求体转换为JSON
jsonData, err := json.Marshal(requestBody)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": utils.SiliconFlowToken,
}
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err)
return err
}
defer resp.Body.Close()
// 解析响应
var response formula.VLFormulaResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err)
return err
}
// 提取LaTeX代码
var latex string
if len(response.Choices) > 0 {
if response.Choices[0].Message.Content != "" {
latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "")
latex = strings.ReplaceAll(latex, "```latex", "")
latex = strings.ReplaceAll(latex, "```", "")
// 规范化LaTeX代码移除不必要的空格
latex = strings.ReplaceAll(latex, " = ", "=")
latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[")
latex = strings.TrimPrefix(latex, "\\[")
latex = strings.TrimSuffix(latex, "\\]")
}
}
resultDao := dao.NewRecognitionResultDao()
var formulaRes *dao.RecognitionResult
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
return err
}
if result == nil {
formulaRes = &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex}
err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
formulaRes.Latex = latex
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
func (s *RecognitionService) processFormulaQueue(ctx context.Context) {
for {
select {
case <-s.stopChan:
return
default:
s.processOneTask(ctx)
}
}
}
func (s *RecognitionService) processOneTask(ctx context.Context) {
// 限制队列数量
s.queueLimit <- struct{}{}
defer func() { <-s.queueLimit }()
taskID, err := cache.PopFormulaTask(ctx)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
err = s.processMathpixTask(ctx, taskID, task.FileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
// processMathpixTask 使用 Mathpix API 处理公式识别任务(用于增强识别)
func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int64, fileURL string) error {
isSuccess := false
logDao := dao.NewRecognitionLogDao()
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
}
}()
// 下载图片
imageUrl, err := oss.GetDownloadURL(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "获取图片URL失败", "error", err)
return err
}
// 创建 Mathpix API 请求
mathpixReq := MathpixRequest{
Src: imageUrl,
Formats: []string{
"text",
"latex_styled",
"data",
"html",
},
DataOptions: &MathpixDataOptions{
IncludeMathml: true,
IncludeAsciimath: true,
IncludeLatex: true,
IncludeTsv: true,
},
MathInlineDelimiters: []string{"$", "$"},
MathDisplayDelimiters: []string{"$$", "$$"},
RmSpaces: &[]bool{true}[0],
}
jsonData, err := json.Marshal(mathpixReq)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"app_id": config.GlobalConfig.Mathpix.AppID,
"app_key": config.GlobalConfig.Mathpix.AppKey,
}
endpoint := "https://api.mathpix.com/v3/text"
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 请求失败", "error", err)
return err
}
defer resp.Body.Close()
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "读取响应体失败", "error", err)
return err
}
// 创建日志记录
recognitionLog := &dao.RecognitionLog{
TaskID: taskID,
Provider: dao.ProviderMathpix,
RequestBody: string(jsonData),
ResponseBody: body.String(),
}
// 解析响应
var mathpixResp MathpixResponse
if err := json.Unmarshal(body.Bytes(), &mathpixResp); err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "解析响应失败", "error", err)
return err
}
// 检查错误
if mathpixResp.Error != "" {
errMsg := mathpixResp.Error
if mathpixResp.ErrorInfo != nil {
errMsg = fmt.Sprintf("%s: %s", mathpixResp.ErrorInfo.ID, mathpixResp.ErrorInfo.Message)
}
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 返回错误", "error", errMsg)
return fmt.Errorf("mathpix error: %s", errMsg)
}
// 保存日志
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "保存日志失败", "error", err)
}
// 更新或创建识别结果
resultDao := dao.NewRecognitionResultDao()
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "获取任务结果失败", "error", err)
return err
}
if result == nil {
// 创建新结果
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypeFormula,
Latex: mathpixResp.LatexStyled,
Markdown: mathpixResp.Text,
MathML: mathpixResp.GetMathML(),
})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
// 更新现有结果
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
"latex": mathpixResp.LatexStyled,
"markdown": mathpixResp.Text,
"mathml": mathpixResp.GetMathML(),
})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error {
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "获取任务失败", "error", err)
return err
}
if task == nil {
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "任务不存在", "task_id", taskID)
return err
}
return s.processMathpixTask(ctx, taskID, task.FileURL)
}