Files
doc_ai_backed/cmd/migrate/main.go
2026-01-27 21:56:21 +08:00

256 lines
7.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"flag"
"fmt"
"log"
"time"
"gitea.com/texpixel/document_ai/config"
"gitea.com/texpixel/document_ai/internal/storage/dao"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func main() {
// 解析命令行参数
testEnv := flag.String("test-env", "dev", "测试环境配置 (dev/prod)")
prodEnv := flag.String("prod-env", "prod", "生产环境配置 (dev/prod)")
flag.Parse()
// 加载测试环境配置
testConfigPath := fmt.Sprintf("./config/config_%s.yaml", *testEnv)
testConfig, err := loadDatabaseConfig(testConfigPath)
if err != nil {
log.Fatalf("加载测试环境配置失败: %v", err)
}
// 连接测试数据库
testDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai",
testConfig.Username, testConfig.Password, testConfig.Host, testConfig.Port, testConfig.DBName)
testDB, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
log.Fatalf("连接测试数据库失败: %v", err)
}
// 加载生产环境配置
prodConfigPath := fmt.Sprintf("./config/config_%s.yaml", *prodEnv)
prodConfig, err := loadDatabaseConfig(prodConfigPath)
if err != nil {
log.Fatalf("加载生产环境配置失败: %v", err)
}
// 连接生产数据库
prodDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai",
prodConfig.Username, prodConfig.Password, prodConfig.Host, prodConfig.Port, prodConfig.DBName)
prodDB, err := gorm.Open(mysql.Open(prodDSN), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
log.Fatalf("连接生产数据库失败: %v", err)
}
// 执行迁移
if err := migrateData(testDB, prodDB); err != nil {
log.Fatalf("数据迁移失败: %v", err)
}
log.Println("数据迁移完成!")
}
func migrateData(testDB, prodDB *gorm.DB) error {
_ = context.Background() // 保留以备将来使用
// 从测试数据库读取所有任务数据(包含结果)
type TaskWithResult struct {
// Task 字段
TaskID int64 `gorm:"column:id"`
UserID int64 `gorm:"column:user_id"`
TaskUUID string `gorm:"column:task_uuid"`
FileName string `gorm:"column:file_name"`
FileHash string `gorm:"column:file_hash"`
FileURL string `gorm:"column:file_url"`
TaskType string `gorm:"column:task_type"`
Status int `gorm:"column:status"`
CompletedAt time.Time `gorm:"column:completed_at"`
Remark string `gorm:"column:remark"`
IP string `gorm:"column:ip"`
TaskCreatedAt time.Time `gorm:"column:created_at"`
TaskUpdatedAt time.Time `gorm:"column:updated_at"`
// Result 字段
ResultID *int64 `gorm:"column:result_id"`
ResultTaskID *int64 `gorm:"column:result_task_id"`
ResultTaskType *string `gorm:"column:result_task_type"`
Latex *string `gorm:"column:latex"`
Markdown *string `gorm:"column:markdown"`
MathML *string `gorm:"column:mathml"`
ResultCreatedAt *time.Time `gorm:"column:result_created_at"`
ResultUpdatedAt *time.Time `gorm:"column:result_updated_at"`
}
var tasksWithResults []TaskWithResult
query := `
SELECT
t.id,
t.user_id,
t.task_uuid,
t.file_name,
t.file_hash,
t.file_url,
t.task_type,
t.status,
t.completed_at,
t.remark,
t.ip,
t.created_at,
t.updated_at,
r.id as result_id,
r.task_id as result_task_id,
r.task_type as result_task_type,
r.latex,
r.markdown,
r.mathml,
r.created_at as result_created_at,
r.updated_at as result_updated_at
FROM recognition_tasks t
LEFT JOIN recognition_results r ON t.id = r.task_id
ORDER BY t.id
`
if err := testDB.Raw(query).Scan(&tasksWithResults).Error; err != nil {
return fmt.Errorf("读取测试数据失败: %v", err)
}
log.Printf("从测试数据库读取到 %d 条任务记录", len(tasksWithResults))
successCount := 0
skipCount := 0
errorCount := 0
// 为每个任务使用独立事务,确保单个任务失败不影响其他任务
for i, item := range tasksWithResults {
// 开始事务
tx := prodDB.Begin()
// 检查生产数据库中是否已存在相同的 task_uuid
var existingTask dao.RecognitionTask
err := tx.Where("task_uuid = ?", item.TaskUUID).First(&existingTask).Error
if err == nil {
log.Printf("[%d/%d] 跳过已存在的任务: task_uuid=%s, id=%d", i+1, len(tasksWithResults), item.TaskUUID, existingTask.ID)
tx.Rollback()
skipCount++
continue
}
if err != gorm.ErrRecordNotFound {
log.Printf("[%d/%d] 检查任务是否存在时出错: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
tx.Rollback()
errorCount++
continue
}
// 创建新任务不指定ID让数据库自动生成
newTask := &dao.RecognitionTask{
UserID: item.UserID,
TaskUUID: item.TaskUUID,
FileName: item.FileName,
FileHash: item.FileHash,
FileURL: item.FileURL,
TaskType: dao.TaskType(item.TaskType),
Status: dao.TaskStatus(item.Status),
CompletedAt: item.CompletedAt,
Remark: item.Remark,
IP: item.IP,
}
// 保留原始时间戳
newTask.CreatedAt = item.TaskCreatedAt
newTask.UpdatedAt = item.TaskUpdatedAt
if err := tx.Create(newTask).Error; err != nil {
log.Printf("[%d/%d] 创建任务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
tx.Rollback()
errorCount++
continue
}
log.Printf("[%d/%d] 创建任务成功: task_uuid=%s, 新ID=%d", i+1, len(tasksWithResults), item.TaskUUID, newTask.ID)
// 如果有结果数据,创建结果记录
if item.ResultID != nil {
// 处理可能为NULL的字段
latex := ""
if item.Latex != nil {
latex = *item.Latex
}
markdown := ""
if item.Markdown != nil {
markdown = *item.Markdown
}
mathml := ""
if item.MathML != nil {
mathml = *item.MathML
}
newResult := dao.RecognitionResult{
TaskID: newTask.ID, // 使用新任务的ID
TaskType: dao.TaskType(item.TaskType),
Latex: latex,
Markdown: markdown,
MathML: mathml,
}
// 保留原始时间戳
if item.ResultCreatedAt != nil {
newResult.CreatedAt = *item.ResultCreatedAt
}
if item.ResultUpdatedAt != nil {
newResult.UpdatedAt = *item.ResultUpdatedAt
}
if err := tx.Create(&newResult).Error; err != nil {
log.Printf("[%d/%d] 创建结果失败: task_id=%d, error=%v", i+1, len(tasksWithResults), newTask.ID, err)
tx.Rollback() // 回滚整个事务(包括任务)
errorCount++
continue
}
log.Printf("[%d/%d] 创建结果成功: task_id=%d", i+1, len(tasksWithResults), newTask.ID)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
log.Printf("[%d/%d] 提交事务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
errorCount++
continue
}
successCount++
}
log.Printf("迁移完成统计:")
log.Printf(" 成功: %d 条", successCount)
log.Printf(" 跳过: %d 条", skipCount)
log.Printf(" 失败: %d 条", errorCount)
return nil
}
// loadDatabaseConfig 从配置文件加载数据库配置
func loadDatabaseConfig(configPath string) (config.DatabaseConfig, error) {
v := viper.New()
v.SetConfigFile(configPath)
if err := v.ReadInConfig(); err != nil {
return config.DatabaseConfig{}, err
}
var dbConfig config.DatabaseConfig
if err := v.UnmarshalKey("database", &dbConfig); err != nil {
return config.DatabaseConfig{}, err
}
return dbConfig, nil
}