test #2
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,3 +6,5 @@
|
||||
/upload
|
||||
texpixel
|
||||
/vendor
|
||||
|
||||
dev_deploy.sh
|
||||
@@ -1,6 +1,5 @@
|
||||
# Build stage
|
||||
FROM crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/golang:1.20-apline AS builder
|
||||
|
||||
FROM crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/golang:1.20-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -43,4 +42,4 @@ EXPOSE 8024
|
||||
ENTRYPOINT ["./doc_ai"]
|
||||
|
||||
# Default command (can be overridden)
|
||||
CMD ["-env", "prod"]
|
||||
CMD ["-env", "prod"]
|
||||
|
||||
@@ -5,15 +5,46 @@ import (
|
||||
"gitea.com/bitwsd/document_ai/api/v1/oss"
|
||||
"gitea.com/bitwsd/document_ai/api/v1/task"
|
||||
"gitea.com/bitwsd/document_ai/api/v1/user"
|
||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetupRouter(engine *gin.RouterGroup) {
|
||||
v1 := engine.Group("/v1")
|
||||
{
|
||||
formula.SetupRouter(v1)
|
||||
oss.SetupRouter(v1)
|
||||
task.SetupRouter(v1)
|
||||
user.SetupRouter(v1)
|
||||
formulaRouter := v1.Group("/formula", common.GetAuthMiddleware())
|
||||
{
|
||||
endpoint := formula.NewFormulaEndpoint()
|
||||
formulaRouter.POST("/recognition", endpoint.CreateTask)
|
||||
formulaRouter.POST("/ai_enhance", endpoint.AIEnhanceRecognition)
|
||||
formulaRouter.GET("/recognition/:task_no", endpoint.GetTaskStatus)
|
||||
}
|
||||
|
||||
taskRouter := v1.Group("/task", common.GetAuthMiddleware())
|
||||
{
|
||||
endpoint := task.NewTaskEndpoint()
|
||||
taskRouter.POST("/evaluate", endpoint.EvaluateTask)
|
||||
taskRouter.GET("/list", common.MustAuthMiddleware(), endpoint.GetTaskList)
|
||||
}
|
||||
|
||||
ossRouter := v1.Group("/oss", common.GetAuthMiddleware())
|
||||
{
|
||||
endpoint := oss.NewOSSEndpoint()
|
||||
ossRouter.POST("/signature", endpoint.GetPostObjectSignature)
|
||||
ossRouter.POST("/signature_url", endpoint.GetSignatureURL)
|
||||
ossRouter.POST("/file/upload", endpoint.UploadFile)
|
||||
}
|
||||
|
||||
userRouter := v1.Group("/user", common.GetAuthMiddleware())
|
||||
{
|
||||
userEndpoint := user.NewUserEndpoint()
|
||||
{
|
||||
userRouter.POST("/sms", userEndpoint.SendVerificationCode)
|
||||
userRouter.POST("/register", userEndpoint.RegisterByEmail)
|
||||
userRouter.POST("/login", userEndpoint.LoginByEmail)
|
||||
userRouter.GET("/info", common.MustAuthMiddleware(), userEndpoint.GetUserInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"gitea.com/bitwsd/document_ai/internal/service"
|
||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -37,11 +38,14 @@ func NewFormulaEndpoint() *FormulaEndpoint {
|
||||
// @Router /v1/formula/recognition [post]
|
||||
func (endpoint *FormulaEndpoint) CreateTask(ctx *gin.Context) {
|
||||
var req formula.CreateFormulaRecognitionRequest
|
||||
uid := ctx.GetInt64(constant.ContextUserID)
|
||||
if err := ctx.BindJSON(&req); err != nil {
|
||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid parameters"))
|
||||
return
|
||||
}
|
||||
|
||||
req.UserID = uid
|
||||
|
||||
if !utils.InArray(req.TaskType, []string{string(dao.TaskTypeFormula), string(dao.TaskTypeFormula)}) {
|
||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "Invalid task type"))
|
||||
return
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package formula
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetupRouter(engine *gin.RouterGroup) {
|
||||
endpoint := NewFormulaEndpoint()
|
||||
engine.POST("/formula/recognition", endpoint.CreateTask)
|
||||
engine.POST("/formula/ai_enhance", endpoint.AIEnhanceRecognition)
|
||||
engine.GET("/formula/recognition/:task_no", endpoint.GetTaskStatus)
|
||||
}
|
||||
@@ -11,14 +11,20 @@ import (
|
||||
"gitea.com/bitwsd/document_ai/config"
|
||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||||
"gitea.com/bitwsd/document_ai/pkg/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func GetPostObjectSignature(ctx *gin.Context) {
|
||||
type OSSEndpoint struct {
|
||||
}
|
||||
|
||||
func NewOSSEndpoint() *OSSEndpoint {
|
||||
return &OSSEndpoint{}
|
||||
}
|
||||
|
||||
func (h *OSSEndpoint) GetPostObjectSignature(ctx *gin.Context) {
|
||||
policyToken, err := oss.GetPolicyToken()
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, err.Error()))
|
||||
@@ -37,8 +43,7 @@ func GetPostObjectSignature(ctx *gin.Context) {
|
||||
// @Success 200 {object} common.Response{data=map[string]string{"sign_url":string, "repeat":bool, "path":string}} "Signed URL generated successfully"
|
||||
// @Failure 200 {object} common.Response "Error response"
|
||||
// @Router /signature_url [get]
|
||||
func GetSignatureURL(ctx *gin.Context) {
|
||||
userID := ctx.GetInt64(constant.ContextUserID)
|
||||
func (h *OSSEndpoint) GetSignatureURL(ctx *gin.Context) {
|
||||
type Req struct {
|
||||
FileHash string `json:"file_hash" binding:"required"`
|
||||
FileName string `json:"file_name" binding:"required"`
|
||||
@@ -51,7 +56,7 @@ func GetSignatureURL(ctx *gin.Context) {
|
||||
}
|
||||
taskDao := dao.NewRecognitionTaskDao()
|
||||
sess := dao.DB.WithContext(ctx)
|
||||
task, err := taskDao.GetTaskByFileURL(sess, userID, req.FileHash)
|
||||
task, err := taskDao.GetTaskByFileURL(sess, req.FileHash)
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeDBError, "failed to get task"))
|
||||
return
|
||||
@@ -78,7 +83,7 @@ func GetSignatureURL(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": url, "repeat": false, "path": path}))
|
||||
}
|
||||
|
||||
func UploadFile(ctx *gin.Context) {
|
||||
func (h *OSSEndpoint) UploadFile(ctx *gin.Context) {
|
||||
if err := os.MkdirAll(config.GlobalConfig.UploadDir, 0755); err != nil {
|
||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeSystemError, "Failed to create upload directory"))
|
||||
return
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package oss
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func SetupRouter(parent *gin.RouterGroup) {
|
||||
router := parent.Group("oss")
|
||||
{
|
||||
router.POST("/signature", GetPostObjectSignature)
|
||||
router.POST("/signature_url", GetSignatureURL)
|
||||
router.POST("/file/upload", UploadFile)
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,10 @@ package task
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||
"gitea.com/bitwsd/document_ai/internal/model/task"
|
||||
"gitea.com/bitwsd/document_ai/internal/service"
|
||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -43,6 +43,8 @@ func (h *TaskEndpoint) GetTaskList(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
req.UserID = common.GetUserIDFromContext(c)
|
||||
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetupRouter(engine *gin.RouterGroup) {
|
||||
endpoint := NewTaskEndpoint()
|
||||
engine.POST("/task/evaluate", endpoint.EvaluateTask)
|
||||
engine.GET("/task/list", endpoint.GetTaskList)
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetupRouter(router *gin.RouterGroup) {
|
||||
userEndpoint := NewUserEndpoint()
|
||||
userRouter := router.Group("/user")
|
||||
{
|
||||
userRouter.POST("/get/sms", userEndpoint.SendVerificationCode)
|
||||
userRouter.POST("/login/phone", userEndpoint.LoginByPhoneCode)
|
||||
userRouter.POST("/register/email", userEndpoint.RegisterByEmail)
|
||||
userRouter.POST("/login/email", userEndpoint.LoginByEmail)
|
||||
userRouter.GET("/info", common.GetAuthMiddleware(), userEndpoint.GetUserInfo)
|
||||
}
|
||||
}
|
||||
@@ -4,16 +4,16 @@ server:
|
||||
|
||||
database:
|
||||
driver: mysql
|
||||
host: 182.92.150.161
|
||||
port: 3006
|
||||
host: mysql
|
||||
port: 3306
|
||||
username: root
|
||||
password: yoge@coder%%%123321!
|
||||
password: texpixel#pwd123!
|
||||
dbname: doc_ai
|
||||
max_idle: 10
|
||||
max_open: 100
|
||||
|
||||
redis:
|
||||
addr: 182.92.150.161:6379
|
||||
addr: redis:6379
|
||||
password: yoge@123321!
|
||||
db: 0
|
||||
|
||||
@@ -22,7 +22,7 @@ limit:
|
||||
|
||||
log:
|
||||
appName: document_ai
|
||||
level: info # debug, info, warn, error
|
||||
level: info
|
||||
format: console # json, console
|
||||
outputPath: ./logs/app.log # 日志文件路径
|
||||
maxSize: 2 # 单个日志文件最大尺寸,单位MB
|
||||
@@ -39,8 +39,8 @@ aliyun:
|
||||
template_code: "SMS_291510729"
|
||||
|
||||
oss:
|
||||
endpoint: oss-cn-beijing.aliyuncs.com
|
||||
endpoint: static.texpixel.com
|
||||
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
||||
access_key_id: LTAI5tKogxeiBb4gJGWEePWN
|
||||
access_key_secret: l4oCxtt5iLSQ1DAs40guTzKUfrxXwq
|
||||
bucket_name: bitwsd-doc-ai
|
||||
access_key_id: LTAI5t8qXhow6NCdYDtu1saF
|
||||
access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
|
||||
bucket_name: texpixel-doc
|
||||
|
||||
@@ -18,7 +18,7 @@ redis:
|
||||
db: 0
|
||||
|
||||
limit:
|
||||
formula_recognition: 2
|
||||
formula_recognition: 10
|
||||
|
||||
log:
|
||||
appName: document_ai
|
||||
@@ -38,7 +38,7 @@ aliyun:
|
||||
template_code: "SMS_291510729"
|
||||
|
||||
oss:
|
||||
endpoint: oss-cn-beijing.aliyuncs.com
|
||||
endpoint: static.texpixel.com
|
||||
inner_endpoint: oss-cn-beijing-internal.aliyuncs.com
|
||||
access_key_id: LTAI5t8qXhow6NCdYDtu1saF
|
||||
access_key_secret: qZ2SwYsNCEBckCVSOszH31yYwXU44A
|
||||
|
||||
@@ -1,27 +1,50 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
doc_ai:
|
||||
build: .
|
||||
container_name: doc_ai
|
||||
ports:
|
||||
- "8024:8024"
|
||||
volumes:
|
||||
- ./config:/app/config
|
||||
- ./logs:/app/logs
|
||||
networks:
|
||||
- backend
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
command: ["-env", "dev"]
|
||||
restart: always
|
||||
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
container_name: mysql
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: 123456 # 设置root用户密码
|
||||
MYSQL_DATABASE: document_ai # 设置默认数据库名
|
||||
MYSQL_USER: bitwsd_document # 设置数据库用户名
|
||||
MYSQL_PASSWORD: 123456 # 设置数据库用户密码
|
||||
MYSQL_ROOT_PASSWORD: texpixel#pwd123!
|
||||
MYSQL_DATABASE: doc_ai
|
||||
MYSQL_USER: texpixel
|
||||
MYSQL_PASSWORD: texpixel#pwd123!
|
||||
ports:
|
||||
- "3306:3306" # 映射宿主机的3306端口到容器内的3306
|
||||
- "3006:3306"
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql # 持久化MySQL数据
|
||||
- mysql_data:/var/lib/mysql
|
||||
networks:
|
||||
- backend
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-uroot", "-ptexpixel#pwd123!"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
start_period: 30s
|
||||
restart: always
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
command: redis-server --requirepass "yoge@123321!"
|
||||
ports:
|
||||
- "6379:6379" # 映射宿主机的6379端口到容器内的6379
|
||||
- "6079:6379"
|
||||
networks:
|
||||
- backend
|
||||
restart: always
|
||||
|
||||
2
go.mod
2
go.mod
@@ -14,6 +14,7 @@ require (
|
||||
github.com/redis/go-redis/v9 v9.7.0
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/spf13/viper v1.19.0
|
||||
golang.org/x/crypto v0.23.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/gorm v1.25.12
|
||||
@@ -68,7 +69,6 @@ require (
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
|
||||
@@ -5,6 +5,7 @@ type CreateFormulaRecognitionRequest struct {
|
||||
FileHash string `json:"file_hash" binding:"required"` // file hash
|
||||
FileName string `json:"file_name" binding:"required"` // file name
|
||||
TaskType string `json:"task_type" binding:"required,oneof=FORMULA"` // task type
|
||||
UserID int64 `json:"user_id"` // user id
|
||||
}
|
||||
|
||||
type GetRecognitionStatusRequest struct {
|
||||
|
||||
@@ -11,26 +11,26 @@ type TaskListRequest struct {
|
||||
TaskType string `json:"task_type" form:"task_type" binding:"required"`
|
||||
Page int `json:"page" form:"page"`
|
||||
PageSize int `json:"page_size" form:"page_size"`
|
||||
}
|
||||
|
||||
type PdfInfo struct {
|
||||
PageCount int `json:"page_count"`
|
||||
PageWidth int `json:"page_width"`
|
||||
PageHeight int `json:"page_height"`
|
||||
UserID int64 `json:"-"`
|
||||
}
|
||||
|
||||
type TaskListDTO struct {
|
||||
TaskID string `json:"task_id"`
|
||||
FileName string `json:"file_name"`
|
||||
Status string `json:"status"`
|
||||
Path string `json:"path"`
|
||||
TaskType string `json:"task_type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
PdfInfo PdfInfo `json:"pdf_info"`
|
||||
TaskID string `json:"task_id"`
|
||||
FileName string `json:"file_name"`
|
||||
Status int `json:"status"`
|
||||
OriginURL string `json:"origin_url"`
|
||||
TaskType string `json:"task_type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Latex string `json:"latex"`
|
||||
Markdown string `json:"markdown"`
|
||||
MathML string `json:"mathml"`
|
||||
MathMLMW string `json:"mathml_mw"`
|
||||
ImageBlob string `json:"image_blob"`
|
||||
DocxURL string `json:"docx_url"`
|
||||
PDFURL string `json:"pdf_url"`
|
||||
}
|
||||
|
||||
type TaskListResponse struct {
|
||||
TaskList []*TaskListDTO `json:"task_list"`
|
||||
HasMore bool `json:"has_more"`
|
||||
NextPage int `json:"next_page"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func (s *RecognitionService) CreateRecognitionTask(ctx context.Context, req *for
|
||||
sess := dao.DB.WithContext(ctx)
|
||||
taskDao := dao.NewRecognitionTaskDao()
|
||||
task := &dao.RecognitionTask{
|
||||
UserID: req.UserID,
|
||||
TaskUUID: utils.NewUUID(),
|
||||
TaskType: dao.TaskType(req.TaskType),
|
||||
Status: dao.TaskStatusPending,
|
||||
@@ -166,7 +167,8 @@ func (s *RecognitionService) GetFormualTask(ctx context.Context, taskNo string)
|
||||
return nil, common.NewError(common.CodeDBError, "查询任务结果失败", err)
|
||||
}
|
||||
latex := taskRet.NewContentCodec().GetContent().(string)
|
||||
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Status: int(task.Status)}, nil
|
||||
markdown := fmt.Sprintf("$$%s$$", latex)
|
||||
return &formula.GetFormulaTaskResponse{TaskNo: taskNo, Latex: latex, Markdown: markdown, Status: int(task.Status)}, nil
|
||||
}
|
||||
|
||||
func (s *RecognitionService) handleFormulaRecognition(ctx context.Context, taskID int64) error {
|
||||
@@ -281,7 +283,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
||||
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
||||
|
||||
// 发送请求时会使用带超时的context
|
||||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "http://cloud.texpixel.com:1080/formula/predict", bytes.NewReader(jsonData), headers)
|
||||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "https://cloud.texpixel.com:10443/vlm/formula/predict", bytes.NewReader(jsonData), headers)
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
||||
@@ -306,7 +308,8 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
||||
log.Error(ctx, "func", "processFormulaTask", "msg", "解析响应JSON失败", "error", err)
|
||||
return err
|
||||
}
|
||||
katex := utils.ToKatex(formulaResp.Result)
|
||||
// katex := utils.ToKatex(formulaResp.Result)
|
||||
katex := formulaResp.Result
|
||||
content := &dao.FormulaRecognitionContent{Latex: katex}
|
||||
b, _ := json.Marshal(content)
|
||||
// Save recognition result
|
||||
|
||||
@@ -3,25 +3,31 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||
"gitea.com/bitwsd/document_ai/internal/model/task"
|
||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||
"gorm.io/gorm"
|
||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||
"gitea.com/bitwsd/document_ai/pkg/oss"
|
||||
)
|
||||
|
||||
type TaskService struct {
|
||||
db *gorm.DB
|
||||
recognitionTaskDao *dao.RecognitionTaskDao
|
||||
evaluateTaskDao *dao.EvaluateTaskDao
|
||||
recognitionResultDao *dao.RecognitionResultDao
|
||||
}
|
||||
|
||||
func NewTaskService() *TaskService {
|
||||
return &TaskService{dao.DB}
|
||||
return &TaskService{
|
||||
recognitionTaskDao: dao.NewRecognitionTaskDao(),
|
||||
evaluateTaskDao: dao.NewEvaluateTaskDao(),
|
||||
recognitionResultDao: dao.NewRecognitionResultDao(),
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTaskRequest) error {
|
||||
taskDao := dao.NewRecognitionTaskDao()
|
||||
task, err := taskDao.GetByTaskNo(svc.db.WithContext(ctx), req.TaskNo)
|
||||
task, err := svc.recognitionTaskDao.GetByTaskNo(dao.DB.WithContext(ctx), req.TaskNo)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "EvaluateTask", "msg", "get task by task no failed", "error", err)
|
||||
return err
|
||||
@@ -36,14 +42,13 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
|
||||
return errors.New("task not finished")
|
||||
}
|
||||
|
||||
evaluateTaskDao := dao.NewEvaluateTaskDao()
|
||||
evaluateTask := &dao.EvaluateTask{
|
||||
TaskID: task.ID,
|
||||
Satisfied: req.Satisfied,
|
||||
Feedback: req.Feedback,
|
||||
Comment: strings.Join(req.Suggestion, ","),
|
||||
}
|
||||
err = evaluateTaskDao.Create(svc.db.WithContext(ctx), evaluateTask)
|
||||
err = svc.evaluateTaskDao.Create(dao.DB.WithContext(ctx), evaluateTask)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "EvaluateTask", "msg", "create evaluate task failed", "error", err)
|
||||
return err
|
||||
@@ -53,23 +58,51 @@ func (svc *TaskService) EvaluateTask(ctx context.Context, req *task.EvaluateTask
|
||||
}
|
||||
|
||||
func (svc *TaskService) GetTaskList(ctx context.Context, req *task.TaskListRequest) (*task.TaskListResponse, error) {
|
||||
taskDao := dao.NewRecognitionTaskDao()
|
||||
tasks, err := taskDao.GetTaskList(svc.db.WithContext(ctx), dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
||||
tasks, total, err := svc.recognitionTaskDao.GetTaskList(dao.DB.WithContext(ctx), req.UserID, dao.TaskType(req.TaskType), req.Page, req.PageSize)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "GetTaskList", "msg", "get task list failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
taskIDs := make([]int64, 0, len(tasks))
|
||||
for _, item := range tasks {
|
||||
taskIDs = append(taskIDs, item.ID)
|
||||
}
|
||||
|
||||
recognitionResults, err := svc.recognitionResultDao.GetByTaskIDs(dao.DB.WithContext(ctx), taskIDs)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "GetTaskList", "msg", "get recognition results failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recognitionResultMap := make(map[int64]*dao.RecognitionResult)
|
||||
for _, item := range recognitionResults {
|
||||
recognitionResultMap[item.TaskID] = item
|
||||
}
|
||||
|
||||
resp := &task.TaskListResponse{
|
||||
TaskList: make([]*task.TaskListDTO, 0, len(tasks)),
|
||||
HasMore: false,
|
||||
NextPage: 0,
|
||||
Total: total,
|
||||
}
|
||||
for _, item := range tasks {
|
||||
var latex string
|
||||
var markdown string
|
||||
recognitionResult := recognitionResultMap[item.ID]
|
||||
if recognitionResult != nil {
|
||||
latex = recognitionResult.NewContentCodec().GetContent().(string)
|
||||
markdown = fmt.Sprintf("$$%s$$", latex)
|
||||
}
|
||||
originURL, err := oss.GetDownloadURL(ctx, item.FileURL)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "GetTaskList", "msg", "get origin url failed", "error", err)
|
||||
}
|
||||
resp.TaskList = append(resp.TaskList, &task.TaskListDTO{
|
||||
Latex: latex,
|
||||
Markdown: markdown,
|
||||
TaskID: item.TaskUUID,
|
||||
FileName: item.FileName,
|
||||
Status: item.Status.String(),
|
||||
Path: item.FileURL,
|
||||
Status: int(item.Status),
|
||||
OriginURL: originURL,
|
||||
TaskType: item.TaskType.String(),
|
||||
CreatedAt: item.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
|
||||
@@ -84,6 +84,11 @@ func (dao *RecognitionResultDao) GetByTaskID(tx *gorm.DB, taskID int64) (result
|
||||
return
|
||||
}
|
||||
|
||||
func (dao *RecognitionResultDao) GetByTaskIDs(tx *gorm.DB, taskIDs []int64) (results []*RecognitionResult, err error) {
|
||||
err = tx.Where("task_id IN (?)", taskIDs).Find(&results).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
|
||||
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
@@ -69,9 +69,9 @@ func (dao *RecognitionTaskDao) GetByTaskNo(tx *gorm.DB, taskUUID string) (task *
|
||||
return
|
||||
}
|
||||
|
||||
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, userID int64, fileHash string) (task *RecognitionTask, err error) {
|
||||
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, fileHash string) (task *RecognitionTask, err error) {
|
||||
task = &RecognitionTask{}
|
||||
err = tx.Model(RecognitionTask{}).Where("user_id = ? AND file_hash = ?", userID, fileHash).First(task).Error
|
||||
err = tx.Model(RecognitionTask{}).Where("file_hash = ?", fileHash).Last(task).Error
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,8 +87,13 @@ func (dao *RecognitionTaskDao) GetTaskByID(tx *gorm.DB, id int64) (task *Recogni
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, err error) {
|
||||
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, userID int64, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, total int64, err error) {
|
||||
offset := (page - 1) * pageSize
|
||||
err = tx.Model(RecognitionTask{}).Where("task_type = ?", taskType).Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
|
||||
return
|
||||
query := tx.Model(RecognitionTask{}).Where("user_id = ? AND task_type = ?", userID, taskType)
|
||||
err = query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = query.Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
|
||||
return tasks, total, err
|
||||
}
|
||||
|
||||
16
main.go
16
main.go
@@ -10,14 +10,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gitea.com/bitwsd/document_ai/pkg/cors"
|
||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||
"gitea.com/bitwsd/document_ai/pkg/middleware"
|
||||
"gitea.com/bitwsd/document_ai/api"
|
||||
"gitea.com/bitwsd/document_ai/config"
|
||||
"gitea.com/bitwsd/document_ai/internal/storage/cache"
|
||||
"gitea.com/bitwsd/document_ai/internal/storage/dao"
|
||||
"gitea.com/bitwsd/document_ai/pkg/common"
|
||||
"gitea.com/bitwsd/document_ai/pkg/cors"
|
||||
"gitea.com/bitwsd/document_ai/pkg/log"
|
||||
"gitea.com/bitwsd/document_ai/pkg/middleware"
|
||||
"gitea.com/bitwsd/document_ai/pkg/sms"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -42,14 +42,6 @@ func main() {
|
||||
cache.InitRedisClient(config.GlobalConfig.Redis)
|
||||
sms.InitSmsClient()
|
||||
|
||||
// 初始化Redis
|
||||
// cache.InitRedis(config.GlobalConfig.Redis.Addr)
|
||||
|
||||
// 初始化OSS客户端
|
||||
// if err := oss.InitOSS(config.GlobalConfig.OSS); err != nil {
|
||||
// logger.Fatal("Failed to init OSS client", logger.Fields{"error": err})
|
||||
// }
|
||||
|
||||
// 设置gin模式
|
||||
gin.SetMode(config.GlobalConfig.Server.Mode)
|
||||
|
||||
@@ -78,6 +70,6 @@ func main() {
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
time.Sleep(time.Second * 5)
|
||||
dao.CloseDB()
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ const (
|
||||
CodeSuccess = 200
|
||||
CodeParamError = 400
|
||||
CodeUnauthorized = 401
|
||||
CodeTokenExpired = 4011
|
||||
CodeForbidden = 403
|
||||
CodeNotFound = 404
|
||||
CodeInvalidStatus = 405
|
||||
@@ -23,6 +24,7 @@ const (
|
||||
CodeSuccessMsg = "success"
|
||||
CodeParamErrorMsg = "param error"
|
||||
CodeUnauthorizedMsg = "unauthorized"
|
||||
CodeTokenExpiredMsg = "token expired"
|
||||
CodeForbiddenMsg = "forbidden"
|
||||
CodeNotFoundMsg = "not found"
|
||||
CodeInvalidStatusMsg = "invalid status"
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.com/bitwsd/document_ai/pkg/constant"
|
||||
"gitea.com/bitwsd/document_ai/pkg/jwt"
|
||||
@@ -45,6 +46,30 @@ func AuthMiddleware(ctx *gin.Context) {
|
||||
ctx.Set(constant.ContextUserID, claims.UserId)
|
||||
}
|
||||
|
||||
func MustAuthMiddleware() gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
token := ctx.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
ctx.JSON(http.StatusOK, ErrorResponse(ctx, CodeUnauthorized, CodeUnauthorizedMsg))
|
||||
ctx.Abort()
|
||||
return
|
||||
}
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
claims, err := jwt.ParseToken(token)
|
||||
if err != nil || claims == nil {
|
||||
ctx.JSON(http.StatusOK, ErrorResponse(ctx, CodeUnauthorized, CodeUnauthorizedMsg))
|
||||
ctx.Abort()
|
||||
return
|
||||
}
|
||||
if claims.ExpiresAt < time.Now().Unix() {
|
||||
ctx.JSON(http.StatusOK, ErrorResponse(ctx, CodeTokenExpired, CodeTokenExpiredMsg))
|
||||
ctx.Abort()
|
||||
return
|
||||
}
|
||||
ctx.Set(constant.ContextUserID, claims.UserId)
|
||||
}
|
||||
}
|
||||
|
||||
func GetAuthMiddleware() gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
token := ctx.GetHeader("Authorization")
|
||||
|
||||
@@ -19,9 +19,9 @@ type Config struct {
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept"},
|
||||
ExposeHeaders: []string{"Content-Length"},
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With"},
|
||||
ExposeHeaders: []string{"Content-Length", "Content-Type"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24 hours
|
||||
}
|
||||
@@ -30,16 +30,30 @@ func DefaultConfig() Config {
|
||||
func Cors(config Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否允许该来源
|
||||
allowOrigin := "*"
|
||||
allowOrigin := ""
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" {
|
||||
// 通配符时,回显实际 origin(兼容 credentials)
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == origin {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowOrigin == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
||||
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
|
||||
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))
|
||||
|
||||
@@ -64,8 +64,7 @@ func GetPolicyToken() (string, error) {
|
||||
}
|
||||
|
||||
func GetPolicyURL(ctx context.Context, path string) (string, error) {
|
||||
// Create OSS client
|
||||
client, err := oss.New(config.GlobalConfig.Aliyun.OSS.Endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret)
|
||||
client, err := oss.New(config.GlobalConfig.Aliyun.OSS.Endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret, oss.UseCname(true))
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "GetPolicyURL", "msg", "create oss client failed", "error", err)
|
||||
return "", err
|
||||
@@ -125,7 +124,7 @@ func DownloadFile(ctx context.Context, ossPath string) (io.ReadCloser, error) {
|
||||
}
|
||||
|
||||
// Create OSS client
|
||||
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret)
|
||||
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret, oss.UseCname(true))
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "DownloadFile", "msg", "create oss client failed", "error", err)
|
||||
return nil, err
|
||||
@@ -151,7 +150,7 @@ func DownloadFile(ctx context.Context, ossPath string) (io.ReadCloser, error) {
|
||||
func GetDownloadURL(ctx context.Context, ossPath string) (string, error) {
|
||||
endpoint := config.GlobalConfig.Aliyun.OSS.Endpoint
|
||||
|
||||
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret)
|
||||
client, err := oss.New(endpoint, config.GlobalConfig.Aliyun.OSS.AccessKeyID, config.GlobalConfig.Aliyun.OSS.AccessKeySecret, oss.UseCname(true))
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "GetDownloadURL", "msg", "create oss client failed", "error", err)
|
||||
return "", err
|
||||
@@ -163,11 +162,13 @@ func GetDownloadURL(ctx context.Context, ossPath string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signURL, err := bucket.SignURL(ossPath, oss.HTTPGet, 60)
|
||||
signURL, err := bucket.SignURL(ossPath, oss.HTTPGet, 3600)
|
||||
if err != nil {
|
||||
log.Error(ctx, "func", "GetDownloadURL", "msg", "get object failed", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
signURL = strings.Replace(signURL, "http://", "https://", 1)
|
||||
|
||||
return signURL, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user