package lxzap import ( "context" "errors" "path/filepath" "runtime" "strings" "time" "go.uber.org/zap" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) type Logger struct { ZapLogger *zap.Logger LogLevel gormlogger.LogLevel SlowThreshold time.Duration SkipCallerLookup bool IgnoreRecordNotFoundError bool } func NewGormZap(zapLogger *zap.Logger) Logger { return Logger{ ZapLogger: zapLogger, LogLevel: gormlogger.Info, SlowThreshold: 100 * time.Millisecond, SkipCallerLookup: false, IgnoreRecordNotFoundError: false, } } func (l Logger) SetAsDefault() { gormlogger.Default = l } func (l Logger) LogMode(level gormlogger.LogLevel) gormlogger.Interface { return Logger{ ZapLogger: l.ZapLogger, SlowThreshold: l.SlowThreshold, LogLevel: level, SkipCallerLookup: l.SkipCallerLookup, IgnoreRecordNotFoundError: l.IgnoreRecordNotFoundError, } } func (l Logger) Info(ctx context.Context, str string, args ...interface{}) { if l.LogLevel < gormlogger.Info { return } l.logger().Sugar().Debugf(str, args...) } func (l Logger) Warn(ctx context.Context, str string, args ...interface{}) { if l.LogLevel < gormlogger.Warn { return } l.logger().Sugar().Warnf(str, args...) } func (l Logger) Error(ctx context.Context, str string, args ...interface{}) { if l.LogLevel < gormlogger.Error { return } l.logger().Sugar().Errorf(str, args...) } func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { spanId := ctx.Value("X-Span-ID") if l.LogLevel <= 0 { return } elapsed := time.Since(begin) switch { case err != nil && l.LogLevel >= gormlogger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)): sql, rows := fc() l.logger().Error("trace", zap.Any("spanId", spanId), zap.Error(err), zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql)) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold && l.LogLevel >= gormlogger.Warn: sql, rows := fc() l.logger().Warn("trace", zap.Any("spanId", spanId), zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql)) case l.LogLevel >= gormlogger.Info: sql, rows := fc() l.logger().Info("trace", zap.Any("spanId", spanId), zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql)) } } var ( gormPackage = filepath.Join("gorm.io", "gorm") zapgormPackage = filepath.Join("moul.io", "zapgorm2") ) func (l Logger) logger() *zap.Logger { for i := 2; i < 15; i++ { _, file, _, ok := runtime.Caller(i) file = filepath.Clean(file) switch { case !ok: case strings.HasSuffix(file, "_test.go"): case strings.Contains(file, gormPackage): case strings.Contains(file, zapgormPackage): default: return l.ZapLogger.WithOptions(zap.AddCallerSkip(i - 1)) } } return l.ZapLogger }