Compare commits
36 Commits
feature/us
...
eabfd83fdf
| Author | SHA1 | Date | |
|---|---|---|---|
| eabfd83fdf | |||
| 97c3617731 | |||
| ece026bea2 | |||
| b9124451d2 | |||
| be1047618e | |||
| 3293f1f8a5 | |||
| ff6795b469 | |||
| cb461f0134 | |||
| 7c4dfaba54 | |||
| 5ee1cea0d7 | |||
| a538bd6680 | |||
| cd221719cf | |||
| d0c0d2cbc3 | |||
| 930d782f18 | |||
| bdd21c4b0f | |||
| 0aaafdbaa3 | |||
| 68a1755a83 | |||
| bb7403f700 | |||
| 3a86f811d0 | |||
| 28295f825b | |||
| e0904f5bfb | |||
| 073808eb30 | |||
| 7be0d705fe | |||
| 770c334083 | |||
| 08d5e37d0e | |||
| 203c2b64c0 | |||
| aa7fb1c7ca | |||
| ae2b58149d | |||
| 9e088879c2 | |||
| be00a91637 | |||
| 4bbbb99634 | |||
| 4bb59ecf7e | |||
| 5a1983f08b | |||
| 8a6da5b627 | |||
| d06f2d9df1 | |||
| b1a3b7cd17 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -6,3 +6,6 @@
|
|||||||
/upload
|
/upload
|
||||||
texpixel
|
texpixel
|
||||||
/vendor
|
/vendor
|
||||||
|
|
||||||
|
dev_deploy.sh
|
||||||
|
speed_take.sh
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
# Build stage
|
# Build stage
|
||||||
FROM crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/golang:1.20-apline AS builder
|
FROM crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/golang:1.20-alpine AS builder
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -43,4 +42,4 @@ EXPOSE 8024
|
|||||||
ENTRYPOINT ["./doc_ai"]
|
ENTRYPOINT ["./doc_ai"]
|
||||||
|
|
||||||
# Default command (can be overridden)
|
# Default command (can be overridden)
|
||||||
CMD ["-env", "prod"]
|
CMD ["-env", "prod"]
|
||||||
|
|||||||
@@ -5,15 +5,48 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/api/v1/oss"
|
"gitea.com/bitwsd/document_ai/api/v1/oss"
|
||||||
"gitea.com/bitwsd/document_ai/api/v1/task"
|
"gitea.com/bitwsd/document_ai/api/v1/task"
|
||||||
"gitea.com/bitwsd/document_ai/api/v1/user"
|
"gitea.com/bitwsd/document_ai/api/v1/user"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupRouter(engine *gin.RouterGroup) {
|
func SetupRouter(engine *gin.RouterGroup) {
|
||||||
v1 := engine.Group("/v1")
|
v1 := engine.Group("/v1")
|
||||||
{
|
{
|
||||||
formula.SetupRouter(v1)
|
formulaRouter := v1.Group("/formula", common.GetAuthMiddleware())
|
||||||
oss.SetupRouter(v1)
|
{
|
||||||
task.SetupRouter(v1)
|
endpoint := formula.NewFormulaEndpoint()
|
||||||
user.SetupRouter(v1)
|
formulaRouter.POST("/recognition", endpoint.CreateTask)
|
||||||
|
formulaRouter.POST("/ai_enhance", endpoint.AIEnhanceRecognition)
|
||||||
|
formulaRouter.GET("/recognition/:task_no", endpoint.GetTaskStatus)
|
||||||
|
formulaRouter.POST("/test_process_mathpix_task", endpoint.TestProcessMathpixTask)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskRouter := v1.Group("/task", common.GetAuthMiddleware())
|
||||||
|
{
|
||||||
|
endpoint := task.NewTaskEndpoint()
|
||||||
|
taskRouter.POST("/evaluate", endpoint.EvaluateTask)
|
||||||
|
taskRouter.GET("/list", common.MustAuthMiddleware(), endpoint.GetTaskList)
|
||||||
|
taskRouter.POST("/export", endpoint.ExportTask)
|
||||||
|
}
|
||||||
|
|
||||||
|
ossRouter := v1.Group("/oss", common.GetAuthMiddleware())
|
||||||
|
{
|
||||||
|
endpoint := oss.NewOSSEndpoint()
|
||||||
|
ossRouter.POST("/signature", endpoint.GetPostObjectSignature)
|
||||||
|
ossRouter.POST("/signature_url", endpoint.GetSignatureURL)
|
||||||
|
ossRouter.POST("/file/upload", endpoint.UploadFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
userRouter := v1.Group("/user", common.GetAuthMiddleware())
|
||||||
|
{
|
||||||
|
userEndpoint := user.NewUserEndpoint()
|
||||||
|
{
|
||||||
|
userRouter.POST("/sms", userEndpoint.SendVerificationCode)
|
||||||
|
userRouter.POST("/register", userEndpoint.RegisterByEmail)
|
||||||
|
userRouter.POST("/login", userEndpoint.LoginByEmail)
|
||||||
|
userRouter.GET("/info", common.MustAuthMiddleware(), userEndpoint.GetUserInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/internal/service"
|
"gitea.com/bitwsd/document_ai/internal/service"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -37,11 +38,14 @@ func NewFormulaEndpoint() *FormulaEndpoint {
|
|||||||
// @Router /v1/formula/recognition [post]
|
// @Router /v1/formula/recognition [post]
|
||||||
func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) {
|
func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) {
|
||||||
var req formula.CreateFormulaRecognitionRequest
|
var req formula.CreateFormulaRecognitionRequest
|
||||||
|
uid := ctx.GetInt64(constant.ContextUserID)
|
||||||
if err := ctx.BindJSON(&req); err != nil {
|
if err := ctx.BindJSON(&req); err != nil {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.UserID = uid
|
||||||
|
|
||||||
if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) {
|
if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type"))
|
||||||
return
|
return
|
||||||
@@ -117,3 +121,20 @@ func (endpoint *FormulaEndpoint) AIEnhanceRecognition(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, common.SuccessResponse(c, nil))
|
c.JSON(http.StatusOK, common.SuccessResponse(c, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (endpoint *FormulaEndpoint) TestProcessMathpixTask(c *gin.Context) {
|
||||||
|
postData := make(map[string]int)
|
||||||
|
if err := c.BindJSON(&postData); err != nil {
|
||||||
|
c.JSON(http.StatusOK, common.ErrorResponse(c, common.CodeParamError, "Invalid parameters"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID := postData["task_id"]
|
||||||
|
err := endpoint.recognitionService.TestProcessMathpixTask(c, int64(taskID))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, common.ErrorResponse(c, common.CodeSystemError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, common.SuccessResponse(c, nil))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package formula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func SetupRouter(engine *gin.RouterGroup) {
|
|
||||||
endpoint := NewFormulaEndpoint()
|
|
||||||
engine.POST("/formula/recognition", endpoint.CreateTask)
|
|
||||||
engine.POST("/formula/ai_enhance", endpoint.AIEnhanceRecognition)
|
|
||||||
engine.GET("/formula/recognition/:task_no", endpoint.GetTaskStatus)
|
|
||||||
}
|
|
||||||
@@ -11,14 +11,20 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/config"
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/oss"
|
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetPostObjectSignature(ctx *gin.Context) {
|
type OSSEndpoint struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOSSEndpoint() *OSSEndpoint {
|
||||||
|
return &OSSEndpoint{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OSSEndpoint) GetPostObjectSignature(ctx *gin.Context) {
|
||||||
policyToken, err := oss.GetPolicyToken()
|
policyToken, err := oss.GetPolicyToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, err.Error()))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, err.Error()))
|
||||||
@@ -37,8 +43,7 @@ func GetPostObjectSignature(ctx *gin.Context) {
|
|||||||
// @Success 200 {object} common.Response{data=map[string]string{"sign_url":string, "repeat":bool, "path":string}} "Signed URL generated successfully"
|
// @Success 200 {object} common.Response{data=map[string]string{"sign_url":string, "repeat":bool, "path":string}} "Signed URL generated successfully"
|
||||||
// @Failure 200 {object} common.Response "Error response"
|
// @Failure 200 {object} common.Response "Error response"
|
||||||
// @Router /signature_url [get]
|
// @Router /signature_url [get]
|
||||||
func GetSignatureURL(ctx *gin.Context) {
|
func (h *OSSEndpoint) GetSignatureURL(ctx *gin.Context) {
|
||||||
userID := ctx.GetInt64(constant.ContextUserID)
|
|
||||||
type Req struct {
|
type Req struct {
|
||||||
FileHash string `json:"file_hash" binding:"required"`
|
FileHash string `json:"file_hash" binding:"required"`
|
||||||
FileName string `json:"file_name" binding:"required"`
|
FileName string `json:"file_name" binding:"required"`
|
||||||
@@ -51,7 +56,7 @@ func GetSignatureURL(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
taskDao := dao.NewRecognitionTaskDao()
|
||||||
sess := dao.DB.WithContext(ctx)
|
sess := dao.DB.WithContext(ctx)
|
||||||
task, err := taskDao.GetTaskByFileURL(sess, userID, req.FileHash)
|
task, err := taskDao.GetTaskByFileURL(sess, req.FileHash)
|
||||||
if err != nil && err != gorm.ErrRecordNotFound {
|
if err != nil && err != gorm.ErrRecordNotFound {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task"))
|
||||||
return
|
return
|
||||||
@@ -78,7 +83,7 @@ func GetSignatureURL(ctx *gin.Context) {
|
|||||||
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": url, "repeat": false, "path": path}))
|
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": url, "repeat": false, "path": path}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func UploadFile(ctx *gin.Context) {
|
func (h *OSSEndpoint) UploadFile(ctx *gin.Context) {
|
||||||
if err := os.MkdirAll(config.GlobalConfig.UploadDir, 0755); err != nil {
|
if err := os.MkdirAll(config.GlobalConfig.UploadDir, 0755); err != nil {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, "Failed to create upload directory"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, "Failed to create upload directory"))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package oss
|
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
func SetupRouter(parent *gin.RouterGroup) {
|
|
||||||
router := parent.Group("oss")
|
|
||||||
{
|
|
||||||
router.POST("/signature", GetPostObjectSignature)
|
|
||||||
router.POST("/signature_url", GetSignatureURL)
|
|
||||||
router.POST("/file/upload", UploadFile)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,10 +3,10 @@ package task
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/internal/model/task"
|
"gitea.com/bitwsd/document_ai/internal/model/task"
|
||||||
"gitea.com/bitwsd/document_ai/internal/service"
|
"gitea.com/bitwsd/document_ai/internal/service"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +43,8 @@ func (h *TaskEndpoint) GetTaskList(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.UserID = common.GetUserIDFromContext(c)
|
||||||
|
|
||||||
if req.Page <= 0 {
|
if req.Page <= 0 {
|
||||||
req.Page = 1
|
req.Page = 1
|
||||||
}
|
}
|
||||||
@@ -59,3 +61,31 @@ func (h *TaskEndpoint) GetTaskList(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, common.SuccessResponse(c, resp))
|
c.JSON(http.StatusOK, common.SuccessResponse(c, resp))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *TaskEndpoint) ExportTask(c *gin.Context) {
|
||||||
|
var req task.ExportTaskRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
log.Error(c, "func", "ExportTask", "msg", "Invalid parameters", "error", err)
|
||||||
|
c.JSON(http.StatusOK, common.ErrorResponse(c, common.CodeParamError, "Invalid parameters"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fileData, contentType, err := h.taskService.ExportTask(c, &req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, common.ErrorResponse(c, common.CodeSystemError, "导出任务失败"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// set filename based on export type
|
||||||
|
var filename string
|
||||||
|
switch req.Type {
|
||||||
|
case "pdf":
|
||||||
|
filename = "texpixel_export.pdf"
|
||||||
|
case "docx":
|
||||||
|
filename = "texpixel_export.docx"
|
||||||
|
default:
|
||||||
|
filename = "texpixel_export"
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Disposition", "attachment; filename="+filename)
|
||||||
|
c.Data(http.StatusOK, contentType, fileData)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package task
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func SetupRouter(engine *gin.RouterGroup) {
|
|
||||||
endpoint := NewTaskEndpoint()
|
|
||||||
engine.POST("/task/evaluate", endpoint.EvaluateTask)
|
|
||||||
engine.GET("/task/list", endpoint.GetTaskList)
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package user
|
|
||||||
|
|
||||||
import (
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func SetupRouter(router *gin.RouterGroup) {
|
|
||||||
userEndpoint := NewUserEndpoint()
|
|
||||||
userRouter := router.Group("/user")
|
|
||||||
{
|
|
||||||
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode)
|
|
||||||
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode)
|
|
||||||
userRouter.POST("/register/email", userEndpoint.RegisterByEmail)
|
|
||||||
userRouter.POST("/login/email", userEndpoint.LoginByEmail)
|
|
||||||
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
73
cmd/migrate/README.md
Normal file
73
cmd/migrate/README.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# 数据迁移工具
|
||||||
|
|
||||||
|
用于将测试数据库的数据迁移到生产数据库,避免ID冲突,使用事务确保数据一致性。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- ✅ 自动避免ID冲突(使用数据库自增ID)
|
||||||
|
- ✅ 使用事务确保每个任务和结果数据的一致性
|
||||||
|
- ✅ 自动跳过已存在的任务(基于task_uuid)
|
||||||
|
- ✅ 保留原始时间戳
|
||||||
|
- ✅ 处理NULL值
|
||||||
|
- ✅ 详细的日志输出和统计信息
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 基本用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从dev环境迁移到prod环境
|
||||||
|
go run cmd/migrate/main.go -test-env=dev -prod-env=prod
|
||||||
|
|
||||||
|
# 从prod环境迁移到dev环境(测试反向迁移)
|
||||||
|
go run cmd/migrate/main.go -test-env=prod -prod-env=dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### 参数说明
|
||||||
|
|
||||||
|
- `-test-env`: 测试环境配置文件名(dev/prod),默认值:dev
|
||||||
|
- `-prod-env`: 生产环境配置文件名(dev/prod),默认值:prod
|
||||||
|
|
||||||
|
### 编译后使用
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 编译
|
||||||
|
go build -o migrate cmd/migrate/main.go
|
||||||
|
|
||||||
|
# 运行
|
||||||
|
./migrate -test-env=dev -prod-env=prod
|
||||||
|
```
|
||||||
|
|
||||||
|
## 工作原理
|
||||||
|
|
||||||
|
1. **连接数据库**:同时连接测试数据库和生产数据库
|
||||||
|
2. **读取数据**:从测试数据库读取所有任务和结果数据(LEFT JOIN)
|
||||||
|
3. **检查重复**:基于`task_uuid`检查生产数据库中是否已存在
|
||||||
|
4. **事务迁移**:为每个任务创建独立事务:
|
||||||
|
- 创建任务记录(自动生成新ID)
|
||||||
|
- 如果存在结果数据,创建结果记录(关联新任务ID)
|
||||||
|
- 提交事务或回滚
|
||||||
|
5. **统计报告**:输出迁移统计信息
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **配置文件**:确保`config/config_dev.yaml`和`config/config_prod.yaml`存在且配置正确
|
||||||
|
2. **数据库权限**:确保数据库用户有读写权限
|
||||||
|
3. **网络连接**:确保能同时连接到两个数据库
|
||||||
|
4. **数据备份**:迁移前建议备份生产数据库
|
||||||
|
5. **ID冲突**:脚本会自动处理ID冲突,使用数据库自增ID,不会覆盖现有数据
|
||||||
|
|
||||||
|
## 输出示例
|
||||||
|
|
||||||
|
```
|
||||||
|
从测试数据库读取到 100 条任务记录
|
||||||
|
[1/100] 创建任务成功: task_uuid=xxx, 新ID=1001
|
||||||
|
[1/100] 创建结果成功: task_id=1001
|
||||||
|
[2/100] 跳过已存在的任务: task_uuid=yyy, id=1002
|
||||||
|
...
|
||||||
|
迁移完成统计:
|
||||||
|
成功: 95 条
|
||||||
|
跳过: 3 条
|
||||||
|
失败: 2 条
|
||||||
|
数据迁移完成!
|
||||||
|
```
|
||||||
255
cmd/migrate/main.go
Normal file
255
cmd/migrate/main.go
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 解析命令行参数
|
||||||
|
testEnv := flag.String("test-env", "dev", "测试环境配置 (dev/prod)")
|
||||||
|
prodEnv := flag.String("prod-env", "prod", "生产环境配置 (dev/prod)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// 加载测试环境配置
|
||||||
|
testConfigPath := fmt.Sprintf("./config/config_%s.yaml", *testEnv)
|
||||||
|
testConfig, err := loadDatabaseConfig(testConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("加载测试环境配置失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接测试数据库
|
||||||
|
testDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai",
|
||||||
|
testConfig.Username, testConfig.Password, testConfig.Host, testConfig.Port, testConfig.DBName)
|
||||||
|
testDB, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Info),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("连接测试数据库失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载生产环境配置
|
||||||
|
prodConfigPath := fmt.Sprintf("./config/config_%s.yaml", *prodEnv)
|
||||||
|
prodConfig, err := loadDatabaseConfig(prodConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("加载生产环境配置失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接生产数据库
|
||||||
|
prodDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai",
|
||||||
|
prodConfig.Username, prodConfig.Password, prodConfig.Host, prodConfig.Port, prodConfig.DBName)
|
||||||
|
prodDB, err := gorm.Open(mysql.Open(prodDSN), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Info),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("连接生产数据库失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行迁移
|
||||||
|
if err := migrateData(testDB, prodDB); err != nil {
|
||||||
|
log.Fatalf("数据迁移失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("数据迁移完成!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateData(testDB, prodDB *gorm.DB) error {
|
||||||
|
_ = context.Background() // 保留以备将来使用
|
||||||
|
|
||||||
|
// 从测试数据库读取所有任务数据(包含结果)
|
||||||
|
type TaskWithResult struct {
|
||||||
|
// Task 字段
|
||||||
|
TaskID int64 `gorm:"column:id"`
|
||||||
|
UserID int64 `gorm:"column:user_id"`
|
||||||
|
TaskUUID string `gorm:"column:task_uuid"`
|
||||||
|
FileName string `gorm:"column:file_name"`
|
||||||
|
FileHash string `gorm:"column:file_hash"`
|
||||||
|
FileURL string `gorm:"column:file_url"`
|
||||||
|
TaskType string `gorm:"column:task_type"`
|
||||||
|
Status int `gorm:"column:status"`
|
||||||
|
CompletedAt time.Time `gorm:"column:completed_at"`
|
||||||
|
Remark string `gorm:"column:remark"`
|
||||||
|
IP string `gorm:"column:ip"`
|
||||||
|
TaskCreatedAt time.Time `gorm:"column:created_at"`
|
||||||
|
TaskUpdatedAt time.Time `gorm:"column:updated_at"`
|
||||||
|
// Result 字段
|
||||||
|
ResultID *int64 `gorm:"column:result_id"`
|
||||||
|
ResultTaskID *int64 `gorm:"column:result_task_id"`
|
||||||
|
ResultTaskType *string `gorm:"column:result_task_type"`
|
||||||
|
Latex *string `gorm:"column:latex"`
|
||||||
|
Markdown *string `gorm:"column:markdown"`
|
||||||
|
MathML *string `gorm:"column:mathml"`
|
||||||
|
ResultCreatedAt *time.Time `gorm:"column:result_created_at"`
|
||||||
|
ResultUpdatedAt *time.Time `gorm:"column:result_updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var tasksWithResults []TaskWithResult
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
t.id,
|
||||||
|
t.user_id,
|
||||||
|
t.task_uuid,
|
||||||
|
t.file_name,
|
||||||
|
t.file_hash,
|
||||||
|
t.file_url,
|
||||||
|
t.task_type,
|
||||||
|
t.status,
|
||||||
|
t.completed_at,
|
||||||
|
t.remark,
|
||||||
|
t.ip,
|
||||||
|
t.created_at,
|
||||||
|
t.updated_at,
|
||||||
|
r.id as result_id,
|
||||||
|
r.task_id as result_task_id,
|
||||||
|
r.task_type as result_task_type,
|
||||||
|
r.latex,
|
||||||
|
r.markdown,
|
||||||
|
r.mathml,
|
||||||
|
r.created_at as result_created_at,
|
||||||
|
r.updated_at as result_updated_at
|
||||||
|
FROM recognition_tasks t
|
||||||
|
LEFT JOIN recognition_results r ON t.id = r.task_id
|
||||||
|
ORDER BY t.id
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := testDB.Raw(query).Scan(&tasksWithResults).Error; err != nil {
|
||||||
|
return fmt.Errorf("读取测试数据失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("从测试数据库读取到 %d 条任务记录", len(tasksWithResults))
|
||||||
|
|
||||||
|
successCount := 0
|
||||||
|
skipCount := 0
|
||||||
|
errorCount := 0
|
||||||
|
|
||||||
|
// 为每个任务使用独立事务,确保单个任务失败不影响其他任务
|
||||||
|
for i, item := range tasksWithResults {
|
||||||
|
// 开始事务
|
||||||
|
tx := prodDB.Begin()
|
||||||
|
|
||||||
|
// 检查生产数据库中是否已存在相同的 task_uuid
|
||||||
|
var existingTask dao.RecognitionTask
|
||||||
|
err := tx.Where("task_uuid = ?", item.TaskUUID).First(&existingTask).Error
|
||||||
|
if err == nil {
|
||||||
|
log.Printf("[%d/%d] 跳过已存在的任务: task_uuid=%s, id=%d", i+1, len(tasksWithResults), item.TaskUUID, existingTask.ID)
|
||||||
|
tx.Rollback()
|
||||||
|
skipCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
log.Printf("[%d/%d] 检查任务是否存在时出错: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
|
||||||
|
tx.Rollback()
|
||||||
|
errorCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建新任务(不指定ID,让数据库自动生成)
|
||||||
|
newTask := &dao.RecognitionTask{
|
||||||
|
UserID: item.UserID,
|
||||||
|
TaskUUID: item.TaskUUID,
|
||||||
|
FileName: item.FileName,
|
||||||
|
FileHash: item.FileHash,
|
||||||
|
FileURL: item.FileURL,
|
||||||
|
TaskType: dao.TaskType(item.TaskType),
|
||||||
|
Status: dao.TaskStatus(item.Status),
|
||||||
|
CompletedAt: item.CompletedAt,
|
||||||
|
Remark: item.Remark,
|
||||||
|
IP: item.IP,
|
||||||
|
}
|
||||||
|
// 保留原始时间戳
|
||||||
|
newTask.CreatedAt = item.TaskCreatedAt
|
||||||
|
newTask.UpdatedAt = item.TaskUpdatedAt
|
||||||
|
|
||||||
|
if err := tx.Create(newTask).Error; err != nil {
|
||||||
|
log.Printf("[%d/%d] 创建任务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
|
||||||
|
tx.Rollback()
|
||||||
|
errorCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[%d/%d] 创建任务成功: task_uuid=%s, 新ID=%d", i+1, len(tasksWithResults), item.TaskUUID, newTask.ID)
|
||||||
|
|
||||||
|
// 如果有结果数据,创建结果记录
|
||||||
|
if item.ResultID != nil {
|
||||||
|
// 处理可能为NULL的字段
|
||||||
|
latex := ""
|
||||||
|
if item.Latex != nil {
|
||||||
|
latex = *item.Latex
|
||||||
|
}
|
||||||
|
markdown := ""
|
||||||
|
if item.Markdown != nil {
|
||||||
|
markdown = *item.Markdown
|
||||||
|
}
|
||||||
|
mathml := ""
|
||||||
|
if item.MathML != nil {
|
||||||
|
mathml = *item.MathML
|
||||||
|
}
|
||||||
|
|
||||||
|
newResult := dao.RecognitionResult{
|
||||||
|
TaskID: newTask.ID, // 使用新任务的ID
|
||||||
|
TaskType: dao.TaskType(item.TaskType),
|
||||||
|
Latex: latex,
|
||||||
|
Markdown: markdown,
|
||||||
|
MathML: mathml,
|
||||||
|
}
|
||||||
|
// 保留原始时间戳
|
||||||
|
if item.ResultCreatedAt != nil {
|
||||||
|
newResult.CreatedAt = *item.ResultCreatedAt
|
||||||
|
}
|
||||||
|
if item.ResultUpdatedAt != nil {
|
||||||
|
newResult.UpdatedAt = *item.ResultUpdatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Create(&newResult).Error; err != nil {
|
||||||
|
log.Printf("[%d/%d] 创建结果失败: task_id=%d, error=%v", i+1, len(tasksWithResults), newTask.ID, err)
|
||||||
|
tx.Rollback() // 回滚整个事务(包括任务)
|
||||||
|
errorCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[%d/%d] 创建结果成功: task_id=%d", i+1, len(tasksWithResults), newTask.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
log.Printf("[%d/%d] 提交事务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
|
||||||
|
errorCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("迁移完成统计:")
|
||||||
|
log.Printf(" 成功: %d 条", successCount)
|
||||||
|
log.Printf(" 跳过: %d 条", skipCount)
|
||||||
|
log.Printf(" 失败: %d 条", errorCount)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadDatabaseConfig 从配置文件加载数据库配置
|
||||||
|
func loadDatabaseConfig(configPath string) (config.DatabaseConfig, error) {
|
||||||
|
v := viper.New()
|
||||||
|
v.SetConfigFile(configPath)
|
||||||
|
if err := v.ReadInConfig(); err != nil {
|
||||||
|
return config.DatabaseConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var dbConfig config.DatabaseConfig
|
||||||
|
if err := v.UnmarshalKey("database", &dbConfig); err != nil {
|
||||||
|
return config.DatabaseConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dbConfig, nil
|
||||||
|
}
|
||||||
@@ -13,6 +13,17 @@ type Config struct {
|
|||||||
UploadDir string `mapstructure:"upload_dir"`
|
UploadDir string `mapstructure:"upload_dir"`
|
||||||
Limit LimitConfig `mapstructure:"limit"`
|
Limit LimitConfig `mapstructure:"limit"`
|
||||||
Aliyun AliyunConfig `mapstructure:"aliyun"`
|
Aliyun AliyunConfig `mapstructure:"aliyun"`
|
||||||
|
Mathpix MathpixConfig `mapstructure:"mathpix"`
|
||||||
|
BaiduOCR BaiduOCRConfig `mapstructure:"baidu_ocr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduOCRConfig struct {
|
||||||
|
Token string `mapstructure:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MathpixConfig struct {
|
||||||
|
AppID string `mapstructure:"app_id"`
|
||||||
|
AppKey string `mapstructure:"app_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LimitConfig struct {
|
type LimitConfig struct {
|
||||||
|
|||||||
@@ -4,16 +4,16 @@ server:
|
|||||||
|
|
||||||
database:
|
database:
|
||||||
driver: mysql
|
driver: mysql
|
||||||
host: 182.92.150.161
|
host: mysql
|
||||||
port: 3006
|
port: 3306
|
||||||
username: root
|
username: root
|
||||||
password: yoge@coder%%%123321!
|
password: texpixel#pwd123!
|
||||||
dbname: doc_ai
|
dbname: doc_ai
|
||||||
max_idle: 10
|
max_idle: 10
|
||||||
max_open: 100
|
max_open: 100
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
addr: 182.92.150.161:6379
|
addr: redis:6379
|
||||||
password: yoge@123321!
|
password: yoge@123321!
|
||||||
db: 0
|
db: 0
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ limit:
|
|||||||
|
|
||||||
log:
|
log:
|
||||||
appName: document_ai
|
appName: document_ai
|
||||||
level: info # debug, info, warn, error
|
level: info
|
||||||
format: console # json, console
|
format: console # json, console
|
||||||
outputPath: ./logs/app.log # 日志文件路径
|
outputPath: ./logs/app.log # 日志文件路径
|
||||||
maxSize: 2 # 单个日志文件最大尺寸,单位MB
|
maxSize: 2 # 单个日志文件最大尺寸,单位MB
|
||||||
@@ -39,8 +39,16 @@ aliyun:
|
|||||||
template_code: "SMS_291510729"
|
template_code: "SMS_291510729"
|
||||||
|
|
||||||
oss:
|
oss:
|
||||||
endpoint: oss-cn-beijing.aliyuncs.com
|
endpoint: static.texpixel.com
|
||||||
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
||||||
access_key_id: LTAI5tKogxeiBb4gJGWEePWN
|
access_key_id: LTAI5t8qXhow6NCdYDtu1saF
|
||||||
access_key_secret: l4oCxtt5iLSQ1DAs40guTzKUfrxXwq
|
access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
|
||||||
bucket_name: bitwsd-doc-ai
|
bucket_name: texpixel-doc
|
||||||
|
|
||||||
|
mathpix:
|
||||||
|
app_id: "ocr_eede6f_ea9b5c"
|
||||||
|
app_key: "fb72d251e33ac85c929bfd4eec40d78368d08d82fb2ee1cffb04a8bb967d1db5"
|
||||||
|
|
||||||
|
|
||||||
|
baidu_ocr:
|
||||||
|
token: "e3a47bd2438f1f38840c203fc5939d17a54482d1"
|
||||||
@@ -18,7 +18,7 @@ redis:
|
|||||||
db: 0
|
db: 0
|
||||||
|
|
||||||
limit:
|
limit:
|
||||||
formula_recognition: 2
|
formula_recognition: 10
|
||||||
|
|
||||||
log:
|
log:
|
||||||
appName: document_ai
|
appName: document_ai
|
||||||
@@ -38,8 +38,16 @@ aliyun:
|
|||||||
template_code: "SMS_291510729"
|
template_code: "SMS_291510729"
|
||||||
|
|
||||||
oss:
|
oss:
|
||||||
endpoint: oss-cn-beijing.aliyuncs.com
|
endpoint: static.texpixel.com
|
||||||
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
||||||
access_key_id: LTAI5t8qXhow6NCdYDtu1saF
|
access_key_id: LTAI5t8qXhow6NCdYDtu1saF
|
||||||
access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
|
access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
|
||||||
bucket_name: texpixel-doc
|
bucket_name: texpixel-doc
|
||||||
|
|
||||||
|
|
||||||
|
mathpix:
|
||||||
|
app_id: "ocr_eede6f_ea9b5c"
|
||||||
|
app_key: "fb72d251e33ac85c929bfd4eec40d78368d08d82fb2ee1cffb04a8bb967d1db5"
|
||||||
|
|
||||||
|
baidu_ocr:
|
||||||
|
token: "e3a47bd2438f1f38840c203fc5939d17a54482d1"
|
||||||
@@ -1,27 +1,50 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
|
doc_ai:
|
||||||
|
build: .
|
||||||
|
container_name: doc_ai
|
||||||
|
ports:
|
||||||
|
- "8024:8024"
|
||||||
|
volumes:
|
||||||
|
- ./config:/app/config
|
||||||
|
- ./logs:/app/logs
|
||||||
|
networks:
|
||||||
|
- backend
|
||||||
|
depends_on:
|
||||||
|
mysql:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_started
|
||||||
|
command: ["-env", "dev"]
|
||||||
|
restart: always
|
||||||
|
|
||||||
mysql:
|
mysql:
|
||||||
image: mysql:8.0
|
image: mysql:8.0
|
||||||
container_name: mysql
|
container_name: mysql
|
||||||
environment:
|
environment:
|
||||||
MYSQL_ROOT_PASSWORD: 123456 # 设置root用户密码
|
MYSQL_ROOT_PASSWORD: texpixel#pwd123!
|
||||||
MYSQL_DATABASE: document_ai # 设置默认数据库名
|
MYSQL_DATABASE: doc_ai
|
||||||
MYSQL_USER: bitwsd_document # 设置数据库用户名
|
MYSQL_USER: texpixel
|
||||||
MYSQL_PASSWORD: 123456 # 设置数据库用户密码
|
MYSQL_PASSWORD: texpixel#pwd123!
|
||||||
ports:
|
ports:
|
||||||
- "3306:3306" # 映射宿主机的3306端口到容器内的3306
|
- "3006:3306"
|
||||||
volumes:
|
volumes:
|
||||||
- mysql_data:/var/lib/mysql # 持久化MySQL数据
|
- mysql_data:/var/lib/mysql
|
||||||
networks:
|
networks:
|
||||||
- backend
|
- backend
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-uroot", "-ptexpixel#pwd123!"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 10
|
||||||
|
start_period: 30s
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
image: redis:latest
|
image: redis:latest
|
||||||
container_name: redis
|
container_name: redis
|
||||||
|
command: redis-server --requirepass "yoge@123321!"
|
||||||
ports:
|
ports:
|
||||||
- "6379:6379" # 映射宿主机的6379端口到容器内的6379
|
- "6079:6379"
|
||||||
networks:
|
networks:
|
||||||
- backend
|
- backend
|
||||||
restart: always
|
restart: always
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -11,9 +11,11 @@ require (
|
|||||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/jtolds/gls v4.20.0+incompatible
|
||||||
github.com/redis/go-redis/v9 v9.7.0
|
github.com/redis/go-redis/v9 v9.7.0
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/rs/zerolog v1.33.0
|
||||||
github.com/spf13/viper v1.19.0
|
github.com/spf13/viper v1.19.0
|
||||||
|
golang.org/x/crypto v0.23.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gorm.io/driver/mysql v1.5.7
|
gorm.io/driver/mysql v1.5.7
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.25.12
|
||||||
@@ -42,6 +44,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
@@ -68,7 +71,6 @@ require (
|
|||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.23.0 // indirect
|
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.25.0 // indirect
|
||||||
golang.org/x/sys v0.20.0 // indirect
|
golang.org/x/sys v0.20.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -79,6 +79,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/
|
|||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
|
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
@@ -89,6 +90,7 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
|
|||||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type CreateFormulaRecognitionRequest struct {
|
|||||||
FileHash string `json:"file_hash" binding:"required"` // file hash
|
FileHash string `json:"file_hash" binding:"required"` // file hash
|
||||||
FileName string `json:"file_name" binding:"required"` // file name
|
FileName string `json:"file_name" binding:"required"` // file name
|
||||||
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
|
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
|
||||||
|
UserID int64 `json:"user_id"` // user id
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetRecognitionStatusRequest struct {
|
type GetRecognitionStatusRequest struct {
|
||||||
|
|||||||
@@ -22,3 +22,10 @@ type GetFormulaTaskResponse struct {
|
|||||||
type FormulaRecognitionResponse struct {
|
type FormulaRecognitionResponse struct {
|
||||||
Result string `json:"result"`
|
Result string `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageOCRResponse 图片OCR接口返回的响应
|
||||||
|
type ImageOCRResponse struct {
|
||||||
|
Latex string `json:"latex"` // LaTeX 格式内容
|
||||||
|
Markdown string `json:"markdown"` // Markdown 格式内容
|
||||||
|
MathML string `json:"mathml"` // MathML 格式(无公式时为空)
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,26 +11,31 @@ type TaskListRequest struct {
|
|||||||
TaskType string `json:"task_type" form:"task_type" binding:"required"`
|
TaskType string `json:"task_type" form:"task_type" binding:"required"`
|
||||||
Page int `json:"page" form:"page"`
|
Page int `json:"page" form:"page"`
|
||||||
PageSize int `json:"page_size" form:"page_size"`
|
PageSize int `json:"page_size" form:"page_size"`
|
||||||
}
|
UserID int64 `json:"-"`
|
||||||
|
|
||||||
type PdfInfo struct {
|
|
||||||
PageCount int `json:"page_count"`
|
|
||||||
PageWidth int `json:"page_width"`
|
|
||||||
PageHeight int `json:"page_height"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskListDTO struct {
|
type TaskListDTO struct {
|
||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
FileName string `json:"file_name"`
|
FileName string `json:"file_name"`
|
||||||
Status string `json:"status"`
|
Status int `json:"status"`
|
||||||
Path string `json:"path"`
|
OriginURL string `json:"origin_url"`
|
||||||
TaskType string `json:"task_type"`
|
TaskType string `json:"task_type"`
|
||||||
CreatedAt string `json:"created_at"`
|
CreatedAt string `json:"created_at"`
|
||||||
PdfInfo PdfInfo `json:"pdf_info"`
|
Latex string `json:"latex"`
|
||||||
|
Markdown string `json:"markdown"`
|
||||||
|
MathML string `json:"mathml"`
|
||||||
|
MathMLMW string `json:"mathml_mw"`
|
||||||
|
ImageBlob string `json:"image_blob"`
|
||||||
|
DocxURL string `json:"docx_url"`
|
||||||
|
PDFURL string `json:"pdf_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskListResponse struct {
|
type TaskListResponse struct {
|
||||||
TaskList []*TaskListDTO `json:"task_list"`
|
TaskList []*TaskListDTO `json:"task_list"`
|
||||||
HasMore bool `json:"has_more"`
|
Total int64 `json:"total"`
|
||||||
NextPage int `json:"next_page"`
|
}
|
||||||
|
|
||||||
|
type ExportTaskRequest struct {
|
||||||
|
TaskNo string `json:"task_no" binding:"required"`
|
||||||
|
Type string `json:"type" binding:"required,oneof=pdf docx"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/httpclient"
|
"gitea.com/bitwsd/document_ai/pkg/httpclient"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/oss"
|
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/requestid"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -105,6 +107,7 @@ func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *for
|
|||||||
sess := dao.DB.WithContext(ctx)
|
sess := dao.DB.WithContext(ctx)
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
taskDao := dao.NewRecognitionTaskDao()
|
||||||
task := &dao.RecognitionTask{
|
task := &dao.RecognitionTask{
|
||||||
|
UserID: req.UserID,
|
||||||
TaskUUID: utils.NewUUID(),
|
TaskUUID: utils.NewUUID(),
|
||||||
TaskType: dao.TaskType(req.TaskType),
|
TaskType: dao.TaskType(req.TaskType),
|
||||||
Status: dao.TaskStatusPending,
|
Status: dao.TaskStatusPending,
|
||||||
@@ -165,8 +168,20 @@ func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string)
|
|||||||
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
|
log.Error(ctx, "func", "GetFormualTask", "msg", "查询任务结果失败", "error", err, "task_no", taskNo)
|
||||||
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
||||||
}
|
}
|
||||||
latex := taskRet.NewContentCodec().GetContent().(string)
|
|
||||||
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
|
// 构建 Markdown 格式
|
||||||
|
markdown := taskRet.Markdown
|
||||||
|
if markdown == "" {
|
||||||
|
markdown = fmt.Sprintf("$$%s$$", taskRet.Latex)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &formula.GetFormulaTaskResponse{
|
||||||
|
TaskNo: taskNo,
|
||||||
|
Latex: taskRet.Latex,
|
||||||
|
Markdown: markdown,
|
||||||
|
MathML: taskRet.MathML,
|
||||||
|
Status: int(task.Status),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
||||||
@@ -207,6 +222,223 @@ func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64)
|
|||||||
|
|
||||||
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MathpixRequest Mathpix API /v3/text 完整请求结构
|
||||||
|
type MathpixRequest struct {
|
||||||
|
// 图片源:URL 或 base64 编码
|
||||||
|
Src string `json:"src"`
|
||||||
|
// 元数据键值对
|
||||||
|
Metadata map[string]interface{} `json:"metadata"`
|
||||||
|
// 标签列表,用于标识结果
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
// 异步请求标志
|
||||||
|
Async bool `json:"async"`
|
||||||
|
// 回调配置
|
||||||
|
Callback *MathpixCallback `json:"callback"`
|
||||||
|
// 输出格式列表:text, data, html, latex_styled
|
||||||
|
Formats []string `json:"formats"`
|
||||||
|
// 数据选项
|
||||||
|
DataOptions *MathpixDataOptions `json:"data_options,omitempty"`
|
||||||
|
// 返回检测到的字母表
|
||||||
|
IncludeDetectedAlphabets *bool `json:"include_detected_alphabets,omitempty"`
|
||||||
|
// 允许的字母表
|
||||||
|
AlphabetsAllowed *MathpixAlphabetsAllowed `json:"alphabets_allowed,omitempty"`
|
||||||
|
// 指定图片区域
|
||||||
|
Region *MathpixRegion `json:"region,omitempty"`
|
||||||
|
// 蓝色HSV过滤模式
|
||||||
|
EnableBlueHsvFilter bool `json:"enable_blue_hsv_filter"`
|
||||||
|
// 置信度阈值
|
||||||
|
ConfidenceThreshold float64 `json:"confidence_threshold"`
|
||||||
|
// 符号级别置信度阈值,默认0.75
|
||||||
|
ConfidenceRateThreshold float64 `json:"confidence_rate_threshold"`
|
||||||
|
// 包含公式标签
|
||||||
|
IncludeEquationTags bool `json:"include_equation_tags"`
|
||||||
|
// 返回逐行信息
|
||||||
|
IncludeLineData bool `json:"include_line_data"`
|
||||||
|
// 返回逐词信息
|
||||||
|
IncludeWordData bool `json:"include_word_data"`
|
||||||
|
// 化学结构OCR
|
||||||
|
IncludeSmiles bool `json:"include_smiles"`
|
||||||
|
// InChI数据
|
||||||
|
IncludeInchi bool `json:"include_inchi"`
|
||||||
|
// 几何图形数据
|
||||||
|
IncludeGeometryData bool `json:"include_geometry_data"`
|
||||||
|
// 图表文本提取
|
||||||
|
IncludeDiagramText bool `json:"include_diagram_text"`
|
||||||
|
// 页面信息,默认true
|
||||||
|
IncludePageInfo *bool `json:"include_page_info,omitempty"`
|
||||||
|
// 自动旋转置信度阈值,默认0.99
|
||||||
|
AutoRotateConfidenceThreshold float64 `json:"auto_rotate_confidence_threshold"`
|
||||||
|
// 移除多余空格,默认true
|
||||||
|
RmSpaces *bool `json:"rm_spaces,omitempty"`
|
||||||
|
// 移除字体命令,默认false
|
||||||
|
RmFonts bool `json:"rm_fonts"`
|
||||||
|
// 使用aligned/gathered/cases代替array,默认false
|
||||||
|
IdiomaticEqnArrays bool `json:"idiomatic_eqn_arrays"`
|
||||||
|
// 移除不必要的大括号,默认false
|
||||||
|
IdiomaticBraces bool `json:"idiomatic_braces"`
|
||||||
|
// 数字始终为数学模式,默认false
|
||||||
|
NumbersDefaultToMath bool `json:"numbers_default_to_math"`
|
||||||
|
// 数学字体始终为数学模式,默认false
|
||||||
|
MathFontsDefaultToMath bool `json:"math_fonts_default_to_math"`
|
||||||
|
// 行内数学分隔符,默认 ["\\(", "\\)"]
|
||||||
|
MathInlineDelimiters []string `json:"math_inline_delimiters"`
|
||||||
|
// 行间数学分隔符,默认 ["\\[", "\\]"]
|
||||||
|
MathDisplayDelimiters []string `json:"math_display_delimiters"`
|
||||||
|
// 高级表格处理,默认false
|
||||||
|
EnableTablesFallback bool `json:"enable_tables_fallback"`
|
||||||
|
// 全角标点,null表示自动判断
|
||||||
|
FullwidthPunctuation *bool `json:"fullwidth_punctuation,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixCallback 回调配置
|
||||||
|
type MathpixCallback struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Headers map[string]string `json:"headers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixDataOptions 数据选项
|
||||||
|
type MathpixDataOptions struct {
|
||||||
|
IncludeAsciimath bool `json:"include_asciimath"`
|
||||||
|
IncludeMathml bool `json:"include_mathml"`
|
||||||
|
IncludeLatex bool `json:"include_latex"`
|
||||||
|
IncludeTsv bool `json:"include_tsv"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixAlphabetsAllowed 允许的字母表
|
||||||
|
type MathpixAlphabetsAllowed struct {
|
||||||
|
En bool `json:"en"`
|
||||||
|
Hi bool `json:"hi"`
|
||||||
|
Zh bool `json:"zh"`
|
||||||
|
Ja bool `json:"ja"`
|
||||||
|
Ko bool `json:"ko"`
|
||||||
|
Ru bool `json:"ru"`
|
||||||
|
Th bool `json:"th"`
|
||||||
|
Vi bool `json:"vi"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixRegion 图片区域
|
||||||
|
type MathpixRegion struct {
|
||||||
|
TopLeftX int `json:"top_left_x"`
|
||||||
|
TopLeftY int `json:"top_left_y"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixResponse Mathpix API /v3/text 完整响应结构
|
||||||
|
type MathpixResponse struct {
|
||||||
|
// 请求ID,用于调试
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
// Mathpix Markdown 格式文本
|
||||||
|
Text string `json:"text"`
|
||||||
|
// 带样式的LaTeX(仅单个公式图片时返回)
|
||||||
|
LatexStyled string `json:"latex_styled"`
|
||||||
|
// 置信度 [0,1]
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
// 置信度比率 [0,1]
|
||||||
|
ConfidenceRate float64 `json:"confidence_rate"`
|
||||||
|
// 行数据
|
||||||
|
LineData []map[string]interface{} `json:"line_data"`
|
||||||
|
// 词数据
|
||||||
|
WordData []map[string]interface{} `json:"word_data"`
|
||||||
|
// 数据对象列表
|
||||||
|
Data []MathpixDataItem `json:"data"`
|
||||||
|
// HTML输出
|
||||||
|
HTML string `json:"html"`
|
||||||
|
// 检测到的字母表
|
||||||
|
DetectedAlphabets []map[string]interface{} `json:"detected_alphabets"`
|
||||||
|
// 是否打印内容
|
||||||
|
IsPrinted bool `json:"is_printed"`
|
||||||
|
// 是否手写内容
|
||||||
|
IsHandwritten bool `json:"is_handwritten"`
|
||||||
|
// 自动旋转置信度
|
||||||
|
AutoRotateConfidence float64 `json:"auto_rotate_confidence"`
|
||||||
|
// 几何数据
|
||||||
|
GeometryData []map[string]interface{} `json:"geometry_data"`
|
||||||
|
// 自动旋转角度 {0, 90, -90, 180}
|
||||||
|
AutoRotateDegrees int `json:"auto_rotate_degrees"`
|
||||||
|
// 图片宽度
|
||||||
|
ImageWidth int `json:"image_width"`
|
||||||
|
// 图片高度
|
||||||
|
ImageHeight int `json:"image_height"`
|
||||||
|
// 错误信息
|
||||||
|
Error string `json:"error"`
|
||||||
|
// 错误详情
|
||||||
|
ErrorInfo *MathpixErrorInfo `json:"error_info"`
|
||||||
|
// API版本
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixDataItem 数据项
|
||||||
|
type MathpixDataItem struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MathpixErrorInfo 错误详情
|
||||||
|
type MathpixErrorInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaiduOCRRequest 百度 OCR 版面分析请求结构
|
||||||
|
type BaiduOCRRequest struct {
|
||||||
|
// 文件内容 base64 编码
|
||||||
|
File string `json:"file"`
|
||||||
|
// 文件类型: 0=PDF, 1=图片
|
||||||
|
FileType int `json:"fileType"`
|
||||||
|
// 是否启用文档方向分类
|
||||||
|
UseDocOrientationClassify bool `json:"useDocOrientationClassify"`
|
||||||
|
// 是否启用文档扭曲矫正
|
||||||
|
UseDocUnwarping bool `json:"useDocUnwarping"`
|
||||||
|
// 是否启用图表识别
|
||||||
|
UseChartRecognition bool `json:"useChartRecognition"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaiduOCRResponse 百度 OCR 版面分析响应结构
|
||||||
|
type BaiduOCRResponse struct {
|
||||||
|
ErrorCode int `json:"errorCode"`
|
||||||
|
ErrorMsg string `json:"errorMsg"`
|
||||||
|
Result *BaiduOCRResult `json:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaiduOCRResult 百度 OCR 响应结果
|
||||||
|
type BaiduOCRResult struct {
|
||||||
|
LayoutParsingResults []BaiduLayoutParsingResult `json:"layoutParsingResults"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaiduLayoutParsingResult 单页版面解析结果
|
||||||
|
type BaiduLayoutParsingResult struct {
|
||||||
|
Markdown BaiduMarkdownResult `json:"markdown"`
|
||||||
|
OutputImages map[string]string `json:"outputImages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaiduMarkdownResult markdown 结果
|
||||||
|
type BaiduMarkdownResult struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Images map[string]string `json:"images"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMathML 从响应中获取MathML
|
||||||
|
func (r *MathpixResponse) GetMathML() string {
|
||||||
|
for _, item := range r.Data {
|
||||||
|
if item.Type == "mathml" {
|
||||||
|
return item.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAsciiMath 从响应中获取AsciiMath
|
||||||
|
func (r *MathpixResponse) GetAsciiMath() string {
|
||||||
|
for _, item := range r.Data {
|
||||||
|
if item.Type == "asciimath" {
|
||||||
|
return item.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
|
func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int64, fileURL string) (err error) {
|
||||||
// 为整个任务处理添加超时控制
|
// 为整个任务处理添加超时控制
|
||||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||||
@@ -280,8 +512,8 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
|||||||
// 设置Content-Type头为application/json
|
// 设置Content-Type头为application/json
|
||||||
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
||||||
|
|
||||||
// 发送请求时会使用带超时的context
|
// 发送请求到新的 OCR 接口
|
||||||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "http://cloud.texpixel.com:1080/formula/predict", bytes.NewReader(jsonData), headers)
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/doc_process/v1/image/ocr", bytes.NewReader(jsonData), headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
||||||
@@ -301,24 +533,23 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
|||||||
log.Info(ctx, "func", "processFormulaTask", "msg", "响应内容", "body", body.String())
|
log.Info(ctx, "func", "processFormulaTask", "msg", "响应内容", "body", body.String())
|
||||||
|
|
||||||
// 解析 JSON 响应
|
// 解析 JSON 响应
|
||||||
var formulaResp formula.FormulaRecognitionResponse
|
var ocrResp formula.ImageOCRResponse
|
||||||
if err := json.Unmarshal(body.Bytes(), &formulaResp); err != nil {
|
if err := json.Unmarshal(body.Bytes(), &ocrResp); err != nil {
|
||||||
log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err)
|
log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
katex := utils.ToKatex(formulaResp.Result)
|
err = resultDao.Create(tx, dao.RecognitionResult{
|
||||||
content := &dao.FormulaRecognitionContent{Latex: katex}
|
|
||||||
b, _ := json.Marshal(content)
|
|
||||||
// Save recognition result
|
|
||||||
result := &dao.RecognitionResult{
|
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
TaskType: dao.TaskTypeFormula,
|
TaskType: dao.TaskTypeFormula,
|
||||||
Content: b,
|
Latex: ocrResp.Latex,
|
||||||
}
|
Markdown: ocrResp.Markdown,
|
||||||
if err := resultDao.Create(tx, *result); err != nil {
|
MathML: ocrResp.MathML,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
|
log.Error(ctx, "func", "processFormulaTask", "msg", "保存任务结果失败", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
isSuccess = true
|
isSuccess = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -423,39 +654,21 @@ func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID in
|
|||||||
}
|
}
|
||||||
|
|
||||||
resultDao := dao.NewRecognitionResultDao()
|
resultDao := dao.NewRecognitionResultDao()
|
||||||
var formulaRes *dao.FormulaRecognitionContent
|
|
||||||
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "获取任务结果失败", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if result == nil {
|
if result == nil {
|
||||||
formulaRes = &dao.FormulaRecognitionContent{EnhanceLatex: latex}
|
formulaRes := &dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Latex: latex}
|
||||||
b, err := formulaRes.Encode()
|
err = resultDao.Create(dao.DB.WithContext(ctx), *formulaRes)
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{TaskID: taskID, TaskType: dao.TaskTypeFormula, Content: b})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "创建任务结果失败", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formulaRes = result.NewContentCodec().(*dao.FormulaRecognitionContent)
|
result.Latex = latex
|
||||||
err = formulaRes.Decode()
|
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"latex": latex})
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "解码任务结果失败", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
formulaRes.EnhanceLatex = latex
|
|
||||||
b, err := formulaRes.Encode()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "编码任务结果失败", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{"content": b})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
|
log.Error(ctx, "func", "processVLFormulaTask", "msg", "更新任务结果失败", "error", err)
|
||||||
return err
|
return err
|
||||||
@@ -499,14 +712,408 @@ func (s *RecognitionService) processOneTask(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
ctx = context.WithValue(ctx, utils.RequestIDKey, task.TaskUUID)
|
||||||
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
|
||||||
|
|
||||||
// 处理任务
|
// 使用 gls 设置 request_id,确保在整个任务处理过程中可用
|
||||||
err = s.processFormulaTask(ctx, taskID, task.FileURL)
|
requestid.SetRequestID(task.TaskUUID, func() {
|
||||||
|
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
||||||
|
|
||||||
|
err = s.processFormulaTask(ctx, taskID, task.FileURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMathpixTask 使用 Mathpix API 处理公式识别任务(用于增强识别)
|
||||||
|
func (s *RecognitionService) processMathpixTask(ctx context.Context, taskID int64, fileURL string) error {
|
||||||
|
isSuccess := false
|
||||||
|
logDao := dao.NewRecognitionLogDao()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if !isSuccess {
|
||||||
|
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务状态失败", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 下载图片
|
||||||
|
imageUrl, err := oss.GetDownloadURL(ctx, fileURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
log.Error(ctx, "func", "processMathpixTask", "msg", "获取图片URL失败", "error", err)
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
// 创建 Mathpix API 请求
|
||||||
|
mathpixReq := MathpixRequest{
|
||||||
|
Src: imageUrl,
|
||||||
|
Formats: []string{
|
||||||
|
"text",
|
||||||
|
"latex_styled",
|
||||||
|
"data",
|
||||||
|
"html",
|
||||||
|
},
|
||||||
|
DataOptions: &MathpixDataOptions{
|
||||||
|
IncludeMathml: true,
|
||||||
|
IncludeAsciimath: true,
|
||||||
|
IncludeLatex: true,
|
||||||
|
IncludeTsv: true,
|
||||||
|
},
|
||||||
|
MathInlineDelimiters: []string{"$", "$"},
|
||||||
|
MathDisplayDelimiters: []string{"$$", "$$"},
|
||||||
|
RmSpaces: &[]bool{true}[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(mathpixReq)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "JSON编码失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"app_id": config.GlobalConfig.Mathpix.AppID,
|
||||||
|
"app_key": config.GlobalConfig.Mathpix.AppKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := "https://api.mathpix.com/v3/text"
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_Start", "start_time", startTime)
|
||||||
|
|
||||||
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 请求失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processMathpixTask", "msg", "MathpixApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
if _, err = body.ReadFrom(resp.Body); err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "读取响应体失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建日志记录
|
||||||
|
recognitionLog := &dao.RecognitionLog{
|
||||||
|
TaskID: taskID,
|
||||||
|
Provider: dao.ProviderMathpix,
|
||||||
|
RequestBody: string(jsonData),
|
||||||
|
ResponseBody: body.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
var mathpixResp MathpixResponse
|
||||||
|
if err := json.Unmarshal(body.Bytes(), &mathpixResp); err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "解析响应失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查错误
|
||||||
|
if mathpixResp.Error != "" {
|
||||||
|
errMsg := mathpixResp.Error
|
||||||
|
if mathpixResp.ErrorInfo != nil {
|
||||||
|
errMsg = fmt.Sprintf("%s: %s", mathpixResp.ErrorInfo.ID, mathpixResp.ErrorInfo.Message)
|
||||||
|
}
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "Mathpix API 返回错误", "error", errMsg)
|
||||||
|
return fmt.Errorf("mathpix error: %s", errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存日志
|
||||||
|
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "保存日志失败", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新或创建识别结果
|
||||||
|
resultDao := dao.NewRecognitionResultDao()
|
||||||
|
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "获取任务结果失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processMathpixTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
// 创建新结果
|
||||||
|
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
|
||||||
|
TaskID: taskID,
|
||||||
|
TaskType: dao.TaskTypeFormula,
|
||||||
|
Latex: mathpixResp.LatexStyled,
|
||||||
|
Markdown: mathpixResp.Text,
|
||||||
|
MathML: mathpixResp.GetMathML(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "创建任务结果失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 更新现有结果
|
||||||
|
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
|
||||||
|
"latex": mathpixResp.LatexStyled,
|
||||||
|
"markdown": mathpixResp.Text,
|
||||||
|
"mathml": mathpixResp.GetMathML(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processMathpixTask", "msg", "更新任务结果失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isSuccess = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RecognitionService) processBaiduOCRTask(ctx context.Context, taskID int64, fileURL string) error {
|
||||||
|
isSuccess := false
|
||||||
|
logDao := dao.NewRecognitionLogDao()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if !isSuccess {
|
||||||
|
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusFailed})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := dao.NewRecognitionTaskDao().Update(dao.DB.WithContext(ctx), map[string]interface{}{"id": taskID}, map[string]interface{}{"status": dao.TaskStatusCompleted})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务状态失败", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 从 OSS 下载文件
|
||||||
|
reader, err := oss.DownloadFile(ctx, fileURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "从OSS下载文件失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// 读取文件内容
|
||||||
|
fileBytes, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取文件内容失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64 编码
|
||||||
|
fileData := base64.StdEncoding.EncodeToString(fileBytes)
|
||||||
|
|
||||||
|
// 根据文件扩展名确定 fileType: 0=PDF, 1=图片
|
||||||
|
fileType := 1 // 默认为图片
|
||||||
|
lowerFileURL := strings.ToLower(fileURL)
|
||||||
|
if strings.HasSuffix(lowerFileURL, ".pdf") {
|
||||||
|
fileType = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建百度 OCR API 请求
|
||||||
|
baiduReq := BaiduOCRRequest{
|
||||||
|
File: fileData,
|
||||||
|
FileType: fileType,
|
||||||
|
UseDocOrientationClassify: false,
|
||||||
|
UseDocUnwarping: false,
|
||||||
|
UseChartRecognition: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(baiduReq)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "JSON编码失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": fmt.Sprintf("token %s", config.GlobalConfig.BaiduOCR.Token),
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := "https://j5veh2l2r6ubk6cb.aistudio-app.com/layout-parsing"
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_Start", "start_time", startTime)
|
||||||
|
|
||||||
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData), headers)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 请求失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "BaiduOCRApi_End", "end_time", time.Now(), "duration", time.Since(startTime))
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
if _, err = body.ReadFrom(resp.Body); err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "读取响应体失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建日志记录(不记录请求体中的 base64 数据以节省存储)
|
||||||
|
requestLogData := map[string]interface{}{
|
||||||
|
"fileType": fileType,
|
||||||
|
"useDocOrientationClassify": false,
|
||||||
|
"useDocUnwarping": false,
|
||||||
|
"useChartRecognition": false,
|
||||||
|
"fileSize": len(fileBytes),
|
||||||
|
}
|
||||||
|
requestLogBytes, _ := json.Marshal(requestLogData)
|
||||||
|
recognitionLog := &dao.RecognitionLog{
|
||||||
|
TaskID: taskID,
|
||||||
|
Provider: dao.ProviderBaiduOCR,
|
||||||
|
RequestBody: string(requestLogBytes),
|
||||||
|
ResponseBody: body.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
var baiduResp BaiduOCRResponse
|
||||||
|
if err := json.Unmarshal(body.Bytes(), &baiduResp); err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "解析响应失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查错误
|
||||||
|
if baiduResp.ErrorCode != 0 {
|
||||||
|
errMsg := fmt.Sprintf("errorCode: %d, errorMsg: %s", baiduResp.ErrorCode, baiduResp.ErrorMsg)
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "百度 OCR API 返回错误", "error", errMsg)
|
||||||
|
return fmt.Errorf("baidu ocr error: %s", errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存日志
|
||||||
|
err = logDao.Create(dao.DB.WithContext(ctx), recognitionLog)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "保存日志失败", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并所有页面的 markdown 结果
|
||||||
|
var markdownTexts []string
|
||||||
|
if baiduResp.Result != nil && len(baiduResp.Result.LayoutParsingResults) > 0 {
|
||||||
|
for _, res := range baiduResp.Result.LayoutParsingResults {
|
||||||
|
if res.Markdown.Text != "" {
|
||||||
|
markdownTexts = append(markdownTexts, res.Markdown.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
markdownResult := strings.Join(markdownTexts, "\n\n---\n\n")
|
||||||
|
|
||||||
|
latex, mml, e := s.HandleConvert(ctx, markdownResult)
|
||||||
|
if e != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "转换失败", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新或创建识别结果
|
||||||
|
resultDao := dao.NewRecognitionResultDao()
|
||||||
|
result, err := resultDao.GetByTaskID(dao.DB.WithContext(ctx), taskID)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "获取任务结果失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "processBaiduOCRTask", "msg", "saveLog", "end_time", time.Now(), "duration", time.Since(startTime))
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
// 创建新结果
|
||||||
|
err = resultDao.Create(dao.DB.WithContext(ctx), dao.RecognitionResult{
|
||||||
|
TaskID: taskID,
|
||||||
|
TaskType: dao.TaskTypeFormula,
|
||||||
|
Markdown: markdownResult,
|
||||||
|
Latex: latex,
|
||||||
|
MathML: mml,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "创建任务结果失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 更新现有结果
|
||||||
|
err = resultDao.Update(dao.DB.WithContext(ctx), result.ID, map[string]interface{}{
|
||||||
|
"markdown": markdownResult,
|
||||||
|
"latex": latex,
|
||||||
|
"mathml": mml,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "processBaiduOCRTask", "msg", "更新任务结果失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isSuccess = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RecognitionService) TestProcessMathpixTask(ctx context.Context, taskID int64) error {
|
||||||
|
task, err := dao.NewRecognitionTaskDao().GetTaskByID(dao.DB.WithContext(ctx), taskID)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "获取任务失败", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
log.Error(ctx, "func", "TestProcessMathpixTask", "msg", "任务不存在", "task_id", taskID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.processMathpixTask(ctx, taskID, task.FileURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertResponse Python 接口返回结构
|
||||||
|
type ConvertResponse struct {
|
||||||
|
Latex string `json:"latex"`
|
||||||
|
MathML string `json:"mathml"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RecognitionService) HandleConvert(ctx context.Context, markdown string) (latex string, mml string, err error) {
|
||||||
|
url := "https://cloud.texpixel.com:10443/doc_converter/v1/convert"
|
||||||
|
|
||||||
|
// 构建 multipart form
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
_ = writer.WriteField("markdown_input", markdown)
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
// 使用正确的 Content-Type(包含 boundary)
|
||||||
|
headers := map[string]string{
|
||||||
|
"Content-Type": writer.FormDataContentType(),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, url, body, headers)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 读取响应体
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 HTTP 状态码
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", "", fmt.Errorf("convert failed: status %d, body: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 JSON 响应
|
||||||
|
var convertResp ConvertResponse
|
||||||
|
if err := json.Unmarshal(respBody, &convertResp); err != nil {
|
||||||
|
return "", "", fmt.Errorf("unmarshal response failed: %v, body: %s", err, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查业务错误
|
||||||
|
if convertResp.Error != "" {
|
||||||
|
return "", "", fmt.Errorf("convert error: %s", convertResp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return convertResp.Latex, convertResp.MathML, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +1,37 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/internal/model/task"
|
"gitea.com/bitwsd/document_ai/internal/model/task"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gorm.io/gorm"
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TaskService struct {
|
type TaskService struct {
|
||||||
db *gorm.DB
|
recognitionTaskDao *dao.RecognitionTaskDao
|
||||||
|
evaluateTaskDao *dao.EvaluateTaskDao
|
||||||
|
recognitionResultDao *dao.RecognitionResultDao
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTaskService() *TaskService {
|
func NewTaskService() *TaskService {
|
||||||
return &TaskService{dao.DB}
|
return &TaskService{
|
||||||
|
recognitionTaskDao: dao.NewRecognitionTaskDao(),
|
||||||
|
evaluateTaskDao: dao.NewEvaluateTaskDao(),
|
||||||
|
recognitionResultDao: dao.NewRecognitionResultDao(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
|
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
task, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
||||||
task, err := taskDao.GetByTaskNo(svc.db.WithContext(ctx), req.TaskNo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
|
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
|
||||||
return err
|
return err
|
||||||
@@ -36,14 +46,13 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
|
|||||||
return errors.New("task not finished")
|
return errors.New("task not finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluateTaskDao := dao.NewEvaluateTaskDao()
|
|
||||||
evaluateTask := &dao.EvaluateTask{
|
evaluateTask := &dao.EvaluateTask{
|
||||||
TaskID: task.ID,
|
TaskID: task.ID,
|
||||||
Satisfied: req.Satisfied,
|
Satisfied: req.Satisfied,
|
||||||
Feedback: req.Feedback,
|
Feedback: req.Feedback,
|
||||||
Comment: strings.Join(req.Suggestion, ","),
|
Comment: strings.Join(req.Suggestion, ","),
|
||||||
}
|
}
|
||||||
err = evaluateTaskDao.Create(svc.db.WithContext(ctx), evaluateTask)
|
err = svc.evaluateTaskDao.Create(dao.DB.WithContext(ctx), evaluateTask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
|
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
|
||||||
return err
|
return err
|
||||||
@@ -53,26 +62,140 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
|
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
|
||||||
taskDao := dao.NewRecognitionTaskDao()
|
tasks, total, err := svc.recognitionTaskDao.GetTaskList(dao.DB.WithContext(ctx), req.UserID, dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
||||||
tasks, err := taskDao.GetTaskList(svc.db.WithContext(ctx), dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
|
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
taskIDs := make([]int64, 0, len(tasks))
|
||||||
|
for _, item := range tasks {
|
||||||
|
taskIDs = append(taskIDs, item.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
recognitionResults, err := svc.recognitionResultDao.GetByTaskIDs(dao.DB.WithContext(ctx), taskIDs)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "GetTaskList", "msg", "get recognition results failed", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
recognitionResultMap := make(map[int64]*dao.RecognitionResult)
|
||||||
|
for _, item := range recognitionResults {
|
||||||
|
recognitionResultMap[item.TaskID] = item
|
||||||
|
}
|
||||||
|
|
||||||
resp := &task.TaskListResponse{
|
resp := &task.TaskListResponse{
|
||||||
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
|
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
|
||||||
HasMore: false,
|
Total: total,
|
||||||
NextPage: 0,
|
|
||||||
}
|
}
|
||||||
for _, item := range tasks {
|
for _, item := range tasks {
|
||||||
|
var latex string
|
||||||
|
var markdown string
|
||||||
|
var mathML string
|
||||||
|
recognitionResult := recognitionResultMap[item.ID]
|
||||||
|
if recognitionResult != nil {
|
||||||
|
latex = recognitionResult.Latex
|
||||||
|
markdown = recognitionResult.Markdown
|
||||||
|
mathML = recognitionResult.MathML
|
||||||
|
}
|
||||||
|
originURL, err := oss.GetDownloadURL(ctx, item.FileURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "GetTaskList", "msg", "get origin url failed", "error", err)
|
||||||
|
}
|
||||||
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
|
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
|
||||||
|
Latex: latex,
|
||||||
|
Markdown: markdown,
|
||||||
|
MathML: mathML,
|
||||||
TaskID: item.TaskUUID,
|
TaskID: item.TaskUUID,
|
||||||
FileName: item.FileName,
|
FileName: item.FileName,
|
||||||
Status: item.Status.String(),
|
Status: int(item.Status),
|
||||||
Path: item.FileURL,
|
OriginURL: originURL,
|
||||||
TaskType: item.TaskType.String(),
|
TaskType: item.TaskType.String(),
|
||||||
CreatedAt: item.CreatedAt.Format("2006-01-02 15:04:05"),
|
CreatedAt: item.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (svc *TaskService) ExportTask(ctx context.Context, req *task.ExportTaskRequest) ([]byte, string, error) {
|
||||||
|
recognitionTask, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "get task by task id failed", "error", err)
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if recognitionTask == nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "task not found")
|
||||||
|
return nil, "", errors.New("task not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if recognitionTask.Status != dao.TaskStatusCompleted {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "task not finished")
|
||||||
|
return nil, "", errors.New("task not finished")
|
||||||
|
}
|
||||||
|
|
||||||
|
recognitionResult, err := svc.recognitionResultDao.GetByTaskID(dao.DB.WithContext(ctx), recognitionTask.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "get recognition result by task id failed", "error", err)
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if recognitionResult == nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "recognition result not found")
|
||||||
|
return nil, "", errors.New("recognition result not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
markdown := recognitionResult.Markdown
|
||||||
|
if markdown == "" {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "markdown not found")
|
||||||
|
return nil, "", errors.New("markdown not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取文件名(去掉扩展名)
|
||||||
|
filename := strings.TrimSuffix(recognitionTask.FileName, "."+strings.ToLower(strings.Split(recognitionTask.FileName, ".")[len(strings.Split(recognitionTask.FileName, "."))-1]))
|
||||||
|
if filename == "" {
|
||||||
|
filename = "texpixel"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 JSON 请求体
|
||||||
|
requestBody := map[string]string{
|
||||||
|
"markdown": markdown,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(requestBody)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "json marshal failed", "error", err)
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/doc_process/v1/convert/file", bytes.NewReader(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "create http request failed", "error", err)
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "http request failed", "error", err)
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "http request failed", "status", resp.StatusCode)
|
||||||
|
return nil, "", fmt.Errorf("export service returned status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileData, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx, "func", "ExportTask", "msg", "read response body failed", "error", err)
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 新接口只返回 DOCX 格式
|
||||||
|
contentType := "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||||
|
|
||||||
|
return fileData, contentType, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ import (
|
|||||||
"gitea.com/bitwsd/document_ai/config"
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
|
|
||||||
func InitDB(conf config.DatabaseConfig) {
|
func InitDB(conf config.DatabaseConfig) {
|
||||||
dns := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", conf.Username, conf.Password, conf.Host, conf.Port, conf.DBName)
|
dns := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", conf.Username, conf.Password, conf.Host, conf.Port, conf.DBName)
|
||||||
db, err := gorm.Open(mysql.Open(dns), &gorm.Config{})
|
db, err := gorm.Open(mysql.Open(dns), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent), // 禁用 GORM 日志输出
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
53
internal/storage/dao/recognition_log.go
Normal file
53
internal/storage/dao/recognition_log.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package dao
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecognitionLogProvider 第三方服务提供商
|
||||||
|
type RecognitionLogProvider string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProviderMathpix RecognitionLogProvider = "mathpix"
|
||||||
|
ProviderSiliconflow RecognitionLogProvider = "siliconflow"
|
||||||
|
ProviderTexpixel RecognitionLogProvider = "texpixel"
|
||||||
|
ProviderBaiduOCR RecognitionLogProvider = "baidu_ocr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecognitionLog 识别调用日志表,记录第三方API调用请求和响应
|
||||||
|
type RecognitionLog struct {
|
||||||
|
BaseModel
|
||||||
|
TaskID int64 `gorm:"column:task_id;bigint;not null;default:0;index;comment:关联任务ID" json:"task_id"`
|
||||||
|
Provider RecognitionLogProvider `gorm:"column:provider;varchar(32);not null;comment:服务提供商" json:"provider"`
|
||||||
|
RequestBody string `gorm:"column:request_body;type:longtext;comment:请求体" json:"request_body"`
|
||||||
|
ResponseBody string `gorm:"column:response_body;type:longtext;comment:响应体" json:"response_body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (RecognitionLog) TableName() string {
|
||||||
|
return "recognition_log"
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecognitionLogDao struct{}
|
||||||
|
|
||||||
|
func NewRecognitionLogDao() *RecognitionLogDao {
|
||||||
|
return &RecognitionLogDao{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建日志记录
|
||||||
|
func (d *RecognitionLogDao) Create(tx *gorm.DB, log *RecognitionLog) error {
|
||||||
|
return tx.Create(log).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByTaskID 根据任务ID获取日志
|
||||||
|
func (d *RecognitionLogDao) GetByTaskID(tx *gorm.DB, taskID int64) ([]*RecognitionLog, error) {
|
||||||
|
var logs []*RecognitionLog
|
||||||
|
err := tx.Where("task_id = ?", taskID).Order("created_at DESC").Find(&logs).Error
|
||||||
|
return logs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByProvider 根据提供商获取日志
|
||||||
|
func (d *RecognitionLogDao) GetByProvider(tx *gorm.DB, provider RecognitionLogProvider, limit int) ([]*RecognitionLog, error) {
|
||||||
|
var logs []*RecognitionLog
|
||||||
|
err := tx.Where("provider = ?", provider).Order("created_at DESC").Limit(limit).Find(&logs).Error
|
||||||
|
return logs, err
|
||||||
|
}
|
||||||
@@ -1,66 +1,16 @@
|
|||||||
package dao
|
package dao
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSON []byte
|
|
||||||
|
|
||||||
// ContentCodec 定义内容编解码接口
|
|
||||||
type ContentCodec interface {
|
|
||||||
Encode() (JSON, error)
|
|
||||||
Decode() error
|
|
||||||
GetContent() interface{} // 更明确的方法名
|
|
||||||
}
|
|
||||||
|
|
||||||
type FormulaRecognitionContent struct {
|
|
||||||
content JSON
|
|
||||||
Latex string `json:"latex"`
|
|
||||||
AdjustLatex string `json:"adjust_latex"`
|
|
||||||
EnhanceLatex string `json:"enhance_latex"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *FormulaRecognitionContent) Encode() (JSON, error) {
|
|
||||||
b, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *FormulaRecognitionContent) Decode() error {
|
|
||||||
return json.Unmarshal(c.content, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPreferredContent 按优先级返回公式内容
|
|
||||||
func (c *FormulaRecognitionContent) GetContent() interface{} {
|
|
||||||
c.Decode()
|
|
||||||
if c.EnhanceLatex != "" {
|
|
||||||
return c.EnhanceLatex
|
|
||||||
} else if c.AdjustLatex != "" {
|
|
||||||
return c.AdjustLatex
|
|
||||||
} else {
|
|
||||||
return c.Latex
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type RecognitionResult struct {
|
type RecognitionResult struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
TaskID int64 `gorm:"column:task_id;bigint;not null;default:0;comment:任务ID" json:"task_id"`
|
TaskID int64 `gorm:"column:task_id;bigint;not null;default:0;comment:任务ID" json:"task_id"`
|
||||||
TaskType TaskType `gorm:"column:task_type;varchar(16);not null;comment:任务类型;default:''" json:"task_type"`
|
TaskType TaskType `gorm:"column:task_type;varchar(16);not null;comment:任务类型;default:''" json:"task_type"`
|
||||||
Content JSON `gorm:"column:content;type:json;not null;comment:识别内容" json:"content"`
|
Latex string `json:"latex" gorm:"column:latex;type:text;not null;default:''"`
|
||||||
}
|
Markdown string `json:"markdown" gorm:"column:markdown;type:text;not null;default:''"` // Mathpix Markdown 格式
|
||||||
|
MathML string `json:"mathml" gorm:"column:mathml;type:text;not null;default:''"` // MathML 格式
|
||||||
// NewContentCodec 创建对应任务类型的内容编解码器
|
|
||||||
func (r *RecognitionResult) NewContentCodec() ContentCodec {
|
|
||||||
switch r.TaskType {
|
|
||||||
case TaskTypeFormula:
|
|
||||||
return &FormulaRecognitionContent{content: r.Content}
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RecognitionResultDao struct {
|
type RecognitionResultDao struct {
|
||||||
@@ -84,6 +34,11 @@ func (dao *RecognitionResultDao) GetByTaskID(tx *gorm.DB, taskID int64) (result
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dao *RecognitionResultDao) GetByTaskIDs(tx *gorm.DB, taskIDs []int64) (results []*RecognitionResult, err error) {
|
||||||
|
err = tx.Where("task_id IN (?)", taskIDs).Find(&results).Error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
|
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
|
||||||
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
|
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -69,9 +69,9 @@ func (dao *RecognitionTaskDao) GetByTaskNo(tx *gorm.DB, taskUUID string) (task *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, userID int64, fileHash string) (task *RecognitionTask, err error) {
|
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, fileHash string) (task *RecognitionTask, err error) {
|
||||||
task = &RecognitionTask{}
|
task = &RecognitionTask{}
|
||||||
err = tx.Model(RecognitionTask{}).Where("user_id = ? AND file_hash = ?", userID, fileHash).First(task).Error
|
err = tx.Model(RecognitionTask{}).Where("file_hash = ?", fileHash).Last(task).Error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,8 +87,13 @@ func (dao *RecognitionTaskDao) GetTaskByID(tx *gorm.DB, id int64) (task *Recogni
|
|||||||
return task, nil
|
return task, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, err error) {
|
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, userID int64, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, total int64, err error) {
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
err = tx.Model(RecognitionTask{}).Where("task_type = ?", taskType).Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
|
query := tx.Model(RecognitionTask{}).Where("user_id = ? AND task_type = ?", userID, taskType)
|
||||||
return
|
err = query.Count(&total).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
err = query.Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
|
||||||
|
return tasks, total, err
|
||||||
}
|
}
|
||||||
|
|||||||
21
main.go
21
main.go
@@ -10,23 +10,26 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/cors"
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/middleware"
|
|
||||||
"gitea.com/bitwsd/document_ai/api"
|
"gitea.com/bitwsd/document_ai/api"
|
||||||
"gitea.com/bitwsd/document_ai/config"
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
||||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/cors"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/middleware"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/sms"
|
"gitea.com/bitwsd/document_ai/pkg/sms"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// 加载配置
|
// 加载配置
|
||||||
env := "dev"
|
env := ""
|
||||||
flag.StringVar(&env, "env", "dev", "environment (dev/prod)")
|
flag.StringVar(&env, "env", "dev", "environment (dev/prod)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
fmt.Println("env:", env)
|
||||||
|
|
||||||
configPath := fmt.Sprintf("./config/config_%s.yaml", env)
|
configPath := fmt.Sprintf("./config/config_%s.yaml", env)
|
||||||
if err := config.Init(configPath); err != nil {
|
if err := config.Init(configPath); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -42,14 +45,6 @@ func main() {
|
|||||||
cache.InitRedisClient(config.GlobalConfig.Redis)
|
cache.InitRedisClient(config.GlobalConfig.Redis)
|
||||||
sms.InitSmsClient()
|
sms.InitSmsClient()
|
||||||
|
|
||||||
// 初始化Redis
|
|
||||||
// cache.InitRedis(config.GlobalConfig.Redis.Addr)
|
|
||||||
|
|
||||||
// 初始化OSS客户端
|
|
||||||
// if err := oss.InitOSS(config.GlobalConfig.OSS); err != nil {
|
|
||||||
// logger.Fatal("Failed to init OSS client", logger.Fields{"error": err})
|
|
||||||
// }
|
|
||||||
|
|
||||||
// 设置gin模式
|
// 设置gin模式
|
||||||
gin.SetMode(config.GlobalConfig.Server.Mode)
|
gin.SetMode(config.GlobalConfig.Server.Mode)
|
||||||
|
|
||||||
@@ -78,6 +73,6 @@ func main() {
|
|||||||
if err := srv.Shutdown(context.Background()); err != nil {
|
if err := srv.Shutdown(context.Background()); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 5)
|
||||||
dao.CloseDB()
|
dao.CloseDB()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ const (
|
|||||||
CodeSuccess = 200
|
CodeSuccess = 200
|
||||||
CodeParamError = 400
|
CodeParamError = 400
|
||||||
CodeUnauthorized = 401
|
CodeUnauthorized = 401
|
||||||
|
CodeTokenExpired = 4011
|
||||||
CodeForbidden = 403
|
CodeForbidden = 403
|
||||||
CodeNotFound = 404
|
CodeNotFound = 404
|
||||||
CodeInvalidStatus = 405
|
CodeInvalidStatus = 405
|
||||||
@@ -23,6 +24,7 @@ const (
|
|||||||
CodeSuccessMsg = "success"
|
CodeSuccessMsg = "success"
|
||||||
CodeParamErrorMsg = "param error"
|
CodeParamErrorMsg = "param error"
|
||||||
CodeUnauthorizedMsg = "unauthorized"
|
CodeUnauthorizedMsg = "unauthorized"
|
||||||
|
CodeTokenExpiredMsg = "token expired"
|
||||||
CodeForbiddenMsg = "forbidden"
|
CodeForbiddenMsg = "forbidden"
|
||||||
CodeNotFoundMsg = "not found"
|
CodeNotFoundMsg = "not found"
|
||||||
CodeInvalidStatusMsg = "invalid status"
|
CodeInvalidStatusMsg = "invalid status"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||||
"gitea.com/bitwsd/document_ai/pkg/jwt"
|
"gitea.com/bitwsd/document_ai/pkg/jwt"
|
||||||
@@ -45,6 +46,30 @@ func AuthMiddleware(ctx *gin.Context) {
|
|||||||
ctx.Set(constant.ContextUserID, claims.UserId)
|
ctx.Set(constant.ContextUserID, claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MustAuthMiddleware() gin.HandlerFunc {
|
||||||
|
return func(ctx *gin.Context) {
|
||||||
|
token := ctx.GetHeader("Authorization")
|
||||||
|
if token == "" {
|
||||||
|
ctx.JSON(http.StatusOK, ErrorResponse(ctx, CodeUnauthorized, CodeUnauthorizedMsg))
|
||||||
|
ctx.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token = strings.TrimPrefix(token, "Bearer ")
|
||||||
|
claims, err := jwt.ParseToken(token)
|
||||||
|
if err != nil || claims == nil {
|
||||||
|
ctx.JSON(http.StatusOK, ErrorResponse(ctx, CodeUnauthorized, CodeUnauthorizedMsg))
|
||||||
|
ctx.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if claims.ExpiresAt < time.Now().Unix() {
|
||||||
|
ctx.JSON(http.StatusOK, ErrorResponse(ctx, CodeTokenExpired, CodeTokenExpiredMsg))
|
||||||
|
ctx.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx.Set(constant.ContextUserID, claims.UserId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetAuthMiddleware() gin.HandlerFunc {
|
func GetAuthMiddleware() gin.HandlerFunc {
|
||||||
return func(ctx *gin.Context) {
|
return func(ctx *gin.Context) {
|
||||||
token := ctx.GetHeader("Authorization")
|
token := ctx.GetHeader("Authorization")
|
||||||
|
|||||||
@@ -19,9 +19,9 @@ type Config struct {
|
|||||||
func DefaultConfig() Config {
|
func DefaultConfig() Config {
|
||||||
return Config{
|
return Config{
|
||||||
AllowOrigins: []string{"*"},
|
AllowOrigins: []string{"*"},
|
||||||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
|
||||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept"},
|
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With"},
|
||||||
ExposeHeaders: []string{"Content-Length"},
|
ExposeHeaders: []string{"Content-Length", "Content-Type"},
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
MaxAge: 86400, // 24 hours
|
MaxAge: 86400, // 24 hours
|
||||||
}
|
}
|
||||||
@@ -30,16 +30,30 @@ func DefaultConfig() Config {
|
|||||||
func Cors(config Config) gin.HandlerFunc {
|
func Cors(config Config) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
origin := c.Request.Header.Get("Origin")
|
origin := c.Request.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否允许该来源
|
// 检查是否允许该来源
|
||||||
allowOrigin := "*"
|
allowOrigin := ""
|
||||||
for _, o := range config.AllowOrigins {
|
for _, o := range config.AllowOrigins {
|
||||||
|
if o == "*" {
|
||||||
|
// 通配符时,回显实际 origin(兼容 credentials)
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
if o == origin {
|
if o == origin {
|
||||||
allowOrigin = origin
|
allowOrigin = origin
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if allowOrigin == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
||||||
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
|
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
|
||||||
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))
|
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ type RetryConfig struct {
|
|||||||
|
|
||||||
// DefaultRetryConfig 默认重试配置
|
// DefaultRetryConfig 默认重试配置
|
||||||
var DefaultRetryConfig = RetryConfig{
|
var DefaultRetryConfig = RetryConfig{
|
||||||
MaxRetries: 2,
|
MaxRetries: 1,
|
||||||
InitialInterval: 100 * time.Millisecond,
|
InitialInterval: 100 * time.Millisecond,
|
||||||
MaxInterval: 5 * time.Second,
|
MaxInterval: 30 * time.Second,
|
||||||
SkipTLSVerify: true,
|
SkipTLSVerify: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/requestid"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
)
|
)
|
||||||
@@ -67,8 +69,13 @@ func log(ctx context.Context, level zerolog.Level, logType LogType, kv ...interf
|
|||||||
// 添加日志类型
|
// 添加日志类型
|
||||||
event.Str("type", string(logType))
|
event.Str("type", string(logType))
|
||||||
|
|
||||||
// 添加请求ID
|
reqID := requestid.GetRequestID()
|
||||||
if reqID, exists := ctx.Value("request_id").(string); exists {
|
if reqID == "" {
|
||||||
|
if id, exists := ctx.Value("request_id").(string); exists {
|
||||||
|
reqID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if reqID != "" {
|
||||||
event.Str("request_id", reqID)
|
event.Str("request_id", reqID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,4 +156,3 @@ func Fatal(ctx context.Context, kv ...interface{}) {
|
|||||||
func Access(ctx context.Context, kv ...interface{}) {
|
func Access(ctx context.Context, kv ...interface{}) {
|
||||||
log(ctx, zerolog.InfoLevel, TypeAccess, kv...)
|
log(ctx, zerolog.InfoLevel, TypeAccess, kv...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gitea.com/bitwsd/document_ai/pkg/requestid"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RequestID() gin.HandlerFunc {
|
func RequestID() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
requestID := c.Request.Header.Get("X-Request-ID")
|
reqID := c.Request.Header.Get("X-Request-ID")
|
||||||
if requestID == "" {
|
if reqID == "" {
|
||||||
requestID = uuid.New().String()
|
reqID = uuid.New().String()
|
||||||
}
|
}
|
||||||
c.Request.Header.Set("X-Request-ID", requestID)
|
c.Request.Header.Set("X-Request-ID", reqID)
|
||||||
c.Set("request_id", requestID)
|
c.Set("request_id", reqID)
|
||||||
c.Next()
|
|
||||||
|
requestid.SetRequestID(reqID, func() {
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,8 +64,7 @@ func GetPolicyToken() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetPolicyURL(ctx context.Context, path string) (string, error) {
|
func GetPolicyURL(ctx context.Context, path string) (string, error) {
|
||||||
// Create OSS client
|
client, err := oss.New(config.GlobalConfig.Aliyun.OSS.Endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret, oss.UseCname(true))
|
||||||
client, err := oss.New(config.GlobalConfig.Aliyun.OSS.Endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "GetPolicyURL", "msg", "create oss client failed", "error", err)
|
log.Error(ctx, "func", "GetPolicyURL", "msg", "create oss client failed", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -120,12 +119,16 @@ func GetPolicyURL(ctx context.Context, path string) (string, error) {
|
|||||||
// DownloadFile downloads a file from OSS and returns the reader, caller should close the reader
|
// DownloadFile downloads a file from OSS and returns the reader, caller should close the reader
|
||||||
func DownloadFile(ctx context.Context, ossPath string) (io.ReadCloser, error) {
|
func DownloadFile(ctx context.Context, ossPath string) (io.ReadCloser, error) {
|
||||||
endpoint := config.GlobalConfig.Aliyun.OSS.InnerEndpoint
|
endpoint := config.GlobalConfig.Aliyun.OSS.InnerEndpoint
|
||||||
|
useCname := false
|
||||||
if config.GlobalConfig.Server.IsDebug() {
|
if config.GlobalConfig.Server.IsDebug() {
|
||||||
endpoint = config.GlobalConfig.Aliyun.OSS.Endpoint
|
endpoint = config.GlobalConfig.Aliyun.OSS.Endpoint
|
||||||
|
useCname = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info(ctx, "func", "DownloadFile", "msg", "endpoint", endpoint, "ossPath", ossPath)
|
||||||
|
|
||||||
// Create OSS client
|
// Create OSS client
|
||||||
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret)
|
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret, oss.UseCname(useCname))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "DownloadFile", "msg", "create oss client failed", "error", err)
|
log.Error(ctx, "func", "DownloadFile", "msg", "create oss client failed", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -151,7 +154,7 @@ func DownloadFile(ctx context.Context, ossPath string) (io.ReadCloser, error) {
|
|||||||
func GetDownloadURL(ctx context.Context, ossPath string) (string, error) {
|
func GetDownloadURL(ctx context.Context, ossPath string) (string, error) {
|
||||||
endpoint := config.GlobalConfig.Aliyun.OSS.Endpoint
|
endpoint := config.GlobalConfig.Aliyun.OSS.Endpoint
|
||||||
|
|
||||||
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret)
|
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret, oss.UseCname(true))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "GetDownloadURL", "msg", "create oss client failed", "error", err)
|
log.Error(ctx, "func", "GetDownloadURL", "msg", "create oss client failed", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -163,11 +166,13 @@ func GetDownloadURL(ctx context.Context, ossPath string) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
signURL, err := bucket.SignURL(ossPath, oss.HTTPGet, 60)
|
signURL, err := bucket.SignURL(ossPath, oss.HTTPGet, 3600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx, "func", "GetDownloadURL", "msg", "get object failed", "error", err)
|
log.Error(ctx, "func", "GetDownloadURL", "msg", "get object failed", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
signURL = strings.Replace(signURL, "http://", "https://", 1)
|
||||||
|
|
||||||
return signURL, nil
|
return signURL, nil
|
||||||
}
|
}
|
||||||
|
|||||||
27
pkg/requestid/requestid.go
Normal file
27
pkg/requestid/requestid.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package requestid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jtolds/gls"
|
||||||
|
)
|
||||||
|
|
||||||
|
// requestIDKey 是 gls 中存储 request_id 的 key
|
||||||
|
var requestIDKey = gls.GenSym()
|
||||||
|
|
||||||
|
// glsMgr 是 gls 管理器
|
||||||
|
var glsMgr = gls.NewContextManager()
|
||||||
|
|
||||||
|
// SetRequestID 在 gls 中设置 request_id,并在 fn 执行期间保持有效
|
||||||
|
func SetRequestID(requestID string, fn func()) {
|
||||||
|
glsMgr.SetValues(gls.Values{requestIDKey: requestID}, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestID 从 gls 中获取当前 goroutine 的 request_id
|
||||||
|
func GetRequestID() string {
|
||||||
|
val, ok := glsMgr.GetValue(requestIDKey)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
reqID, _ := val.(string)
|
||||||
|
return reqID
|
||||||
|
}
|
||||||
|
|
||||||
10
prod_deploy.sh
Executable file
10
prod_deploy.sh
Executable file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
docker build -t crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest . && docker push crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest
|
||||||
|
|
||||||
|
ssh ecs << 'ENDSSH'
|
||||||
|
docker stop doc_ai doc_ai_backend 2>/dev/null || true
|
||||||
|
docker rm doc_ai doc_ai_backend 2>/dev/null || true
|
||||||
|
docker pull crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest
|
||||||
|
docker run -d --name doc_ai -p 8024:8024 --restart unless-stopped crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest -env=prod
|
||||||
|
ENDSSH
|
||||||
Reference in New Issue
Block a user