diff --git a/cmd/migrate/README.md b/cmd/migrate/README.md new file mode 100644 index 0000000..cc67ca9 --- /dev/null +++ b/cmd/migrate/README.md @@ -0,0 +1,73 @@ +# 数据迁移工具 + +用于将测试数据库的数据迁移到生产数据库,避免ID冲突,使用事务确保数据一致性。 + +## 功能特性 + +- ✅ 自动避免ID冲突(使用数据库自增ID) +- ✅ 使用事务确保每个任务和结果数据的一致性 +- ✅ 自动跳过已存在的任务(基于task_uuid) +- ✅ 保留原始时间戳 +- ✅ 处理NULL值 +- ✅ 详细的日志输出和统计信息 + +## 使用方法 + +### 基本用法 + +```bash +# 从dev环境迁移到prod环境 +go run cmd/migrate/main.go -test-env=dev -prod-env=prod + +# 从prod环境迁移到dev环境(测试反向迁移) +go run cmd/migrate/main.go -test-env=prod -prod-env=dev +``` + +### 参数说明 + +- `-test-env`: 测试环境配置文件名(dev/prod),默认值:dev +- `-prod-env`: 生产环境配置文件名(dev/prod),默认值:prod + +### 编译后使用 + +```bash +# 编译 +go build -o migrate cmd/migrate/main.go + +# 运行 +./migrate -test-env=dev -prod-env=prod +``` + +## 工作原理 + +1. **连接数据库**:同时连接测试数据库和生产数据库 +2. **读取数据**:从测试数据库读取所有任务和结果数据(LEFT JOIN) +3. **检查重复**:基于`task_uuid`检查生产数据库中是否已存在 +4. **事务迁移**:为每个任务创建独立事务: + - 创建任务记录(自动生成新ID) + - 如果存在结果数据,创建结果记录(关联新任务ID) + - 提交事务或回滚 +5. **统计报告**:输出迁移统计信息 + +## 注意事项 + +1. **配置文件**:确保`config/config_dev.yaml`和`config/config_prod.yaml`存在且配置正确 +2. **数据库权限**:确保数据库用户有读写权限 +3. **网络连接**:确保能同时连接到两个数据库 +4. **数据备份**:迁移前建议备份生产数据库 +5. **ID冲突**:脚本会自动处理ID冲突,使用数据库自增ID,不会覆盖现有数据 + +## 输出示例 + +``` +从测试数据库读取到 100 条任务记录 +[1/100] 创建任务成功: task_uuid=xxx, 新ID=1001 +[1/100] 创建结果成功: task_id=1001 +[2/100] 跳过已存在的任务: task_uuid=yyy, id=1002 +... +迁移完成统计: + 成功: 95 条 + 跳过: 3 条 + 失败: 2 条 +数据迁移完成! +``` diff --git a/cmd/migrate/main.go b/cmd/migrate/main.go new file mode 100644 index 0000000..77838c6 --- /dev/null +++ b/cmd/migrate/main.go @@ -0,0 +1,255 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "time" + + "gitea.com/bitwsd/document_ai/config" + "gitea.com/bitwsd/document_ai/internal/storage/dao" + "github.com/spf13/viper" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func main() { + // 解析命令行参数 + testEnv := flag.String("test-env", "dev", "测试环境配置 (dev/prod)") + prodEnv := flag.String("prod-env", "prod", "生产环境配置 (dev/prod)") + flag.Parse() + + // 加载测试环境配置 + testConfigPath := fmt.Sprintf("./config/config_%s.yaml", *testEnv) + testConfig, err := loadDatabaseConfig(testConfigPath) + if err != nil { + log.Fatalf("加载测试环境配置失败: %v", err) + } + + // 连接测试数据库 + testDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", + testConfig.Username, testConfig.Password, testConfig.Host, testConfig.Port, testConfig.DBName) + testDB, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + if err != nil { + log.Fatalf("连接测试数据库失败: %v", err) + } + + // 加载生产环境配置 + prodConfigPath := fmt.Sprintf("./config/config_%s.yaml", *prodEnv) + prodConfig, err := loadDatabaseConfig(prodConfigPath) + if err != nil { + log.Fatalf("加载生产环境配置失败: %v", err) + } + + // 连接生产数据库 + prodDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", + prodConfig.Username, prodConfig.Password, prodConfig.Host, prodConfig.Port, prodConfig.DBName) + prodDB, err := gorm.Open(mysql.Open(prodDSN), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + if err != nil { + log.Fatalf("连接生产数据库失败: %v", err) + } + + // 执行迁移 + if err := migrateData(testDB, prodDB); err != nil { + log.Fatalf("数据迁移失败: %v", err) + } + + log.Println("数据迁移完成!") +} + +func migrateData(testDB, prodDB *gorm.DB) error { + _ = context.Background() // 保留以备将来使用 + + // 从测试数据库读取所有任务数据(包含结果) + type TaskWithResult struct { + // Task 字段 + TaskID int64 `gorm:"column:id"` + UserID int64 `gorm:"column:user_id"` + TaskUUID string `gorm:"column:task_uuid"` + FileName string `gorm:"column:file_name"` + FileHash string `gorm:"column:file_hash"` + FileURL string `gorm:"column:file_url"` + TaskType string `gorm:"column:task_type"` + Status int `gorm:"column:status"` + CompletedAt time.Time `gorm:"column:completed_at"` + Remark string `gorm:"column:remark"` + IP string `gorm:"column:ip"` + TaskCreatedAt time.Time `gorm:"column:created_at"` + TaskUpdatedAt time.Time `gorm:"column:updated_at"` + // Result 字段 + ResultID *int64 `gorm:"column:result_id"` + ResultTaskID *int64 `gorm:"column:result_task_id"` + ResultTaskType *string `gorm:"column:result_task_type"` + Latex *string `gorm:"column:latex"` + Markdown *string `gorm:"column:markdown"` + MathML *string `gorm:"column:mathml"` + ResultCreatedAt *time.Time `gorm:"column:result_created_at"` + ResultUpdatedAt *time.Time `gorm:"column:result_updated_at"` + } + + var tasksWithResults []TaskWithResult + query := ` + SELECT + t.id, + t.user_id, + t.task_uuid, + t.file_name, + t.file_hash, + t.file_url, + t.task_type, + t.status, + t.completed_at, + t.remark, + t.ip, + t.created_at, + t.updated_at, + r.id as result_id, + r.task_id as result_task_id, + r.task_type as result_task_type, + r.latex, + r.markdown, + r.mathml, + r.created_at as result_created_at, + r.updated_at as result_updated_at + FROM recognition_tasks t + LEFT JOIN recognition_results r ON t.id = r.task_id + ORDER BY t.id + ` + + if err := testDB.Raw(query).Scan(&tasksWithResults).Error; err != nil { + return fmt.Errorf("读取测试数据失败: %v", err) + } + + log.Printf("从测试数据库读取到 %d 条任务记录", len(tasksWithResults)) + + successCount := 0 + skipCount := 0 + errorCount := 0 + + // 为每个任务使用独立事务,确保单个任务失败不影响其他任务 + for i, item := range tasksWithResults { + // 开始事务 + tx := prodDB.Begin() + + // 检查生产数据库中是否已存在相同的 task_uuid + var existingTask dao.RecognitionTask + err := tx.Where("task_uuid = ?", item.TaskUUID).First(&existingTask).Error + if err == nil { + log.Printf("[%d/%d] 跳过已存在的任务: task_uuid=%s, id=%d", i+1, len(tasksWithResults), item.TaskUUID, existingTask.ID) + tx.Rollback() + skipCount++ + continue + } + if err != gorm.ErrRecordNotFound { + log.Printf("[%d/%d] 检查任务是否存在时出错: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err) + tx.Rollback() + errorCount++ + continue + } + + // 创建新任务(不指定ID,让数据库自动生成) + newTask := &dao.RecognitionTask{ + UserID: item.UserID, + TaskUUID: item.TaskUUID, + FileName: item.FileName, + FileHash: item.FileHash, + FileURL: item.FileURL, + TaskType: dao.TaskType(item.TaskType), + Status: dao.TaskStatus(item.Status), + CompletedAt: item.CompletedAt, + Remark: item.Remark, + IP: item.IP, + } + // 保留原始时间戳 + newTask.CreatedAt = item.TaskCreatedAt + newTask.UpdatedAt = item.TaskUpdatedAt + + if err := tx.Create(newTask).Error; err != nil { + log.Printf("[%d/%d] 创建任务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err) + tx.Rollback() + errorCount++ + continue + } + + log.Printf("[%d/%d] 创建任务成功: task_uuid=%s, 新ID=%d", i+1, len(tasksWithResults), item.TaskUUID, newTask.ID) + + // 如果有结果数据,创建结果记录 + if item.ResultID != nil { + // 处理可能为NULL的字段 + latex := "" + if item.Latex != nil { + latex = *item.Latex + } + markdown := "" + if item.Markdown != nil { + markdown = *item.Markdown + } + mathml := "" + if item.MathML != nil { + mathml = *item.MathML + } + + newResult := dao.RecognitionResult{ + TaskID: newTask.ID, // 使用新任务的ID + TaskType: dao.TaskType(item.TaskType), + Latex: latex, + Markdown: markdown, + MathML: mathml, + } + // 保留原始时间戳 + if item.ResultCreatedAt != nil { + newResult.CreatedAt = *item.ResultCreatedAt + } + if item.ResultUpdatedAt != nil { + newResult.UpdatedAt = *item.ResultUpdatedAt + } + + if err := tx.Create(&newResult).Error; err != nil { + log.Printf("[%d/%d] 创建结果失败: task_id=%d, error=%v", i+1, len(tasksWithResults), newTask.ID, err) + tx.Rollback() // 回滚整个事务(包括任务) + errorCount++ + continue + } + + log.Printf("[%d/%d] 创建结果成功: task_id=%d", i+1, len(tasksWithResults), newTask.ID) + } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + log.Printf("[%d/%d] 提交事务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err) + errorCount++ + continue + } + + successCount++ + } + + log.Printf("迁移完成统计:") + log.Printf(" 成功: %d 条", successCount) + log.Printf(" 跳过: %d 条", skipCount) + log.Printf(" 失败: %d 条", errorCount) + + return nil +} + +// loadDatabaseConfig 从配置文件加载数据库配置 +func loadDatabaseConfig(configPath string) (config.DatabaseConfig, error) { + v := viper.New() + v.SetConfigFile(configPath) + if err := v.ReadInConfig(); err != nil { + return config.DatabaseConfig{}, err + } + + var dbConfig config.DatabaseConfig + if err := v.UnmarshalKey("database", &dbConfig); err != nil { + return config.DatabaseConfig{}, err + } + + return dbConfig, nil +} diff --git a/internal/storage/dao/engine.go b/internal/storage/dao/engine.go index 41db511..9bd1fcd 100644 --- a/internal/storage/dao/engine.go +++ b/internal/storage/dao/engine.go @@ -6,13 +6,16 @@ import ( "gitea.com/bitwsd/document_ai/config" "gorm.io/driver/mysql" "gorm.io/gorm" + "gorm.io/gorm/logger" ) var DB *gorm.DB func InitDB(conf config.DatabaseConfig) { dns := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", conf.Username, conf.Password, conf.Host, conf.Port, conf.DBName) - db, err := gorm.Open(mysql.Open(dns), &gorm.Config{}) + db, err := gorm.Open(mysql.Open(dns), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), // 禁用 GORM 日志输出 + }) if err != nil { panic(err) } diff --git a/main.go b/main.go index fe4a998..6a41973 100644 --- a/main.go +++ b/main.go @@ -28,6 +28,8 @@ func main() { flag.StringVar(&env, "env", "dev", "environment (dev/prod)") flag.Parse() + fmt.Println("env:", env) + configPath := fmt.Sprintf("./config/config_%s.yaml", env) if err := config.Init(configPath); err != nil { panic(err) diff --git a/prod_deploy.sh b/prod_deploy.sh old mode 100644 new mode 100755 index e69de29..37d1236 --- a/prod_deploy.sh +++ b/prod_deploy.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +docker build -t crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest . && docker push crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest + +ssh ecs << 'ENDSSH' +docker stop doc_ai doc_ai_backend 2>/dev/null || true +docker rm doc_ai doc_ai_backend 2>/dev/null || true +docker pull crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest +docker run -d --name doc_ai -p 8024:8024 --restart unless-stopped crpi-8s2ierii2xan4klg.cn-beijing.personal.cr.aliyuncs.com/texpixel/doc_ai_backend:latest -env=prod +ENDSSH \ No newline at end of file