diff --git a/internal/service/analytics_service.go b/internal/service/analytics_service.go index f061fe7..d13a2d8 100644 --- a/internal/service/analytics_service.go +++ b/internal/service/analytics_service.go @@ -10,11 +10,9 @@ import ( "gitea.com/texpixel/document_ai/internal/storage/dao" "gitea.com/texpixel/document_ai/pkg/log" "gorm.io/datatypes" - "gorm.io/gorm" ) type AnalyticsService struct { - db *gorm.DB eventDao *dao.AnalyticsEventDao } @@ -54,7 +52,7 @@ func (s *AnalyticsService) TrackEvent(ctx context.Context, req *analytics.TrackE CreatedAt: time.Now(), } - if err := s.eventDao.Create(s.db, event); err != nil { + if err := s.eventDao.Create(dao.DB.WithContext(ctx), event); err != nil { log.Error(ctx, "create analytics event failed", "error", err) return fmt.Errorf("failed to track event") } @@ -105,7 +103,7 @@ func (s *AnalyticsService) BatchTrackEvents(ctx context.Context, req *analytics. return fmt.Errorf("no valid events to track") } - if err := s.eventDao.BatchCreate(s.db, events); err != nil { + if err := s.eventDao.BatchCreate(dao.DB.WithContext(ctx), events); err != nil { log.Error(ctx, "batch create analytics events failed", "error", err) return fmt.Errorf("failed to batch track events") } @@ -123,16 +121,16 @@ func (s *AnalyticsService) QueryEvents(ctx context.Context, req *analytics.Query // 根据不同条件查询 if req.UserID != nil && req.EventName != "" { // 查询用户的指定事件 - events, total, err = s.eventDao.GetUserEventsByName(s.db, *req.UserID, req.EventName, req.Page, req.PageSize) + events, total, err = s.eventDao.GetUserEventsByName(dao.DB.WithContext(ctx), *req.UserID, req.EventName, req.Page, req.PageSize) } else if req.UserID != nil { // 查询用户的所有事件 - events, total, err = s.eventDao.GetUserEvents(s.db, *req.UserID, req.Page, req.PageSize) + events, total, err = s.eventDao.GetUserEvents(dao.DB.WithContext(ctx), *req.UserID, req.Page, req.PageSize) } else if req.EventName != "" { // 查询指定事件 - events, total, err = s.eventDao.GetEventsByName(s.db, req.EventName, req.Page, req.PageSize) + events, total, err = s.eventDao.GetEventsByName(dao.DB.WithContext(ctx), req.EventName, req.Page, req.PageSize) } else if req.StartTime != nil && req.EndTime != nil { // 查询时间范围内的事件 - events, total, err = s.eventDao.GetEventsByTimeRange(s.db, *req.StartTime, *req.EndTime, req.Page, req.PageSize) + events, total, err = s.eventDao.GetEventsByTimeRange(dao.DB.WithContext(ctx), *req.StartTime, *req.EndTime, req.Page, req.PageSize) } else { return nil, fmt.Errorf("invalid query parameters") } @@ -178,7 +176,7 @@ func (s *AnalyticsService) QueryEvents(ctx context.Context, req *analytics.Query // GetEventStats 获取事件统计 func (s *AnalyticsService) GetEventStats(ctx context.Context, req *analytics.EventStatsRequest) (*analytics.EventStatsListResponse, error) { - results, err := s.eventDao.GetEventStats(s.db, req.StartTime, req.EndTime) + results, err := s.eventDao.GetEventStats(dao.DB.WithContext(ctx), req.StartTime, req.EndTime) if err != nil { log.Error(ctx, "get event stats failed", "error", err) return nil, fmt.Errorf("failed to get event stats") @@ -202,7 +200,7 @@ func (s *AnalyticsService) GetEventStats(ctx context.Context, req *analytics.Eve // CountUserEvents 统计用户事件数量 func (s *AnalyticsService) CountUserEvents(ctx context.Context, userID int64) (int64, error) { - count, err := s.eventDao.CountUserEvents(s.db, userID) + count, err := s.eventDao.CountUserEvents(dao.DB.WithContext(ctx), userID) if err != nil { log.Error(ctx, "count user events failed", "error", err, "user_id", userID) return 0, fmt.Errorf("failed to count user events") @@ -212,7 +210,7 @@ func (s *AnalyticsService) CountUserEvents(ctx context.Context, userID int64) (i // CountEventsByName 统计指定事件的数量 func (s *AnalyticsService) CountEventsByName(ctx context.Context, eventName string) (int64, error) { - count, err := s.eventDao.CountEventsByName(s.db, eventName) + count, err := s.eventDao.CountEventsByName(dao.DB.WithContext(ctx), eventName) if err != nil { log.Error(ctx, "count events by name failed", "error", err, "event_name", eventName) return 0, fmt.Errorf("failed to count events") @@ -224,7 +222,7 @@ func (s *AnalyticsService) CountEventsByName(ctx context.Context, eventName stri func (s *AnalyticsService) CleanOldEvents(ctx context.Context, retentionDays int) error { beforeTime := time.Now().AddDate(0, 0, -retentionDays) - if err := s.eventDao.DeleteOldEvents(s.db, beforeTime); err != nil { + if err := s.eventDao.DeleteOldEvents(dao.DB.WithContext(ctx), beforeTime); err != nil { log.Error(ctx, "clean old events failed", "error", err, "before_time", beforeTime) return fmt.Errorf("failed to clean old events") }