From 21bfcf3419ddc71625a9cf2e529eadc72533d89f Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:37:02 +0100 Subject: [PATCH 1/2] fix: panic on nil error in handoffWorkerManager closeConnFromRequest --- .../maintnotifications/logs/log_messages.go | 26 +++++++++++++------ maintnotifications/handoff_worker.go | 6 ++--- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/internal/maintnotifications/logs/log_messages.go b/internal/maintnotifications/logs/log_messages.go index 34cb1692d9..44ae006096 100644 --- a/internal/maintnotifications/logs/log_messages.go +++ b/internal/maintnotifications/logs/log_messages.go @@ -288,19 +288,29 @@ func OperationNotTracked(connID uint64, seqID int64) string { // Connection pool functions func RemovingConnectionFromPool(connID uint64, reason error) string { - message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason) - return appendJSONIfDebug(message, map[string]interface{}{ + metadata := map[string]interface{}{ "connID": connID, - "reason": reason.Error(), - }) + "reason": "unknown", // this will be overwritten if reason is not nil + } + if reason != nil { + metadata["reason"] = reason.Error() + } + + message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason) + return appendJSONIfDebug(message, metadata) } func NoPoolProvidedCannotRemove(connID uint64, reason error) string { - message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason) - return appendJSONIfDebug(message, map[string]interface{}{ + metadata := map[string]interface{}{ "connID": connID, - "reason": reason.Error(), - }) + "reason": "unknown", // this will be overwritten if reason is not nil + } + if reason != nil { + metadata["reason"] = reason.Error() + } + + message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason) + return appendJSONIfDebug(message, metadata) } // Circuit breaker functions diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 5b60e39b59..53f28f49c8 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -501,9 +501,9 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) } } else { - err := conn.Close() // Close the connection if no pool provided - if err != nil { - internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err) + errClose := conn.Close() // Close the connection if no pool provided + if errClose != nil { + internal.Logger.Printf(ctx, "redis: failed to close connection: %v", errClose) } if internal.LogLevel.WarnOrAbove() { internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) From f92e566b13a5db9fab6d2bcf6c00e358970928d6 Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:59:40 +0200 Subject: [PATCH 2/2] feat: add optional logger wherever possible This commit introduces an optional logger parameter to various structs. This enhancement allows users to provide custom logging implementations. An issue will remain for direct calls to internal.Logger.Printf, as the call depth cannot be adjusted there without changing the function signature. While most of the changes comes from ccoVeille's work, Nedyalko Dyakov made some changes in other files to ensure the legacy logger is replaced with the new logger. Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> --- commands.go | 13 +- export_test.go | 4 +- internal/auth/streaming/pool_hook.go | 3 +- internal/log.go | 6 +- internal/pool/conn.go | 6 +- internal/pool/pool.go | 35 ++- logging/custom.go | 155 ++++++++++++ logging/custom_after_go_121_test.go | 227 ++++++++++++++++++ logging/legacy.go | 98 ++++++++ logging/level_after_go_121.go | 48 ++++ logging/level_before_go_121.go | 16 ++ maintnotifications/circuit_breaker.go | 36 +-- maintnotifications/config.go | 21 +- maintnotifications/handoff_worker.go | 69 +++--- maintnotifications/manager.go | 25 +- maintnotifications/pool_hook.go | 13 +- .../push_notification_handler.go | 80 +++--- options.go | 8 + osscluster.go | 35 ++- pubsub.go | 36 ++- redis.go | 43 ++-- redis_after_go_121_test.go | 107 +++++++++ ring.go | 14 +- sentinel.go | 42 +++- 24 files changed, 944 insertions(+), 196 deletions(-) create mode 100644 logging/custom.go create mode 100644 logging/custom_after_go_121_test.go create mode 100644 logging/legacy.go create mode 100644 logging/level_after_go_121.go create mode 100644 logging/level_before_go_121.go create mode 100644 redis_after_go_121_test.go diff --git a/commands.go b/commands.go index daee5505e1..da6d225d04 100644 --- a/commands.go +++ b/commands.go @@ -13,6 +13,7 @@ import ( "time" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/logging" ) // KeepTTL is a Redis KEEPTTL option to keep existing TTL, it requires your redis-server version >= 6.0, @@ -28,11 +29,7 @@ func usePrecise(dur time.Duration) bool { func formatMs(ctx context.Context, dur time.Duration) int64 { if dur > 0 && dur < time.Millisecond { - internal.Logger.Printf( - ctx, - "specified duration is %s, but minimal supported value is %s - truncating to 1ms", - dur, time.Millisecond, - ) + logging.LoggerWithLevel().Infof(ctx, "specified duration is %s, but minimal supported value is %s - truncating to 1ms", dur, time.Millisecond) return 1 } return int64(dur / time.Millisecond) @@ -40,11 +37,7 @@ func formatMs(ctx context.Context, dur time.Duration) int64 { func formatSec(ctx context.Context, dur time.Duration) int64 { if dur > 0 && dur < time.Second { - internal.Logger.Printf( - ctx, - "specified duration is %s, but minimal supported value is %s - truncating to 1s", - dur, time.Second, - ) + logging.LoggerWithLevel().Infof(ctx, "specified duration is %s, but minimal supported value is %s - truncating to 1s", dur, time.Second) return 1 } return int64(dur / time.Second) diff --git a/export_test.go b/export_test.go index 97b6179a44..7d2d1833be 100644 --- a/export_test.go +++ b/export_test.go @@ -6,9 +6,9 @@ import ( "net" "strings" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) func (c *baseClient) Pool() pool.Pooler { @@ -87,7 +87,7 @@ func (c *clusterState) IsConsistent(ctx context.Context) bool { func GetSlavesAddrByName(ctx context.Context, c *SentinelClient, name string) []string { addrs, err := c.Replicas(ctx, name).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + logging.LoggerWithLevel().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", name, err) return []string{} } diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index aaf4f6099f..29c4210338 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -7,6 +7,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) // ReAuthPoolHook is a pool hook that manages background re-authentication of connections @@ -166,7 +167,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, defer func() { if rec := recover(); rec != nil { // once again - safety first - internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec) + logging.LoggerWithLevel().Errorf(context.Background(), "panic in reauth worker: %v", rec) } r.scheduledLock.Lock() delete(r.scheduledReAuth, connID) diff --git a/internal/log.go b/internal/log.go index 0bfffc311b..d462655d26 100644 --- a/internal/log.go +++ b/internal/log.go @@ -7,9 +7,6 @@ import ( "os" ) -// TODO (ned): Revisit logging -// Add more standardized approach with log levels and configurability - type Logging interface { Printf(ctx context.Context, format string, v ...interface{}) } @@ -19,7 +16,7 @@ type DefaultLogger struct { } func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) { - _ = l.log.Output(2, fmt.Sprintf(format, v...)) + _ = l.log.Output(4, fmt.Sprintf(format, v...)) } func NewDefaultLogger() Logging { @@ -38,6 +35,7 @@ var LogLevel LogLevelT = LogLevelError type LogLevelT int // Log level constants for the entire go-redis library +// TODO(ndyakov): In v10 align those levels with slog.Level const ( LogLevelError LogLevelT = iota // 0 - errors only LogLevelWarn // 1 - warnings and errors diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 95d83bfde4..4d96e82626 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -11,9 +11,9 @@ import ( "sync/atomic" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/logging" ) var noDeadline = time.Time{} @@ -508,7 +508,7 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati // Deadline has passed, clear relaxed timeouts atomically and use normal timeout newCount := cn.relaxedCounter.Add(-1) if newCount <= 0 { - internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID())) + logging.LoggerWithLevel().Infof(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID())) cn.clearRelaxedTimeout() } return normalTimeout @@ -542,7 +542,7 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat // Deadline has passed, clear relaxed timeouts atomically and use normal timeout newCount := cn.relaxedCounter.Add(-1) if newCount <= 0 { - internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID())) + logging.LoggerWithLevel().Infof(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID())) cn.clearRelaxedTimeout() } return normalTimeout diff --git a/internal/pool/pool.go b/internal/pool/pool.go index d757d1f4fa..e1e29c2313 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" ) var ( @@ -119,6 +120,9 @@ type Options struct { // DialerRetryTimeout is the backoff duration between retry attempts. // Default: 100ms DialerRetryTimeout time.Duration + + // Optional logger for connection pool operations. + Logger logging.LoggerWithLevelI } type lastDialErrorWrap struct { @@ -254,7 +258,7 @@ func (p *ConnPool) checkMinIdleConns() { p.idleConnsLen.Add(-1) p.freeTurn() - internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) + p.logger().Errorf(context.Background(), "addIdleConn panic: %+v", err) } }() @@ -416,7 +420,7 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { return cn, nil } - internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", attempt, lastErr) + p.logger().Errorf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", attempt, lastErr) // All retries failed - handle error tracking p.setLastDialError(lastErr) if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { @@ -510,10 +514,10 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) if err != nil || !acceptConn { if err != nil { - internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) + p.logger().Errorf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) _ = p.CloseConn(cn) } else { - internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + p.logger().Errorf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) // Return connection to pool without freeing the turn that this Get() call holds. // We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn. p.putConnWithoutTurn(ctx, cn) @@ -541,7 +545,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { // this should not happen with a new connection, but we handle it gracefully if err != nil || !acceptConn { // Failed to process connection, discard it - internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) + p.logger().Errorf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) _ = p.CloseConn(newcn) return nil, err } @@ -583,7 +587,7 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { if !freeTurnCalled { p.freeTurn() } - internal.Logger.Printf(context.Background(), "queuedNewConn panic: %+v", err) + p.logger().Errorf(ctx, "queuedNewConn panic: %+v", err) } }() @@ -736,7 +740,7 @@ func (p *ConnPool) popIdle() (*Conn, error) { // If we exhausted all attempts without finding a usable connection, return nil if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() { - internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) + p.logger().Errorf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) return nil, nil } @@ -765,7 +769,7 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { // Peek at the reply type to check if it's a push notification if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { // Not a push notification or error peeking, remove connection - internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + p.logger().Errorf(ctx, "Conn has unread data (not push notification), removing it") p.removeConnInternal(ctx, cn, err, freeTurn) return } @@ -778,7 +782,7 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { if hookManager != nil { shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) if err != nil { - internal.Logger.Printf(ctx, "Connection hook error: %v", err) + p.logger().Errorf(ctx, "Connection hook error: %v", err) p.removeConnInternal(ctx, cn, err, freeTurn) return } @@ -811,12 +815,12 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { case StateUnusable: // expected state, don't log it case StateClosed: - internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, closing it", cn.GetID(), currentState) + p.logger().Errorf(ctx, "Unexpected conn[%d] state changed by hook to %v, closing it", cn.GetID(), currentState) shouldCloseConn = true p.removeConnWithLock(cn) default: // Pool as-is - internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, pooling as-is", cn.GetID(), currentState) + p.logger().Warnf(ctx, "Unexpected conn[%d] state changed by hook to %v, pooling as-is", cn.GetID(), currentState) } } @@ -1030,7 +1034,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, nowNs int64) bool { if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { // For RESP3 connections with push notifications, we allow some buffered data // The client will process these notifications before using the connection - internal.Logger.Printf( + p.logger().Infof( context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID(), @@ -1053,3 +1057,10 @@ func (p *ConnPool) isHealthyConn(cn *Conn, nowNs int64) bool { cn.SetUsedAtNs(nowNs) return true } + +func (p *ConnPool) logger() *logging.LoggerWrapper { + if p.cfg != nil && p.cfg.Logger != nil { + return logging.NewLoggerWrapper(p.cfg.Logger) + } + return logging.LoggerWithLevel() +} diff --git a/logging/custom.go b/logging/custom.go new file mode 100644 index 0000000000..3f9c098d8a --- /dev/null +++ b/logging/custom.go @@ -0,0 +1,155 @@ +package logging + +import ( + "context" + "fmt" +) + +// LoggerWrapper is a slog.Logger wrapper that implements the Lgr interface. +type LoggerWrapper struct { + logger LoggerWithLevelI + loggerLevel *LogLevelT + printfAdapter PrintfAdapter +} + +func NewLoggerWrapper(logger LoggerWithLevelI, opts ...LoggerWrapperOption) *LoggerWrapper { + cl, ok := logger.(*LoggerWrapper) + if !ok { + cl = &LoggerWrapper{ + logger: logger, + } + } + for _, opt := range opts { + opt(cl) + } + return cl +} + +type LoggerWrapperOption func(*LoggerWrapper) + +func WithPrintfAdapter(adapter PrintfAdapter) LoggerWrapperOption { + return func(cl *LoggerWrapper) { + cl.printfAdapter = adapter + } +} + +func WithLoggerLevel(level LogLevelT) LoggerWrapperOption { + return func(cl *LoggerWrapper) { + cl.loggerLevel = &level + } +} + +// PrintfAdapter is a function that converts Printf-style log messages into structured log messages. +// It can be used to extract key-value pairs from the formatted message. +type PrintfAdapter func(ctx context.Context, format string, v ...any) (context.Context, string, []any) + +// ErrorContext is a structured error level logging method with context and arguments. +func (cl *LoggerWrapper) ErrorContext(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Errorf(ctx, msg, args...) + return + } + cl.logger.ErrorContext(ctx, msg, args...) +} + +func (cl *LoggerWrapper) Errorf(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Errorf(ctx, format, v...) + return + } + ctx, msg, args := cl.printfToStructured(ctx, format, v...) + cl.logger.ErrorContext(ctx, msg, args...) +} + +// WarnContext is a structured warning level logging method with context and arguments. +func (cl *LoggerWrapper) WarnContext(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Warnf(ctx, msg, args...) + return + } + cl.logger.WarnContext(ctx, msg, args...) +} + +func (cl *LoggerWrapper) Warnf(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Warnf(ctx, format, v...) + return + } + ctx, msg, args := cl.printfToStructured(ctx, format, v...) + cl.logger.WarnContext(ctx, msg, args...) +} + +// InfoContext is a structured info level logging method with context and arguments. +func (cl *LoggerWrapper) InfoContext(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Infof(ctx, msg, args...) + return + } + cl.logger.InfoContext(ctx, msg, args...) +} + +// DebugContext is a structured debug level logging method with context and arguments. +func (cl *LoggerWrapper) DebugContext(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Debugf(ctx, msg, args...) + return + } + cl.logger.DebugContext(ctx, msg, args...) +} + +func (cl *LoggerWrapper) Infof(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Infof(ctx, format, v...) + return + } + + ctx, msg, args := cl.printfToStructured(ctx, format, v...) + cl.logger.InfoContext(ctx, msg, args...) +} + +func (cl *LoggerWrapper) Debugf(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Debugf(ctx, format, v...) + return + } + ctx, msg, args := cl.printfToStructured(ctx, format, v...) + cl.logger.DebugContext(ctx, msg, args...) +} + +func (cl *LoggerWrapper) printfToStructured(ctx context.Context, format string, v ...any) (context.Context, string, []any) { + if cl != nil && cl.printfAdapter != nil { + return cl.printfAdapter(ctx, format, v...) + } + return ctx, fmt.Sprintf(format, v...), nil +} + +func (cl *LoggerWrapper) Enabled(ctx context.Context, level LogLevelT) bool { + if cl != nil && cl.loggerLevel != nil { + return level >= *cl.loggerLevel + } + + // delegate to a method that use go:build tags to determine how to check level + return isLevelEnabled(ctx, cl.logger, level) +} + +// LoggerWithLevelI is a logger interface with leveled logging methods. +// +// [slog.Logger] from the standard library satisfies this interface. +type LoggerWithLevelI interface { + // InfoContext logs an info level message + InfoContext(ctx context.Context, format string, v ...any) + + // WarnContext logs a warning level message + WarnContext(ctx context.Context, format string, v ...any) + + // Debugf logs a debug level message + DebugContext(ctx context.Context, format string, v ...any) + + // Errorf logs an error level message + ErrorContext(ctx context.Context, format string, v ...any) + + //TODO(ndyakov): add Enabled when Go 1.21 is min supported +} + +// Verify that LoggerWrapper implements LoggerWithLevelI +var _ LoggerWithLevelI = (*LoggerWrapper)(nil) diff --git a/logging/custom_after_go_121_test.go b/logging/custom_after_go_121_test.go new file mode 100644 index 0000000000..ff66da55e0 --- /dev/null +++ b/logging/custom_after_go_121_test.go @@ -0,0 +1,227 @@ +//go:build go1.21 + +package logging + +// The purpose of this file is to provide tests for [LoggerWrapper] with [slog.Logger]. +// These tests require Go 1.21 or above because they use [slog.Logger] which was not available +// before Go 1.21. + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "os" + "testing" +) + +// validation that [slog.Logger] implements [LoggerWithLevelI] +var _ LoggerWithLevelI = &slog.Logger{} + +var _ *LoggerWrapper = NewLoggerWrapper(&slog.Logger{}) + +func TestLoggerWrapper_slog(t *testing.T) { + + ctx := context.Background() + + t.Run("Debug", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.DebugContext(ctx, "debug message", "foo", "bar") + + checkLog(t, buf, map[string]any{ + "level": "DEBUG", + "msg": "debug message", + "foo": "bar", + }) + }) + + t.Run("Info", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.InfoContext(ctx, "info message", "foo", "bar") + + checkLog(t, buf, map[string]any{ + "level": "INFO", + "msg": "info message", + "foo": "bar", + }) + }) + t.Run("Warn", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.WarnContext(ctx, "warn message", "foo", "bar") + + checkLog(t, buf, map[string]any{ + "level": "WARN", + "msg": "warn message", + "foo": "bar", + }) + }) + t.Run("Error", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.ErrorContext(ctx, "error message", "foo", "bar") + + checkLog(t, buf, map[string]any{ + "level": "ERROR", + "msg": "error message", + "foo": "bar", + }) + }) + + t.Run("Errorf", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.Errorf(ctx, "%d is the answer to %s", 42, "everything") + + checkLog(t, buf, map[string]any{ + "level": "ERROR", + "msg": "42 is the answer to everything", + }) + }) + + t.Run("Infof", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.Infof(ctx, "%d is the answer to %s", 42, "everything") + + checkLog(t, buf, map[string]any{ + "level": "INFO", + "msg": "42 is the answer to everything", + }) + }) + + t.Run("Warnf", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.Warnf(ctx, "%d is the answer to %s", 42, "everything") + + checkLog(t, buf, map[string]any{ + "level": "WARN", + "msg": "42 is the answer to everything", + }) + }) + + t.Run("Debugf", func(t *testing.T) { + wrapped, buf := NewTestLogger(t) + wrapped.Debugf(ctx, "%d is the answer to %s", 42, "everything") + + checkLog(t, buf, map[string]any{ + "level": "DEBUG", + "msg": "42 is the answer to everything", + }) + }) + + t.Run("Insufficient loglevel: default", func(t *testing.T) { + wrapped, buf := NewTestLoggerWithLevel(t, nil) + + wrapped.DebugContext(ctx, "debug message", "foo", "bar") + if buf.Len() != 0 { + t.Errorf("expected no log message, got %s", buf.String()) + } + }) + + t.Run("Insufficient loglevel: error", func(t *testing.T) { + wrapped, buf := NewTestLoggerWithLevel(t, slog.LevelError) + wrapped.DebugContext(ctx, "debug message", "foo", "bar") + wrapped.WarnContext(ctx, "warn message", "foo", "bar") + wrapped.InfoContext(ctx, "info message", "foo", "bar") + if buf.Len() != 0 { + t.Errorf("expected no log message, got %s", buf.String()) + } + wrapped.ErrorContext(ctx, "error message", "foo", "bar") + + checkLog(t, buf, map[string]any{ + "level": "ERROR", + "msg": "error message", + "foo": "bar", + }) + }) + + t.Run("Enabled loglevel: debug", func(t *testing.T) { + + wrapped, buf := NewTestLoggerWithLevel(t, slog.LevelInfo) + if wrapped.Enabled(ctx, LogLevelDebug) { + t.Errorf("expected debug level to be disabled") + } + if !wrapped.Enabled(ctx, LogLevelInfo) { + t.Errorf("expected info level to be enabled") + } + if !wrapped.Enabled(ctx, LogLevelWarn) { + t.Errorf("expected warn level to be enabled") + } + + if !wrapped.Enabled(ctx, LogLevelError) { + t.Errorf("expected error level to be enabled") + } + + wrapped.DebugContext(ctx, "debug message", "foo", "bar") + if buf.Len() != 0 { + t.Errorf("expected no log message, got %s", buf.String()) + } + + wrapped.InfoContext(ctx, "info message", "foo", "bar") + checkLog(t, buf, map[string]any{ + "level": "INFO", + "msg": "info message", + "foo": "bar", + }) + }) +} + +func NewTestLogger(t *testing.T) (*LoggerWrapper, *bytes.Buffer) { + return NewTestLoggerWithLevel(t, slog.LevelDebug) +} + +func NewTestLoggerWithLevel(t *testing.T, level slog.Leveler) (*LoggerWrapper, *bytes.Buffer) { + var buf bytes.Buffer + wrapped := NewLoggerWrapper(slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: level}))) + return wrapped, &buf +} + +func checkLog(t *testing.T, buf *bytes.Buffer, attrs map[string]any) { + t.Helper() + + res := buf.Bytes() + + var m map[string]any + err := json.Unmarshal(res, &m) + if err != nil { + t.Fatalf("failed to unmarshal log message: %v", err) + } + + delete(m, "time") // remove time for testing purposes + + if len(m) != len(attrs) { + t.Errorf("expected %d attributes, got %d", len(attrs), len(m)) + } + + for k, expected := range attrs { + v, ok := m[k] + if !ok { + t.Errorf("expected log to have key %s", k) + continue + } + if v != expected { + t.Errorf("expected %s to be %v, got %v", k, expected, v) + } + } + + if t.Failed() { + t.Logf("log message: %s", res) + t.Log("time is ignored in comparison") + } +} + +func ExampleNewLoggerWrapper_slog_logger() { + // assuming you have a context + ctx := context.Background() + + // assuming you have a slogLogger + slogLogger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + // Create a LoggerWrapper with the custom logger + redisLogger := NewLoggerWrapper(slogLogger) + + // Use the wrapped logger with Printf-style logging + redisLogger.Infof(ctx, "This is an info message: %s", "hello world") + + // Use the wrapped logger with structured logging + redisLogger.ErrorContext(ctx, "This is an error message", "key", "value") +} diff --git a/logging/legacy.go b/logging/legacy.go new file mode 100644 index 0000000000..44974cc0bd --- /dev/null +++ b/logging/legacy.go @@ -0,0 +1,98 @@ +package logging + +import ( + "context" + + "github.com/redis/go-redis/v9/internal" +) + +// legacyLoggerAdapter is a logger that implements [LoggerWithLevelI] interface +// using the global [internal.Logger] and [internal.LogLevel] variables. +type legacyLoggerAdapter struct{} + +var _ LoggerWithLevelI = (*legacyLoggerAdapter)(nil) + +// structuredToPrintf converts a structured log message and key-value pairs into something a Printf-style logger can understand. +func (l *legacyLoggerAdapter) structuredToPrintf(msg string, v ...any) (string, []any) { + format := msg + var args []any + + for i := 0; i < len(v); i += 2 { + format += " %v=%v" + if i+1 >= len(v) { + // Odd number of arguments, append a placeholder for the missing value + // adapted from https://cs.opensource.google/go/go/+/master:src/log/slog/record.go;l=160-182;drc=8c41a482f9b7a101404cd0b417ac45abd441e598 + args = append(args, "!BADKEY", v[i]) + break + } + args = append(args, v[i], v[i+1]) + } + + return format, args +} + +func (l legacyLoggerAdapter) Errorf(ctx context.Context, format string, v ...any) { + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) ErrorContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Errorf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) WarnContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Warnf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Warnf(ctx context.Context, format string, v ...any) { + if !internal.LogLevel.WarnOrAbove() { + // Skip logging + return + } + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) InfoContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Infof(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Infof(ctx context.Context, format string, v ...any) { + if !internal.LogLevel.InfoOrAbove() { + // Skip logging + return + } + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) DebugContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Debugf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Debugf(ctx context.Context, format string, v ...any) { + if !internal.LogLevel.DebugOrAbove() { + // Skip logging + return + } + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Enabled(ctx context.Context, level LogLevelT) bool { + switch level { + case LogLevelDebug: + return internal.LogLevel.DebugOrAbove() + case LogLevelWarn: + return internal.LogLevel.WarnOrAbove() + case LogLevelInfo: + return internal.LogLevel.InfoOrAbove() + } + return true +} + +var legacyLoggerWithLevel = &legacyLoggerAdapter{} + +func LoggerWithLevel() *LoggerWrapper { + return NewLoggerWrapper(legacyLoggerWithLevel) +} diff --git a/logging/level_after_go_121.go b/logging/level_after_go_121.go new file mode 100644 index 0000000000..9282b49104 --- /dev/null +++ b/logging/level_after_go_121.go @@ -0,0 +1,48 @@ +//go:build go1.21 + +package logging + +// The purpose of this file is to provide an implementation of isLevelEnabled +// that uses [slog.Logger.Enabled] method when available. + +import ( + "context" + "log/slog" +) + +// isLevelEnabled checks whether the given logging level is enabled +// for the provided logger. If the logger is of type [slog.Logger], +// it uses its [slog.Logger.Enabled] method to determine whether +// the level is enabled. +func isLevelEnabled(ctx context.Context, logger LoggerWithLevelI, level LogLevelT) bool { + sl, ok := logger.(slogEnabler) + if !ok { + // unknown logger type, fall back to legacy logger + return legacyLoggerWithLevel.Enabled(ctx, level) + } + + // map our [LogLevelT] to [slog.Level] + // TODO(ccoVeille): simplify in v10 align when levels will be aligned with slog.Level + slogLevel, ok := levelMap[level] + if !ok { + // unknown level, assume enabled + return true + } + return sl.Enabled(ctx, slogLevel) + +} + +// TODO(ccoVeille): simplify in v10 align when levels will be aligned with slog.Level +var levelMap = map[LogLevelT]slog.Level{ + LogLevelDebug: slog.LevelDebug, + LogLevelInfo: slog.LevelInfo, + LogLevelWarn: slog.LevelWarn, + LogLevelError: slog.LevelError, +} + +type slogEnabler interface { + Enabled(ctx context.Context, level slog.Level) bool +} + +// Verify that [slog.Logger] implements [slogEnabler] +var _ slogEnabler = (*slog.Logger)(nil) diff --git a/logging/level_before_go_121.go b/logging/level_before_go_121.go new file mode 100644 index 0000000000..5b4a180849 --- /dev/null +++ b/logging/level_before_go_121.go @@ -0,0 +1,16 @@ +//go:build !go1.21 + +package logging + +// The purpose of this file is to provide an implementation of isLevelEnabled +// when [slog.Logger.Enabled] method is not available (before Go 1.21). + +import "context" + +// isLevelEnabled checks whether the given logging level is enabled +// for the provided logger. Since before Go 1.21 we don't have +// [slog.Logger.Enabled], we always fall back to the legacy logger. +func isLevelEnabled(ctx context.Context, logger LoggerWithLevelI, level LogLevelT) bool { + // unknown logger type, fall back to legacy logger + return legacyLoggerWithLevel.Enabled(ctx, level) +} diff --git a/maintnotifications/circuit_breaker.go b/maintnotifications/circuit_breaker.go index cb76b6447f..0477f1caf4 100644 --- a/maintnotifications/circuit_breaker.go +++ b/maintnotifications/circuit_breaker.go @@ -6,8 +6,8 @@ import ( "sync/atomic" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/logging" ) // CircuitBreakerState represents the state of a circuit breaker @@ -102,9 +102,7 @@ func (cb *CircuitBreaker) Execute(fn func() error) error { if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) { cb.requests.Store(0) cb.successes.Store(0) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) - } + cb.logger().Infof(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) // Fall through to half-open logic } else { return ErrCircuitBreakerOpen @@ -144,17 +142,13 @@ func (cb *CircuitBreaker) recordFailure() { case CircuitBreakerClosed: if failures >= int64(cb.failureThreshold) { if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) - } + cb.logger().Warnf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) } } case CircuitBreakerHalfOpen: // Any failure in half-open state immediately opens the circuit if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) - } + cb.logger().Warnf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) } } } @@ -176,9 +170,7 @@ func (cb *CircuitBreaker) recordSuccess() { if successes >= int64(cb.maxRequests) { if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) { cb.failures.Store(0) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) - } + cb.logger().Infof(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) } } } @@ -202,6 +194,13 @@ func (cb *CircuitBreaker) GetStats() CircuitBreakerStats { } } +func (cb *CircuitBreaker) logger() *logging.LoggerWrapper { + if cb.config != nil && cb.config.Logger != nil { + return logging.NewLoggerWrapper(cb.config.Logger) + } + return logging.LoggerWithLevel() +} + // CircuitBreakerStats provides statistics about a circuit breaker type CircuitBreakerStats struct { Endpoint string @@ -325,8 +324,8 @@ func (cbm *CircuitBreakerManager) cleanup() { } // Log cleanup results - if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) + if len(toDelete) > 0 { + cbm.logger().Infof(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) } cbm.lastCleanup.Store(now.Unix()) @@ -351,3 +350,10 @@ func (cbm *CircuitBreakerManager) Reset() { return true }) } + +func (cbm *CircuitBreakerManager) logger() *logging.LoggerWrapper { + if cbm.config != nil && cbm.config.Logger != nil { + return logging.NewLoggerWrapper(cbm.config.Logger) + } + return logging.LoggerWithLevel() +} diff --git a/maintnotifications/config.go b/maintnotifications/config.go index cbf4f6b22b..a280c4fb5f 100644 --- a/maintnotifications/config.go +++ b/maintnotifications/config.go @@ -7,9 +7,9 @@ import ( "strings" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" ) // Mode represents the maintenance notifications mode @@ -128,6 +128,9 @@ type Config struct { // After this many retries, the connection will be removed from the pool. // Default: 3 MaxHandoffRetries int + + // Logger is an optional custom logger for maintenance notifications. + Logger logging.LoggerWithLevelI } func (c *Config) IsEnabled() bool { @@ -312,10 +315,9 @@ func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) * result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests } - if internal.LogLevel.DebugOrAbove() { - internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled()) - internal.Logger.Printf(context.Background(), logs.ConfigDebug(result)) - } + c.logger().Debugf(context.Background(), logs.DebugLoggingEnabled()) + c.logger().Debugf(context.Background(), logs.ConfigDebug(result)) + return result } @@ -341,6 +343,8 @@ func (c *Config) Clone() *Config { // Configuration fields MaxHandoffRetries: c.MaxHandoffRetries, + + Logger: c.Logger, } } @@ -365,6 +369,13 @@ func (c *Config) applyWorkerDefaults(poolSize int) { } } +func (c *Config) logger() *logging.LoggerWrapper { + if c.Logger != nil { + return logging.NewLoggerWrapper(c.Logger) + } + return logging.LoggerWithLevel() +} + // DetectEndpointType automatically detects the appropriate endpoint type // based on the connection address and TLS configuration. // diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 53f28f49c8..ee616ef34c 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) // handoffWorkerManager manages background workers and queue for connection handoffs @@ -121,7 +122,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() { defer func() { // Handle panics to ensure proper cleanup if r := recover(); r != nil { - internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r)) + hwm.logger().Errorf(context.Background(), logs.WorkerPanicRecovered(r)) } // Decrement active worker count when exiting @@ -145,23 +146,17 @@ func (hwm *handoffWorkerManager) onDemandWorker() { select { case <-hwm.shutdown: - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown()) - } + hwm.logger().Infof(context.Background(), logs.WorkerExitingDueToShutdown()) return case <-timer.C: // Worker has been idle for too long, exit to save resources - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) - } + hwm.logger().Infof(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) return case request := <-hwm.handoffQueue: // Check for shutdown before processing select { case <-hwm.shutdown: - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) - } + hwm.logger().Infof(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) // Clean up the request before exiting hwm.pending.Delete(request.ConnID) return @@ -175,9 +170,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() { // processHandoffRequest processes a single handoff request func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) - } + hwm.logger().Infof(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) // Create a context with handoff timeout from config handoffTimeout := 15 * time.Second // Default timeout @@ -217,21 +210,22 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { afterTime = minRetryBackoff } - if internal.LogLevel.InfoOrAbove() { + // the HandoffRetries() requires locking resource via [atomic.Uint32.Load], + // so we check the log level first before calling it + if hwm.logger().Enabled(context.Background(), internal.LogLevelInfo) { + // Get current retry count for better logging currentRetries := request.Conn.HandoffRetries() maxRetries := 3 // Default fallback if hwm.config != nil { maxRetries = hwm.config.MaxHandoffRetries } - internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) + hwm.logger().Infof(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) } // Schedule retry - keep connection in pending map until retry is queued time.AfterFunc(afterTime, func() { if err := hwm.queueHandoff(request.Conn); err != nil { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err)) - } + hwm.logger().Warnf(context.Background(), logs.CannotQueueHandoffForRetry(err)) // Failed to queue retry - remove from pending and close connection hwm.pending.Delete(request.Conn.GetID()) hwm.closeConnFromRequest(context.Background(), request, err) @@ -268,9 +262,7 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { // on retries the connection will not be marked for handoff, but it will have retries > 0 // if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff if !shouldHandoff && conn.HandoffRetries() == 0 { - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) - } + hwm.logger().Infof(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID())) } @@ -311,9 +303,7 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { // Queue is full - log and attempt scaling queueLen := len(hwm.handoffQueue) queueCap := cap(hwm.handoffQueue) - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) - } + hwm.logger().Warnf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) } } } @@ -366,7 +356,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c // Check if circuit breaker is open before attempting handoff if circuitBreaker.IsOpen() { - internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) + hwm.logger().Infof(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open } @@ -395,16 +385,14 @@ func (hwm *handoffWorkerManager) performHandoffInternal( connID uint64, ) (shouldRetry bool, err error) { retries := conn.IncrementAndGetHandoffRetries(1) - internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) + hwm.logger().Infof(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) maxRetries := 3 // Default fallback if hwm.config != nil { maxRetries = hwm.config.MaxHandoffRetries } if retries > maxRetries { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) - } + hwm.logger().Warnf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) // won't retry on ErrMaxHandoffRetriesReached return false, ErrMaxHandoffRetriesReached } @@ -415,7 +403,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // Create new connection to the new endpoint newNetConn, err := endpointDialer(ctx) if err != nil { - internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) + hwm.logger().Errorf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) // will retry // Maybe a network error - retry after a delay return true, err @@ -434,9 +422,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration) conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) - } + hwm.logger().Infof(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) } // Replace the connection and execute initialization @@ -459,7 +445,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // Note: Theoretically there may be a short window where the connection is in the pool // and IDLE (initConn completed) but still has handoff state set. conn.ClearHandoffState() - internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) + hwm.logger().Infof(ctx, logs.HandoffSucceeded(connID, newEndpoint)) // successfully completed the handoff, no retry needed and no error return false, nil @@ -497,16 +483,19 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque // Remove() is meant to be called after Get() and frees a turn. // RemoveWithoutTurn() removes and closes the connection without affecting the queue. pooler.RemoveWithoutTurn(ctx, conn, err) - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) - } + hwm.logger().Warnf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) } else { errClose := conn.Close() // Close the connection if no pool provided if errClose != nil { internal.Logger.Printf(ctx, "redis: failed to close connection: %v", errClose) } - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) - } + hwm.logger().Warnf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) + } +} + +func (hwm *handoffWorkerManager) logger() *logging.LoggerWrapper { + if hwm.config != nil && hwm.config.Logger != nil { + return logging.NewLoggerWrapper(hwm.config.Logger) } + return logging.LoggerWithLevel() } diff --git a/maintnotifications/manager.go b/maintnotifications/manager.go index 775c163e14..f394230b9e 100644 --- a/maintnotifications/manager.go +++ b/maintnotifications/manager.go @@ -9,10 +9,10 @@ import ( "sync/atomic" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/interfaces" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -150,14 +150,10 @@ func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoi // Use LoadOrStore for atomic check-and-set operation if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded { // Duplicate MOVING notification, ignore - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) - } + hm.logger().Debugf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) return nil } - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) - } + hm.logger().Debugf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) // Increment active operation count atomically hm.activeOperationCount.Add(1) @@ -175,15 +171,11 @@ func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) { // Remove from active operations atomically if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded { - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) - } + hm.logger().Debugf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) // Decrement active operation count only if operation existed hm.activeOperationCount.Add(-1) } else { - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID)) - } + hm.logger().Debugf(context.Background(), logs.OperationNotTracked(connID, seqID)) } } @@ -318,3 +310,10 @@ func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) { defer hm.hooksMu.Unlock() hm.hooks = append(hm.hooks, notificationHook) } + +func (hm *Manager) logger() *logging.LoggerWrapper { + if hm.config != nil && hm.config.Logger != nil { + return logging.NewLoggerWrapper(hm.config.Logger) + } + return logging.LoggerWithLevel() +} diff --git a/maintnotifications/pool_hook.go b/maintnotifications/pool_hook.go index 9ea0558bf8..721ca1e750 100644 --- a/maintnotifications/pool_hook.go +++ b/maintnotifications/pool_hook.go @@ -6,9 +6,9 @@ import ( "sync" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) // OperationsManagerInterface defines the interface for completing handoff operations @@ -148,7 +148,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool if err := ph.workerManager.queueHandoff(conn); err != nil { // Failed to queue handoff, remove the connection - internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) + ph.logger().Errorf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) // Don't pool, remove connection, no error to caller return false, true, nil } @@ -168,7 +168,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool // Other error - remove the connection return false, true, nil } - internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID())) + ph.logger().Errorf(ctx, logs.MarkedForHandoff(conn.GetID())) return true, false, nil } @@ -180,3 +180,10 @@ func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) { func (ph *PoolHook) Shutdown(ctx context.Context) error { return ph.workerManager.shutdownWorkers(ctx) } + +func (ph *PoolHook) logger() *logging.LoggerWrapper { + if ph.config != nil && ph.config.Logger != nil { + return logging.NewLoggerWrapper(ph.config.Logger) + } + return logging.LoggerWithLevel() +} diff --git a/maintnotifications/push_notification_handler.go b/maintnotifications/push_notification_handler.go index 937b4ae82e..8dcf9dfea6 100644 --- a/maintnotifications/push_notification_handler.go +++ b/maintnotifications/push_notification_handler.go @@ -9,6 +9,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -21,13 +22,13 @@ type NotificationHandler struct { // HandlePushNotification processes push notifications with hook support. func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) == 0 { - internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification)) + snh.logger().Errorf(ctx, logs.InvalidNotificationFormat(notification)) return ErrInvalidNotification } notificationType, ok := notification[0].(string) if !ok { - internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) + snh.logger().Errorf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) return ErrInvalidNotification } @@ -64,19 +65,19 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand // ["MOVING", seqNum, timeS, endpoint] - per-connection handoff func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) < 3 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("MOVING", notification)) return ErrInvalidNotification } seqID, ok := notification[1].(int64) if !ok { - internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) + snh.logger().Errorf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) return ErrInvalidNotification } // Extract timeS timeS, ok := notification[2].(int64) if !ok { - internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) + snh.logger().Errorf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) return ErrInvalidNotification } @@ -90,7 +91,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if notification[3] == nil || stringified == internal.RedisNull { newEndpoint = "" } else { - internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) + snh.logger().Errorf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) return ErrInvalidNotification } } @@ -99,7 +100,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // Get the connection that received this notification conn := handlerCtx.Conn if conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("MOVING")) return ErrInvalidNotification } @@ -108,7 +109,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if pc, ok := conn.(*pool.Conn); ok { poolConn = pc } else { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) return ErrInvalidNotification } @@ -124,9 +125,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus deadline := time.Now().Add(time.Duration(timeS) * time.Second) // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds if newEndpoint == "" || newEndpoint == internal.RedisNull { - if internal.LogLevel.DebugOrAbove() { - internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) - } + snh.logger().Debugf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) // same as current endpoint newEndpoint = snh.manager.options.GetAddr() // delay the handoff for timeS/2 seconds to the same endpoint @@ -139,7 +138,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus } if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { // Log error but don't fail the goroutine - use background context since original may be cancelled - internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) + snh.logger().Errorf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) } }) return nil @@ -150,7 +149,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { - internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) + snh.logger().Errorf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) // Connection is already marked for handoff, which is acceptable // This can happen if multiple MOVING notifications are received for the same connection return nil @@ -171,25 +170,23 @@ func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx // MIGRATING notifications indicate that a connection is about to be migrated // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("MIGRATING", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Apply relaxed timeout to this specific connection - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) - } + snh.logger().Infof(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) return nil } @@ -199,26 +196,25 @@ func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx p // MIGRATED notifications indicate that a connection migration has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("MIGRATED", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Clear relaxed timeout for this specific connection - if internal.LogLevel.InfoOrAbove() { - connID := conn.GetID() - internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) - } + connID := conn.GetID() + snh.logger().Infof(ctx, logs.UnrelaxedTimeout(connID)) + conn.ClearRelaxedTimeout() return nil } @@ -228,26 +224,25 @@ func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCt // FAILING_OVER notifications indicate that a connection is about to failover // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Apply relaxed timeout to this specific connection - if internal.LogLevel.InfoOrAbove() { - connID := conn.GetID() - internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) - } + connID := conn.GetID() + snh.logger().Infof(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) return nil } @@ -257,26 +252,31 @@ func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx // FAILED_OVER notifications indicate that a connection failover has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Clear relaxed timeout for this specific connection - if internal.LogLevel.InfoOrAbove() { - connID := conn.GetID() - internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) - } + connID := conn.GetID() + snh.logger().Infof(ctx, logs.UnrelaxedTimeout(connID)) conn.ClearRelaxedTimeout() return nil } + +func (snh *NotificationHandler) logger() *logging.LoggerWrapper { + if snh.manager != nil { + return snh.manager.logger() + } + return logging.LoggerWithLevel() +} diff --git a/options.go b/options.go index 9773e86f77..cba4cb3776 100644 --- a/options.go +++ b/options.go @@ -17,6 +17,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -267,6 +268,13 @@ type Options struct { // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. // If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it. MaintNotificationsConfig *maintnotifications.Config + + // Logger is the logger used by the client for logging. + // If none is provided, the global logger [internal.Logger] is used. + // Keep in mind that the global logger is shared by all clients in the library, and at this time + // it is still the only logger for some internal components. This will change in the future and the global + // logger will be removed. + Logger logging.LoggerWithLevelI } func (opt *Options) init() { diff --git a/osscluster.go b/osscluster.go index 6994ae83f6..dff2af8851 100644 --- a/osscluster.go +++ b/osscluster.go @@ -22,6 +22,7 @@ import ( "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" "github.com/redis/go-redis/v9/internal/routing" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -159,9 +160,13 @@ type ClusterOptions struct { // If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it. // The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications. MaintNotificationsConfig *maintnotifications.Config + // ShardPicker is used to pick a shard when the request_policy is // ReqDefault and the command has no keys. ShardPicker routing.ShardPicker + + // Logger is an optional logger for logging cluster-related messages. + Logger logging.LoggerWithLevelI } func (opt *ClusterOptions) init() { @@ -408,6 +413,8 @@ func (opt *ClusterOptions) clientOptions() *Options { UnstableResp3: opt.UnstableResp3, MaintNotificationsConfig: maintNotificationsConfig, PushNotificationProcessor: opt.PushNotificationProcessor, + + Logger: opt.Logger, } } @@ -721,6 +728,13 @@ func (c *clusterNodes) Random() (*clusterNode, error) { return c.GetOrCreate(addrs[n]) } +func (c *clusterNodes) logger() *logging.LoggerWrapper { + if c.opt != nil && c.opt.Logger != nil { + return logging.NewLoggerWrapper(c.opt.Logger) + } + return logging.LoggerWithLevel() +} + //------------------------------------------------------------------------------ type clusterSlot struct { @@ -918,12 +932,12 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { // if all nodes are failing, we will pick the temporarily failing node with lowest latency if minLatency < maximumNodeLatency && closestNode != nil { - internal.Logger.Printf(context.TODO(), "redis: all nodes are marked as failed, picking the temporarily failing node with lowest latency") + c.nodes.logger().Errorf(context.TODO(), "redis: all nodes are marked as failed, picking the temporarily failing node with lowest latency") return closestNode, nil } // If all nodes are having the maximum latency(all pings are failing) - return a random node across the cluster - internal.Logger.Printf(context.TODO(), "redis: pings to all nodes are failing, picking a random node across the cluster") + c.nodes.logger().Errorf(context.TODO(), "redis: pings to all nodes are failing, picking a random node across the cluster") return c.nodes.Random() } @@ -1799,7 +1813,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } if err := statusCmd.readReply(rd); err != nil { return err @@ -1810,7 +1824,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } err := statusCmd.readReply(rd) if err != nil { @@ -1829,7 +1843,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse number of replies. line, err := rd.ReadLine() @@ -2091,13 +2105,13 @@ func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { cmdsInfo, err := c.cmdsInfoCache.Get(cmdInfoCtx) if err != nil { - internal.Logger.Printf(cmdInfoCtx, "getting command info: %s", err) + c.logger().Errorf(cmdInfoCtx, "getting command info: %s", err) return nil } info := cmdsInfo[name] if info == nil { - internal.Logger.Printf(cmdInfoCtx, "info for cmd=%s not found", name) + c.logger().Errorf(cmdInfoCtx, "info for cmd=%s not found", name) } return info @@ -2256,6 +2270,13 @@ func (c *ClusterClient) NewDynamicResolver() *commandInfoResolver { } } +func (c *ClusterClient) logger() *logging.LoggerWrapper { + if c.opt.Logger != nil { + return logging.NewLoggerWrapper(c.opt.Logger) + } + return logging.LoggerWithLevel() +} + func appendIfNotExist[T comparable](vals []T, newVal T) []T { for _, v := range vals { if v == newVal { diff --git a/pubsub.go b/pubsub.go index 1b9d4e7fe1..81e1bed672 100644 --- a/pubsub.go +++ b/pubsub.go @@ -10,6 +10,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -141,6 +142,16 @@ func mapKeys(m map[string]struct{}) []string { return s } +// logger is a wrapper around the logger to log messages with context. +// +// it uses the client logger if set, otherwise it uses the global logger. +func (c *PubSub) logger() *logging.LoggerWrapper { + if c.opt != nil && c.opt.Logger != nil { + return logging.NewLoggerWrapper(c.opt.Logger) + } + return logging.LoggerWithLevel() +} + func (c *PubSub) _subscribe( ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, ) error { @@ -190,7 +201,7 @@ func (c *PubSub) reconnect(ctx context.Context, reason error) { // Update the address in the options oldAddr := c.cn.RemoteAddr().String() c.opt.Addr = newEndpoint - internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) + c.logger().Infof(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) } } _ = c.closeTheCn(reason) @@ -475,7 +486,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) + c.logger().Errorf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) } return c.cmd.readReply(rd) }) @@ -634,6 +645,9 @@ func WithChannelSendTimeout(d time.Duration) ChannelOption { type channel struct { pubSub *PubSub + // Optional logger for logging channel-related messages. + Logger logging.LoggerWithLevelI + msgCh chan *Message allCh chan interface{} ping chan struct{} @@ -733,12 +747,10 @@ func (c *channel) initMsgChan() { <-timer.C } case <-timer.C: - internal.Logger.Printf( - ctx, "redis: %v channel is full for %s (message is dropped)", - c, c.chanSendTimeout) + c.logger().Errorf(ctx, "redis: %v channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: - internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) + c.logger().Errorf(ctx, "redis: unknown message type: %T", msg) } } }() @@ -787,13 +799,19 @@ func (c *channel) initAllChan() { <-timer.C } case <-timer.C: - internal.Logger.Printf( - ctx, "redis: %v channel is full for %s (message is dropped)", + c.logger().Errorf(ctx, "redis: %v channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: - internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) + c.logger().Errorf(ctx, "redis: unknown message type: %T", msg) } } }() } + +func (c *channel) logger() *logging.LoggerWrapper { + if c.Logger != nil { + return logging.NewLoggerWrapper(c.Logger) + } + return logging.LoggerWithLevel() +} diff --git a/redis.go b/redis.go index a6a7106779..b69fc8096c 100644 --- a/redis.go +++ b/redis.go @@ -15,6 +15,7 @@ import ( "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -336,16 +337,16 @@ func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { // Close the connection to force a reconnection. err := c.connPool.CloseConn(poolCn) if err != nil { - internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err) + c.logger().Errorf(context.Background(), "redis: failed to close connection: %v", err) // try to close the network connection directly // so that no resource is leaked err := poolCn.Close() if err != nil { - internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err) + c.logger().Errorf(context.Background(), "redis: failed to close network connection: %v", err) } } } - internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err) + c.logger().Errorf(context.Background(), "redis: re-authentication failed: %v", err) } } } @@ -562,14 +563,13 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr) default: // will handle auto and any other // Disabling logging here as it's too noisy. - // TODO: Enable when we have a better logging solution for log levels - // internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) + // c.logger().Errorf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled c.optLock.Unlock() // auto mode, disable maintnotifications and continue if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil { // Log error but continue - auto mode should be resilient - internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr) + c.logger().Errorf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr) } } } else { @@ -633,7 +633,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) } else { // process any pending push notifications before returning the connection to the pool if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before releasing connection: %v", err) } c.connPool.Put(ctx, cn) } @@ -700,7 +700,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the command if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before command: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -723,7 +723,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } return readReplyFunc(rd) }); err != nil { @@ -769,6 +769,15 @@ func (c *baseClient) context(ctx context.Context) context.Context { return context.Background() } +// logger is a wrapper around the logger to log messages with context. +// it uses the client logger if set, otherwise it uses the global logger. +func (c *baseClient) logger() *logging.LoggerWrapper { + if c.opt != nil && c.opt.Logger != nil { + return logging.NewLoggerWrapper(c.opt.Logger) + } + return logging.LoggerWithLevel() +} + // createInitConnFunc creates a connection initialization function that can be used for reconnections. func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { return func(ctx context.Context, cn *pool.Conn) error { @@ -880,7 +889,7 @@ func (c *baseClient) generalProcessPipeline( lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) } var err error canRetry, err = p(ctx, cn, cmds) @@ -902,7 +911,7 @@ func (c *baseClient) pipelineProcessCmds( ) (bool, error) { // Process any pending push notifications before executing the pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -926,7 +935,7 @@ func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *pr for i, cmd := range cmds { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } err := cmd.readReply(rd) cmd.SetErr(err) @@ -944,7 +953,7 @@ func (c *baseClient) txPipelineProcessCmds( ) (bool, error) { // Process any pending push notifications before executing the transaction pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before transaction: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -978,7 +987,7 @@ func (c *baseClient) txPipelineProcessCmds( func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse +OK. if err := statusCmd.readReply(rd); err != nil { @@ -989,7 +998,7 @@ func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd for _, cmd := range cmds { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } if err := statusCmd.readReply(rd); err != nil { cmd.SetErr(err) @@ -1001,7 +1010,7 @@ func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse number of replies. line, err := rd.ReadLine() @@ -1075,7 +1084,7 @@ func NewClient(opt *Options) *Client { if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 { err := c.enableMaintNotificationsUpgrades() if err != nil { - internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err) + c.logger().Errorf(context.Background(), "failed to initialize maintnotifications: %v", err) if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled { /* Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested. diff --git a/redis_after_go_121_test.go b/redis_after_go_121_test.go new file mode 100644 index 0000000000..3dc37ad83a --- /dev/null +++ b/redis_after_go_121_test.go @@ -0,0 +1,107 @@ +//go:build go1.21 + +package redis_test + +// The purpose of this file is to provide an implementation of isLevelEnabled +// that uses [slog.Logger.Enabled] method when available. + +import ( + "context" + "log/slog" + "os" + "strings" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/logging" +) + +// You can pass a custom logger that implements [logging.LoggerWithLevelI] interface +// to the Redis [Client]. Here's an example using slog.Logger. +func ExampleNewClient_with_slog_logger() { + // assuming you have a context + ctx := context.Background() + + // assuming you have a logger + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + // Create a Redis client with the custom logger + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + Logger: logger, + }) + + // Use the Redis client + _ = rdb.Set(ctx, "key", "value", 0) +} + +// You can also wrap your slog.Logger with [logging.NewLoggerWrapper] to +// customize its behavior when used with Redis client. Here's an example: +func ExampleNewClient_with_logger_wrapper() { + // assuming you have a context + ctx := context.Background() + + // assuming you have a logger + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + // Wrap the slog.Logger with LoggerWrapper + wrappedLogger := logging.NewLoggerWrapper( + logger, + logging.WithLoggerLevel(logging.LogLevelDebug), // set minimum log level to Debug + logging.WithPrintfAdapter( + func(ctx context.Context, format string, v ...any) (context.Context, string, []any) { + + if after, ok := strings.CutPrefix(format, "redis:"); ok { + // adjust the format string to remove "redis:" prefix when present + // some log messages from go-redis have this prefix + format = after + } + + // Here is an example of customizing the log format: + // if any of the arguments is an error, add "ERROR:" prefix to the format + // you can customize this logic as needed + // + // for example, you might want to inject the `err` value into the context, + // if you have a slog.Logger that supports it. + for i := range v { + if _, ok := v[i].(error); ok { + format = "ERROR: " + format + break + } + } + + // add a prefix to indicate these are Redis logs + format = "redis-logger: " + format + + return ctx, format, v + }, + ), + ) + + // Create a Redis client with the wrapped logger + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + Logger: wrappedLogger, + }) + + // Use the Redis client + _ = rdb.Set(ctx, "key", "value", 0) +} + +// You can use your own logger that satisfies [logging.LoggerWithLevelI] interface. Here's an example: +func ExampleNewClient_with_custom_logger() { + // assuming you have your own logger that implements [logging.LoggerWithLevelI] + var myLogger logging.LoggerWithLevelI + + // Create a Redis client with the custom logger + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + Logger: myLogger, + }) + + // Use the Redis client + _ = rdb.Set(ctx, "key", "value", 0) +} diff --git a/ring.go b/ring.go index 3381460abd..ccdfcb8aac 100644 --- a/ring.go +++ b/ring.go @@ -20,6 +20,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/logging" ) var errRingShardsDown = errors.New("redis: all ring shards are down") @@ -154,6 +155,8 @@ type RingOptions struct { DisableIdentity bool IdentitySuffix string UnstableResp3 bool + + Logger logging.LoggerWithLevelI } func (opt *RingOptions) init() { @@ -345,7 +348,7 @@ func (c *ringSharding) SetAddrs(addrs map[string]string) { cleanup := func(shards map[string]*ringShard) { for addr, shard := range shards { if err := shard.Client.Close(); err != nil { - internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + c.logger().Errorf(context.Background(), "shard.Close %s failed: %s", addr, err) } } } @@ -490,7 +493,7 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { for _, shard := range c.List() { isUp := c.opt.HeartbeatFn(ctx, shard.Client) if shard.Vote(isUp) { - internal.Logger.Printf(ctx, "ring shard state changed: %s", shard) + c.logger().Infof(ctx, "ring shard state changed: %s", shard) rebalance = true } } @@ -559,6 +562,13 @@ func (c *ringSharding) Close() error { return firstErr } +func (c *ringSharding) logger() *logging.LoggerWrapper { + if c.opt != nil && c.opt.Logger != nil { + return logging.NewLoggerWrapper(c.opt.Logger) + } + return logging.LoggerWithLevel() +} + //------------------------------------------------------------------------------ // Ring is a Redis client that uses consistent hashing to distribute diff --git a/sentinel.go b/sentinel.go index 663f7b1ad9..a5491de114 100644 --- a/sentinel.go +++ b/sentinel.go @@ -13,10 +13,10 @@ import ( "time" "github.com/redis/go-redis/v9/auth" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/rand" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -149,6 +149,9 @@ type FailoverOptions struct { // If nil, maintnotifications upgrades are disabled. // (however if Mode is nil, it defaults to "auto" - enable if server supports it) //MaintNotificationsConfig *maintnotifications.Config + + // Optional logger for logging + Logger logging.LoggerWithLevelI } func (opt *FailoverOptions) clientOptions() *Options { @@ -199,6 +202,8 @@ func (opt *FailoverOptions) clientOptions() *Options { MaintNotificationsConfig: &maintnotifications.Config{ Mode: maintnotifications.ModeDisabled, }, + + Logger: opt.Logger, } } @@ -247,6 +252,8 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options { MaintNotificationsConfig: &maintnotifications.Config{ Mode: maintnotifications.ModeDisabled, }, + + Logger: opt.Logger, } } @@ -300,6 +307,8 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { MaintNotificationsConfig: &maintnotifications.Config{ Mode: maintnotifications.ModeDisabled, }, + + Logger: opt.Logger, } } @@ -831,7 +840,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { return "", err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", c.opt.MasterName, err) } else { return addr, nil @@ -849,7 +858,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { return "", err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", c.opt.MasterName, err) } else { return addr, nil @@ -878,7 +887,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { sentinelCli := NewSentinelClient(c.opt.sentinelOptions(addr)) addrVal, err := sentinelCli.GetMasterAddrByName(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", addr, c.opt.MasterName, err) _ = sentinelCli.Close() errCh <- err @@ -889,7 +898,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { // Push working sentinel to the top c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] c.setSentinel(ctx, sentinelCli) - internal.Logger.Printf(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) + c.logger().Infof(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) cancel() }) }(i, sentinelAddr) @@ -934,7 +943,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo return nil, err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil @@ -952,7 +961,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo return nil, err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil @@ -973,7 +982,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } - internal.Logger.Printf(ctx, "sentinel: Replicas master=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas master=%q failed: %s", c.opt.MasterName, err) continue } @@ -1006,7 +1015,7 @@ func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *Sentinel func (c *sentinelFailover) getReplicaAddrs(ctx context.Context, sentinel *SentinelClient) ([]string, error) { addrs, err := sentinel.Replicas(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) return nil, err } @@ -1054,7 +1063,7 @@ func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { } c.masterAddr = addr - internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", + c.logger().Infof(ctx, "sentinel: new master=%q addr=%q", c.opt.MasterName, addr) if c.onFailover != nil { c.onFailover(ctx, addr) @@ -1075,7 +1084,7 @@ func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelCl func (c *sentinelFailover) discoverSentinels(ctx context.Context) { sentinels, err := c.sentinel.Sentinels(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) + c.logger().Errorf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) return } for _, sentinel := range sentinels { @@ -1090,7 +1099,7 @@ func (c *sentinelFailover) discoverSentinels(ctx context.Context) { if ip != "" && port != "" { sentinelAddr := net.JoinHostPort(ip, port) if !contains(c.sentinelAddrs, sentinelAddr) { - internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q", + c.logger().Infof(ctx, "sentinel: discovered new sentinel=%q for master=%q", sentinelAddr, c.opt.MasterName) c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) } @@ -1110,7 +1119,7 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { if msg.Channel == "+switch-master" { parts := strings.Split(msg.Payload, " ") if parts[0] != c.opt.MasterName { - internal.Logger.Printf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) + c.logger().Infof(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) continue } addr := net.JoinHostPort(parts[3], parts[4]) @@ -1123,6 +1132,13 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { } } +func (c *sentinelFailover) logger() *logging.LoggerWrapper { + if c.opt != nil && c.opt.Logger != nil { + return logging.NewLoggerWrapper(c.opt.Logger) + } + return logging.LoggerWithLevel() +} + func contains(slice []string, str string) bool { for _, s := range slice { if s == str {