init repo

This commit is contained in:
liuyuanchuang
2025-12-10 18:33:37 +08:00
commit 48e63894eb
2408 changed files with 1053045 additions and 0 deletions

View File

@@ -0,0 +1,69 @@
package formula
type CreateFormulaRecognitionRequest struct {
FileURL string `json:"file_url" binding:"required"` // oss file url
FileHash string `json:"file_hash" binding:"required"` // file hash
FileName string `json:"file_name" binding:"required"` // file name
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
}
type GetRecognitionStatusRequest struct {
TaskNo string `uri:"task_no" binding:"required"`
}
type AIEnhanceRecognitionRequest struct {
TaskNo string `json:"task_no" binding:"required"`
}
type VLFormulaRequest struct {
Model string `json:"model"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
N int `json:"n"`
FrequencyPenalty float64 `json:"frequency_penalty"`
Messages []Message `json:"messages"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Content struct {
Text string `json:"text"`
Type string `json:"type"`
ImageURL Image `json:"image_url"`
}
type Image struct {
Detail string `json:"detail"`
URL string `json:"url"`
}
type VLFormulaResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}
type Choice struct {
Index int `json:"index"`
Message struct {
Content string `json:"content"`
Role string `json:"role"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@@ -0,0 +1,13 @@
package formula
type CreateTaskResponse struct {
TaskNo string `json:"task_no"`
Status int `json:"status"`
}
type GetFormulaTaskResponse struct {
TaskNo string `json:"task_no"`
Status int `json:"status"`
Count int `json:"count"`
Latex string `json:"latex"`
}

View File

@@ -0,0 +1,36 @@
package task
type EvaluateTaskRequest struct {
TaskNo string `json:"task_no" binding:"required"` // 任务编号
Satisfied int `json:"satisfied"` // 0: 不满意, 1: 满意
Suggestion []string `json:"suggestion"` // 建议 1. 公式无法渲染 2. 公式渲染错误
Feedback string `json:"feedback"` // 反馈
}
type TaskListRequest struct {
TaskType string `json:"task_type" form:"task_type" binding:"required"`
Page int `json:"page" form:"page"`
PageSize int `json:"page_size" form:"page_size"`
}
type PdfInfo struct {
PageCount int `json:"page_count"`
PageWidth int `json:"page_width"`
PageHeight int `json:"page_height"`
}
type TaskListDTO struct {
TaskID string `json:"task_id"`
FileName string `json:"file_name"`
Status string `json:"status"`
Path string `json:"path"`
TaskType string `json:"task_type"`
CreatedAt string `json:"created_at"`
PdfInfo PdfInfo `json:"pdf_info"`
}
type TaskListResponse struct {
TaskList []*TaskListDTO `json:"task_list"`
HasMore bool `json:"has_more"`
NextPage int `json:"next_page"`
}

View File

@@ -0,0 +1,24 @@
package model
type SmsSendRequest struct {
Phone string `json:"phone" binding:"required"`
}
type SmsSendResponse struct {
Code string `json:"code"`
}
type PhoneLoginRequest struct {
Phone string `json:"phone" binding:"required"`
Code string `json:"code" binding:"required"`
}
type PhoneLoginResponse struct {
Token string `json:"token"`
}
type UserInfoResponse struct {
Username string `json:"username"`
Phone string `json:"phone"`
Status int `json:"status"` // 0: not login, 1: login
}

View File

@@ -0,0 +1,540 @@
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"gitea.com/bitwsd/core/common/log"
"gitea.com/bitwsd/document_ai/config"
"gitea.com/bitwsd/document_ai/internal/model/formula"
"gitea.com/bitwsd/document_ai/internal/storage/cache"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
"gitea.com/bitwsd/document_ai/pkg/common"
"gitea.com/bitwsd/document_ai/pkg/constant"
"gitea.com/bitwsd/document_ai/pkg/httpclient"
"gitea.com/bitwsd/document_ai/pkg/oss"
"gitea.com/bitwsd/document_ai/pkg/utils"
"gorm.io/gorm"
)
type RecognitionService struct {
db *gorm.DB
queueLimit chan struct{}
stopChan chan struct{}
httpClient *httpclient.Client
}
func NewRecognitionService() *RecognitionService {
s := &RecognitionService{
db: dao.DB,
queueLimit: make(chan struct{}, config.GlobalConfig.Limit.FormulaRecognition),
stopChan: make(chan struct{}),
httpClient: httpclient.NewClient(nil), // 使用默认配置
}
// 服务启动时就开始处理队列
utils.SafeGo(func() {
lock, err := cache.GetDistributedLock(context.Background())
if err != nil {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败", "error", err)
return
}
if !lock {
log.Error(context.Background(), "func", "NewRecognitionService", "msg", "获取分布式锁失败")
return
}
s.processFormulaQueue(context.Background())
})
return s
}
func (s *RecognitionService) AIEnhanceRecognition(ctx context.Context, req *formula.AIEnhanceRecognitionRequest) (*dao.RecognitionTask, error) {
count, err := cache.GetVLMFormulaCount(ctx, common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取VLM公式识别次数失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "系统错误", err)
}
if count >= constant.VLMFormulaCount {
return nil, common.NewError(common.CodeForbidden, "今日VLM公式识别次数已达上限,请明天再试!", nil)
}
taskDao := dao.NewRecognitionTaskDao()
task, err := taskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "获取任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "获取任务失败", err)
}
if task == nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务不存在", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
if task.Status == dao.TaskStatusProcessing || task.Status == dao.TaskStatusPending {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "任务未完成", "task_no", req.TaskNo)
return nil, common.NewError(common.CodeInvalidStatus, "任务未完成", err)
}
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": task.ID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "更新任务状态失败", "error", err)
return nil, common.NewError(common.CodeDBError, "更新任务状态失败", err)
}
utils.SafeGo(func() {
s.processVLFormula(context.Background(), task.ID)
_, err := cache.IncrVLMFormulaCount(context.Background(), common.GetIPFromContext(ctx))
if err != nil {
log.Error(ctx, "func", "AIEnhanceRecognition", "msg", "增加VLM公式识别次数失败", "error", err)
}
})
return task, nil
}
func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *formula.CreateFormulaRecognitionRequest) (*dao.RecognitionTask, error) {
sess := dao.DB.WithContext(ctx)
taskDao := dao.NewRecognitionTaskDao()
task := &dao.RecognitionTask{
TaskUUID: utils.NewUUID(),
TaskType: dao.TaskType(req.TaskType),
Status: dao.TaskStatusPending,
FileURL: req.FileURL,
FileName: req.FileName,
FileHash: req.FileHash,
IP: common.GetIPFromContext(ctx),
}
if err := taskDao.Create(sess, task); err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "创建任务失败", "error", err)
return nil, common.NewError(common.CodeDBError, "创建任务失败", err)
}
err := s.handleFormulaRecognition(ctx, task.ID)
if err != nil {
log.Error(ctx, "func", "CreateRecognitionTask", "msg", "处理任务失败", "error", err)
return nil, common.NewError(common.CodeSystemError, "处理任务失败", err)
}
return task, nil
}
func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string) (*formula.GetFormulaTaskResponse, error) {
taskDao := dao.NewRecognitionTaskDao()
resultDao := dao.NewRecognitionResultDao()
sess := dao.DB.WithContext(ctx)
count, err := cache.GetFormulaTaskCount(ctx)
if err != nil {
log.Error(ctx, "func", "GetFormualTask", "msg", "获取任务数量失败", "error", err)
}
if count > int64(config.GlobalConfig.Limit.FormulaRecognition) {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(dao.TaskStatusPending), Count: int(count)}, nil
}
task, err := taskDao.GetByTaskNo(sess, taskNo)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务失败", err)
}
if task.Status != dao.TaskStatusCompleted {
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Status: int(task.Status)}, nil
}
taskRet, err := resultDao.GetByTaskID(sess, task.ID)
if err != nil {
if err == gorm.ErrRecordNotFound {
log.Info(ctx, "func", "GetFormualTask", "msg", "任务结果不存在", "task_no", taskNo)
return nil, common.NewError(common.CodeNotFound, "任务结果不存在", err)
}
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
}
latex := taskRet.NewContentCodec().GetContent().(string)
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
}
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
// 简化为只负责将任务加入队列
_, err := cache.PushFormulaTask(ctx, taskID)
if err != nil {
log.Error(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数失败", "error", err)
return err
}
log.Info(ctx, "func", "handleFormulaRecognition", "msg", "增加任务计数成功", "task_id", taskID)
return nil
}
// Stop 用于优雅关闭服务
func (s *RecognitionService) Stop() {
close(s.stopChan)
}
func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64) {
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL); err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
// 为整个任务处理添加超时控制
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tx := dao.DB.Begin()
var (
taskDao = dao.NewRecognitionTaskDao()
resultDao = dao.NewRecognitionResultDao()
)
isSuccess := false
defer func() {
if !isSuccess {
tx.Rollback()
status := dao.TaskStatusFailed
remark := "任务处理失败"
if ctx.Err() == context.DeadlineExceeded {
remark = "任务处理超时"
}
if err != nil {
remark = err.Error()
}
_ = taskDao.Update(dao.DB.WithContext(context.Background()), // 使用新的context避免已取消的context影响状态更新
map[string]interface{}{"id": taskID},
map[string]interface{}{
"status": status,
"completed_at": time.Now(),
"remark": remark,
})
return
}
_ = taskDao.Update(tx, map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted, "completed_at": time.Now()})
tx.Commit()
}()
err = taskDao.Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusProcessing})
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "更新任务状态失败", "error", err)
}
// 下载图片文件
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "下载图片文件失败", "error", err)
return err
}
defer reader.Close()
// 读取图片数据
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
downloadURL, err := oss.GetDownloadURL(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "获取下载URL失败", "error", err)
return err
}
// 将图片转为base64编码
base64Image := base64.StdEncoding.EncodeToString(imageData)
// 创建JSON请求
requestData := map[string]string{
"image_base64": base64Image,
"img_url": downloadURL,
}
jsonData, err := json.Marshal(requestData)
if err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
// 设置Content-Type头为application/json
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
// 发送请求时会使用带超时的context
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, s.getURL(ctx), bytes.NewReader(jsonData), headers)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
return fmt.Errorf("request timeout")
}
log.Error(ctx, "func", "processFormulaTask", "msg", "请求失败", "error", err)
return err
}
defer resp.Body.Close()
log.Info(ctx, "func", "processFormulaTask", "msg", "请求成功", "resp", resp.Body)
body := &bytes.Buffer{}
if _, err = body.ReadFrom(resp.Body); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "读取响应体失败", "error", err)
return err
}
katex := utils.ToKatex(body.String())
content := &dao.FormulaRecognitionContent{Latex: katex}
b, _ := json.Marshal(content)
// Save recognition result
result := &dao.RecognitionResult{
TaskID: taskID,
TaskType: dao.TaskTypeFormula,
Content: b,
}
if err := resultDao.Create(tx, *result); err != nil {
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
return err
}
isSuccess = true
return nil
}
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string) error {
isSuccess := false
defer func() {
if !isSuccess {
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
return
}
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务状态失败", "error", err)
}
}()
reader, err := oss.DownloadFile(ctx, fileURL)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取签名URL失败", "error", err)
return err
}
defer reader.Close()
imageData, err := io.ReadAll(reader)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "读取图片数据失败", "error", err)
return err
}
prompt := `
Please perform OCR on the image and output only LaTeX code.
Important instructions:
* "The image contains mathematical formulas, no plain text."
* "Preserve all layout, symbols, subscripts, summations, parentheses, etc., exactly as shown."
* "Use \[ ... \] or align environments to represent multiline math expressions."
* "Use adaptive symbols such as \left and \right where applicable."
* "Do not include any extra commentary, template answers, or unrelated equations."
* "Only output valid LaTeX code based on the actual content of the image, and not change the original mathematical expression."
* "The output result must be can render by better-react-mathjax."
`
base64Image := base64.StdEncoding.EncodeToString(imageData)
requestBody := formula.VLFormulaRequest{
Model: "Qwen/Qwen2.5-VL-32B-Instruct",
Stream: false,
MaxTokens: 512,
Temperature: 0.1,
TopP: 0.1,
TopK: 50,
FrequencyPenalty: 0.2,
N: 1,
Messages: []formula.Message{
{
Role: "user",
Content: []formula.Content{
{
Type: "text",
Text: prompt,
},
{
Type: "image_url",
ImageURL: formula.Image{
Detail: "auto",
URL: "data:image/jpeg;base64," + base64Image,
},
},
},
},
},
}
// 将请求体转换为JSON
jsonData, err := json.Marshal(requestBody)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "JSON编码失败", "error", err)
return err
}
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": utils.SiliconFlowToken,
}
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://api.siliconflow.cn/v1/chat/completions", bytes.NewReader(jsonData), headers)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "请求VL服务失败", "error", err)
return err
}
defer resp.Body.Close()
// 解析响应
var response formula.VLFormulaResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解析响应失败", "error", err)
return err
}
// 提取LaTeX代码
var latex string
if len(response.Choices) > 0 {
if response.Choices[0].Message.Content != "" {
latex = strings.ReplaceAll(response.Choices[0].Message.Content, "\n", "")
latex = strings.ReplaceAll(latex, "```latex", "")
latex = strings.ReplaceAll(latex, "```", "")
// 规范化LaTeX代码移除不必要的空格
latex = strings.ReplaceAll(latex, " = ", "=")
latex = strings.ReplaceAll(latex, "\\left[ ", "\\left[")
latex = strings.TrimPrefix(latex, "\\[")
latex = strings.TrimSuffix(latex, "\\]")
}
}
resultDao := dao.NewRecognitionResultDao()
var formulaRes *dao.FormulaRecognitionContent
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
return err
}
if result == nil {
formulaRes = &dao.FormulaRecognitionContent{EnhanceLatex: latex}
b, err := formulaRes.Encode()
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
return err
}
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Content: b})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
return err
}
} else {
formulaRes = result.NewContentCodec().(*dao.FormulaRecognitionContent)
err = formulaRes.Decode()
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解码任务结果失败", "error", err)
return err
}
formulaRes.EnhanceLatex = latex
b, err := formulaRes.Encode()
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
return err
}
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"content": b})
if err != nil {
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
return err
}
}
isSuccess = true
return nil
}
func (s *RecognitionService) processFormulaQueue(ctx context.Context) {
for {
select {
case <-s.stopChan:
return
default:
s.processOneTask(ctx)
}
}
}
func (s *RecognitionService) processOneTask(ctx context.Context) {
// 限制队列数量
s.queueLimit <- struct{}{}
defer func() { <-s.queueLimit }()
taskID, err := cache.PopFormulaTask(ctx)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
if err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "获取任务失败", "error", err)
return
}
if task == nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "任务不存在", "task_id", taskID)
return
}
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务
if err := s.processFormulaTask(ctx, taskID, task.FileURL); err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
return
}
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
}
func (s *RecognitionService) getURL(ctx context.Context) string {
return "http://cloud.srcstar.com:8045/formula/predict"
count, err := cache.IncrURLCount(ctx)
if err != nil {
log.Error(ctx, "func", "getURL", "msg", "获取URL计数失败", "error", err)
return "http://cloud.srcstar.com:8026/formula/predict"
}
if count%2 == 0 {
return "http://cloud.srcstar.com:8026/formula/predict"
}
return "https://cloud.texpixel.com:1080/formula/predict"
}

