feat: add baidu api
This commit is contained in:
@@ -379,6 +379,44 @@ type MathpixErrorInfo struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// BaiduOCRRequest 百度 OCR 版面分析请求结构
|
||||
type BaiduOCRRequest struct {
|
||||
// 文件内容 base64 编码
|
||||
File string `json:"file"`
|
||||
// 文件类型: 0=PDF, 1=图片
|
||||
FileType int `json:"fileType"`
|
||||
// 是否启用文档方向分类
|
||||
UseDocOrientationClassify bool `json:"useDocOrientationClassify"`
|
||||
// 是否启用文档扭曲矫正
|
||||
UseDocUnwarping bool `json:"useDocUnwarping"`
|
||||
// 是否启用图表识别
|
||||
UseChartRecognition bool `json:"useChartRecognition"`
|
||||
}
|
||||
|
||||
// BaiduOCRResponse 百度 OCR 版面分析响应结构
|
||||
type BaiduOCRResponse struct {
|
||||
ErrorCode int `json:"errorCode"`
|
||||
ErrorMsg string `json:"errorMsg"`
|
||||
Result *BaiduOCRResult `json:"result"`
|
||||
}
|
||||
|
||||
// BaiduOCRResult 百度 OCR 响应结果
|
||||
type BaiduOCRResult struct {
|
||||
LayoutParsingResults []BaiduLayoutParsingResult `json:"layoutParsingResults"`
|
||||
}
|
||||
|
||||
// BaiduLayoutParsingResult 单页版面解析结果
|
||||
type BaiduLayoutParsingResult struct {
|
||||
Markdown BaiduMarkdownResult `json:"markdown"`
|
||||
OutputImages map[string]string `json:"outputImages"`
|
||||
}
|
||||
|
||||
// BaiduMarkdownResult markdown 结果
|
||||
type BaiduMarkdownResult struct {
|
||||
Text string `json:"text"`
|
||||
Images map[string]string `json:"images"`
|
||||
}
|
||||
|
||||
// GetMathML 从响应中获取MathML
|
||||
func (r *MathpixResponse) GetMathML() string {
|
||||
for _, item := range r.Data {
|
||||
@@ -608,21 +646,20 @@ func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID in
|
||||
}
|
||||
|
||||
resultDao := dao.NewRecognitionResultDao()
|
||||
var formulaRes *dao.RecognitionResult
|
||||
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
|
||||
return err
|
||||
}
|
||||
if result == nil {
|
||||
formulaRes = &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex}
|
||||
formulaRes := &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex}
|
||||
err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
formulaRes.Latex = latex
|
||||
result.Latex = latex
|
||||
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex})
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
|
||||
@@ -669,7 +706,7 @@ func (s *RecognitionService) processOneTask(ctx context.Context) {
|
||||
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
||||
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
||||
|
||||
err = s.processMathpixTask(ctx, taskID, task.FileURL)
|
||||
err = s.processBaiduOCRTask(ctx, taskID, task.FileURL)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
||||
return
|
||||
@@ -828,6 +865,173 @@ func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int6
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RecognitionService) processBaiduOCRTask(ctx context.Context, taskID int64, fileURL string) error {
|
||||
isSuccess := false
|
||||
logDao := dao.NewRecognitionLogDao()
|
||||
|
||||
defer func() {
|
||||
if !isSuccess {
|
||||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 从 OSS 下载文件
|
||||
reader, err := oss.DownloadFile(ctx, fileURL)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "从OSS下载文件失败", "error", err)
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 读取文件内容
|
||||
fileBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取文件内容失败", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Base64 编码
|
||||
fileData := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
// 根据文件扩展名确定 fileType: 0=PDF, 1=图片
|
||||
fileType := 1 // 默认为图片
|
||||
lowerFileURL := strings.ToLower(fileURL)
|
||||
if strings.HasSuffix(lowerFileURL, ".pdf") {
|
||||
fileType = 0
|
||||
}
|
||||
|
||||
// 创建百度 OCR API 请求
|
||||
baiduReq := BaiduOCRRequest{
|
||||
File: fileData,
|
||||
FileType: fileType,
|
||||
UseDocOrientationClassify: false,
|
||||
UseDocUnwarping: false,
|
||||
UseChartRecognition: false,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(baiduReq)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "JSON编码失败", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("token %s", config.GlobalConfig.BaiduOCR.Token),
|
||||
}
|
||||
|
||||
endpoint := "https://j5veh2l2r6ubk6cb.aistudio-app.com/layout-parsing"
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_Start", "start_time", startTime)
|
||||
|
||||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 请求失败", "error", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
if _, err = body.ReadFrom(resp.Body); err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取响应体失败", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建日志记录(不记录请求体中的 base64 数据以节省存储)
|
||||
requestLogData := map[string]interface{}{
|
||||
"fileType": fileType,
|
||||
"useDocOrientationClassify": false,
|
||||
"useDocUnwarping": false,
|
||||
"useChartRecognition": false,
|
||||
"fileSize": len(fileBytes),
|
||||
}
|
||||
requestLogBytes, _ := json.Marshal(requestLogData)
|
||||
recognitionLog := &dao.RecognitionLog{
|
||||
TaskID: taskID,
|
||||
Provider: dao.ProviderBaiduOCR,
|
||||
RequestBody: string(requestLogBytes),
|
||||
ResponseBody: body.String(),
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var baiduResp BaiduOCRResponse
|
||||
if err := json.Unmarshal(body.Bytes(), &baiduResp); err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "解析响应失败", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if baiduResp.ErrorCode != 0 {
|
||||
errMsg := fmt.Sprintf("errorCode: %d, errorMsg: %s", baiduResp.ErrorCode, baiduResp.ErrorMsg)
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 返回错误", "error", errMsg)
|
||||
return fmt.Errorf("baidu ocr error: %s", errMsg)
|
||||
}
|
||||
|
||||
// 保存日志
|
||||
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "保存日志失败", "error", err)
|
||||
}
|
||||
|
||||
// 合并所有页面的 markdown 结果
|
||||
var markdownTexts []string
|
||||
if baiduResp.Result != nil && len(baiduResp.Result.LayoutParsingResults) > 0 {
|
||||
for _, res := range baiduResp.Result.LayoutParsingResults {
|
||||
if res.Markdown.Text != "" {
|
||||
markdownTexts = append(markdownTexts, res.Markdown.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
markdownResult := strings.Join(markdownTexts, "\n\n---\n\n")
|
||||
|
||||
// 更新或创建识别结果
|
||||
resultDao := dao.NewRecognitionResultDao()
|
||||
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "获取任务结果失败", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
|
||||
|
||||
if result == nil {
|
||||
// 创建新结果
|
||||
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
|
||||
TaskID: taskID,
|
||||
TaskType: dao.TaskTypeFormula,
|
||||
Markdown: markdownResult,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "创建任务结果失败", "error", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 更新现有结果
|
||||
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
|
||||
"markdown": markdownResult,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务结果失败", "error", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
isSuccess = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error {
|
||||
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,6 +11,7 @@ const (
|
||||
ProviderMathpix RecognitionLogProvider = "mathpix"
|
||||
ProviderSiliconflow RecognitionLogProvider = "siliconflow"
|
||||
ProviderTexpixel RecognitionLogProvider = "texpixel"
|
||||
ProviderBaiduOCR RecognitionLogProvider = "baidu_ocr"
|
||||
)
|
||||
|
||||
// RecognitionLog 识别调用日志表,记录第三方API调用请求和响应
|
||||
|
||||
Reference in New Issue
Block a user