Files
doc_ai_backed/internal/service/pdf_recognition_service.go
yoge 9d712c921a feat: add PDF document recognition with 10-page pre-hook
- Migrate recognition_results table to JSON schema (meta_data + content),
  replacing flat latex/markdown/mathml/mml columns
- Add TaskTypePDF constant and update all formula read/write paths
- Add PDFRecognitionService using pdftoppm (Poppler) for CGO-free page
  rendering; limits processing to first 10 pages (pre-hook)
- Reuse existing downstream OCR endpoint (cloud.texpixel.com) for each
  page image; stores results as [{page_number, markdown}] JSON array
- Add Redis queue + distributed lock for PDF worker goroutine
- Add REST endpoints: POST /v1/pdf/recognition, GET /v1/pdf/recognition/:task_no
- Add .pdf to OSS upload file type whitelist
- Add migrations/pdf_recognition.sql for safe data migration

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-31 14:17:44 +08:00

344 lines
9.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"sort"
"time"
pdfmodel "gitea.com/texpixel/document_ai/internal/model/pdf"
"gitea.com/texpixel/document_ai/internal/storage/cache"
"gitea.com/texpixel/document_ai/internal/storage/dao"
"gitea.com/texpixel/document_ai/pkg/common"
"gitea.com/texpixel/document_ai/pkg/httpclient"
"gitea.com/texpixel/document_ai/pkg/log"
"gitea.com/texpixel/document_ai/pkg/oss"
"gitea.com/texpixel/document_ai/pkg/requestid"
"gitea.com/texpixel/document_ai/pkg/utils"
"gorm.io/gorm"
"gitea.com/texpixel/document_ai/internal/model/formula"
)
const (
pdfMaxPages = 10
pdfOCREndpoint = "https://cloud.texpixel.com:10443/doc_process/v1/image/ocr"
)
// PDFRecognitionService 处理 PDF 识别任务
type PDFRecognitionService struct {
db *gorm.DB
queueLimit chan struct{}
stopChan chan struct{}
httpClient *httpclient.Client
}
func NewPDFRecognitionService() *PDFRecognitionService {
s := &PDFRecognitionService{
db: dao.DB,
queueLimit: make(chan struct{}, 3),
stopChan: make(chan struct{}),
httpClient: httpclient.NewClient(nil),
}
utils.SafeGo(func() {
lock, err := cache.GetPDFDistributedLock(context.Background())
if err != nil || !lock {
log.Error(context.Background(), "func", "NewPDFRecognitionService", "msg", "获取PDF分布式锁失败")
return
}
s.processPDFQueue(context.Background())
})
return s
}
// CreatePDFTask 创建识别任务并入队
func (s *PDFRecognitionService) CreatePDFTask(ctx context.Context, req *pdfmodel.CreatePDFRecognitionRequest) (*dao.RecognitionTask, error) {
task := &dao.RecognitionTask{
UserID: req.UserID,
TaskUUID: utils.NewUUID(),
TaskType: dao.TaskTypePDF,
Status: dao.TaskStatusPending,
FileURL: req.FileURL,
FileName: req.FileName,
FileHash: req.FileHash,
IP: common.GetIPFromContext(ctx),
}
if err := dao.NewRecognitionTaskDao().Create(dao.DB.WithContext(ctx), task); err != nil {
log.Error(ctx, "func", "CreatePDFTask", "msg", "创建任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
}
if _, err := cache.PushPDFTask(ctx, task.ID); err != nil {
log.Error(ctx, "func", "CreatePDFTask", "msg", "推入队列失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "推入队列失败", err)
}
return task, nil
}
// GetPDFTask 查询任务状态和结果
func (s *PDFRecognitionService) GetPDFTask(ctx context.Context, taskNo string) (*pdfmodel.GetPDFTaskResponse, error) {
sess := dao.DB.WithContext(ctx)
task, err := dao.NewRecognitionTaskDao().GetByTaskNo(sess, taskNo)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
}
// 类型校验:防止公式任务被当成 PDF 解析
if task.TaskType != dao.TaskTypePDF {
return nil, common.NewError(common.CodeNotFound, "任务不存在", nil)
}
resp := &pdfmodel.GetPDFTaskResponse{
TaskNo: taskNo,
Status: int(task.Status),
}
if task.Status != dao.TaskStatusCompleted {
return resp, nil
}
result, err := dao.NewRecognitionResultDao().GetByTaskID(sess, task.ID)
if err != nil || result == nil {
return nil, common.NewError(common.CodeDBError, "查询识别结果失败", err)
}
pages, err := result.GetPDFContent()
if err != nil {
return nil, common.NewError(common.CodeSystemError, "解析识别结果失败", err)
}
resp.TotalPages = len(pages)
for _, p := range pages {
resp.Pages = append(resp.Pages, pdfmodel.PDFPageResult{
PageNumber: p.PageNumber,
Markdown: p.Markdown,
})
}
return resp, nil
}
// processPDFQueue 持续消费队列
func (s *PDFRecognitionService) processPDFQueue(ctx context.Context) {
for {
select {
case <-s.stopChan:
return
default:
s.processOnePDFTask(ctx)
}
}
}
func (s *PDFRecognitionService) processOnePDFTask(ctx context.Context) {
s.queueLimit <- struct{}{}
defer func() { <-s.queueLimit }()
taskID, err := cache.PopPDFTask(ctx)
if err != nil {
log.Error(ctx, "func", "processOnePDFTask", "msg", "获取任务失败", "error", err)
return
}
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil || task == nil {
log.Error(ctx, "func", "processOnePDFTask", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
requestid.SetRequestID(task.TaskUUID, func() {
if err := s.processPDFTask(ctx, taskID, task.FileURL); err != nil {
log.Error(ctx, "func", "processOnePDFTask", "msg", "处理PDF任务失败", "error", err)
}
})
}
// processPDFTask 核心处理:下载 → pre-hook → 逐页OCR → 写入DB
func (s *PDFRecognitionService) processPDFTask(ctx context.Context, taskID int64, fileURL string) error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
taskDao := dao.NewRecognitionTaskDao()
resultDao := dao.NewRecognitionResultDao()
isSuccess := false
defer func() {
status, remark := dao.TaskStatusFailed, "任务处理失败"
if isSuccess {
status, remark = dao.TaskStatusCompleted, ""
}
_ = taskDao.Update(dao.DB.WithContext(context.Background()),
map[string]interface{}{"id": taskID},
map[string]interface{}{"status": status, "completed_at": time.Now(), "remark": remark},
)
}()
// 更新为处理中
if err := taskDao.Update(dao.DB.WithContext(ctx),
map[string]interface{}{"id": taskID},
map[string]interface{}{"status": dao.TaskStatusProcessing},
); err != nil {
return fmt.Errorf("更新任务状态失败: %w", err)
}
// 下载 PDF
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
return fmt.Errorf("下载PDF失败: %w", err)
}
defer reader.Close()
pdfBytes, err := io.ReadAll(reader)
if err != nil {
return fmt.Errorf("读取PDF数据失败: %w", err)
}
// pre-hook: 用 pdftoppm 渲染前 pdfMaxPages 页为 PNG
pageImages, err := renderPDFPages(ctx, pdfBytes, pdfMaxPages)
if err != nil {
return fmt.Errorf("渲染PDF页面失败: %w", err)
}
processPages := len(pageImages)
log.Info(ctx, "func", "processPDFTask", "msg", "开始处理PDF",
"task_id", taskID, "process_pages", processPages)
// 逐页 OCR结果收集
var pages []dao.PDFPageContent
for i, imgBytes := range pageImages {
ocrResult, err := s.callOCR(ctx, imgBytes)
if err != nil {
return fmt.Errorf("OCR第%d页失败: %w", i+1, err)
}
pages = append(pages, dao.PDFPageContent{
PageNumber: i + 1,
Markdown: ocrResult.Markdown,
})
log.Info(ctx, "func", "processPDFTask", "msg", "页面OCR完成",
"page", i+1, "total", processPages)
}
// 序列化并写入 DB单行
contentJSON, err := dao.MarshalPDFContent(pages)
if err != nil {
return fmt.Errorf("序列化PDF内容失败: %w", err)
}
dbResult := dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypePDF,
Content: contentJSON,
}
if err := dbResult.SetMetaData(dao.ResultMetaData{TotalNum: processPages}); err != nil {
return fmt.Errorf("序列化MetaData失败: %w", err)
}
if err := resultDao.Create(dao.DB.WithContext(ctx), dbResult); err != nil {
return fmt.Errorf("保存PDF结果失败: %w", err)
}
isSuccess = true
return nil
}
// renderPDFPages 使用 pdftoppm 将 PDF 渲染为 PNG 字节切片,最多渲染 maxPages 页
func renderPDFPages(ctx context.Context, pdfBytes []byte, maxPages int) ([][]byte, error) {
tmpDir, err := os.MkdirTemp("", "pdf-ocr-*")
if err != nil {
return nil, fmt.Errorf("创建临时目录失败: %w", err)
}
defer os.RemoveAll(tmpDir)
pdfPath := filepath.Join(tmpDir, "input.pdf")
if err := os.WriteFile(pdfPath, pdfBytes, 0600); err != nil {
return nil, fmt.Errorf("写入临时PDF失败: %w", err)
}
outPrefix := filepath.Join(tmpDir, "page")
cmd := exec.CommandContext(ctx, "pdftoppm",
"-r", "150",
"-png",
"-l", fmt.Sprintf("%d", maxPages),
pdfPath,
outPrefix,
)
if out, err := cmd.CombinedOutput(); err != nil {
return nil, fmt.Errorf("pdftoppm失败: %w, output: %s", err, string(out))
}
files, err := filepath.Glob(filepath.Join(tmpDir, "page-*.png"))
if err != nil {
return nil, fmt.Errorf("查找渲染输出文件失败: %w", err)
}
if len(files) == 0 {
return nil, fmt.Errorf("pdftoppm未输出任何页面")
}
sort.Strings(files)
pages := make([][]byte, 0, len(files))
for _, f := range files {
data, err := os.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("读取页面图片失败: %w", err)
}
pages = append(pages, data)
}
return pages, nil
}
// callOCR 调用与公式识别相同的下游 OCR 接口
func (s *PDFRecognitionService) callOCR(ctx context.Context, imgBytes []byte) (*formula.ImageOCRResponse, error) {
reqBody := map[string]string{
"image_base64": base64.StdEncoding.EncodeToString(imgBytes),
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
headers := map[string]string{
"Content-Type": "application/json",
utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx),
}
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, pdfOCREndpoint, bytes.NewReader(jsonData), headers)
if err != nil {
return nil, fmt.Errorf("请求OCR接口失败: %w", err)
}
defer resp.Body.Close()
// 下游非 2xx 视为失败,避免把错误响应 body 当成识别结果存库
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("OCR接口返回非200状态: %d, body: %s", resp.StatusCode, string(body))
}
var ocrResp formula.ImageOCRResponse
if err := json.NewDecoder(resp.Body).Decode(&ocrResp); err != nil {
return nil, fmt.Errorf("解析OCR响应失败: %w", err)
}
return &ocrResp, nil
}
func (s *PDFRecognitionService) Stop() {
close(s.stopChan)
}