feat: use siliconflow model
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitea.com/bitwsd/document_ai/config"
|
"gitea.com/bitwsd/document_ai/config"
|
||||||
@@ -59,7 +60,7 @@ func GetSignatureURL(ctx *gin.Context) {
|
|||||||
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": "", "repeat": true, "path": task.FileURL}))
|
ctx.JSON(http.StatusOK, common.SuccessResponse(ctx, gin.H{"sign_url": "", "repeat": true, "path": task.FileURL}))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
extend := filepath.Ext(req.FileName)
|
extend := strings.ToLower(filepath.Ext(req.FileName))
|
||||||
if extend == "" {
|
if extend == "" {
|
||||||
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "invalid file name"))
|
ctx.JSON(http.StatusOK, common.ErrorResponse(ctx, common.CodeParamError, "invalid file name"))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ func (s *RecognitionService) processVLFormula(ctx context.Context, taskID int64)
|
|||||||
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
log.Info(ctx, "func", "processVLFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
||||||
|
|
||||||
// 处理具体任务
|
// 处理具体任务
|
||||||
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL); err != nil {
|
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLQwen32BInstruct); err != nil {
|
||||||
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
|
log.Error(ctx, "func", "processVLFormulaQueue", "msg", "处理任务失败", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -288,7 +288,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
|||||||
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
headers := map[string]string{"Content-Type": "application/json", utils.RequestIDHeaderKey: utils.GetRequestIDFromContext(ctx)}
|
||||||
|
|
||||||
// 发送请求时会使用带超时的context
|
// 发送请求时会使用带超时的context
|
||||||
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, s.getURL(ctx), bytes.NewReader(jsonData), headers)
|
resp, err := s.httpClient.RequestWithRetry(ctx, http.MethodPost, "", bytes.NewReader(jsonData), headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
log.Error(ctx, "func", "processFormulaTask", "msg", "请求超时")
|
||||||
@@ -322,7 +322,7 @@ func (s *RecognitionService) processFormulaTask(ctx context.Context, taskID int6
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string) error {
|
func (s *RecognitionService) processVLFormulaTask(ctx context.Context, taskID int64, fileURL string, model string) error {
|
||||||
isSuccess := false
|
isSuccess := false
|
||||||
defer func() {
|
defer func() {
|
||||||
if !isSuccess {
|
if !isSuccess {
|
||||||
@@ -370,7 +370,7 @@ Important instructions:
|
|||||||
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
base64Image := base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
|
||||||
requestBody := formula.VLFormulaRequest{
|
requestBody := formula.VLFormulaRequest{
|
||||||
Model: "Qwen/Qwen2.5-VL-32B-Instruct",
|
Model: model,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
MaxTokens: 512,
|
MaxTokens: 512,
|
||||||
Temperature: 0.1,
|
Temperature: 0.1,
|
||||||
@@ -518,14 +518,10 @@ func (s *RecognitionService) processOneTask(ctx context.Context) {
|
|||||||
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
log.Info(ctx, "func", "processFormulaQueue", "msg", "获取任务成功", "task_id", taskID)
|
||||||
|
|
||||||
// 处理具体任务
|
// 处理具体任务
|
||||||
if err := s.processFormulaTask(ctx, taskID, task.FileURL); err != nil {
|
if err := s.processVLFormulaTask(ctx, taskID, task.FileURL, utils.ModelVLDeepSeekOCR); err != nil {
|
||||||
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
log.Error(ctx, "func", "processFormulaQueue", "msg", "处理任务失败", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
log.Info(ctx, "func", "processFormulaQueue", "msg", "处理任务成功", "task_id", taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RecognitionService) getURL(ctx context.Context) string {
|
|
||||||
return "http://cloud.srcstar.com:8045/formula/predict"
|
|
||||||
}
|
|
||||||
|
|||||||
6
pkg/utils/model.go
Normal file
6
pkg/utils/model.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelVLQwen32BInstruct = "Qwen/Qwen2.5-VL-32B-Instruct"
|
||||||
|
ModelVLDeepSeekOCR = "deepseek-ai/DeepSeek-OCR"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user