package main import ( "context" "flag" "fmt" "log" "time" "gitea.com/bitwsd/document_ai/config" "gitea.com/bitwsd/document_ai/internal/storage/dao" "github.com/spf13/viper" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) func main() { // 解析命令行参数 testEnv := flag.String("test-env", "dev", "测试环境配置 (dev/prod)") prodEnv := flag.String("prod-env", "prod", "生产环境配置 (dev/prod)") flag.Parse() // 加载测试环境配置 testConfigPath := fmt.Sprintf("./config/config_%s.yaml", *testEnv) testConfig, err := loadDatabaseConfig(testConfigPath) if err != nil { log.Fatalf("加载测试环境配置失败: %v", err) } // 连接测试数据库 testDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", testConfig.Username, testConfig.Password, testConfig.Host, testConfig.Port, testConfig.DBName) testDB, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), }) if err != nil { log.Fatalf("连接测试数据库失败: %v", err) } // 加载生产环境配置 prodConfigPath := fmt.Sprintf("./config/config_%s.yaml", *prodEnv) prodConfig, err := loadDatabaseConfig(prodConfigPath) if err != nil { log.Fatalf("加载生产环境配置失败: %v", err) } // 连接生产数据库 prodDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai", prodConfig.Username, prodConfig.Password, prodConfig.Host, prodConfig.Port, prodConfig.DBName) prodDB, err := gorm.Open(mysql.Open(prodDSN), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), }) if err != nil { log.Fatalf("连接生产数据库失败: %v", err) } // 执行迁移 if err := migrateData(testDB, prodDB); err != nil { log.Fatalf("数据迁移失败: %v", err) } log.Println("数据迁移完成!") } func migrateData(testDB, prodDB *gorm.DB) error { _ = context.Background() // 保留以备将来使用 // 从测试数据库读取所有任务数据(包含结果) type TaskWithResult struct { // Task 字段 TaskID int64 `gorm:"column:id"` UserID int64 `gorm:"column:user_id"` TaskUUID string `gorm:"column:task_uuid"` FileName string `gorm:"column:file_name"` FileHash string `gorm:"column:file_hash"` FileURL string `gorm:"column:file_url"` TaskType string `gorm:"column:task_type"` Status int `gorm:"column:status"` CompletedAt time.Time `gorm:"column:completed_at"` Remark string `gorm:"column:remark"` IP string `gorm:"column:ip"` TaskCreatedAt time.Time `gorm:"column:created_at"` TaskUpdatedAt time.Time `gorm:"column:updated_at"` // Result 字段 ResultID *int64 `gorm:"column:result_id"` ResultTaskID *int64 `gorm:"column:result_task_id"` ResultTaskType *string `gorm:"column:result_task_type"` Latex *string `gorm:"column:latex"` Markdown *string `gorm:"column:markdown"` MathML *string `gorm:"column:mathml"` ResultCreatedAt *time.Time `gorm:"column:result_created_at"` ResultUpdatedAt *time.Time `gorm:"column:result_updated_at"` } var tasksWithResults []TaskWithResult query := ` SELECT t.id, t.user_id, t.task_uuid, t.file_name, t.file_hash, t.file_url, t.task_type, t.status, t.completed_at, t.remark, t.ip, t.created_at, t.updated_at, r.id as result_id, r.task_id as result_task_id, r.task_type as result_task_type, r.latex, r.markdown, r.mathml, r.created_at as result_created_at, r.updated_at as result_updated_at FROM recognition_tasks t LEFT JOIN recognition_results r ON t.id = r.task_id ORDER BY t.id ` if err := testDB.Raw(query).Scan(&tasksWithResults).Error; err != nil { return fmt.Errorf("读取测试数据失败: %v", err) } log.Printf("从测试数据库读取到 %d 条任务记录", len(tasksWithResults)) successCount := 0 skipCount := 0 errorCount := 0 // 为每个任务使用独立事务,确保单个任务失败不影响其他任务 for i, item := range tasksWithResults { // 开始事务 tx := prodDB.Begin() // 检查生产数据库中是否已存在相同的 task_uuid var existingTask dao.RecognitionTask err := tx.Where("task_uuid = ?", item.TaskUUID).First(&existingTask).Error if err == nil { log.Printf("[%d/%d] 跳过已存在的任务: task_uuid=%s, id=%d", i+1, len(tasksWithResults), item.TaskUUID, existingTask.ID) tx.Rollback() skipCount++ continue } if err != gorm.ErrRecordNotFound { log.Printf("[%d/%d] 检查任务是否存在时出错: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err) tx.Rollback() errorCount++ continue } // 创建新任务(不指定ID,让数据库自动生成) newTask := &dao.RecognitionTask{ UserID: item.UserID, TaskUUID: item.TaskUUID, FileName: item.FileName, FileHash: item.FileHash, FileURL: item.FileURL, TaskType: dao.TaskType(item.TaskType), Status: dao.TaskStatus(item.Status), CompletedAt: item.CompletedAt, Remark: item.Remark, IP: item.IP, } // 保留原始时间戳 newTask.CreatedAt = item.TaskCreatedAt newTask.UpdatedAt = item.TaskUpdatedAt if err := tx.Create(newTask).Error; err != nil { log.Printf("[%d/%d] 创建任务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err) tx.Rollback() errorCount++ continue } log.Printf("[%d/%d] 创建任务成功: task_uuid=%s, 新ID=%d", i+1, len(tasksWithResults), item.TaskUUID, newTask.ID) // 如果有结果数据,创建结果记录 if item.ResultID != nil { // 处理可能为NULL的字段 latex := "" if item.Latex != nil { latex = *item.Latex } markdown := "" if item.Markdown != nil { markdown = *item.Markdown } mathml := "" if item.MathML != nil { mathml = *item.MathML } newResult := dao.RecognitionResult{ TaskID: newTask.ID, // 使用新任务的ID TaskType: dao.TaskType(item.TaskType), Latex: latex, Markdown: markdown, MathML: mathml, } // 保留原始时间戳 if item.ResultCreatedAt != nil { newResult.CreatedAt = *item.ResultCreatedAt } if item.ResultUpdatedAt != nil { newResult.UpdatedAt = *item.ResultUpdatedAt } if err := tx.Create(&newResult).Error; err != nil { log.Printf("[%d/%d] 创建结果失败: task_id=%d, error=%v", i+1, len(tasksWithResults), newTask.ID, err) tx.Rollback() // 回滚整个事务(包括任务) errorCount++ continue } log.Printf("[%d/%d] 创建结果成功: task_id=%d", i+1, len(tasksWithResults), newTask.ID) } // 提交事务 if err := tx.Commit().Error; err != nil { log.Printf("[%d/%d] 提交事务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err) errorCount++ continue } successCount++ } log.Printf("迁移完成统计:") log.Printf(" 成功: %d 条", successCount) log.Printf(" 跳过: %d 条", skipCount) log.Printf(" 失败: %d 条", errorCount) return nil } // loadDatabaseConfig 从配置文件加载数据库配置 func loadDatabaseConfig(configPath string) (config.DatabaseConfig, error) { v := viper.New() v.SetConfigFile(configPath) if err := v.ReadInConfig(); err != nil { return config.DatabaseConfig{}, err } var dbConfig config.DatabaseConfig if err := v.UnmarshalKey("database", &dbConfig); err != nil { return config.DatabaseConfig{}, err } return dbConfig, nil }