feat: use siliconflow model

This commit is contained in:
liuyuanchuang
2025-12-11 19:39:35 +08:00
parent ea0f5d8765
commit 696919611c
3 changed files with 13 additions and 10 deletions

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"gitea.com/bitwsd/document_ai/config" "gitea.com/bitwsd/document_ai/config"
@@ -59,7 +60,7 @@ func GetSignatureURL(ctx *gin.Context) {
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": "", "repeat": true, "path": task.FileURL})) ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": "", "repeat": true, "path": task.FileURL}))
return return
} }
extend := filepath.Ext(req.FileName) extend := strings.ToLower(filepath.Ext(req.FileName))
if extend == "" { if extend == "" {
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "invalid file name")) ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "invalid file name"))
return return

View File

@@ -200,7 +200,7 @@ func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64)
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID) log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务 // 处理具体任务
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL); err != nil { if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLQwen32BInstruct); err != nil {
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err) log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
return return
} }
@@ -288,7 +288,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)} headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
// 发送请求时会使用带超时的context // 发送请求时会使用带超时的context
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, s.getURL(ctx), bytes.NewReader(jsonData), headers) resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "", bytes.NewReader(jsonData), headers)
if err != nil { if err != nil {
if ctx.Err() == context.DeadlineExceeded { if ctx.Err() == context.DeadlineExceeded {
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时") log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
@@ -322,7 +322,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
return nil return nil
} }
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string) error { func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string, model string) error {
isSuccess := false isSuccess := false
defer func() { defer func() {
if !isSuccess { if !isSuccess {
@@ -370,7 +370,7 @@ Important instructions:
base64Image := base64.StdEncoding.EncodeToString(imageData) base64Image := base64.StdEncoding.EncodeToString(imageData)
requestBody := formula.VLFormulaRequest{ requestBody := formula.VLFormulaRequest{
Model: "Qwen/Qwen2.5-VL-32B-Instruct", Model: model,
Stream: false, Stream: false,
MaxTokens: 512, MaxTokens: 512,
Temperature: 0.1, Temperature: 0.1,
@@ -518,14 +518,10 @@ func (s *RecognitionService) processOneTask(ctx context.Context) {
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID) log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
// 处理具体任务 // 处理具体任务
if err := s.processFormulaTask(ctx, taskID, task.FileURL); err != nil { if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLDeepSeekOCR); err != nil {
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err) log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
return return
} }
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID) log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
} }
func (s *RecognitionService) getURL(ctx context.Context) string {
return "http://cloud.srcstar.com:8045/formula/predict"
}

6
pkg/utils/model.go Normal file
View File

@@ -0,0 +1,6 @@
package utils
const (
ModelVLQwen32BInstruct = "Qwen/Qwen2.5-VL-32B-Instruct"
ModelVLDeepSeekOCR = "deepseek-ai/DeepSeek-OCR"
)