2026-01-27 17:40:15 +08:00
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"flag"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"log"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
2026-01-27 21:56:21 +08:00
|
|
|
|
"gitea.com/texpixel/document_ai/config"
|
|
|
|
|
|
"gitea.com/texpixel/document_ai/internal/storage/dao"
|
2026-01-27 17:40:15 +08:00
|
|
|
|
"github.com/spf13/viper"
|
|
|
|
|
|
"gorm.io/driver/mysql"
|
|
|
|
|
|
"gorm.io/gorm"
|
|
|
|
|
|
"gorm.io/gorm/logger"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
|
|
// 解析命令行参数
|
|
|
|
|
|
testEnv := flag.String("test-env", "dev", "测试环境配置 (dev/prod)")
|
|
|
|
|
|
prodEnv := flag.String("prod-env", "prod", "生产环境配置 (dev/prod)")
|
|
|
|
|
|
flag.Parse()
|
|
|
|
|
|
|
|
|
|
|
|
// 加载测试环境配置
|
|
|
|
|
|
testConfigPath := fmt.Sprintf("./config/config_%s.yaml", *testEnv)
|
|
|
|
|
|
testConfig, err := loadDatabaseConfig(testConfigPath)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Fatalf("加载测试环境配置失败: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 连接测试数据库
|
|
|
|
|
|
testDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai",
|
|
|
|
|
|
testConfig.Username, testConfig.Password, testConfig.Host, testConfig.Port, testConfig.DBName)
|
|
|
|
|
|
testDB, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{
|
|
|
|
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
|
|
|
|
})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Fatalf("连接测试数据库失败: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 加载生产环境配置
|
|
|
|
|
|
prodConfigPath := fmt.Sprintf("./config/config_%s.yaml", *prodEnv)
|
|
|
|
|
|
prodConfig, err := loadDatabaseConfig(prodConfigPath)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Fatalf("加载生产环境配置失败: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 连接生产数据库
|
|
|
|
|
|
prodDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Asia%%2FShanghai",
|
|
|
|
|
|
prodConfig.Username, prodConfig.Password, prodConfig.Host, prodConfig.Port, prodConfig.DBName)
|
|
|
|
|
|
prodDB, err := gorm.Open(mysql.Open(prodDSN), &gorm.Config{
|
|
|
|
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
|
|
|
|
})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
log.Fatalf("连接生产数据库失败: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 执行迁移
|
|
|
|
|
|
if err := migrateData(testDB, prodDB); err != nil {
|
|
|
|
|
|
log.Fatalf("数据迁移失败: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Println("数据迁移完成!")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func migrateData(testDB, prodDB *gorm.DB) error {
|
|
|
|
|
|
_ = context.Background() // 保留以备将来使用
|
|
|
|
|
|
|
|
|
|
|
|
// 从测试数据库读取所有任务数据(包含结果)
|
|
|
|
|
|
type TaskWithResult struct {
|
|
|
|
|
|
// Task 字段
|
2026-01-27 21:56:21 +08:00
|
|
|
|
TaskID int64 `gorm:"column:id"`
|
|
|
|
|
|
UserID int64 `gorm:"column:user_id"`
|
|
|
|
|
|
TaskUUID string `gorm:"column:task_uuid"`
|
|
|
|
|
|
FileName string `gorm:"column:file_name"`
|
|
|
|
|
|
FileHash string `gorm:"column:file_hash"`
|
|
|
|
|
|
FileURL string `gorm:"column:file_url"`
|
|
|
|
|
|
TaskType string `gorm:"column:task_type"`
|
|
|
|
|
|
Status int `gorm:"column:status"`
|
|
|
|
|
|
CompletedAt time.Time `gorm:"column:completed_at"`
|
|
|
|
|
|
Remark string `gorm:"column:remark"`
|
|
|
|
|
|
IP string `gorm:"column:ip"`
|
2026-01-27 17:40:15 +08:00
|
|
|
|
TaskCreatedAt time.Time `gorm:"column:created_at"`
|
|
|
|
|
|
TaskUpdatedAt time.Time `gorm:"column:updated_at"`
|
|
|
|
|
|
// Result 字段
|
2026-01-27 21:56:21 +08:00
|
|
|
|
ResultID *int64 `gorm:"column:result_id"`
|
|
|
|
|
|
ResultTaskID *int64 `gorm:"column:result_task_id"`
|
|
|
|
|
|
ResultTaskType *string `gorm:"column:result_task_type"`
|
|
|
|
|
|
Latex *string `gorm:"column:latex"`
|
|
|
|
|
|
Markdown *string `gorm:"column:markdown"`
|
|
|
|
|
|
MathML *string `gorm:"column:mathml"`
|
2026-01-27 17:40:15 +08:00
|
|
|
|
ResultCreatedAt *time.Time `gorm:"column:result_created_at"`
|
|
|
|
|
|
ResultUpdatedAt *time.Time `gorm:"column:result_updated_at"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var tasksWithResults []TaskWithResult
|
|
|
|
|
|
query := `
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
t.id,
|
|
|
|
|
|
t.user_id,
|
|
|
|
|
|
t.task_uuid,
|
|
|
|
|
|
t.file_name,
|
|
|
|
|
|
t.file_hash,
|
|
|
|
|
|
t.file_url,
|
|
|
|
|
|
t.task_type,
|
|
|
|
|
|
t.status,
|
|
|
|
|
|
t.completed_at,
|
|
|
|
|
|
t.remark,
|
|
|
|
|
|
t.ip,
|
|
|
|
|
|
t.created_at,
|
|
|
|
|
|
t.updated_at,
|
|
|
|
|
|
r.id as result_id,
|
|
|
|
|
|
r.task_id as result_task_id,
|
|
|
|
|
|
r.task_type as result_task_type,
|
|
|
|
|
|
r.latex,
|
|
|
|
|
|
r.markdown,
|
|
|
|
|
|
r.mathml,
|
|
|
|
|
|
r.created_at as result_created_at,
|
|
|
|
|
|
r.updated_at as result_updated_at
|
|
|
|
|
|
FROM recognition_tasks t
|
|
|
|
|
|
LEFT JOIN recognition_results r ON t.id = r.task_id
|
|
|
|
|
|
ORDER BY t.id
|
|
|
|
|
|
`
|
|
|
|
|
|
|
|
|
|
|
|
if err := testDB.Raw(query).Scan(&tasksWithResults).Error; err != nil {
|
|
|
|
|
|
return fmt.Errorf("读取测试数据失败: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("从测试数据库读取到 %d 条任务记录", len(tasksWithResults))
|
|
|
|
|
|
|
|
|
|
|
|
successCount := 0
|
|
|
|
|
|
skipCount := 0
|
|
|
|
|
|
errorCount := 0
|
|
|
|
|
|
|
|
|
|
|
|
// 为每个任务使用独立事务,确保单个任务失败不影响其他任务
|
|
|
|
|
|
for i, item := range tasksWithResults {
|
|
|
|
|
|
// 开始事务
|
|
|
|
|
|
tx := prodDB.Begin()
|
|
|
|
|
|
|
|
|
|
|
|
// 检查生产数据库中是否已存在相同的 task_uuid
|
|
|
|
|
|
var existingTask dao.RecognitionTask
|
|
|
|
|
|
err := tx.Where("task_uuid = ?", item.TaskUUID).First(&existingTask).Error
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
log.Printf("[%d/%d] 跳过已存在的任务: task_uuid=%s, id=%d", i+1, len(tasksWithResults), item.TaskUUID, existingTask.ID)
|
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
|
skipCount++
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if err != gorm.ErrRecordNotFound {
|
|
|
|
|
|
log.Printf("[%d/%d] 检查任务是否存在时出错: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
|
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
|
errorCount++
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 创建新任务(不指定ID,让数据库自动生成)
|
|
|
|
|
|
newTask := &dao.RecognitionTask{
|
|
|
|
|
|
UserID: item.UserID,
|
|
|
|
|
|
TaskUUID: item.TaskUUID,
|
|
|
|
|
|
FileName: item.FileName,
|
|
|
|
|
|
FileHash: item.FileHash,
|
|
|
|
|
|
FileURL: item.FileURL,
|
|
|
|
|
|
TaskType: dao.TaskType(item.TaskType),
|
|
|
|
|
|
Status: dao.TaskStatus(item.Status),
|
|
|
|
|
|
CompletedAt: item.CompletedAt,
|
|
|
|
|
|
Remark: item.Remark,
|
|
|
|
|
|
IP: item.IP,
|
|
|
|
|
|
}
|
|
|
|
|
|
// 保留原始时间戳
|
|
|
|
|
|
newTask.CreatedAt = item.TaskCreatedAt
|
|
|
|
|
|
newTask.UpdatedAt = item.TaskUpdatedAt
|
|
|
|
|
|
|
|
|
|
|
|
if err := tx.Create(newTask).Error; err != nil {
|
|
|
|
|
|
log.Printf("[%d/%d] 创建任务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
|
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
|
errorCount++
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("[%d/%d] 创建任务成功: task_uuid=%s, 新ID=%d", i+1, len(tasksWithResults), item.TaskUUID, newTask.ID)
|
|
|
|
|
|
|
|
|
|
|
|
// 如果有结果数据,创建结果记录
|
|
|
|
|
|
if item.ResultID != nil {
|
|
|
|
|
|
// 处理可能为NULL的字段
|
|
|
|
|
|
latex := ""
|
|
|
|
|
|
if item.Latex != nil {
|
|
|
|
|
|
latex = *item.Latex
|
|
|
|
|
|
}
|
|
|
|
|
|
markdown := ""
|
|
|
|
|
|
if item.Markdown != nil {
|
|
|
|
|
|
markdown = *item.Markdown
|
|
|
|
|
|
}
|
|
|
|
|
|
mathml := ""
|
|
|
|
|
|
if item.MathML != nil {
|
|
|
|
|
|
mathml = *item.MathML
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
newResult := dao.RecognitionResult{
|
|
|
|
|
|
TaskID: newTask.ID, // 使用新任务的ID
|
|
|
|
|
|
TaskType: dao.TaskType(item.TaskType),
|
|
|
|
|
|
Latex: latex,
|
|
|
|
|
|
Markdown: markdown,
|
|
|
|
|
|
MathML: mathml,
|
|
|
|
|
|
}
|
|
|
|
|
|
// 保留原始时间戳
|
|
|
|
|
|
if item.ResultCreatedAt != nil {
|
|
|
|
|
|
newResult.CreatedAt = *item.ResultCreatedAt
|
|
|
|
|
|
}
|
|
|
|
|
|
if item.ResultUpdatedAt != nil {
|
|
|
|
|
|
newResult.UpdatedAt = *item.ResultUpdatedAt
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := tx.Create(&newResult).Error; err != nil {
|
|
|
|
|
|
log.Printf("[%d/%d] 创建结果失败: task_id=%d, error=%v", i+1, len(tasksWithResults), newTask.ID, err)
|
|
|
|
|
|
tx.Rollback() // 回滚整个事务(包括任务)
|
|
|
|
|
|
errorCount++
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("[%d/%d] 创建结果成功: task_id=%d", i+1, len(tasksWithResults), newTask.ID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 提交事务
|
|
|
|
|
|
if err := tx.Commit().Error; err != nil {
|
|
|
|
|
|
log.Printf("[%d/%d] 提交事务失败: task_uuid=%s, error=%v", i+1, len(tasksWithResults), item.TaskUUID, err)
|
|
|
|
|
|
errorCount++
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
successCount++
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("迁移完成统计:")
|
|
|
|
|
|
log.Printf(" 成功: %d 条", successCount)
|
|
|
|
|
|
log.Printf(" 跳过: %d 条", skipCount)
|
|
|
|
|
|
log.Printf(" 失败: %d 条", errorCount)
|
|
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// loadDatabaseConfig 从配置文件加载数据库配置
|
|
|
|
|
|
func loadDatabaseConfig(configPath string) (config.DatabaseConfig, error) {
|
|
|
|
|
|
v := viper.New()
|
|
|
|
|
|
v.SetConfigFile(configPath)
|
|
|
|
|
|
if err := v.ReadInConfig(); err != nil {
|
|
|
|
|
|
return config.DatabaseConfig{}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var dbConfig config.DatabaseConfig
|
|
|
|
|
|
if err := v.UnmarshalKey("database", &dbConfig); err != nil {
|
|
|
|
|
|
return config.DatabaseConfig{}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return dbConfig, nil
|
|
|
|
|
|
}
|