78
internal/service/task.go Normal file
View File

@@ -0,0 +1,78 @@
package service
import (
"context"
"errors"
"strings"
"gitea.com/bitwsd/core/common/log"
"gitea.com/bitwsd/document_ai/internal/model/task"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
"gorm.io/gorm"
)
type TaskService struct {
db *gorm.DB
}
func NewTaskService() *TaskService {
return &TaskService{dao.DB}
}
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
taskDao := dao.NewRecognitionTaskDao()
task, err := taskDao.GetByTaskNo(svc.db.WithContext(ctx), req.TaskNo)
if err != nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
return err
}
if task == nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "task not found")
return errors.New("task not found")
}
if task.Status != dao.TaskStatusCompleted {
log.Error(ctx, "func", "EvaluateTask", "msg", "task not finished")
return errors.New("task not finished")
}
evaluateTaskDao := dao.NewEvaluateTaskDao()
evaluateTask := &dao.EvaluateTask{
TaskID: task.ID,
Satisfied: req.Satisfied,
Feedback: req.Feedback,
Comment: strings.Join(req.Suggestion, ","),
}
err = evaluateTaskDao.Create(svc.db.WithContext(ctx), evaluateTask)
if err != nil {
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
return err
}
return nil
}
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
taskDao := dao.NewRecognitionTaskDao()
tasks, err := taskDao.GetTaskList(svc.db.WithContext(ctx), dao.TaskType(req.TaskType), req.Page, req.PageSize)
if err != nil {
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
return nil, err
}
resp := &task.TaskListResponse{
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
HasMore: false,
NextPage: 0,
}
for _, item := range tasks {
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
TaskID: item.TaskUUID,
FileName: item.FileName,
Status: item.Status.String(),
Path: item.FileURL,
TaskType: item.TaskType.String(),
CreatedAt: item.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
return resp, nil
}

