init repo

This commit is contained in:
liuyuanchuang
2025-12-10 18:33:37 +08:00
commit 48e63894eb
2408 changed files with 1053045 additions and 0 deletions

33
internal/storage/cache/engine.go vendored Normal file
View File

@@ -0,0 +1,33 @@
package cache
import (
"context"
"fmt"
"time"
"gitea.com/bitwsd/document_ai/config"
"github.com/redis/go-redis/v9"
)
var RedisClient *redis.Client
func InitRedisClient(config config.RedisConfig) {
fmt.Println("Initializing Redis client...")
RedisClient = redis.NewClient(&redis.Options{
Addr: config.Addr,
Password: config.Password,
DB: config.DB,
DialTimeout: 10 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
})
fmt.Println("Pinging Redis server...")
_, err := RedisClient.Ping(context.Background()).Result()
if err != nil {
fmt.Printf("Init redis client failed, err: %v\n", err)
panic(err)
}
fmt.Println("Redis client initialized successfully.")
}

100
internal/storage/cache/formula.go vendored Normal file
View File

@@ -0,0 +1,100 @@
package cache
import (
"context"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const (
FormulaRecognitionTaskCount = "formula_recognition_task"
FormulaRecognitionTaskQueue = "formula_recognition_queue"
FormulaRecognitionDistLock = "formula_recognition_dist_lock"
VLMFormulaCount = "vlm_formula_count:%s" // VLM公式识别次数 ip
VLMRecognitionTaskQueue = "vlm_recognition_queue"
DefaultLockTimeout = 60 * time.Second // 默认锁超时时间
)
// TODO the sigle queue not reliable, message maybe lost
func PushVLMRecognitionTask(ctx context.Context, taskID int64) (count int64, err error) {
count, err = RedisClient.LPush(ctx, VLMRecognitionTaskQueue, taskID).Result()
if err != nil {
return 0, err
}
return count, nil
}
func PopVLMRecognitionTask(ctx context.Context) (int64, error) {
result, err := RedisClient.BRPop(ctx, 0, VLMRecognitionTaskQueue).Result()
if err != nil {
return 0, err
}
return strconv.ParseInt(result[1], 10, 64)
}
func PushFormulaTask(ctx context.Context, taskID int64) (count int64, err error) {
count, err = RedisClient.LPush(ctx, FormulaRecognitionTaskQueue, taskID).Result()
if err != nil {
return 0, err
}
return count, nil
}
func PopFormulaTask(ctx context.Context) (int64, error) {
result, err := RedisClient.BRPop(ctx, 0, FormulaRecognitionTaskQueue).Result()
if err != nil {
return 0, err
}
return strconv.ParseInt(result[1], 10, 64)
}
func GetFormulaTaskCount(ctx context.Context) (int64, error) {
count, err := RedisClient.LLen(ctx, FormulaRecognitionTaskQueue).Result()
if err != nil {
return 0, err
}
return count, nil
}
// GetDistributedLock 获取分布式锁
func GetDistributedLock(ctx context.Context) (bool, error) {
return RedisClient.SetNX(ctx, FormulaRecognitionDistLock, "locked", DefaultLockTimeout).Result()
}
// ReleaseLock 释放分布式锁
func ReleaseLock(ctx context.Context) error {
return RedisClient.Del(ctx, FormulaRecognitionDistLock).Err()
}
func GetVLMFormulaCount(ctx context.Context, ip string) (int64, error) {
count, err := RedisClient.Get(ctx, fmt.Sprintf(VLMFormulaCount, ip)).Result()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
return strconv.ParseInt(count, 10, 64)
}
func IncrVLMFormulaCount(ctx context.Context, ip string) (int64, error) {
key := fmt.Sprintf(VLMFormulaCount, ip)
count, err := RedisClient.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
if count == 1 {
now := time.Now()
nextMidnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
ttl := nextMidnight.Sub(now)
if err := RedisClient.Expire(ctx, key, ttl).Err(); err != nil {
return count, err
}
}
return count, nil
}

12
internal/storage/cache/url.go vendored Normal file
View File

@@ -0,0 +1,12 @@
package cache
import "context"
func IncrURLCount(ctx context.Context) (int64, error) {
key := "formula_recognition:url_count"
count, err := RedisClient.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
return count, nil
}

63
internal/storage/cache/user.go vendored Normal file
View File

@@ -0,0 +1,63 @@
package cache
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const (
UserSmsCodeTTL = 10 * time.Minute
UserSendSmsLimitTTL = 24 * time.Hour
UserSendSmsLimitCount = 5
)
const (
UserSmsCodePrefix = "user:sms_code:%s"
UserSendSmsLimit = "user:send_sms_limit:%s"
)
func GetUserSmsCode(ctx context.Context, phone string) (string, error) {
code, err := RedisClient.Get(ctx, fmt.Sprintf(UserSmsCodePrefix, phone)).Result()
if err != nil {
if err == redis.Nil {
return "", nil
}
return "", err
}
return code, nil
}
func SetUserSmsCode(ctx context.Context, phone, code string) error {
return RedisClient.Set(ctx, fmt.Sprintf(UserSmsCodePrefix, phone), code, UserSmsCodeTTL).Err()
}
func GetUserSendSmsLimit(ctx context.Context, phone string) (int, error) {
limit, err := RedisClient.Get(ctx, fmt.Sprintf(UserSendSmsLimit, phone)).Result()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, err
}
return strconv.Atoi(limit)
}
func SetUserSendSmsLimit(ctx context.Context, phone string) error {
count, err := RedisClient.Incr(ctx, fmt.Sprintf(UserSendSmsLimit, phone)).Result()
if err != nil {
return err
}
if count > UserSendSmsLimitCount {
return errors.New("send sms limit")
}
return RedisClient.Expire(ctx, fmt.Sprintf(UserSendSmsLimit, phone), UserSendSmsLimitTTL).Err()
}
func DeleteUserSmsCode(ctx context.Context, phone string) error {
return RedisClient.Del(ctx, fmt.Sprintf(UserSmsCodePrefix, phone)).Err()
}

