1047 lines
34 KiB
Go
1047 lines
34 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.com/bitwsd/document_ai/config"
|
||
"gitea.com/bitwsd/document_ai/internal/model/formula"
|
||
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||
|
||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||
"gitea.com/bitwsd/document_ai/pkg/httpclient"
|
||
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type RecognitionService struct {
|
||
db *gorm.DB
|
||
queueLimit chan struct{}
|
||
stopChan chan struct{}
|
||
httpClient *httpclient.Client
|
||
}
|
||
|
||
func NewRecognitionService() *RecognitionService {
|
||
s := &RecognitionService{
|
||
db: dao.DB,
|
||
queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition),
|
||
stopChan: make(chan struct{}),
|
||
httpClient: httpclient.NewClient(nil), // 使用默认配置
|
||
}
|
||
|
||
// 服务启动时就开始处理队列
|
||
utils.SafeGo(func() {
|
||
lock, err := cache.GetDistributedLock(context.Background())
|
||
if err != nil {
|
||
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err)
|
||
return
|
||
}
|
||
if !lock {
|
||
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败")
|
||
return
|
||
}
|
||
s.processFormulaQueue(context.Background())
|
||
})
|
||
|
||
return s
|
||
}
|
||
|
||
func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) {
|
||
count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx))
|
||
if err != nil {
|
||
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err)
|
||
return nil, common.NewError(common.CodeSystemError, "系统错误", err)
|
||
}
|
||
if count >= constant.VLMFormulaCount {
|
||
return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil)
|
||
}
|
||
|
||
taskDao := dao.NewRecognitionTaskDao()
|
||
task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err)
|
||
return nil, common.NewError(common.CodeDBError, "获取任务失败", err)
|
||
}
|
||
if task == nil {
|
||
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo)
|
||
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
|
||
}
|
||
|
||
if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending {
|
||
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo)
|
||
return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err)
|
||
}
|
||
|
||
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err)
|
||
return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err)
|
||
}
|
||
|
||
utils.SafeGo(func() {
|
||
s.processVLFormula(context.Background(), task.ID)
|
||
_, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx))
|
||
if err != nil {
|
||
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err)
|
||
}
|
||
})
|
||
|
||
return task, nil
|
||
|
||
}
|
||
|
||
func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) {
|
||
sess := dao.DB.WithContext(ctx)
|
||
taskDao := dao.NewRecognitionTaskDao()
|
||
task := &dao.RecognitionTask{
|
||
UserID: req.UserID,
|
||
TaskUUID: utils.NewUUID(),
|
||
TaskType: dao.TaskType(req.TaskType),
|
||
Status: dao.TaskStatusPending,
|
||
FileURL: req.FileURL,
|
||
FileName: req.FileName,
|
||
FileHash: req.FileHash,
|
||
IP: common.GetIPFromContext(ctx),
|
||
}
|
||
|
||
if err := taskDao.Create(sess, task); err != nil {
|
||
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err)
|
||
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
|
||
}
|
||
|
||
err := s.handleFormulaRecognition(ctx, task.ID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err)
|
||
return nil, common.NewError(common.CodeSystemError, "处理任务失败", err)
|
||
}
|
||
|
||
return task, nil
|
||
}
|
||
|
||
func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) {
|
||
taskDao := dao.NewRecognitionTaskDao()
|
||
resultDao := dao.NewRecognitionResultDao()
|
||
sess := dao.DB.WithContext(ctx)
|
||
|
||
count, err := cache.GetFormulaTaskCount(ctx)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err)
|
||
}
|
||
|
||
if count > int64(config.GlobalConfig.Limit.FormulaRecognition) {
|
||
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil
|
||
}
|
||
|
||
task, err := taskDao.GetByTaskNo(sess, taskNo)
|
||
if err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo)
|
||
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
|
||
}
|
||
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo)
|
||
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
|
||
}
|
||
|
||
if task.Status != dao.TaskStatusCompleted {
|
||
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil
|
||
}
|
||
|
||
taskRet, err := resultDao.GetByTaskID(sess, task.ID)
|
||
if err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo)
|
||
return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err)
|
||
}
|
||
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
|
||
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
||
}
|
||
|
||
// 构建 Markdown 格式
|
||
markdown := taskRet.Markdown
|
||
if markdown == "" {
|
||
markdown = fmt.Sprintf("$$%s$$", taskRet.Latex)
|
||
}
|
||
|
||
return &formula.GetFormulaTaskResponse{
|
||
TaskNo: taskNo,
|
||
Latex: taskRet.Latex,
|
||
Markdown: markdown,
|
||
MathML: taskRet.MathML,
|
||
Status: int(task.Status),
|
||
}, nil
|
||
}
|
||
|
||
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
||
// 简化为只负责将任务加入队列
|
||
_, err := cache.PushFormulaTask(ctx, taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err)
|
||
return err
|
||
}
|
||
log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID)
|
||
return nil
|
||
}
|
||
|
||
// Stop 用于优雅关闭服务
|
||
func (s *RecognitionService) Stop() {
|
||
close(s.stopChan)
|
||
}
|
||
|
||
func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) {
|
||
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err)
|
||
return
|
||
}
|
||
if task == nil {
|
||
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID)
|
||
return
|
||
}
|
||
|
||
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
||
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
||
|
||
// 处理具体任务
|
||
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLQwen3VL32BInstruct); err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
|
||
return
|
||
}
|
||
|
||
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
||
}
|
||
|
||
// MathpixRequest Mathpix API /v3/text 完整请求结构
|
||
type MathpixRequest struct {
|
||
// 图片源:URL 或 base64 编码
|
||
Src string `json:"src"`
|
||
// 元数据键值对
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
// 标签列表,用于标识结果
|
||
Tags []string `json:"tags"`
|
||
// 异步请求标志
|
||
Async bool `json:"async"`
|
||
// 回调配置
|
||
Callback *MathpixCallback `json:"callback"`
|
||
// 输出格式列表:text, data, html, latex_styled
|
||
Formats []string `json:"formats"`
|
||
// 数据选项
|
||
DataOptions *MathpixDataOptions `json:"data_options,omitempty"`
|
||
// 返回检测到的字母表
|
||
IncludeDetectedAlphabets *bool `json:"include_detected_alphabets,omitempty"`
|
||
// 允许的字母表
|
||
AlphabetsAllowed *MathpixAlphabetsAllowed `json:"alphabets_allowed,omitempty"`
|
||
// 指定图片区域
|
||
Region *MathpixRegion `json:"region,omitempty"`
|
||
// 蓝色HSV过滤模式
|
||
EnableBlueHsvFilter bool `json:"enable_blue_hsv_filter"`
|
||
// 置信度阈值
|
||
ConfidenceThreshold float64 `json:"confidence_threshold"`
|
||
// 符号级别置信度阈值,默认0.75
|
||
ConfidenceRateThreshold float64 `json:"confidence_rate_threshold"`
|
||
// 包含公式标签
|
||
IncludeEquationTags bool `json:"include_equation_tags"`
|
||
// 返回逐行信息
|
||
IncludeLineData bool `json:"include_line_data"`
|
||
// 返回逐词信息
|
||
IncludeWordData bool `json:"include_word_data"`
|
||
// 化学结构OCR
|
||
IncludeSmiles bool `json:"include_smiles"`
|
||
// InChI数据
|
||
IncludeInchi bool `json:"include_inchi"`
|
||
// 几何图形数据
|
||
IncludeGeometryData bool `json:"include_geometry_data"`
|
||
// 图表文本提取
|
||
IncludeDiagramText bool `json:"include_diagram_text"`
|
||
// 页面信息,默认true
|
||
IncludePageInfo *bool `json:"include_page_info,omitempty"`
|
||
// 自动旋转置信度阈值,默认0.99
|
||
AutoRotateConfidenceThreshold float64 `json:"auto_rotate_confidence_threshold"`
|
||
// 移除多余空格,默认true
|
||
RmSpaces *bool `json:"rm_spaces,omitempty"`
|
||
// 移除字体命令,默认false
|
||
RmFonts bool `json:"rm_fonts"`
|
||
// 使用aligned/gathered/cases代替array,默认false
|
||
IdiomaticEqnArrays bool `json:"idiomatic_eqn_arrays"`
|
||
// 移除不必要的大括号,默认false
|
||
IdiomaticBraces bool `json:"idiomatic_braces"`
|
||
// 数字始终为数学模式,默认false
|
||
NumbersDefaultToMath bool `json:"numbers_default_to_math"`
|
||
// 数学字体始终为数学模式,默认false
|
||
MathFontsDefaultToMath bool `json:"math_fonts_default_to_math"`
|
||
// 行内数学分隔符,默认 ["\\(", "\\)"]
|
||
MathInlineDelimiters []string `json:"math_inline_delimiters"`
|
||
// 行间数学分隔符,默认 ["\\[", "\\]"]
|
||
MathDisplayDelimiters []string `json:"math_display_delimiters"`
|
||
// 高级表格处理,默认false
|
||
EnableTablesFallback bool `json:"enable_tables_fallback"`
|
||
// 全角标点,null表示自动判断
|
||
FullwidthPunctuation *bool `json:"fullwidth_punctuation,omitempty"`
|
||
}
|
||
|
||
// MathpixCallback 回调配置
|
||
type MathpixCallback struct {
|
||
URL string `json:"url"`
|
||
Headers map[string]string `json:"headers"`
|
||
}
|
||
|
||
// MathpixDataOptions 数据选项
|
||
type MathpixDataOptions struct {
|
||
IncludeAsciimath bool `json:"include_asciimath"`
|
||
IncludeMathml bool `json:"include_mathml"`
|
||
IncludeLatex bool `json:"include_latex"`
|
||
IncludeTsv bool `json:"include_tsv"`
|
||
}
|
||
|
||
// MathpixAlphabetsAllowed 允许的字母表
|
||
type MathpixAlphabetsAllowed struct {
|
||
En bool `json:"en"`
|
||
Hi bool `json:"hi"`
|
||
Zh bool `json:"zh"`
|
||
Ja bool `json:"ja"`
|
||
Ko bool `json:"ko"`
|
||
Ru bool `json:"ru"`
|
||
Th bool `json:"th"`
|
||
Vi bool `json:"vi"`
|
||
}
|
||
|
||
// MathpixRegion 图片区域
|
||
type MathpixRegion struct {
|
||
TopLeftX int `json:"top_left_x"`
|
||
TopLeftY int `json:"top_left_y"`
|
||
Width int `json:"width"`
|
||
Height int `json:"height"`
|
||
}
|
||
|
||
// MathpixResponse Mathpix API /v3/text 完整响应结构
|
||
type MathpixResponse struct {
|
||
// 请求ID,用于调试
|
||
RequestID string `json:"request_id"`
|
||
// Mathpix Markdown 格式文本
|
||
Text string `json:"text"`
|
||
// 带样式的LaTeX(仅单个公式图片时返回)
|
||
LatexStyled string `json:"latex_styled"`
|
||
// 置信度 [0,1]
|
||
Confidence float64 `json:"confidence"`
|
||
// 置信度比率 [0,1]
|
||
ConfidenceRate float64 `json:"confidence_rate"`
|
||
// 行数据
|
||
LineData []map[string]interface{} `json:"line_data"`
|
||
// 词数据
|
||
WordData []map[string]interface{} `json:"word_data"`
|
||
// 数据对象列表
|
||
Data []MathpixDataItem `json:"data"`
|
||
// HTML输出
|
||
HTML string `json:"html"`
|
||
// 检测到的字母表
|
||
DetectedAlphabets []map[string]interface{} `json:"detected_alphabets"`
|
||
// 是否打印内容
|
||
IsPrinted bool `json:"is_printed"`
|
||
// 是否手写内容
|
||
IsHandwritten bool `json:"is_handwritten"`
|
||
// 自动旋转置信度
|
||
AutoRotateConfidence float64 `json:"auto_rotate_confidence"`
|
||
// 几何数据
|
||
GeometryData []map[string]interface{} `json:"geometry_data"`
|
||
// 自动旋转角度 {0, 90, -90, 180}
|
||
AutoRotateDegrees int `json:"auto_rotate_degrees"`
|
||
// 图片宽度
|
||
ImageWidth int `json:"image_width"`
|
||
// 图片高度
|
||
ImageHeight int `json:"image_height"`
|
||
// 错误信息
|
||
Error string `json:"error"`
|
||
// 错误详情
|
||
ErrorInfo *MathpixErrorInfo `json:"error_info"`
|
||
// API版本
|
||
Version string `json:"version"`
|
||
}
|
||
|
||
// MathpixDataItem 数据项
|
||
type MathpixDataItem struct {
|
||
Type string `json:"type"`
|
||
Value string `json:"value"`
|
||
}
|
||
|
||
// MathpixErrorInfo 错误详情
|
||
type MathpixErrorInfo struct {
|
||
ID string `json:"id"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
// BaiduOCRRequest 百度 OCR 版面分析请求结构
|
||
type BaiduOCRRequest struct {
|
||
// 文件内容 base64 编码
|
||
File string `json:"file"`
|
||
// 文件类型: 0=PDF, 1=图片
|
||
FileType int `json:"fileType"`
|
||
// 是否启用文档方向分类
|
||
UseDocOrientationClassify bool `json:"useDocOrientationClassify"`
|
||
// 是否启用文档扭曲矫正
|
||
UseDocUnwarping bool `json:"useDocUnwarping"`
|
||
// 是否启用图表识别
|
||
UseChartRecognition bool `json:"useChartRecognition"`
|
||
}
|
||
|
||
// BaiduOCRResponse 百度 OCR 版面分析响应结构
|
||
type BaiduOCRResponse struct {
|
||
ErrorCode int `json:"errorCode"`
|
||
ErrorMsg string `json:"errorMsg"`
|
||
Result *BaiduOCRResult `json:"result"`
|
||
}
|
||
|
||
// BaiduOCRResult 百度 OCR 响应结果
|
||
type BaiduOCRResult struct {
|
||
LayoutParsingResults []BaiduLayoutParsingResult `json:"layoutParsingResults"`
|
||
}
|
||
|
||
// BaiduLayoutParsingResult 单页版面解析结果
|
||
type BaiduLayoutParsingResult struct {
|
||
Markdown BaiduMarkdownResult `json:"markdown"`
|
||
OutputImages map[string]string `json:"outputImages"`
|
||
}
|
||
|
||
// BaiduMarkdownResult markdown 结果
|
||
type BaiduMarkdownResult struct {
|
||
Text string `json:"text"`
|
||
Images map[string]string `json:"images"`
|
||
}
|
||
|
||
// GetMathML 从响应中获取MathML
|
||
func (r *MathpixResponse) GetMathML() string {
|
||
for _, item := range r.Data {
|
||
if item.Type == "mathml" {
|
||
return item.Value
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// GetAsciiMath 从响应中获取AsciiMath
|
||
func (r *MathpixResponse) GetAsciiMath() string {
|
||
for _, item := range r.Data {
|
||
if item.Type == "asciimath" {
|
||
return item.Value
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
|
||
// 为整个任务处理添加超时控制
|
||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||
defer cancel()
|
||
|
||
tx := dao.DB.Begin()
|
||
var (
|
||
taskDao = dao.NewRecognitionTaskDao()
|
||
resultDao = dao.NewRecognitionResultDao()
|
||
)
|
||
|
||
isSuccess := false
|
||
defer func() {
|
||
if !isSuccess {
|
||
tx.Rollback()
|
||
status := dao.TaskStatusFailed
|
||
remark := "任务处理失败"
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
remark = "任务处理超时"
|
||
}
|
||
if err != nil {
|
||
remark = err.Error()
|
||
}
|
||
_ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context,避免已取消的context影响状态更新
|
||
map[string]interface{}{"id": taskID},
|
||
map[string]interface{}{
|
||
"status": status,
|
||
"completed_at": time.Now(),
|
||
"remark": remark,
|
||
})
|
||
return
|
||
}
|
||
_ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()})
|
||
tx.Commit()
|
||
}()
|
||
|
||
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
|
||
// 下载图片文件
|
||
reader, err := oss.DownloadFile(ctx, fileURL)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err)
|
||
return err
|
||
}
|
||
defer reader.Close()
|
||
|
||
// 读取图片数据
|
||
imageData, err := io.ReadAll(reader)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 将图片转为base64编码
|
||
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
||
|
||
// 创建JSON请求
|
||
requestData := map[string]string{
|
||
"image_base64": base64Image,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(requestData)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 设置Content-Type头为application/json
|
||
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
||
|
||
// 发送请求时会使用带超时的context
|
||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/vlm/formula/predict", bytes.NewReader(jsonData), headers)
|
||
if err != nil {
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
||
return fmt.Errorf("request timeout")
|
||
}
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err)
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功")
|
||
body := &bytes.Buffer{}
|
||
if _, err = body.ReadFrom(resp.Body); err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err)
|
||
return err
|
||
}
|
||
log.Info(ctx, "func", "processFormulaTask", "msg", "响应内容", "body", body.String())
|
||
|
||
// 解析 JSON 响应
|
||
var formulaResp formula.FormulaRecognitionResponse
|
||
if err := json.Unmarshal(body.Bytes(), &formulaResp); err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err)
|
||
return err
|
||
}
|
||
err = resultDao.Create(tx, dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: formulaResp.Result})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
isSuccess = true
|
||
return nil
|
||
}
|
||
|
||
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string, model string) error {
|
||
isSuccess := false
|
||
defer func() {
|
||
if !isSuccess {
|
||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
return
|
||
}
|
||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
}()
|
||
|
||
reader, err := oss.DownloadFile(ctx, fileURL)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err)
|
||
return err
|
||
}
|
||
defer reader.Close()
|
||
imageData, err := io.ReadAll(reader)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err)
|
||
return err
|
||
}
|
||
prompt := `Please perform OCR on the image and output only LaTeX code.`
|
||
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
||
|
||
requestBody := formula.VLFormulaRequest{
|
||
Model: model,
|
||
Stream: false,
|
||
MaxTokens: 512,
|
||
Temperature: 0.1,
|
||
TopP: 0.1,
|
||
TopK: 50,
|
||
FrequencyPenalty: 0.2,
|
||
N: 1,
|
||
Messages: []formula.Message{
|
||
{
|
||
Role: "user",
|
||
Content: []formula.Content{
|
||
{
|
||
Type: "text",
|
||
Text: prompt,
|
||
},
|
||
{
|
||
Type: "image_url",
|
||
ImageURL: formula.Image{
|
||
Detail: "auto",
|
||
URL: "data:image/jpeg;base64," + base64Image,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
// 将请求体转换为JSON
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
headers := map[string]string{
|
||
"Content-Type": "application/json",
|
||
"Authorization": utils.SiliconFlowToken,
|
||
}
|
||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err)
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 解析响应
|
||
var response formula.VLFormulaResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 提取LaTeX代码
|
||
var latex string
|
||
if len(response.Choices) > 0 {
|
||
if response.Choices[0].Message.Content != "" {
|
||
latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "")
|
||
latex = strings.ReplaceAll(latex, "```latex", "")
|
||
latex = strings.ReplaceAll(latex, "```", "")
|
||
// 规范化LaTeX代码,移除不必要的空格
|
||
latex = strings.ReplaceAll(latex, " = ", "=")
|
||
latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[")
|
||
latex = strings.TrimPrefix(latex, "\\[")
|
||
latex = strings.TrimSuffix(latex, "\\]")
|
||
}
|
||
}
|
||
|
||
resultDao := dao.NewRecognitionResultDao()
|
||
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
if result == nil {
|
||
formulaRes := &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex}
|
||
err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
} else {
|
||
result.Latex = latex
|
||
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
isSuccess = true
|
||
return nil
|
||
}
|
||
|
||
func (s *RecognitionService) processFormulaQueue(ctx context.Context) {
|
||
for {
|
||
select {
|
||
case <-s.stopChan:
|
||
return
|
||
default:
|
||
s.processOneTask(ctx)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *RecognitionService) processOneTask(ctx context.Context) {
|
||
// 限制队列数量
|
||
s.queueLimit <- struct{}{}
|
||
defer func() { <-s.queueLimit }()
|
||
|
||
taskID, err := cache.PopFormulaTask(ctx)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
|
||
return
|
||
}
|
||
|
||
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
|
||
return
|
||
}
|
||
if task == nil {
|
||
log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID)
|
||
return
|
||
}
|
||
|
||
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
||
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
||
|
||
err = s.processBaiduOCRTask(ctx, taskID, task.FileURL)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
||
return
|
||
}
|
||
|
||
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
||
}
|
||
|
||
// processMathpixTask 使用 Mathpix API 处理公式识别任务(用于增强识别)
|
||
func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int64, fileURL string) error {
|
||
isSuccess := false
|
||
logDao := dao.NewRecognitionLogDao()
|
||
|
||
defer func() {
|
||
if !isSuccess {
|
||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
return
|
||
}
|
||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
}()
|
||
|
||
// 下载图片
|
||
imageUrl, err := oss.GetDownloadURL(ctx, fileURL)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "获取图片URL失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 创建 Mathpix API 请求
|
||
mathpixReq := MathpixRequest{
|
||
Src: imageUrl,
|
||
Formats: []string{
|
||
"text",
|
||
"latex_styled",
|
||
"data",
|
||
"html",
|
||
},
|
||
DataOptions: &MathpixDataOptions{
|
||
IncludeMathml: true,
|
||
IncludeAsciimath: true,
|
||
IncludeLatex: true,
|
||
IncludeTsv: true,
|
||
},
|
||
MathInlineDelimiters: []string{"$", "$"},
|
||
MathDisplayDelimiters: []string{"$$", "$$"},
|
||
RmSpaces: &[]bool{true}[0],
|
||
}
|
||
|
||
jsonData, err := json.Marshal(mathpixReq)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "JSON编码失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
headers := map[string]string{
|
||
"Content-Type": "application/json",
|
||
"app_id": config.GlobalConfig.Mathpix.AppID,
|
||
"app_key": config.GlobalConfig.Mathpix.AppKey,
|
||
}
|
||
|
||
endpoint := "https://api.mathpix.com/v3/text"
|
||
|
||
startTime := time.Now()
|
||
|
||
log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_Start", "start_time", startTime)
|
||
|
||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 请求失败", "error", err)
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
|
||
|
||
body := &bytes.Buffer{}
|
||
if _, err = body.ReadFrom(resp.Body); err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "读取响应体失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 创建日志记录
|
||
recognitionLog := &dao.RecognitionLog{
|
||
TaskID: taskID,
|
||
Provider: dao.ProviderMathpix,
|
||
RequestBody: string(jsonData),
|
||
ResponseBody: body.String(),
|
||
}
|
||
|
||
// 解析响应
|
||
var mathpixResp MathpixResponse
|
||
if err := json.Unmarshal(body.Bytes(), &mathpixResp); err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "解析响应失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 检查错误
|
||
if mathpixResp.Error != "" {
|
||
errMsg := mathpixResp.Error
|
||
if mathpixResp.ErrorInfo != nil {
|
||
errMsg = fmt.Sprintf("%s: %s", mathpixResp.ErrorInfo.ID, mathpixResp.ErrorInfo.Message)
|
||
}
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 返回错误", "error", errMsg)
|
||
return fmt.Errorf("mathpix error: %s", errMsg)
|
||
}
|
||
|
||
// 保存日志
|
||
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "保存日志失败", "error", err)
|
||
}
|
||
|
||
// 更新或创建识别结果
|
||
resultDao := dao.NewRecognitionResultDao()
|
||
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "获取任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
log.Info(ctx, "func", "processMathpixTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
|
||
|
||
if result == nil {
|
||
// 创建新结果
|
||
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
|
||
TaskID: taskID,
|
||
TaskType: dao.TaskTypeFormula,
|
||
Latex: mathpixResp.LatexStyled,
|
||
Markdown: mathpixResp.Text,
|
||
MathML: mathpixResp.GetMathML(),
|
||
})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "创建任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
} else {
|
||
// 更新现有结果
|
||
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
|
||
"latex": mathpixResp.LatexStyled,
|
||
"markdown": mathpixResp.Text,
|
||
"mathml": mathpixResp.GetMathML(),
|
||
})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
isSuccess = true
|
||
return nil
|
||
}
|
||
|
||
func (s *RecognitionService) processBaiduOCRTask(ctx context.Context, taskID int64, fileURL string) error {
|
||
isSuccess := false
|
||
logDao := dao.NewRecognitionLogDao()
|
||
|
||
defer func() {
|
||
if !isSuccess {
|
||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
return
|
||
}
|
||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
|
||
}
|
||
}()
|
||
|
||
// 从 OSS 下载文件
|
||
reader, err := oss.DownloadFile(ctx, fileURL)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "从OSS下载文件失败", "error", err)
|
||
return err
|
||
}
|
||
defer reader.Close()
|
||
|
||
// 读取文件内容
|
||
fileBytes, err := io.ReadAll(reader)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取文件内容失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// Base64 编码
|
||
fileData := base64.StdEncoding.EncodeToString(fileBytes)
|
||
|
||
// 根据文件扩展名确定 fileType: 0=PDF, 1=图片
|
||
fileType := 1 // 默认为图片
|
||
lowerFileURL := strings.ToLower(fileURL)
|
||
if strings.HasSuffix(lowerFileURL, ".pdf") {
|
||
fileType = 0
|
||
}
|
||
|
||
// 创建百度 OCR API 请求
|
||
baiduReq := BaiduOCRRequest{
|
||
File: fileData,
|
||
FileType: fileType,
|
||
UseDocOrientationClassify: false,
|
||
UseDocUnwarping: false,
|
||
UseChartRecognition: false,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(baiduReq)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "JSON编码失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
headers := map[string]string{
|
||
"Content-Type": "application/json",
|
||
"Authorization": fmt.Sprintf("token %s", config.GlobalConfig.BaiduOCR.Token),
|
||
}
|
||
|
||
endpoint := "https://j5veh2l2r6ubk6cb.aistudio-app.com/layout-parsing"
|
||
|
||
startTime := time.Now()
|
||
|
||
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_Start", "start_time", startTime)
|
||
|
||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 请求失败", "error", err)
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
|
||
|
||
body := &bytes.Buffer{}
|
||
if _, err = body.ReadFrom(resp.Body); err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取响应体失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 创建日志记录(不记录请求体中的 base64 数据以节省存储)
|
||
requestLogData := map[string]interface{}{
|
||
"fileType": fileType,
|
||
"useDocOrientationClassify": false,
|
||
"useDocUnwarping": false,
|
||
"useChartRecognition": false,
|
||
"fileSize": len(fileBytes),
|
||
}
|
||
requestLogBytes, _ := json.Marshal(requestLogData)
|
||
recognitionLog := &dao.RecognitionLog{
|
||
TaskID: taskID,
|
||
Provider: dao.ProviderBaiduOCR,
|
||
RequestBody: string(requestLogBytes),
|
||
ResponseBody: body.String(),
|
||
}
|
||
|
||
// 解析响应
|
||
var baiduResp BaiduOCRResponse
|
||
if err := json.Unmarshal(body.Bytes(), &baiduResp); err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "解析响应失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
// 检查错误
|
||
if baiduResp.ErrorCode != 0 {
|
||
errMsg := fmt.Sprintf("errorCode: %d, errorMsg: %s", baiduResp.ErrorCode, baiduResp.ErrorMsg)
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 返回错误", "error", errMsg)
|
||
return fmt.Errorf("baidu ocr error: %s", errMsg)
|
||
}
|
||
|
||
// 保存日志
|
||
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "保存日志失败", "error", err)
|
||
}
|
||
|
||
// 合并所有页面的 markdown 结果
|
||
var markdownTexts []string
|
||
if baiduResp.Result != nil && len(baiduResp.Result.LayoutParsingResults) > 0 {
|
||
for _, res := range baiduResp.Result.LayoutParsingResults {
|
||
if res.Markdown.Text != "" {
|
||
markdownTexts = append(markdownTexts, res.Markdown.Text)
|
||
}
|
||
}
|
||
}
|
||
markdownResult := strings.Join(markdownTexts, "\n\n---\n\n")
|
||
|
||
// 更新或创建识别结果
|
||
resultDao := dao.NewRecognitionResultDao()
|
||
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "获取任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
|
||
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
|
||
|
||
if result == nil {
|
||
// 创建新结果
|
||
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
|
||
TaskID: taskID,
|
||
TaskType: dao.TaskTypeFormula,
|
||
Markdown: markdownResult,
|
||
})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "创建任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
} else {
|
||
// 更新现有结果
|
||
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
|
||
"markdown": markdownResult,
|
||
})
|
||
if err != nil {
|
||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务结果失败", "error", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
isSuccess = true
|
||
return nil
|
||
}
|
||
|
||
func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error {
|
||
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
||
if err != nil {
|
||
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "获取任务失败", "error", err)
|
||
return err
|
||
}
|
||
if task == nil {
|
||
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "任务不存在", "task_id", taskID)
|
||
return err
|
||
}
|
||
return s.processMathpixTask(ctx, taskID, task.FileURL)
|
||
}
|