View File

@@ -0,0 +1,109 @@
package service
import (
"context"
"errors"
"fmt"
"math/rand"
"gitea.com/bitwsd/core/common/log"
"gitea.com/bitwsd/document_ai/internal/storage/cache"
"gitea.com/bitwsd/document_ai/internal/storage/dao"
"gitea.com/bitwsd/document_ai/pkg/sms"
)
type UserService struct {
userDao *dao.UserDao
}
func NewUserService() *UserService {
return &UserService{
userDao: dao.NewUserDao(),
}
}
func (svc *UserService) GetSmsCode(ctx context.Context, phone string) (string, error) {
limit, err := cache.GetUserSendSmsLimit(ctx, phone)
if err != nil {
log.Error(ctx, "func", "GetSmsCode", "msg", "get user send sms limit error", "error", err)
return "", err
}
if limit >= cache.UserSendSmsLimitCount {
return "", errors.New("sms code send limit reached")
}
user, err := svc.userDao.GetByPhone(dao.DB.WithContext(ctx), phone)
if err != nil {
log.Error(ctx, "func", "GetSmsCode", "msg", "get user error", "error", err)
return "", err
}
if user == nil {
user = &dao.User{Phone: phone}
err = svc.userDao.Create(dao.DB.WithContext(ctx), user)
if err != nil {
log.Error(ctx, "func", "GetSmsCode", "msg", "create user error", "error", err)
return "", err
}
}
code := fmt.Sprintf("%06d", rand.Intn(1000000))
err = sms.SendMessage(&sms.SendSmsRequest{PhoneNumbers: phone, SignName: sms.Signature, TemplateCode: sms.TemplateCode, TemplateParam: fmt.Sprintf(sms.TemplateParam, code)})
if err != nil {
log.Error(ctx, "func", "GetSmsCode", "msg", "send message error", "error", err)
return "", err
}
cacheErr := cache.SetUserSmsCode(ctx, phone, code)
if cacheErr != nil {
log.Error(ctx, "func", "GetSmsCode", "msg", "set user sms code error", "error", cacheErr)
}
cacheErr = cache.SetUserSendSmsLimit(ctx, phone)
if cacheErr != nil {
log.Error(ctx, "func", "GetSmsCode", "msg", "set user send sms limit error", "error", cacheErr)
}
return code, nil
}
func (svc *UserService) VerifySmsCode(ctx context.Context, phone, code string) (uid int64, err error) {
user, err := svc.userDao.GetByPhone(dao.DB.WithContext(ctx), phone)
if err != nil {
log.Error(ctx, "func", "VerifySmsCode", "msg", "get user error", "error", err, "phone", phone)
return 0, err
}
if user == nil {
log.Error(ctx, "func", "VerifySmsCode", "msg", "user not found", "phone", phone)
return 0, errors.New("user not found")
}
storedCode, err := cache.GetUserSmsCode(ctx, phone)
if err != nil {
log.Error(ctx, "func", "VerifySmsCode", "msg", "get user sms code error", "error", err)
return 0, err
}
if storedCode != code {
log.Error(ctx, "func", "VerifySmsCode", "msg", "invalid sms code", "phone", phone, "code", code, "storedCode", storedCode)
return 0, errors.New("invalid sms code")
}
cacheErr := cache.DeleteUserSmsCode(ctx, phone)
if cacheErr != nil {
log.Error(ctx, "func", "VerifySmsCode", "msg", "delete user sms code error", "error", cacheErr)
}
return user.ID, nil
}
func (svc *UserService) GetUserInfo(ctx context.Context, uid int64) (*dao.User, error) {
user, err := svc.userDao.GetByID(dao.DB.WithContext(ctx), uid)
if err != nil {
log.Error(ctx, "func", "GetUserInfo", "msg", "get user error", "error", err)
return nil, err
}
if user == nil {
log.Warn(ctx, "func", "GetUserInfo", "msg", "user not found", "uid", uid)
return &dao.User{}, nil
}
return user, nil
}

