diff --git a/.gitignore b/.gitignore index ac3dd36..1a1d885 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ texpixel /vendor -dev_deploy.sh \ No newline at end of file +dev_deploy.sh +speed_take.sh \ No newline at end of file diff --git a/config/config.go b/config/config.go index e3fb295..fe61de6 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,11 @@ type Config struct { Limit LimitConfig `mapstructure:"limit"` Aliyun AliyunConfig `mapstructure:"aliyun"` Mathpix MathpixConfig `mapstructure:"mathpix"` + BaiduOCR BaiduOCRConfig `mapstructure:"baidu_ocr"` +} + +type BaiduOCRConfig struct { + Token string `mapstructure:"token"` } type MathpixConfig struct { diff --git a/config/config_dev.yaml b/config/config_dev.yaml index be99b85..ff32c59 100644 --- a/config/config_dev.yaml +++ b/config/config_dev.yaml @@ -48,3 +48,7 @@ aliyun: mathpix: app_id: "ocr_eede6f_ea9b5c" app_key: "fb72d251e33ac85c929bfd4eec40d78368d08d82fb2ee1cffb04a8bb967d1db5" + + +baidu_ocr: + token: "e3a47bd2438f1f38840c203fc5939d17a54482d1" \ No newline at end of file diff --git a/config/config_prod.yaml b/config/config_prod.yaml index dda41bc..8f009b2 100644 --- a/config/config_prod.yaml +++ b/config/config_prod.yaml @@ -47,4 +47,7 @@ aliyun: mathpix: app_id: "ocr_eede6f_ea9b5c" - app_key: "fb72d251e33ac85c929bfd4eec40d78368d08d82fb2ee1cffb04a8bb967d1db5" \ No newline at end of file + app_key: "fb72d251e33ac85c929bfd4eec40d78368d08d82fb2ee1cffb04a8bb967d1db5" + +baidu_ocr: + token: "e3a47bd2438f1f38840c203fc5939d17a54482d1" \ No newline at end of file diff --git a/internal/service/recognition_service.go b/internal/service/recognition_service.go index 9643d45..b6428db 100644 --- a/internal/service/recognition_service.go +++ b/internal/service/recognition_service.go @@ -379,6 +379,44 @@ type MathpixErrorInfo struct { Message string `json:"message"` } +// BaiduOCRRequest 百度 OCR 版面分析请求结构 +type BaiduOCRRequest struct { + // 文件内容 base64 编码 + File string `json:"file"` + // 文件类型: 0=PDF, 1=图片 + FileType int `json:"fileType"` + // 是否启用文档方向分类 + UseDocOrientationClassify bool `json:"useDocOrientationClassify"` + // 是否启用文档扭曲矫正 + UseDocUnwarping bool `json:"useDocUnwarping"` + // 是否启用图表识别 + UseChartRecognition bool `json:"useChartRecognition"` +} + +// BaiduOCRResponse 百度 OCR 版面分析响应结构 +type BaiduOCRResponse struct { + ErrorCode int `json:"errorCode"` + ErrorMsg string `json:"errorMsg"` + Result *BaiduOCRResult `json:"result"` +} + +// BaiduOCRResult 百度 OCR 响应结果 +type BaiduOCRResult struct { + LayoutParsingResults []BaiduLayoutParsingResult `json:"layoutParsingResults"` +} + +// BaiduLayoutParsingResult 单页版面解析结果 +type BaiduLayoutParsingResult struct { + Markdown BaiduMarkdownResult `json:"markdown"` + OutputImages map[string]string `json:"outputImages"` +} + +// BaiduMarkdownResult markdown 结果 +type BaiduMarkdownResult struct { + Text string `json:"text"` + Images map[string]string `json:"images"` +} + // GetMathML 从响应中获取MathML func (r *MathpixResponse) GetMathML() string { for _, item := range r.Data { @@ -608,21 +646,20 @@ func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID in } resultDao := dao.NewRecognitionResultDao() - var formulaRes *dao.RecognitionResult result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err) return err } if result == nil { - formulaRes = &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex} + formulaRes := &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex} err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err) return err } } else { - formulaRes.Latex = latex + result.Latex = latex err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex}) if err != nil { log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err) @@ -669,7 +706,7 @@ func (s *RecognitionService) processOneTask(ctx context.Context) { ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID) log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID) - err = s.processMathpixTask(ctx, taskID, task.FileURL) + err = s.processBaiduOCRTask(ctx, taskID, task.FileURL) if err != nil { log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err) return @@ -738,6 +775,10 @@ func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int6 endpoint := "https://api.mathpix.com/v3/text" + startTime := time.Now() + + log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_Start", "start_time", startTime) + resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers) if err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 请求失败", "error", err) @@ -745,6 +786,8 @@ func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int6 } defer resp.Body.Close() + log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_End", "end_time", time.Now(), "duration", time.Since(startTime)) + body := &bytes.Buffer{} if _, err = body.ReadFrom(resp.Body); err != nil { log.Error(ctx, "func", "processMathpixTask", "msg", "读取响应体失败", "error", err) @@ -790,6 +833,8 @@ func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int6 return err } + log.Info(ctx, "func", "processMathpixTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime)) + if result == nil { // 创建新结果 err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{ @@ -820,6 +865,173 @@ func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int6 return nil } +func (s *RecognitionService) processBaiduOCRTask(ctx context.Context, taskID int64, fileURL string) error { + isSuccess := false + logDao := dao.NewRecognitionLogDao() + + defer func() { + if !isSuccess { + err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed}) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err) + } + return + } + err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted}) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err) + } + }() + + // 从 OSS 下载文件 + reader, err := oss.DownloadFile(ctx, fileURL) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "从OSS下载文件失败", "error", err) + return err + } + defer reader.Close() + + // 读取文件内容 + fileBytes, err := io.ReadAll(reader) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取文件内容失败", "error", err) + return err + } + + // Base64 编码 + fileData := base64.StdEncoding.EncodeToString(fileBytes) + + // 根据文件扩展名确定 fileType: 0=PDF, 1=图片 + fileType := 1 // 默认为图片 + lowerFileURL := strings.ToLower(fileURL) + if strings.HasSuffix(lowerFileURL, ".pdf") { + fileType = 0 + } + + // 创建百度 OCR API 请求 + baiduReq := BaiduOCRRequest{ + File: fileData, + FileType: fileType, + UseDocOrientationClassify: false, + UseDocUnwarping: false, + UseChartRecognition: false, + } + + jsonData, err := json.Marshal(baiduReq) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "JSON编码失败", "error", err) + return err + } + + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("token %s", config.GlobalConfig.BaiduOCR.Token), + } + + endpoint := "https://j5veh2l2r6ubk6cb.aistudio-app.com/layout-parsing" + + startTime := time.Now() + + log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_Start", "start_time", startTime) + + resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 请求失败", "error", err) + return err + } + defer resp.Body.Close() + + log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_End", "end_time", time.Now(), "duration", time.Since(startTime)) + + body := &bytes.Buffer{} + if _, err = body.ReadFrom(resp.Body); err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取响应体失败", "error", err) + return err + } + + // 创建日志记录(不记录请求体中的 base64 数据以节省存储) + requestLogData := map[string]interface{}{ + "fileType": fileType, + "useDocOrientationClassify": false, + "useDocUnwarping": false, + "useChartRecognition": false, + "fileSize": len(fileBytes), + } + requestLogBytes, _ := json.Marshal(requestLogData) + recognitionLog := &dao.RecognitionLog{ + TaskID: taskID, + Provider: dao.ProviderBaiduOCR, + RequestBody: string(requestLogBytes), + ResponseBody: body.String(), + } + + // 解析响应 + var baiduResp BaiduOCRResponse + if err := json.Unmarshal(body.Bytes(), &baiduResp); err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "解析响应失败", "error", err) + return err + } + + // 检查错误 + if baiduResp.ErrorCode != 0 { + errMsg := fmt.Sprintf("errorCode: %d, errorMsg: %s", baiduResp.ErrorCode, baiduResp.ErrorMsg) + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 返回错误", "error", errMsg) + return fmt.Errorf("baidu ocr error: %s", errMsg) + } + + // 保存日志 + err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "保存日志失败", "error", err) + } + + // 合并所有页面的 markdown 结果 + var markdownTexts []string + if baiduResp.Result != nil && len(baiduResp.Result.LayoutParsingResults) > 0 { + for _, res := range baiduResp.Result.LayoutParsingResults { + if res.Markdown.Text != "" { + markdownTexts = append(markdownTexts, res.Markdown.Text) + } + } + } + markdownResult := strings.Join(markdownTexts, "\n\n---\n\n") + + // 更新或创建识别结果 + resultDao := dao.NewRecognitionResultDao() + result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "获取任务结果失败", "error", err) + return err + } + + log.Info(ctx, "func", "processBaiduOCRTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime)) + + if result == nil { + // 创建新结果 + err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{ + TaskID: taskID, + TaskType: dao.TaskTypeFormula, + Markdown: markdownResult, + }) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "创建任务结果失败", "error", err) + return err + } + } else { + // 更新现有结果 + err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{ + "markdown": markdownResult, + }) + if err != nil { + log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务结果失败", "error", err) + return err + } + } + + isSuccess = true + return nil +} + func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error { task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID) if err != nil { diff --git a/internal/storage/dao/recognition_log.go b/internal/storage/dao/recognition_log.go index 77cc2ce..c40e9f5 100644 --- a/internal/storage/dao/recognition_log.go +++ b/internal/storage/dao/recognition_log.go @@ -11,6 +11,7 @@ const ( ProviderMathpix RecognitionLogProvider = "mathpix" ProviderSiliconflow RecognitionLogProvider = "siliconflow" ProviderTexpixel RecognitionLogProvider = "texpixel" + ProviderBaiduOCR RecognitionLogProvider = "baidu_ocr" ) // RecognitionLog 识别调用日志表,记录第三方API调用请求和响应 diff --git a/pkg/httpclient/client.go b/pkg/httpclient/client.go index d508ed0..d07715f 100644 --- a/pkg/httpclient/client.go +++ b/pkg/httpclient/client.go @@ -23,9 +23,9 @@ type RetryConfig struct { // DefaultRetryConfig 默认重试配置 var DefaultRetryConfig = RetryConfig{ - MaxRetries: 2, + MaxRetries: 1, InitialInterval: 100 * time.Millisecond, - MaxInterval: 5 * time.Second, + MaxInterval: 30 * time.Second, SkipTLSVerify: true, }