Files
doc_ai_backed/internal/service/recognition_service.go

1120 lines
36 KiB
Go
Raw Normal View History

2025-12-10 18:33:37 +08:00
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
2025-12-27 22:06:48 +08:00
"mime/multipart"
2025-12-10 18:33:37 +08:00
"net/http"
"strings"
"time"
"gitea.com/bitwsd/document_ai/config"
"gitea.com/bitwsd/document_ai/internal/model/formula"
"gitea.com/bitwsd/document_ai/internal/storage/cache"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
2025-12-10 23:17:24 +08:00
"gitea.com/bitwsd/document_ai/pkg/log"
2025-12-10 18:33:37 +08:00
"gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/httpclient"
"gitea.com/bitwsd/document_ai/pkg/oss"
2025-12-31 17:53:12 +08:00
"gitea.com/bitwsd/document_ai/pkg/requestid"
2025-12-10 18:33:37 +08:00
"gitea.com/bitwsd/document_ai/pkg/utils"
"gorm.io/gorm"
)
type RecognitionService struct {
db *gorm.DB
queueLimit chan struct{}
stopChan chan struct{}
httpClient *httpclient.Client
}
func NewRecognitionService() *RecognitionService {
s := &RecognitionService{
db: dao.DB,
queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition),
stopChan: make(chan struct{}),
httpClient: httpclient.NewClient(nil), // 使用默认配置
}
// 服务启动时就开始处理队列
utils.SafeGo(func() {
lock, err := cache.GetDistributedLock(context.Background())
if err != nil {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err)
return
}
if !lock {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败")
return
}
s.processFormulaQueue(context.Background())
})
return s
}
func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) {
count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "系统错误", err)
}
if count >= constant.VLMFormulaCount {
return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil)
}
taskDao := dao.NewRecognitionTaskDao()
task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "获取任务失败", err)
}
if task == nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err)
}
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err)
return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err)
}
utils.SafeGo(func() {
s.processVLFormula(context.Background(), task.ID)
_, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err)
}
})
return task, nil
}
func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) {
sess := dao.DB.WithContext(ctx)
taskDao := dao.NewRecognitionTaskDao()
task := &dao.RecognitionTask{
2025-12-18 12:39:50 +08:00
UserID: req.UserID,
2025-12-10 18:33:37 +08:00
TaskUUID: utils.NewUUID(),
TaskType: dao.TaskType(req.TaskType),
Status: dao.TaskStatusPending,
FileURL: req.FileURL,
FileName: req.FileName,
FileHash: req.FileHash,
IP: common.GetIPFromContext(ctx),
}
if err := taskDao.Create(sess, task); err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
}
err := s.handleFormulaRecognition(ctx, task.ID)
if err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "处理任务失败", err)
}
return task, nil
}
func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) {
taskDao := dao.NewRecognitionTaskDao()
resultDao := dao.NewRecognitionResultDao()
sess := dao.DB.WithContext(ctx)
count, err := cache.GetFormulaTaskCount(ctx)
if err != nil {
log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err)
}
if count > int64(config.GlobalConfig.Limit.FormulaRecognition) {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil
}
task, err := taskDao.GetByTaskNo(sess, taskNo)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
}
if task.Status != dao.TaskStatusCompleted {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil
}
taskRet, err := resultDao.GetByTaskID(sess, task.ID)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
}
2025-12-20 21:42:58 +08:00
// 构建 Markdown 格式
markdown := taskRet.Markdown
if markdown == "" {
markdown = fmt.Sprintf("$$%s$$", taskRet.Latex)
}
return &formula.GetFormulaTaskResponse{
TaskNo: taskNo,
Latex: taskRet.Latex,
Markdown: markdown,
MathML: taskRet.MathML,
Status: int(task.Status),
}, nil
2025-12-10 18:33:37 +08:00
}
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
// 简化为只负责将任务加入队列
_, err := cache.PushFormulaTask(ctx, taskID)
if err != nil {
log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err)
return err
}
log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID)
return nil
}
// Stop 用于优雅关闭服务
func (s *RecognitionService) Stop() {
close(s.stopChan)
}
func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) {
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务
2025-12-11 19:51:51 +08:00
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLQwen3VL32BInstruct); err != nil {
2025-12-10 18:33:37 +08:00
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
2025-12-20 21:42:58 +08:00
// MathpixRequest Mathpix API /v3/text 完整请求结构
type MathpixRequest struct {
// 图片源URL 或 base64 编码
Src string `json:"src"`
// 元数据键值对
Metadata map[string]interface{} `json:"metadata"`
// 标签列表,用于标识结果
Tags []string `json:"tags"`
// 异步请求标志
Async bool `json:"async"`
// 回调配置
Callback *MathpixCallback `json:"callback"`
// 输出格式列表text, data, html, latex_styled
Formats []string `json:"formats"`
// 数据选项
2025-12-20 22:15:56 +08:00
DataOptions *MathpixDataOptions `json:"data_options,omitempty"`
2025-12-20 21:42:58 +08:00
// 返回检测到的字母表
2025-12-20 22:15:56 +08:00
IncludeDetectedAlphabets *bool `json:"include_detected_alphabets,omitempty"`
2025-12-20 21:42:58 +08:00
// 允许的字母表
2025-12-20 22:15:56 +08:00
AlphabetsAllowed *MathpixAlphabetsAllowed `json:"alphabets_allowed,omitempty"`
2025-12-20 21:42:58 +08:00
// 指定图片区域
2025-12-20 22:15:56 +08:00
Region *MathpixRegion `json:"region,omitempty"`
2025-12-20 21:42:58 +08:00
// 蓝色HSV过滤模式
EnableBlueHsvFilter bool `json:"enable_blue_hsv_filter"`
// 置信度阈值
ConfidenceThreshold float64 `json:"confidence_threshold"`
// 符号级别置信度阈值默认0.75
ConfidenceRateThreshold float64 `json:"confidence_rate_threshold"`
// 包含公式标签
IncludeEquationTags bool `json:"include_equation_tags"`
// 返回逐行信息
IncludeLineData bool `json:"include_line_data"`
// 返回逐词信息
IncludeWordData bool `json:"include_word_data"`
// 化学结构OCR
IncludeSmiles bool `json:"include_smiles"`
// InChI数据
IncludeInchi bool `json:"include_inchi"`
// 几何图形数据
IncludeGeometryData bool `json:"include_geometry_data"`
// 图表文本提取
IncludeDiagramText bool `json:"include_diagram_text"`
// 页面信息默认true
2025-12-20 22:15:56 +08:00
IncludePageInfo *bool `json:"include_page_info,omitempty"`
2025-12-20 21:42:58 +08:00
// 自动旋转置信度阈值默认0.99
AutoRotateConfidenceThreshold float64 `json:"auto_rotate_confidence_threshold"`
// 移除多余空格默认true
2025-12-20 22:15:56 +08:00
RmSpaces *bool `json:"rm_spaces,omitempty"`
2025-12-20 21:42:58 +08:00
// 移除字体命令默认false
RmFonts bool `json:"rm_fonts"`
// 使用aligned/gathered/cases代替array默认false
IdiomaticEqnArrays bool `json:"idiomatic_eqn_arrays"`
// 移除不必要的大括号默认false
IdiomaticBraces bool `json:"idiomatic_braces"`
// 数字始终为数学模式默认false
NumbersDefaultToMath bool `json:"numbers_default_to_math"`
// 数学字体始终为数学模式默认false
MathFontsDefaultToMath bool `json:"math_fonts_default_to_math"`
// 行内数学分隔符,默认 ["\\(", "\\)"]
MathInlineDelimiters []string `json:"math_inline_delimiters"`
// 行间数学分隔符,默认 ["\\[", "\\]"]
MathDisplayDelimiters []string `json:"math_display_delimiters"`
// 高级表格处理默认false
EnableTablesFallback bool `json:"enable_tables_fallback"`
// 全角标点null表示自动判断
2025-12-20 22:48:02 +08:00
FullwidthPunctuation *bool `json:"fullwidth_punctuation,omitempty"`
2025-12-20 21:42:58 +08:00
}
// MathpixCallback 回调配置
type MathpixCallback struct {
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
// MathpixDataOptions 数据选项
type MathpixDataOptions struct {
IncludeAsciimath bool `json:"include_asciimath"`
IncludeMathml bool `json:"include_mathml"`
IncludeLatex bool `json:"include_latex"`
IncludeTsv bool `json:"include_tsv"`
}
// MathpixAlphabetsAllowed 允许的字母表
type MathpixAlphabetsAllowed struct {
En bool `json:"en"`
Hi bool `json:"hi"`
Zh bool `json:"zh"`
Ja bool `json:"ja"`
Ko bool `json:"ko"`
Ru bool `json:"ru"`
Th bool `json:"th"`
Vi bool `json:"vi"`
}
// MathpixRegion 图片区域
type MathpixRegion struct {
TopLeftX int `json:"top_left_x"`
TopLeftY int `json:"top_left_y"`
Width int `json:"width"`
Height int `json:"height"`
}
// MathpixResponse Mathpix API /v3/text 完整响应结构
type MathpixResponse struct {
// 请求ID用于调试
RequestID string `json:"request_id"`
// Mathpix Markdown 格式文本
Text string `json:"text"`
// 带样式的LaTeX仅单个公式图片时返回
LatexStyled string `json:"latex_styled"`
// 置信度 [0,1]
Confidence float64 `json:"confidence"`
// 置信度比率 [0,1]
ConfidenceRate float64 `json:"confidence_rate"`
// 行数据
LineData []map[string]interface{} `json:"line_data"`
// 词数据
WordData []map[string]interface{} `json:"word_data"`
// 数据对象列表
Data []MathpixDataItem `json:"data"`
// HTML输出
HTML string `json:"html"`
// 检测到的字母表
DetectedAlphabets []map[string]interface{} `json:"detected_alphabets"`
// 是否打印内容
IsPrinted bool `json:"is_printed"`
// 是否手写内容
IsHandwritten bool `json:"is_handwritten"`
// 自动旋转置信度
AutoRotateConfidence float64 `json:"auto_rotate_confidence"`
// 几何数据
GeometryData []map[string]interface{} `json:"geometry_data"`
// 自动旋转角度 {0, 90, -90, 180}
AutoRotateDegrees int `json:"auto_rotate_degrees"`
// 图片宽度
ImageWidth int `json:"image_width"`
// 图片高度
ImageHeight int `json:"image_height"`
// 错误信息
Error string `json:"error"`
// 错误详情
ErrorInfo *MathpixErrorInfo `json:"error_info"`
// API版本
Version string `json:"version"`
}
// MathpixDataItem 数据项
type MathpixDataItem struct {
Type string `json:"type"`
Value string `json:"value"`
}
// MathpixErrorInfo 错误详情
type MathpixErrorInfo struct {
ID string `json:"id"`
Message string `json:"message"`
}
2025-12-25 14:02:06 +08:00
// BaiduOCRRequest 百度 OCR 版面分析请求结构
type BaiduOCRRequest struct {
// 文件内容 base64 编码
File string `json:"file"`
// 文件类型: 0=PDF, 1=图片
FileType int `json:"fileType"`
// 是否启用文档方向分类
UseDocOrientationClassify bool `json:"useDocOrientationClassify"`
// 是否启用文档扭曲矫正
UseDocUnwarping bool `json:"useDocUnwarping"`
// 是否启用图表识别
UseChartRecognition bool `json:"useChartRecognition"`
}
// BaiduOCRResponse 百度 OCR 版面分析响应结构
type BaiduOCRResponse struct {
ErrorCode int `json:"errorCode"`
ErrorMsg string `json:"errorMsg"`
Result *BaiduOCRResult `json:"result"`
}
// BaiduOCRResult 百度 OCR 响应结果
type BaiduOCRResult struct {
LayoutParsingResults []BaiduLayoutParsingResult `json:"layoutParsingResults"`
}
// BaiduLayoutParsingResult 单页版面解析结果
type BaiduLayoutParsingResult struct {
Markdown BaiduMarkdownResult `json:"markdown"`
OutputImages map[string]string `json:"outputImages"`
}
// BaiduMarkdownResult markdown 结果
type BaiduMarkdownResult struct {
Text string `json:"text"`
Images map[string]string `json:"images"`
}
2025-12-20 21:42:58 +08:00
// GetMathML 从响应中获取MathML
func (r *MathpixResponse) GetMathML() string {
for _, item := range r.Data {
if item.Type == "mathml" {
return item.Value
}
}
return ""
}
// GetAsciiMath 从响应中获取AsciiMath
func (r *MathpixResponse) GetAsciiMath() string {
for _, item := range r.Data {
if item.Type == "asciimath" {
return item.Value
}
}
return ""
}
2025-12-10 18:33:37 +08:00
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
// 为整个任务处理添加超时控制
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tx := dao.DB.Begin()
var (
taskDao = dao.NewRecognitionTaskDao()
resultDao = dao.NewRecognitionResultDao()
)
isSuccess := false
defer func() {
if !isSuccess {
tx.Rollback()
status := dao.TaskStatusFailed
remark := "任务处理失败"
if ctx.Err() == context.DeadlineExceeded {
remark = "任务处理超时"
}
if err != nil {
remark = err.Error()
}
_ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context避免已取消的context影响状态更新
map[string]interface{}{"id": taskID},
map[string]interface{}{
"status": status,
"completed_at": time.Now(),
"remark": remark,
})
return
}
_ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()})
tx.Commit()
}()
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err)
}
// 下载图片文件
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err)
return err
}
defer reader.Close()
// 读取图片数据
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
// 将图片转为base64编码
base64Image := base64.StdEncoding.EncodeToString(imageData)
// 创建JSON请求
requestData := map[string]string{
"image_base64": base64Image,
}
jsonData, err := json.Marshal(requestData)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
// 设置Content-Type头为application/json
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
2025-12-31 17:53:12 +08:00
// 发送请求到新的 OCR 接口
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/doc_process/v1/image/ocr", bytes.NewReader(jsonData), headers)
2025-12-10 18:33:37 +08:00
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
return fmt.Errorf("request timeout")
}
log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err)
return err
}
defer resp.Body.Close()
2025-12-15 23:29:28 +08:00
log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功")
2025-12-10 18:33:37 +08:00
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err)
return err
}
2025-12-15 23:29:28 +08:00
log.Info(ctx, "func", "processFormulaTask", "msg", "响应内容", "body", body.String())
// 解析 JSON 响应
2025-12-31 17:53:12 +08:00
var ocrResp formula.ImageOCRResponse
if err := json.Unmarshal(body.Bytes(), &ocrResp); err != nil {
2025-12-15 23:29:28 +08:00
log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err)
return err
}
2025-12-31 17:53:12 +08:00
err = resultDao.Create(tx, dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypeFormula,
Latex: ocrResp.Latex,
Markdown: ocrResp.Markdown,
MathML: ocrResp.MathML,
})
2025-12-20 21:42:58 +08:00
if err != nil {
2025-12-10 18:33:37 +08:00
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
return err
}
2025-12-20 21:42:58 +08:00
2025-12-10 18:33:37 +08:00
isSuccess = true
return nil
}
2025-12-11 19:39:35 +08:00
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string, model string) error {
2025-12-10 18:33:37 +08:00
isSuccess := false
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
}()
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err)
return err
}
defer reader.Close()
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
2025-12-11 19:51:51 +08:00
prompt := `Please perform OCR on the image and output only LaTeX code.`
2025-12-10 18:33:37 +08:00
base64Image := base64.StdEncoding.EncodeToString(imageData)
requestBody := formula.VLFormulaRequest{
2025-12-11 19:39:35 +08:00
Model: model,
2025-12-10 18:33:37 +08:00
Stream: false,
MaxTokens: 512,
Temperature: 0.1,
TopP: 0.1,
TopK: 50,
FrequencyPenalty: 0.2,
N: 1,
Messages: []formula.Message{
{
Role: "user",
Content: []formula.Content{
{
Type: "text",
Text: prompt,
},
{
Type: "image_url",
ImageURL: formula.Image{
Detail: "auto",
URL: "data:image/jpeg;base64," + base64Image,
},
},
},
},
},
}
// 将请求体转换为JSON
jsonData, err := json.Marshal(requestBody)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": utils.SiliconFlowToken,
}
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err)
return err
}
defer resp.Body.Close()
// 解析响应
var response formula.VLFormulaResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err)
return err
}
// 提取LaTeX代码
var latex string
if len(response.Choices) > 0 {
if response.Choices[0].Message.Content != "" {
latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "")
latex = strings.ReplaceAll(latex, "```latex", "")
latex = strings.ReplaceAll(latex, "```", "")
// 规范化LaTeX代码移除不必要的空格
latex = strings.ReplaceAll(latex, " = ", "=")
latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[")
latex = strings.TrimPrefix(latex, "\\[")
latex = strings.TrimSuffix(latex, "\\]")
}
}
resultDao := dao.NewRecognitionResultDao()
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
return err
}
if result == nil {
2025-12-25 14:02:06 +08:00
formulaRes := &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex}
2025-12-20 21:42:58 +08:00
err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes)
2025-12-10 18:33:37 +08:00
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
2025-12-25 14:02:06 +08:00
result.Latex = latex
2025-12-20 21:42:58 +08:00
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex})
2025-12-10 18:33:37 +08:00
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
func (s *RecognitionService) processFormulaQueue(ctx context.Context) {
for {
select {
case <-s.stopChan:
return
default:
s.processOneTask(ctx)
}
}
}
func (s *RecognitionService) processOneTask(ctx context.Context) {
// 限制队列数量
s.queueLimit <- struct{}{}
defer func() { <-s.queueLimit }()
taskID, err := cache.PopFormulaTask(ctx)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
2025-12-31 17:53:12 +08:00
// 使用 gls 设置 request_id确保在整个任务处理过程中可用
requestid.SetRequestID(task.TaskUUID, func() {
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
2025-12-10 18:33:37 +08:00
2025-12-31 17:53:12 +08:00
err = s.processFormulaTask(ctx, taskID, task.FileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
})
2025-12-10 18:33:37 +08:00
}
2025-12-20 21:42:58 +08:00
// processMathpixTask 使用 Mathpix API 处理公式识别任务(用于增强识别)
func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int64, fileURL string) error {
isSuccess := false
logDao := dao.NewRecognitionLogDao()
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
}
}()
// 下载图片
imageUrl, err := oss.GetDownloadURL(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "获取图片URL失败", "error", err)
return err
}
// 创建 Mathpix API 请求
mathpixReq := MathpixRequest{
Src: imageUrl,
Formats: []string{
"text",
"latex_styled",
"data",
"html",
},
DataOptions: &MathpixDataOptions{
IncludeMathml: true,
IncludeAsciimath: true,
IncludeLatex: true,
IncludeTsv: true,
},
2025-12-20 22:15:56 +08:00
MathInlineDelimiters: []string{"$", "$"},
MathDisplayDelimiters: []string{"$$", "$$"},
RmSpaces: &[]bool{true}[0],
2025-12-20 21:42:58 +08:00
}
jsonData, err := json.Marshal(mathpixReq)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"app_id": config.GlobalConfig.Mathpix.AppID,
"app_key": config.GlobalConfig.Mathpix.AppKey,
}
endpoint := "https://api.mathpix.com/v3/text"
2025-12-23 22:32:29 +08:00
startTime := time.Now()
log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_Start", "start_time", startTime)
2025-12-20 21:42:58 +08:00
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 请求失败", "error", err)
return err
}
defer resp.Body.Close()
2025-12-23 22:32:29 +08:00
log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
2025-12-20 21:42:58 +08:00
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "读取响应体失败", "error", err)
return err
}
// 创建日志记录
recognitionLog := &dao.RecognitionLog{
TaskID: taskID,
Provider: dao.ProviderMathpix,
RequestBody: string(jsonData),
ResponseBody: body.String(),
}
// 解析响应
var mathpixResp MathpixResponse
if err := json.Unmarshal(body.Bytes(), &mathpixResp); err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "解析响应失败", "error", err)
return err
}
// 检查错误
if mathpixResp.Error != "" {
errMsg := mathpixResp.Error
if mathpixResp.ErrorInfo != nil {
errMsg = fmt.Sprintf("%s: %s", mathpixResp.ErrorInfo.ID, mathpixResp.ErrorInfo.Message)
}
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 返回错误", "error", errMsg)
return fmt.Errorf("mathpix error: %s", errMsg)
}
// 保存日志
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "保存日志失败", "error", err)
}
// 更新或创建识别结果
resultDao := dao.NewRecognitionResultDao()
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "获取任务结果失败", "error", err)
return err
}
2025-12-23 22:32:29 +08:00
log.Info(ctx, "func", "processMathpixTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
2025-12-20 21:42:58 +08:00
if result == nil {
// 创建新结果
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypeFormula,
Latex: mathpixResp.LatexStyled,
Markdown: mathpixResp.Text,
MathML: mathpixResp.GetMathML(),
})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
// 更新现有结果
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
"latex": mathpixResp.LatexStyled,
"markdown": mathpixResp.Text,
"mathml": mathpixResp.GetMathML(),
})
if err != nil {
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
2025-12-20 22:48:02 +08:00
2025-12-25 14:02:06 +08:00
func (s *RecognitionService) processBaiduOCRTask(ctx context.Context, taskID int64, fileURL string) error {
isSuccess := false
logDao := dao.NewRecognitionLogDao()
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
}
}()
// 从 OSS 下载文件
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "从OSS下载文件失败", "error", err)
return err
}
defer reader.Close()
// 读取文件内容
fileBytes, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取文件内容失败", "error", err)
return err
}
// Base64 编码
fileData := base64.StdEncoding.EncodeToString(fileBytes)
// 根据文件扩展名确定 fileType: 0=PDF, 1=图片
fileType := 1 // 默认为图片
lowerFileURL := strings.ToLower(fileURL)
if strings.HasSuffix(lowerFileURL, ".pdf") {
fileType = 0
}
// 创建百度 OCR API 请求
baiduReq := BaiduOCRRequest{
File: fileData,
FileType: fileType,
UseDocOrientationClassify: false,
UseDocUnwarping: false,
UseChartRecognition: false,
}
jsonData, err := json.Marshal(baiduReq)
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("token %s", config.GlobalConfig.BaiduOCR.Token),
}
endpoint := "https://j5veh2l2r6ubk6cb.aistudio-app.com/layout-parsing"
startTime := time.Now()
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_Start", "start_time", startTime)
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 请求失败", "error", err)
return err
}
defer resp.Body.Close()
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取响应体失败", "error", err)
return err
}
// 创建日志记录(不记录请求体中的 base64 数据以节省存储)
requestLogData := map[string]interface{}{
"fileType": fileType,
"useDocOrientationClassify": false,
"useDocUnwarping": false,
"useChartRecognition": false,
"fileSize": len(fileBytes),
}
requestLogBytes, _ := json.Marshal(requestLogData)
recognitionLog := &dao.RecognitionLog{
TaskID: taskID,
Provider: dao.ProviderBaiduOCR,
RequestBody: string(requestLogBytes),
ResponseBody: body.String(),
}
// 解析响应
var baiduResp BaiduOCRResponse
if err := json.Unmarshal(body.Bytes(), &baiduResp); err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "解析响应失败", "error", err)
return err
}
// 检查错误
if baiduResp.ErrorCode != 0 {
errMsg := fmt.Sprintf("errorCode: %d, errorMsg: %s", baiduResp.ErrorCode, baiduResp.ErrorMsg)
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 返回错误", "error", errMsg)
return fmt.Errorf("baidu ocr error: %s", errMsg)
}
// 保存日志
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "保存日志失败", "error", err)
}
// 合并所有页面的 markdown 结果
var markdownTexts []string
if baiduResp.Result != nil && len(baiduResp.Result.LayoutParsingResults) > 0 {
for _, res := range baiduResp.Result.LayoutParsingResults {
if res.Markdown.Text != "" {
markdownTexts = append(markdownTexts, res.Markdown.Text)
}
}
}
markdownResult := strings.Join(markdownTexts, "\n\n---\n\n")
2025-12-27 22:21:34 +08:00
latex, mml, e := s.HandleConvert(ctx, markdownResult)
if e != nil {
2025-12-27 22:06:48 +08:00
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "转换失败", "error", err)
}
2025-12-25 14:02:06 +08:00
// 更新或创建识别结果
resultDao := dao.NewRecognitionResultDao()
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "获取任务结果失败", "error", err)
return err
}
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
if result == nil {
// 创建新结果
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypeFormula,
Markdown: markdownResult,
2025-12-27 22:06:48 +08:00
Latex: latex,
MathML: mml,
2025-12-25 14:02:06 +08:00
})
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
// 更新现有结果
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
"markdown": markdownResult,
2025-12-27 22:06:48 +08:00
"latex": latex,
"mathml": mml,
2025-12-25 14:02:06 +08:00
})
if err != nil {
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
2025-12-20 22:48:02 +08:00
func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error {
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "获取任务失败", "error", err)
return err
}
if task == nil {
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "任务不存在", "task_id", taskID)
return err
}
return s.processMathpixTask(ctx, taskID, task.FileURL)
}
2025-12-27 22:06:48 +08:00
// ConvertResponse Python 接口返回结构
type ConvertResponse struct {
Latex string `json:"latex"`
MathML string `json:"mathml"`
Error string `json:"error,omitempty"`
}
func (s *RecognitionService) HandleConvert(ctx context.Context, markdown string) (latex string, mml string, err error) {
url := "https://cloud.texpixel.com:10443/doc_converter/v1/convert"
// 构建 multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
_ = writer.WriteField("markdown_input", markdown)
writer.Close()
// 使用正确的 Content-Type包含 boundary
headers := map[string]string{
"Content-Type": writer.FormDataContentType(),
}
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, url, body, headers)
if err != nil {
return "", "", err
}
defer resp.Body.Close()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", err
}
// 检查 HTTP 状态码
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("convert failed: status %d, body: %s", resp.StatusCode, string(respBody))
}
// 解析 JSON 响应
var convertResp ConvertResponse
if err := json.Unmarshal(respBody, &convertResp); err != nil {
return "", "", fmt.Errorf("unmarshal response failed: %v, body: %s", err, string(respBody))
}
// 检查业务错误
if convertResp.Error != "" {
return "", "", fmt.Errorf("convert error: %s", convertResp.Error)
}
return convertResp.Latex, convertResp.MathML, nil
}