33
internal/storage/cache/engine.go vendored Normal file
View File

@@ -0,0 +1,33 @@
package cache
import (
"context"
"fmt"
"time"
"gitea.com/bitwsd/document_ai/config"
"github.com/redis/go-redis/v9"
)
var RedisClient *redis.Client
func InitRedisClient(config config.RedisConfig) {
fmt.Println("Initializing Redis client...")
RedisClient = redis.NewClient(&redis.Options{
Addr: config.Addr,
Password: config.Password,
DB: config.DB,
DialTimeout: 10 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
})
fmt.Println("Pinging Redis server...")
_, err := RedisClient.Ping(context.Background()).Result()
if err != nil {
fmt.Printf("Init redis client failed, err: %v\n", err)
panic(err)
}
fmt.Println("Redis client initialized successfully.")
}

100
internal/storage/cache/formula.go vendored Normal file
View File

@@ -0,0 +1,100 @@
package cache
import (
"context"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const (
FormulaRecognitionTaskCount = "formula_recognition_task"
FormulaRecognitionTaskQueue = "formula_recognition_queue"
FormulaRecognitionDistLock = "formula_recognition_dist_lock"
VLMFormulaCount = "vlm_formula_count:%s" // VLM公式识别次数 ip
VLMRecognitionTaskQueue = "vlm_recognition_queue"
DefaultLockTimeout = 60 * time.Second // 默认锁超时时间
)
// TODO the sigle queue not reliable, message maybe lost
func PushVLMRecognitionTask(ctx context.Context, taskID int64) (count int64, err error) {
count, err = RedisClient.LPush(ctx, VLMRecognitionTaskQueue, taskID).Result()
if err != nil {
return 0, err
}
return count, nil
}
func PopVLMRecognitionTask(ctx context.Context) (int64, error) {
result, err := RedisClient.BRPop(ctx, 0, VLMRecognitionTaskQueue).Result()
if err != nil {
return 0, err
}
return strconv.ParseInt(result[1], 10, 64)
}
func PushFormulaTask(ctx context.Context, taskID int64) (count int64, err error) {
count, err = RedisClient.LPush(ctx, FormulaRecognitionTaskQueue, taskID).Result()
if err != nil {
return 0, err
}
return count, nil
}
func PopFormulaTask(ctx context.Context) (int64, error) {
result, err := RedisClient.BRPop(ctx, 0, FormulaRecognitionTaskQueue).Result()
if err != nil {
return 0, err
}
return strconv.ParseInt(result[1], 10, 64)
}
func GetFormulaTaskCount(ctx context.Context) (int64, error) {
count, err := RedisClient.LLen(ctx, FormulaRecognitionTaskQueue).Result()
if err != nil {
return 0, err
}
return count, nil
}
// GetDistributedLock 获取分布式锁
func GetDistributedLock(ctx context.Context) (bool, error) {
return RedisClient.SetNX(ctx, FormulaRecognitionDistLock, "locked", DefaultLockTimeout).Result()
}
// ReleaseLock 释放分布式锁
func ReleaseLock(ctx context.Context) error {
return RedisClient.Del(ctx, FormulaRecognitionDistLock).Err()
}
func GetVLMFormulaCount(ctx context.Context, ip string) (int64, error) {
count, err := RedisClient.Get(ctx, fmt.Sprintf(VLMFormulaCount, ip)).Result()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
return strconv.ParseInt(count, 10, 64)
}
func IncrVLMFormulaCount(ctx context.Context, ip string) (int64, error) {
key := fmt.Sprintf(VLMFormulaCount, ip)
count, err := RedisClient.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
if count == 1 {
now := time.Now()
nextMidnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
ttl := nextMidnight.Sub(now)
if err := RedisClient.Expire(ctx, key, ttl).Err(); err != nil {
return count, err
}
}
return count, nil
}

12
internal/storage/cache/url.go vendored Normal file
View File

@@ -0,0 +1,12 @@
package cache
import "context"
func IncrURLCount(ctx context.Context) (int64, error) {
key := "formula_recognition:url_count"
count, err := RedisClient.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
return count, nil
}

63
internal/storage/cache/user.go vendored Normal file
View File

@@ -0,0 +1,63 @@
package cache
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const (
UserSmsCodeTTL = 10 * time.Minute
UserSendSmsLimitTTL = 24 * time.Hour
UserSendSmsLimitCount = 5
)
const (
UserSmsCodePrefix = "user:sms_code:%s"
UserSendSmsLimit = "user:send_sms_limit:%s"
)
func GetUserSmsCode(ctx context.Context, phone string) (string, error) {
code, err := RedisClient.Get(ctx, fmt.Sprintf(UserSmsCodePrefix, phone)).Result()
if err != nil {
if err == redis.Nil {
return "", nil
}
return "", err
}
return code, nil
}
func SetUserSmsCode(ctx context.Context, phone, code string) error {
return RedisClient.Set(ctx, fmt.Sprintf(UserSmsCodePrefix, phone), code, UserSmsCodeTTL).Err()
}
func GetUserSendSmsLimit(ctx context.Context, phone string) (int, error) {
limit, err := RedisClient.Get(ctx, fmt.Sprintf(UserSendSmsLimit, phone)).Result()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
return strconv.Atoi(limit)
}
func SetUserSendSmsLimit(ctx context.Context, phone string) error {
count, err := RedisClient.Incr(ctx, fmt.Sprintf(UserSendSmsLimit, phone)).Result()
if err != nil {
return err
}
if count > UserSendSmsLimitCount {
return errors.New("send sms limit")
}
return RedisClient.Expire(ctx, fmt.Sprintf(UserSendSmsLimit, phone), UserSendSmsLimitTTL).Err()
}
func DeleteUserSmsCode(ctx context.Context, phone string) error {
return RedisClient.Del(ctx, fmt.Sprintf(UserSmsCodePrefix, phone)).Err()
}

View File

@@ -0,0 +1,11 @@
package dao
import (
"time"
)
type BaseModel struct {
ID int64 `gorm:"bigint;primaryKey;autoIncrement;column:id;comment:主键ID" json:"id"`
CreatedAt time.Time `gorm:"column:created_at;comment:创建时间;not null;default:current_timestamp" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;comment:更新时间;not null;default:current_timestamp on update current_timestamp" json:"updated_at"`
}

View File

@@ -0,0 +1,31 @@
package dao
import (
"fmt"
"gitea.com/bitwsd/document_ai/config"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var DB *gorm.DB
func InitDB(conf config.DatabaseConfig) {
dns := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", conf.Username, conf.Password, conf.Host, conf.Port, conf.DBName)
db, err := gorm.Open(mysql.Open(dns), &gorm.Config{})
if err != nil {
panic(err)
}
sqlDB, err := db.DB()
if err != nil {
panic(err)
}
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
DB = db
}
func CloseDB() {
sqlDB, _ := DB.DB()
sqlDB.Close()
}

View File

@@ -0,0 +1,26 @@
package dao
import "gorm.io/gorm"
type EvaluateTask struct {
BaseModel
TaskID int64 `gorm:"column:task_id;type:int;not null;comment:任务ID"`
Satisfied int `gorm:"column:satisfied;type:int;not null;comment:满意"`
Feedback string `gorm:"column:feedback;type:text;not null;comment:反馈"`
Comment string `gorm:"column:comment;type:text;not null;comment:评论"`
}
func (EvaluateTask) TableName() string {
return "evaluate_tasks"
}
type EvaluateTaskDao struct {
}
func NewEvaluateTaskDao() *EvaluateTaskDao {
return &EvaluateTaskDao{}
}
func (dao *EvaluateTaskDao) Create(sess *gorm.DB, data *EvaluateTask) error {
return sess.Create(data).Error
}

View File

@@ -0,0 +1,89 @@
package dao
import (
"encoding/json"
"gorm.io/gorm"
)
type JSON []byte
// ContentCodec 定义内容编解码接口
type ContentCodec interface {
Encode() (JSON, error)
Decode() error
GetContent() interface{} // 更明确的方法名
}
type FormulaRecognitionContent struct {
content JSON
Latex string `json:"latex"`
AdjustLatex string `json:"adjust_latex"`
EnhanceLatex string `json:"enhance_latex"`
}
func (c *FormulaRecognitionContent) Encode() (JSON, error) {
b, err := json.Marshal(c)
if err != nil {
return nil, err
}
return b, nil
}
func (c *FormulaRecognitionContent) Decode() error {
return json.Unmarshal(c.content, c)
}
// GetPreferredContent 按优先级返回公式内容
func (c *FormulaRecognitionContent) GetContent() interface{} {
c.Decode()
if c.EnhanceLatex != "" {
return c.EnhanceLatex
} else if c.AdjustLatex != "" {
return c.AdjustLatex
} else {
return c.Latex
}
}
type RecognitionResult struct {
BaseModel
TaskID int64 `gorm:"column:task_id;bigint;not null;default:0;comment:任务ID" json:"task_id"`
TaskType TaskType `gorm:"column:task_type;varchar(16);not null;comment:任务类型;default:''" json:"task_type"`
Content JSON `gorm:"column:content;type:json;not null;comment:识别内容" json:"content"`
}
// NewContentCodec 创建对应任务类型的内容编解码器
func (r *RecognitionResult) NewContentCodec() ContentCodec {
switch r.TaskType {
case TaskTypeFormula:
return &FormulaRecognitionContent{content: r.Content}
default:
return nil
}
}
type RecognitionResultDao struct {
}
func NewRecognitionResultDao() *RecognitionResultDao {
return &RecognitionResultDao{}
}
// 模型方法
func (dao *RecognitionResultDao) Create(tx *gorm.DB, data RecognitionResult) error {
return tx.Create(&data).Error
}
func (dao *RecognitionResultDao) GetByTaskID(tx *gorm.DB, taskID int64) (result *RecognitionResult, err error) {
result = &RecognitionResult{}
err = tx.Where("task_id = ?", taskID).First(result).Error
if err != nil && err == gorm.ErrRecordNotFound {
return nil, nil
}
return
}
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
}

View File

@@ -0,0 +1,94 @@
package dao
import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type TaskStatus int
type TaskType string
const (
TaskStatusPending TaskStatus = 0
TaskStatusProcessing TaskStatus = 1
TaskStatusCompleted TaskStatus = 2
TaskStatusFailed TaskStatus = 3
TaskTypeFormula TaskType = "FORMULA"
TaskTypeText TaskType = "TEXT"
TaskTypeTable TaskType = "TABLE"
TaskTypeLayout TaskType = "LAYOUT"
)
func (t TaskType) String() string {
return string(t)
}
func (t TaskStatus) String() string {
return []string{"PENDING", "PROCESSING", "COMPLETED", "FAILED"}[t]
}
type RecognitionTask struct {
BaseModel
UserID int64 `gorm:"column:user_id;not null;default:0;comment:用户ID" json:"user_id"`
TaskUUID string `gorm:"column:task_uuid;varchar(64);not null;default:'';comment:任务唯一标识" json:"task_uuid"`
FileName string `gorm:"column:file_name;varchar(256);not null;default:'';comment:文件名" json:"file_name"`
FileHash string `gorm:"column:file_hash;varchar(64);not null;default:'';comment:文件hash" json:"file_hash"`
FileURL string `gorm:"column:file_url;varchar(128);not null;comment:oss文件地址;default:''" json:"file_url"`
TaskType TaskType `gorm:"column:task_type;varchar(16);not null;comment:任务类型;default:''" json:"task_type"`
Status TaskStatus `gorm:"column:status;tinyint(2);not null;comment:任务状态;default:0" json:"status"`
CompletedAt time.Time `gorm:"column:completed_at;not null;default:current_timestamp;comment:完成时间" json:"completed_at"`
Remark string `gorm:"column:remark;varchar(64);comment:备注;not null;default:''" json:"remark"`
IP string `gorm:"column:ip;varchar(16);comment:IP地址;not null;default:''" json:"ip"`
}
func (t *RecognitionTask) TableName() string {
return "recognition_tasks"
}
type RecognitionTaskDao struct{}
func NewRecognitionTaskDao() *RecognitionTaskDao {
return &RecognitionTaskDao{}
}
// 模型方法
func (dao *RecognitionTaskDao) Create(tx *gorm.DB, data *RecognitionTask) error {
return tx.Create(data).Error
}
func (dao *RecognitionTaskDao) Update(tx *gorm.DB, filter map[string]interface{}, data map[string]interface{}) error {
return tx.Model(RecognitionTask{}).Where(filter).Updates(data).Error
}
func (dao *RecognitionTaskDao) GetByTaskNo(tx *gorm.DB, taskUUID string) (task *RecognitionTask, err error) {
task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("task_uuid = ?", taskUUID).First(task).Error
return
}
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, userID int64, fileHash string) (task *RecognitionTask, err error) {
task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("user_id = ? AND file_hash = ?", userID, fileHash).First(task).Error
return
}
func (dao *RecognitionTaskDao) GetTaskByID(tx *gorm.DB, id int64) (task *RecognitionTask, err error) {
task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("id = ?", id).First(task).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return task, nil
}
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, err error) {
offset := (page - 1) * pageSize
err = tx.Model(RecognitionTask{}).Where("task_type = ?", taskType).Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
return
}

View File

@@ -0,0 +1,53 @@
package dao
import (
"errors"
"gorm.io/gorm"
)
type User struct {
BaseModel
Username string `gorm:"column:username" json:"username"`
Phone string `gorm:"column:phone" json:"phone"`
Password string `gorm:"column:password" json:"password"`
WechatOpenID string `gorm:"column:wechat_open_id" json:"wechat_open_id"`
WechatUnionID string `gorm:"column:wechat_union_id" json:"wechat_union_id"`
}
func (u *User) TableName() string {
return "users"
}
type UserDao struct {
}
func NewUserDao() *UserDao {
return &UserDao{}
}
func (dao *UserDao) Create(tx *gorm.DB, user *User) error {
return tx.Create(user).Error
}
func (dao *UserDao) GetByPhone(tx *gorm.DB, phone string) (*User, error) {
var user User
if err := tx.Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &user, nil
}
func (dao *UserDao) GetByID(tx *gorm.DB, id int64) (*User, error) {
var user User
if err := tx.Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &user, nil
}