View File

@@ -0,0 +1,11 @@
package dao
import (
"time"
)
type BaseModel struct {
ID int64 `gorm:"bigint;primaryKey;autoIncrement;column:id;comment:主键ID" json:"id"`
CreatedAt time.Time `gorm:"column:created_at;comment:创建时间;not null;default:current_timestamp" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;comment:更新时间;not null;default:current_timestamp on update current_timestamp" json:"updated_at"`
}

View File

@@ -0,0 +1,31 @@
package dao
import (
"fmt"
"gitea.com/bitwsd/document_ai/config"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var DB *gorm.DB
func InitDB(conf config.DatabaseConfig) {
dns := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", conf.Username, conf.Password, conf.Host, conf.Port, conf.DBName)
db, err := gorm.Open(mysql.Open(dns), &gorm.Config{})
if err != nil {
panic(err)
}
sqlDB, err := db.DB()
if err != nil {
panic(err)
}
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
DB = db
}
func CloseDB() {
sqlDB, _ := DB.DB()
sqlDB.Close()
}

View File

@@ -0,0 +1,26 @@
package dao
import "gorm.io/gorm"
type EvaluateTask struct {
BaseModel
TaskID int64 `gorm:"column:task_id;type:int;not null;comment:任务ID"`
Satisfied int `gorm:"column:satisfied;type:int;not null;comment:满意"`
Feedback string `gorm:"column:feedback;type:text;not null;comment:反馈"`
Comment string `gorm:"column:comment;type:text;not null;comment:评论"`
}
func (EvaluateTask) TableName() string {
return "evaluate_tasks"
}
type EvaluateTaskDao struct {
}
func NewEvaluateTaskDao() *EvaluateTaskDao {
return &EvaluateTaskDao{}
}
func (dao *EvaluateTaskDao) Create(sess *gorm.DB, data *EvaluateTask) error {
return sess.Create(data).Error
}

View File

@@ -0,0 +1,89 @@
package dao
import (
"encoding/json"
"gorm.io/gorm"
)
type JSON []byte
// ContentCodec 定义内容编解码接口
type ContentCodec interface {
Encode() (JSON, error)
Decode() error
GetContent() interface{} // 更明确的方法名
}
type FormulaRecognitionContent struct {
content JSON
Latex string `json:"latex"`
AdjustLatex string `json:"adjust_latex"`
EnhanceLatex string `json:"enhance_latex"`
}
func (c *FormulaRecognitionContent) Encode() (JSON, error) {
b, err := json.Marshal(c)
if err != nil {
return nil, err
}
return b, nil
}
func (c *FormulaRecognitionContent) Decode() error {
return json.Unmarshal(c.content, c)
}
// GetPreferredContent 按优先级返回公式内容
func (c *FormulaRecognitionContent) GetContent() interface{} {
c.Decode()
if c.EnhanceLatex != "" {
return c.EnhanceLatex
} else if c.AdjustLatex != "" {
return c.AdjustLatex
} else {
return c.Latex
}
}
type RecognitionResult struct {
BaseModel
TaskID int64 `gorm:"column:task_id;bigint;not null;default:0;comment:任务ID" json:"task_id"`
TaskType TaskType `gorm:"column:task_type;varchar(16);not null;comment:任务类型;default:''" json:"task_type"`
Content JSON `gorm:"column:content;type:json;not null;comment:识别内容" json:"content"`
}
// NewContentCodec 创建对应任务类型的内容编解码器
func (r *RecognitionResult) NewContentCodec() ContentCodec {
switch r.TaskType {
case TaskTypeFormula:
return &FormulaRecognitionContent{content: r.Content}
default:
return nil
}
}
type RecognitionResultDao struct {
}
func NewRecognitionResultDao() *RecognitionResultDao {
return &RecognitionResultDao{}
}
// 模型方法
func (dao *RecognitionResultDao) Create(tx *gorm.DB, data RecognitionResult) error {
return tx.Create(&data).Error
}
func (dao *RecognitionResultDao) GetByTaskID(tx *gorm.DB, taskID int64) (result *RecognitionResult, err error) {
result = &RecognitionResult{}
err = tx.Where("task_id = ?", taskID).First(result).Error
if err != nil && err == gorm.ErrRecordNotFound {
return nil, nil
}
return
}
func (dao *RecognitionResultDao) Update(tx *gorm.DB, id int64, updates map[string]interface{}) error {
return tx.Model(&RecognitionResult{}).Where("id = ?", id).Updates(updates).Error
}

View File

@@ -0,0 +1,94 @@
package dao
import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type TaskStatus int
type TaskType string
const (
TaskStatusPending TaskStatus = 0
TaskStatusProcessing TaskStatus = 1
TaskStatusCompleted TaskStatus = 2
TaskStatusFailed TaskStatus = 3
TaskTypeFormula TaskType = "FORMULA"
TaskTypeText TaskType = "TEXT"
TaskTypeTable TaskType = "TABLE"
TaskTypeLayout TaskType = "LAYOUT"
)
func (t TaskType) String() string {
return string(t)
}
func (t TaskStatus) String() string {
return []string{"PENDING", "PROCESSING", "COMPLETED", "FAILED"}[t]
}
type RecognitionTask struct {
BaseModel
UserID int64 `gorm:"column:user_id;not null;default:0;comment:用户ID" json:"user_id"`
TaskUUID string `gorm:"column:task_uuid;varchar(64);not null;default:'';comment:任务唯一标识" json:"task_uuid"`
FileName string `gorm:"column:file_name;varchar(256);not null;default:'';comment:文件名" json:"file_name"`
FileHash string `gorm:"column:file_hash;varchar(64);not null;default:'';comment:文件hash" json:"file_hash"`
FileURL string `gorm:"column:file_url;varchar(128);not null;comment:oss文件地址;default:''" json:"file_url"`
TaskType TaskType `gorm:"column:task_type;varchar(16);not null;comment:任务类型;default:''" json:"task_type"`
Status TaskStatus `gorm:"column:status;tinyint(2);not null;comment:任务状态;default:0" json:"status"`
CompletedAt time.Time `gorm:"column:completed_at;not null;default:current_timestamp;comment:完成时间" json:"completed_at"`
Remark string `gorm:"column:remark;varchar(64);comment:备注;not null;default:''" json:"remark"`
IP string `gorm:"column:ip;varchar(16);comment:IP地址;not null;default:''" json:"ip"`
}
func (t *RecognitionTask) TableName() string {
return "recognition_tasks"
}
type RecognitionTaskDao struct{}
func NewRecognitionTaskDao() *RecognitionTaskDao {
return &RecognitionTaskDao{}
}
// 模型方法
func (dao *RecognitionTaskDao) Create(tx *gorm.DB, data *RecognitionTask) error {
return tx.Create(data).Error
}
func (dao *RecognitionTaskDao) Update(tx *gorm.DB, filter map[string]interface{}, data map[string]interface{}) error {
return tx.Model(RecognitionTask{}).Where(filter).Updates(data).Error
}
func (dao *RecognitionTaskDao) GetByTaskNo(tx *gorm.DB, taskUUID string) (task *RecognitionTask, err error) {
task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("task_uuid = ?", taskUUID).First(task).Error
return
}
func (dao *RecognitionTaskDao) GetTaskByFileURL(tx *gorm.DB, userID int64, fileHash string) (task *RecognitionTask, err error) {
task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("user_id = ? AND file_hash = ?", userID, fileHash).First(task).Error
return
}
func (dao *RecognitionTaskDao) GetTaskByID(tx *gorm.DB, id int64) (task *RecognitionTask, err error) {
task = &RecognitionTask{}
err = tx.Model(RecognitionTask{}).Where("id = ?", id).First(task).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return task, nil
}
func (dao *RecognitionTaskDao) GetTaskList(tx *gorm.DB, taskType TaskType, page int, pageSize int) (tasks []*RecognitionTask, err error) {
offset := (page - 1) * pageSize
err = tx.Model(RecognitionTask{}).Where("task_type = ?", taskType).Offset(offset).Limit(pageSize).Order(clause.OrderByColumn{Column: clause.Column{Name: "id"}, Desc: true}).Find(&tasks).Error
return
}

View File

@@ -0,0 +1,53 @@
package dao
import (
"errors"
"gorm.io/gorm"
)
type User struct {
BaseModel
Username string `gorm:"column:username" json:"username"`
Phone string `gorm:"column:phone" json:"phone"`
Password string `gorm:"column:password" json:"password"`
WechatOpenID string `gorm:"column:wechat_open_id" json:"wechat_open_id"`
WechatUnionID string `gorm:"column:wechat_union_id" json:"wechat_union_id"`
}
func (u *User) TableName() string {
return "users"
}
type UserDao struct {
}
func NewUserDao() *UserDao {
return &UserDao{}
}
func (dao *UserDao) Create(tx *gorm.DB, user *User) error {
return tx.Create(user).Error
}
func (dao *UserDao) GetByPhone(tx *gorm.DB, phone string) (*User, error) {
var user User
if err := tx.Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &user, nil
}
func (dao *UserDao) GetByID(tx *gorm.DB, id int64) (*User, error) {
var user User
if err := tx.Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &user, nil
}