diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go new file mode 100644 index 00000000000..60917274fb3 --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper.go @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strings" +) + +// ConfigHelper provides typed, ergonomic access to azd configuration through +// the gRPC UserConfig and Environment services. It eliminates the boilerplate +// of raw gRPC calls and JSON marshaling that extension authors otherwise need. +// +// Configuration sources (in merge priority, lowest to highest): +// 1. User config (global azd config) — via UserConfigService +// 2. Environment config (per-env) — via EnvironmentService +// +// Usage: +// +// ch := azdext.NewConfigHelper(client) +// port, err := ch.GetUserString(ctx, "extensions.myext.port") +// var cfg MyConfig +// err = ch.GetUserJSON(ctx, "extensions.myext", &cfg) +type ConfigHelper struct { + client *AzdClient +} + +// NewConfigHelper creates a [ConfigHelper] for the given AZD client. +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) { + if client == nil { + return nil, errors.New("azdext.NewConfigHelper: client must not be nil") + } + + return &ConfigHelper{client: client}, nil +} + +// --- User Config (global) --- + +// GetUserString retrieves a string value from the global user config at the +// given dot-separated path. Returns ("", false, nil) when the path does not +// exist, and ("", false, err) on gRPC errors. +func (ch *ConfigHelper) GetUserString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.UserConfig().GetString(ctx, &GetUserConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetUserString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetUserJSON retrieves a value from the global user config and unmarshals it +// into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetUserJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetUserJSON: out must not be nil") + } + + resp, err := ch.client.UserConfig().Get(ctx, &GetUserConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetUserJSON marshals value as JSON and writes it to the global user config +// at the given path. +func (ch *ConfigHelper) SetUserJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetUserJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for path %q: %w", path, err), + } + } + + _, err = ch.client.UserConfig().Set(ctx, &SetUserConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetUser removes a value from the global user config. +func (ch *ConfigHelper) UnsetUser(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.UserConfig().Unset(ctx, &UnsetUserConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetUser: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Environment Config (per-environment) --- + +// GetEnvString retrieves a string config value from the current environment. +// Returns ("", false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.Environment().GetConfigString(ctx, &GetConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetEnvString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetEnvJSON retrieves a value from the current environment's config and +// unmarshals it into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetEnvJSON: out must not be nil") + } + + resp, err := ch.client.Environment().GetConfig(ctx, &GetConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal env config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetEnvJSON marshals value as JSON and writes it to the current environment's config. +func (ch *ConfigHelper) SetEnvJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetEnvJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for env config path %q: %w", path, err), + } + } + + _, err = ch.client.Environment().SetConfig(ctx, &SetConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetEnv removes a value from the current environment's config. +func (ch *ConfigHelper) UnsetEnv(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.Environment().UnsetConfig(ctx, &UnsetConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetEnv: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Merge --- + +// MergeJSON performs a shallow merge of override into base, returning a new map. +// Both inputs must be JSON-compatible maps (map[string]any). Keys in override +// take precedence over keys in base. +// +// This is NOT a deep merge — nested maps are replaced entirely by the override +// value. For predictable extension config behavior, keep config structures flat +// or use explicit path-based Set operations for nested values. +func MergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + merged[k] = v + } + + return merged +} + +// deepMergeMaxDepth is the maximum recursion depth for [DeepMergeJSON]. +// This prevents stack overflow from deeply nested or adversarial JSON +// structures. 32 levels is far deeper than any legitimate config hierarchy. +const deepMergeMaxDepth = 32 + +// DeepMergeJSON performs a recursive merge of override into base. +// When both base and override have a map value for the same key, those maps +// are merged recursively. Otherwise the override value replaces the base value. +// +// Recursion is bounded to [deepMergeMaxDepth] levels to prevent stack overflow +// from deeply nested or adversarial inputs. Beyond the limit, the override +// value replaces the base value (merge degrades to shallow at that level). +func DeepMergeJSON(base, override map[string]any) map[string]any { + return deepMergeJSON(base, override, 0) +} + +func deepMergeJSON(base, override map[string]any, depth int) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + baseVal, exists := merged[k] + if !exists { + merged[k] = v + continue + } + + baseMap, baseIsMap := baseVal.(map[string]any) + overMap, overIsMap := v.(map[string]any) + + if baseIsMap && overIsMap && depth < deepMergeMaxDepth { + merged[k] = deepMergeJSON(baseMap, overMap, depth+1) + } else { + merged[k] = v + } + } + + return merged +} + +// --- Validation --- + +// ConfigValidator defines a function that validates a config value. +// It returns nil if valid, or an error describing the validation failure. +type ConfigValidator func(value any) error + +// ValidateConfig unmarshals the raw JSON data and runs all supplied validators. +// Returns the first validation error encountered, wrapped in a [*ConfigError]. +func ValidateConfig(path string, data []byte, validators ...ConfigValidator) error { + if len(data) == 0 { + return &ConfigError{ + Path: path, + Reason: ConfigReasonMissing, + Err: fmt.Errorf("config at path %q is empty", path), + } + } + + var value any + if err := json.Unmarshal(data, &value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("config at path %q is not valid JSON: %w", path, err), + } + } + + for _, v := range validators { + if err := v(value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonValidationFailed, + Err: fmt.Errorf("config validation failed at path %q: %w", path, err), + } + } + } + + return nil +} + +// RequiredKeys returns a [ConfigValidator] that checks for the presence of +// the specified keys in a map value. +func RequiredKeys(keys ...string) ConfigValidator { + return func(value any) error { + m, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("expected object, got %T", value) + } + + for _, key := range keys { + if _, exists := m[key]; !exists { + return fmt.Errorf("required key %q is missing", key) + } + } + + return nil + } +} + +// --- Error types --- + +// ConfigReason classifies the cause of a [ConfigError]. +type ConfigReason int + +const ( + // ConfigReasonMissing indicates the config path does not exist or is empty. + ConfigReasonMissing ConfigReason = iota + + // ConfigReasonInvalidFormat indicates the config value is not valid JSON + // or cannot be unmarshaled into the target type. + ConfigReasonInvalidFormat + + // ConfigReasonValidationFailed indicates a validator rejected the config value. + ConfigReasonValidationFailed +) + +// String returns a human-readable label. +func (r ConfigReason) String() string { + switch r { + case ConfigReasonMissing: + return "missing" + case ConfigReasonInvalidFormat: + return "invalid_format" + case ConfigReasonValidationFailed: + return "validation_failed" + default: + return "unknown" + } +} + +// ConfigError is returned by [ConfigHelper] methods on domain-level failures. +type ConfigError struct { + // Path is the config path that was being accessed. + Path string + + // Reason classifies the failure. + Reason ConfigReason + + // Err is the underlying error. + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("azdext.ConfigHelper: %s (path=%s): %v", e.Reason, e.Path, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +var configSegmentRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`) + +func validatePath(path string) error { + if path == "" { + return errors.New("azdext.ConfigHelper: config path must not be empty") + } + if strings.HasPrefix(path, ".") || strings.HasSuffix(path, ".") || strings.Contains(path, "..") { + return errors.New( + "azdext.ConfigHelper: config path must not have empty segments " + + "(no leading/trailing dots or consecutive dots)", + ) + } + for _, seg := range strings.Split(path, ".") { + if !configSegmentRe.MatchString(seg) { + return fmt.Errorf( + "azdext.ConfigHelper: config path segment %q must start with alphanumeric "+ + "and contain only [a-zA-Z0-9_-], max 63 chars", + truncateConfigValue(seg, 64), + ) + } + } + return nil +} + +func truncateConfigValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go new file mode 100644 index 00000000000..c68464f16bc --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -0,0 +1,967 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "google.golang.org/grpc" +) + +// --- Stub UserConfigService --- + +type stubUserConfigService struct { + getResp *GetUserConfigResponse + getStringResp *GetUserConfigStringResponse + getSectionErr error + getErr error + getStringErr error + setErr error + unsetErr error +} + +func (s *stubUserConfigService) Get( + _ context.Context, _ *GetUserConfigRequest, _ ...grpc.CallOption, +) (*GetUserConfigResponse, error) { + return s.getResp, s.getErr +} + +func (s *stubUserConfigService) GetString( + _ context.Context, _ *GetUserConfigStringRequest, _ ...grpc.CallOption, +) (*GetUserConfigStringResponse, error) { + return s.getStringResp, s.getStringErr +} + +func (s *stubUserConfigService) GetSection( + _ context.Context, _ *GetUserConfigSectionRequest, _ ...grpc.CallOption, +) (*GetUserConfigSectionResponse, error) { + return nil, s.getSectionErr +} + +func (s *stubUserConfigService) Set( + _ context.Context, _ *SetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setErr +} + +func (s *stubUserConfigService) Unset( + _ context.Context, _ *UnsetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetErr +} + +// --- Stub EnvironmentService --- + +type stubEnvironmentService struct { + getConfigResp *GetConfigResponse + getConfigStringResp *GetConfigStringResponse + getConfigErr error + getConfigStringErr error + setConfigErr error + unsetConfigErr error +} + +func (s *stubEnvironmentService) GetCurrent( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) List( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Get( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Select( + _ context.Context, _ *SelectEnvironmentRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValues( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*KeyValueListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValue( + _ context.Context, _ *GetEnvRequest, _ ...grpc.CallOption, +) (*KeyValueResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetValue( + _ context.Context, _ *SetEnvRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetConfig( + _ context.Context, _ *GetConfigRequest, _ ...grpc.CallOption, +) (*GetConfigResponse, error) { + return s.getConfigResp, s.getConfigErr +} + +func (s *stubEnvironmentService) GetConfigString( + _ context.Context, _ *GetConfigStringRequest, _ ...grpc.CallOption, +) (*GetConfigStringResponse, error) { + return s.getConfigStringResp, s.getConfigStringErr +} + +func (s *stubEnvironmentService) GetConfigSection( + _ context.Context, _ *GetConfigSectionRequest, _ ...grpc.CallOption, +) (*GetConfigSectionResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetConfig( + _ context.Context, _ *SetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setConfigErr +} + +func (s *stubEnvironmentService) UnsetConfig( + _ context.Context, _ *UnsetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetConfigErr +} + +// --- NewConfigHelper --- + +func TestNewConfigHelper_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewConfigHelper(nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewConfigHelper_Success(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, err := NewConfigHelper(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ch == nil { + t.Fatal("expected non-nil ConfigHelper") + } +} + +// --- GetUserString --- + +func TestGetUserString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestGetUserString_Found(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "8080", Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "extensions.myext.port") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "8080" { + t.Errorf("value = %q, want %q", val, "8080") + } +} + +func TestGetUserString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "nonexistent.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } + + if val != "" { + t.Errorf("value = %q, want empty", val) + } +} + +func TestGetUserString_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringErr: errors.New("grpc unavailable"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "some.path") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +// --- GetUserJSON --- + +func TestGetUserJSON_Found(t *testing.T) { + t.Parallel() + + type myConfig struct { + Port int `json:"port"` + Host string `json:"host"` + } + + data, _ := json.Marshal(myConfig{Port: 3000, Host: "localhost"}) + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg myConfig + found, err := ch.GetUserJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if cfg.Port != 3000 { + t.Errorf("Port = %d, want 3000", cfg.Port) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } +} + +func TestGetUserJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetUserJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetUserJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +func TestGetUserJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: []byte("not json"), Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "bad.json", &cfg) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestGetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "", &cfg) + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- SetUserJSON --- + +func TestSetUserJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "extensions.myext.port", 3000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetUserJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestSetUserJSON_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + setErr: errors.New("grpc write error"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", "value") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +func TestSetUserJSON_UnmarshalableValue(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + // Channels cannot be marshaled to JSON + err := ch.SetUserJSON(context.Background(), "some.path", make(chan int)) + if err == nil { + t.Fatal("expected error for unmarshalable value") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +// --- UnsetUser --- + +func TestUnsetUser_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetUser_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvString --- + +func TestGetEnvString_Found(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "prod", Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetEnvString(context.Background(), "extensions.myext.mode") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "prod" { + t.Errorf("value = %q, want %q", val, "prod") + } +} + +func TestGetEnvString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + _, found, err := ch.GetEnvString(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetEnvString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvJSON --- + +func TestGetEnvJSON_Found(t *testing.T) { + t.Parallel() + + type envConfig struct { + Debug bool `json:"debug"` + } + + data, _ := json.Marshal(envConfig{Debug: true}) + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg envConfig + found, err := ch.GetEnvJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if !cfg.Debug { + t.Error("expected Debug = true") + } +} + +func TestGetEnvJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetEnvJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +// --- SetEnvJSON --- + +func TestSetEnvJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "extensions.myext.mode", "prod") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetEnvJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetEnvJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +// --- UnsetEnv --- + +func TestUnsetEnv_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetEnv_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- MergeJSON --- + +func TestMergeJSON_Basic(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1, "b": 2} + override := map[string]any{"b": 3, "c": 4} + + result := MergeJSON(base, override) + + if result["a"] != 1 { + t.Errorf("a = %v, want 1", result["a"]) + } + + if result["b"] != 3 { + t.Errorf("b = %v, want 3 (override wins)", result["b"]) + } + + if result["c"] != 4 { + t.Errorf("c = %v, want 4", result["c"]) + } +} + +func TestMergeJSON_EmptyBase(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, map[string]any{"x": "y"}) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_EmptyOverride(t *testing.T) { + t.Parallel() + + result := MergeJSON(map[string]any{"x": "y"}, nil) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_BothEmpty(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, nil) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1} + override := map[string]any{"b": 2} + + _ = MergeJSON(base, override) + + if _, ok := base["b"]; ok { + t.Error("MergeJSON mutated base map") + } + + if _, ok := override["a"]; ok { + t.Error("MergeJSON mutated override map") + } +} + +// --- DeepMergeJSON --- + +func TestDeepMergeJSON_RecursiveMerge(t *testing.T) { + t.Parallel() + + base := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 3000, + }, + "debug": false, + } + + override := map[string]any{ + "server": map[string]any{ + "port": 8080, + "tls": true, + }, + "version": "1.0", + } + + result := DeepMergeJSON(base, override) + + server, ok := result["server"].(map[string]any) + if !ok { + t.Fatal("server should be a map") + } + + if server["host"] != "localhost" { + t.Errorf("server.host = %v, want localhost", server["host"]) + } + + if server["port"] != 8080 { + t.Errorf("server.port = %v, want 8080 (override wins)", server["port"]) + } + + if server["tls"] != true { + t.Errorf("server.tls = %v, want true", server["tls"]) + } + + if result["debug"] != false { + t.Errorf("debug = %v, want false", result["debug"]) + } + + if result["version"] != "1.0" { + t.Errorf("version = %v, want 1.0", result["version"]) + } +} + +func TestDeepMergeJSON_OverrideReplacesNonMap(t *testing.T) { + t.Parallel() + + base := map[string]any{"x": "string-value"} + override := map[string]any{"x": map[string]any{"nested": true}} + + result := DeepMergeJSON(base, override) + + nested, ok := result["x"].(map[string]any) + if !ok { + t.Fatal("override should replace string with map") + } + + if nested["nested"] != true { + t.Errorf("x.nested = %v, want true", nested["nested"]) + } +} + +func TestDeepMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": map[string]any{"x": 1}} + override := map[string]any{"a": map[string]any{"y": 2}} + + _ = DeepMergeJSON(base, override) + + baseA := base["a"].(map[string]any) + if _, ok := baseA["y"]; ok { + t.Error("DeepMergeJSON mutated base nested map") + } +} + +// --- ValidateConfig --- + +func TestValidateConfig_EmptyData(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", nil) + if err == nil { + t.Fatal("expected error for empty data") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonMissing { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonMissing) + } +} + +func TestValidateConfig_InvalidJSON(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", []byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestValidateConfig_ValidatorFails(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + failValidator := func(_ any) error { return errors.New("validation failed") } + + err := ValidateConfig("test.path", data, failValidator) + if err == nil { + t.Fatal("expected error from failing validator") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonValidationFailed { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonValidationFailed) + } +} + +func TestValidateConfig_AllValidatorsPass(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1, "b": 2}) + passValidator := func(_ any) error { return nil } + + err := ValidateConfig("test.path", data, passValidator, passValidator) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateConfig_NoValidators(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + + err := ValidateConfig("test.path", data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// --- RequiredKeys --- + +func TestRequiredKeys_AllPresent(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost", "port": 3000, "extra": true} + + err := validator(value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRequiredKeys_MissingKey(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost"} + + err := validator(value) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRequiredKeys_NotAMap(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("key") + + err := validator("not a map") + if err == nil { + t.Fatal("expected error for non-map value") + } +} + +// --- ConfigError --- + +func TestConfigError_Error(t *testing.T) { + t.Parallel() + + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonMissing, + Err: errors.New("not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestConfigError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonInvalidFormat, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestConfigReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ConfigReason + want string + }{ + {ConfigReasonMissing, "missing"}, + {ConfigReasonInvalidFormat, "invalid_format"}, + {ConfigReasonValidationFailed, "validation_failed"}, + {ConfigReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ConfigReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go new file mode 100644 index 00000000000..9818f8d7960 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" +) + +// KeyVaultResolver resolves Azure Key Vault secret references for extension +// scenarios. It uses the extension's [TokenProvider] for authentication and +// the Azure SDK data-plane client for secret retrieval. +// +// Secret references use the akvs:// URI scheme: +// +// akvs://// +// +// Usage: +// +// tp, _ := azdext.NewTokenProvider(ctx, client, nil) +// resolver, _ := azdext.NewKeyVaultResolver(tp, nil) +// value, err := resolver.Resolve(ctx, "akvs://sub-id/my-vault/my-secret") +type KeyVaultResolver struct { + credential azcore.TokenCredential + clientFactory secretClientFactory + opts KeyVaultResolverOptions +} + +// secretClientFactory abstracts secret client creation for testability. +type secretClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) + +// secretGetter abstracts the Azure SDK secret client's GetSecret method. +type secretGetter interface { + GetSecret( + ctx context.Context, + name string, + version string, + options *azsecrets.GetSecretOptions, + ) (azsecrets.GetSecretResponse, error) +} + +// KeyVaultResolverOptions configures a [KeyVaultResolver]. +type KeyVaultResolverOptions struct { + // VaultSuffix overrides the default Key Vault DNS suffix. + // Defaults to "vault.azure.net" (Azure public cloud). + VaultSuffix string + + // ClientFactory overrides the default secret client constructor. + // Useful for testing. When nil, the production [azsecrets.NewClient] is used. + ClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) +} + +// NewKeyVaultResolver creates a [KeyVaultResolver] with the given credential. +// +// credential must not be nil; it is typically a [*TokenProvider] from P1-1. +// If opts is nil, production defaults are used. +func NewKeyVaultResolver(credential azcore.TokenCredential, opts *KeyVaultResolverOptions) (*KeyVaultResolver, error) { + if credential == nil { + return nil, errors.New("azdext.NewKeyVaultResolver: credential must not be nil") + } + + if opts == nil { + opts = &KeyVaultResolverOptions{} + } + + if opts.VaultSuffix == "" { + opts.VaultSuffix = "vault.azure.net" + } + + factory := defaultSecretClientFactory + if opts.ClientFactory != nil { + factory = opts.ClientFactory + } + + return &KeyVaultResolver{ + credential: credential, + clientFactory: factory, + opts: *opts, + }, nil +} + +// defaultSecretClientFactory creates a real Azure SDK secrets client. +func defaultSecretClientFactory(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) { + client, err := azsecrets.NewClient(vaultURL, credential, nil) + if err != nil { + return nil, err + } + + return client, nil +} + +// Resolve fetches the secret value for an akvs:// reference. +// +// The reference must match the format: akvs://// +// +// Returns a [*KeyVaultResolveError] for all domain errors (invalid reference, +// secret not found, authentication failure). No silent fallbacks or hidden retries. +func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, error) { + if ctx == nil { + return "", errors.New("azdext.KeyVaultResolver.Resolve: context must not be nil") + } + + parsed, err := ParseSecretReference(ref) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonInvalidReference, + Err: err, + } + } + + vaultURL := fmt.Sprintf("https://%s.%s", parsed.VaultName, r.opts.VaultSuffix) + + client, err := r.clientFactory(vaultURL, r.credential) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonClientCreation, + Err: fmt.Errorf("failed to create Key Vault client for %s: %w", vaultURL, err), + } + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, "", nil) + if err != nil { + reason := ResolveReasonAccessDenied + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + switch respErr.StatusCode { + case http.StatusNotFound: + reason = ResolveReasonNotFound + case http.StatusForbidden, http.StatusUnauthorized: + reason = ResolveReasonAccessDenied + default: + reason = ResolveReasonServiceError + } + } + + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: reason, + Err: fmt.Errorf( + "failed to retrieve secret %q from vault %q: %w", + parsed.SecretName, + parsed.VaultName, + err, + ), + } + } + + if resp.Value == nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonNotFound, + Err: fmt.Errorf("secret %q in vault %q has a nil value", parsed.SecretName, parsed.VaultName), + } + } + + return *resp.Value, nil +} + +// ResolveMap resolves a map of key → akvs:// references, returning a map of +// key → resolved secret values. Processing stops at the first error. +// +// Non-akvs:// values are passed through unchanged, so callers can safely +// resolve a mixed map of plain values and secret references. +func (r *KeyVaultResolver) ResolveMap(ctx context.Context, refs map[string]string) (map[string]string, error) { + if ctx == nil { + return nil, errors.New("azdext.KeyVaultResolver.ResolveMap: context must not be nil") + } + + result := make(map[string]string, len(refs)) + + for key, value := range refs { + if !IsSecretReference(value) { + result[key] = value + continue + } + + resolved, err := r.Resolve(ctx, value) + if err != nil { + return nil, fmt.Errorf("azdext.KeyVaultResolver.ResolveMap: key %q: %w", key, err) + } + + result[key] = resolved + } + + return result, nil +} + +// SecretReference represents a parsed akvs:// URI. +type SecretReference struct { + // SubscriptionID is the Azure subscription containing the Key Vault. + SubscriptionID string + + // VaultName is the Key Vault name (not the full URL). + VaultName string + + // SecretName is the name of the secret within the vault. + SecretName string +} + +// IsSecretReference reports whether s uses the akvs:// scheme. +func IsSecretReference(s string) bool { + return keyvault.IsAzureKeyVaultSecret(s) +} + +// vaultNameRe validates Azure Key Vault names per Azure naming rules: +// - 3–24 characters +// - starts with a letter +// - contains only alphanumeric and hyphens +// - does not end with a hyphen +var vaultNameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$`) + +// ParseSecretReference parses an akvs:// URI into its components. +// +// Expected format: akvs://// +// +// The vault name is validated against Azure Key Vault naming rules (3–24 +// characters, starts with letter, alphanumeric and hyphens only, does not +// end with a hyphen). +func ParseSecretReference(ref string) (*SecretReference, error) { + parsed, err := keyvault.ParseAzureKeyVaultSecret(ref) + if err != nil { + return nil, err + } + + if strings.TrimSpace(parsed.SubscriptionId) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: subscription-id must not be empty", ref) + } + if strings.TrimSpace(parsed.VaultName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: vault-name must not be empty", ref) + } + if !vaultNameRe.MatchString(parsed.VaultName) { + return nil, fmt.Errorf( + "invalid akvs:// reference %q: vault name %q must be 3-24 characters, "+ + "start with a letter, end with alphanumeric, and contain only alphanumeric characters and hyphens", + ref, parsed.VaultName, + ) + } + if strings.TrimSpace(parsed.SecretName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: secret-name must not be empty", ref) + } + + return &SecretReference{ + SubscriptionID: parsed.SubscriptionId, + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, + }, nil +} + +// ResolveReason classifies the cause of a [KeyVaultResolveError]. +type ResolveReason int + +const ( + // ResolveReasonInvalidReference indicates the akvs:// URI is malformed. + ResolveReasonInvalidReference ResolveReason = iota + + // ResolveReasonClientCreation indicates failure to create the Key Vault client. + ResolveReasonClientCreation + + // ResolveReasonNotFound indicates the secret does not exist. + ResolveReasonNotFound + + // ResolveReasonAccessDenied indicates an authentication or authorization failure. + ResolveReasonAccessDenied + + // ResolveReasonServiceError indicates an unexpected Key Vault service error. + ResolveReasonServiceError +) + +// String returns a human-readable label for the reason. +func (r ResolveReason) String() string { + switch r { + case ResolveReasonInvalidReference: + return "invalid_reference" + case ResolveReasonClientCreation: + return "client_creation" + case ResolveReasonNotFound: + return "not_found" + case ResolveReasonAccessDenied: + return "access_denied" + case ResolveReasonServiceError: + return "service_error" + default: + return "unknown" + } +} + +// KeyVaultResolveError is returned when [KeyVaultResolver.Resolve] fails. +type KeyVaultResolveError struct { + // Reference is the original akvs:// URI that was being resolved. + Reference string + + // Reason classifies the failure. + Reason ResolveReason + + // Err is the underlying error. + Err error +} + +func (e *KeyVaultResolveError) Error() string { + return fmt.Sprintf( + "azdext.KeyVaultResolver: %s (ref=%s): %v", + e.Reason, e.Reference, e.Err, + ) +} + +func (e *KeyVaultResolveError) Unwrap() error { + return e.Err +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go new file mode 100644 index 00000000000..628f301d5d3 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -0,0 +1,575 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// stubSecretGetter is a test double for the Key Vault data-plane client. +type stubSecretGetter struct { + resp azsecrets.GetSecretResponse + err error +} + +func (s *stubSecretGetter) GetSecret( + _ context.Context, _ string, _ string, _ *azsecrets.GetSecretOptions, +) (azsecrets.GetSecretResponse, error) { + return s.resp, s.err +} + +// stubSecretFactory returns a factory that always returns the given stubSecretGetter. +func stubSecretFactory(g secretGetter, factoryErr error) func(string, azcore.TokenCredential) (secretGetter, error) { + return func(_ string, _ azcore.TokenCredential) (secretGetter, error) { + if factoryErr != nil { + return nil, factoryErr + } + return g, nil + } +} + +// --- NewKeyVaultResolver --- + +func TestNewKeyVaultResolver_NilCredential(t *testing.T) { + t.Parallel() + + _, err := NewKeyVaultResolver(nil, nil) + if err == nil { + t.Fatal("expected error for nil credential") + } +} + +func TestNewKeyVaultResolver_Defaults(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.net" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.net") + } +} + +func TestNewKeyVaultResolver_CustomSuffix(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + VaultSuffix: "vault.azure.cn", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.cn" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.cn") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"akvs://", true}, + {"AKVS://sub/vault/secret", false}, // case-sensitive + {"https://vault.azure.net", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- ParseSecretReference --- + +func TestParseSecretReference_Valid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference("akvs://sub-123/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SubscriptionID != "sub-123" { + t.Errorf("SubscriptionID = %q, want %q", ref.SubscriptionID, "sub-123") + } + if ref.VaultName != "my-vault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "my-vault") + } + if ref.SecretName != "my-secret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "my-secret") + } +} + +func TestParseSecretReference_NotAkvsScheme(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("https://vault.azure.net/secrets/x") + if err == nil { + t.Fatal("expected error for non-akvs scheme") + } +} + +func TestParseSecretReference_TooFewParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault") + if err == nil { + t.Fatal("expected error for two-part ref") + } +} + +func TestParseSecretReference_TooManyParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault/secret/extra") + if err == nil { + t.Fatal("expected error for four-part ref") + } +} + +func TestParseSecretReference_EmptyComponent(t *testing.T) { + t.Parallel() + + cases := []string{ + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret + } + + for _, ref := range cases { + _, err := ParseSecretReference(ref) + if err == nil { + t.Errorf("ParseSecretReference(%q) expected error, got nil", ref) + } + } +} + +// --- Resolve --- + +func TestResolve_Success(t *testing.T) { + t.Parallel() + + secretValue := "super-secret-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(context.Background(), "akvs://sub-id/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } +} + +func TestResolve_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.Resolve(nil, "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResolve_InvalidReference(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(context.Background(), "not-akvs://x") + if err == nil { + t.Fatal("expected error for invalid reference") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +func TestResolve_ClientCreationFailure(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(nil, errors.New("connection refused")), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for client creation failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonClientCreation { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonClientCreation) + } +} + +func TestResolve_SecretNotFound(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/missing-secret") + if err == nil { + t.Fatal("expected error for missing secret") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_AccessDenied(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusForbidden}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for forbidden access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_Unauthorized(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusUnauthorized}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for unauthorized access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_ServiceError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for server error") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +func TestResolve_NilValue(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: nil, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil secret value") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_NonResponseError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: errors.New("network timeout"), + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for network failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + // Non-ResponseError defaults to access_denied + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +// --- ResolveMap --- + +func TestResolveMap_MixedValues(t *testing.T) { + t.Parallel() + + secretValue := "resolved-secret" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "plain": "hello-world", + "secret": "akvs://sub/vault/secret", + } + + result, err := resolver.ResolveMap(context.Background(), input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result["plain"] != "hello-world" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "hello-world") + } + + if result["secret"] != secretValue { + t.Errorf("result[secret] = %q, want %q", result["secret"], secretValue) + } +} + +func TestResolveMap_Empty(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + result, err := resolver.ResolveMap(context.Background(), map[string]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestResolveMap_ErrorStopsProcessing(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "secret": "akvs://sub/vault/missing", + } + + _, err := resolver.ResolveMap(context.Background(), input) + if err == nil { + t.Fatal("expected error when resolution fails") + } +} + +func TestResolveMap_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.ResolveMap(nil, map[string]string{"k": "v"}) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +// --- Error types --- + +func TestKeyVaultResolveError_Error(t *testing.T) { + t.Parallel() + + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonNotFound, + Err: errors.New("secret not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestKeyVaultResolveError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonServiceError, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestResolveReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ResolveReason + want string + }{ + {ResolveReasonInvalidReference, "invalid_reference"}, + {ResolveReasonClientCreation, "client_creation"}, + {ResolveReasonNotFound, "not_found"}, + {ResolveReasonAccessDenied, "access_denied"}, + {ResolveReasonServiceError, "service_error"}, + {ResolveReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ResolveReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/logger.go b/cli/azd/pkg/azdext/logger.go new file mode 100644 index 00000000000..d0d156ca565 --- /dev/null +++ b/cli/azd/pkg/azdext/logger.go @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "io" + "log/slog" + "os" + "strings" +) + +// LoggerOptions configures [SetupLogging] and [NewLogger]. +type LoggerOptions struct { + // Debug enables debug-level logging. When false, messages below Info are + // suppressed. If not set explicitly, [NewLogger] checks the AZD_DEBUG + // environment variable. + Debug bool + // Structured selects JSON output when true, human-readable text when false. + Structured bool + // Writer overrides the output destination. Defaults to os.Stderr. + Writer io.Writer +} + +// SetupLogging configures the process-wide default [slog.Logger]. +// It is typically called once at startup (for example from +// [NewExtensionRootCommand]'s PersistentPreRunE callback). +// +// Calling SetupLogging is optional — [NewLogger] works without it and creates +// loggers that inherit from [slog.Default]. SetupLogging is provided for +// extensions that want explicit control over the global log level and format. +func SetupLogging(opts LoggerOptions) { + handler := newHandler(opts) + slog.SetDefault(slog.New(handler)) +} + +// Logger provides component-scoped structured logging built on [log/slog]. +// +// Each Logger carries a "component" attribute so log lines can be filtered or +// routed by subsystem. Additional context can be attached via [Logger.With], +// [Logger.WithComponent], or [Logger.WithOperation]. +// +// Logger writes to stderr by default and never writes to stdout, so it does +// not interfere with command output or JSON-mode piping. +type Logger struct { + slogger *slog.Logger + component string +} + +// NewLogger creates a Logger scoped to the given component name. +// +// If the AZD_DEBUG environment variable is set to a truthy value ("1", "true", +// "yes") and opts.Debug is false, debug logging is enabled automatically. This +// lets extension authors respect the framework's debug flag without extra +// plumbing. +// +// When opts is omitted (zero value), the logger uses Info level with text +// format on stderr. +func NewLogger(component string, opts ...LoggerOptions) *Logger { + var o LoggerOptions + if len(opts) > 0 { + o = opts[0] + } + + // Auto-detect debug from environment when not explicitly set. + if !o.Debug { + o.Debug = isDebugEnv() + } + + handler := newHandler(o) + base := slog.New(handler).With("component", component) + + return &Logger{ + slogger: base, + component: component, + } +} + +// Component returns the component name this logger was created with. +func (l *Logger) Component() string { + return l.component +} + +// Debug logs a message at debug level with optional key-value pairs. +func (l *Logger) Debug(msg string, args ...any) { + l.slogger.Debug(msg, args...) +} + +// Info logs a message at info level with optional key-value pairs. +func (l *Logger) Info(msg string, args ...any) { + l.slogger.Info(msg, args...) +} + +// Warn logs a message at warn level with optional key-value pairs. +func (l *Logger) Warn(msg string, args ...any) { + l.slogger.Warn(msg, args...) +} + +// Error logs a message at error level with optional key-value pairs. +func (l *Logger) Error(msg string, args ...any) { + l.slogger.Error(msg, args...) +} + +// With returns a new Logger that includes the given key-value pairs in every +// subsequent log entry. Keys must be strings; values can be any type +// supported by [slog]. +// +// Example: +// +// l := logger.With("request_id", reqID) +// l.Info("processing") // includes component + request_id +func (l *Logger) With(args ...any) *Logger { + return &Logger{ + slogger: l.slogger.With(args...), + component: l.component, + } +} + +// WithComponent returns a new Logger with a different component name. The +// original component is preserved as "parent_component". +func (l *Logger) WithComponent(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("parent_component", l.component, "component", name), + component: name, + } +} + +// WithOperation returns a new Logger with an "operation" attribute. +func (l *Logger) WithOperation(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("operation", name), + component: l.component, + } +} + +// Slogger returns the underlying [*slog.Logger] for advanced use cases such +// as passing to libraries that accept a standard slog logger. +func (l *Logger) Slogger() *slog.Logger { + return l.slogger +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// newHandler creates an slog.Handler from LoggerOptions. +func newHandler(opts LoggerOptions) slog.Handler { + w := opts.Writer + if w == nil { + w = os.Stderr + } + + level := slog.LevelInfo + if opts.Debug { + level = slog.LevelDebug + } + + handlerOpts := &slog.HandlerOptions{Level: level} + + if opts.Structured { + return slog.NewJSONHandler(w, handlerOpts) + } + return slog.NewTextHandler(w, handlerOpts) +} + +// isDebugEnv checks the AZD_DEBUG environment variable. +// +// Security note: AZD_DEBUG enables verbose logging that may include +// request details, configuration paths, and internal state. It should +// NOT be enabled in production deployments. The variable is intended +// for local development and CI debugging only. +func isDebugEnv() bool { + v := strings.ToLower(os.Getenv("AZD_DEBUG")) + return v == "1" || v == "true" || v == "yes" +} diff --git a/cli/azd/pkg/azdext/logger_test.go b/cli/azd/pkg/azdext/logger_test.go new file mode 100644 index 00000000000..9e21b763ada --- /dev/null +++ b/cli/azd/pkg/azdext/logger_test.go @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// NewLogger — basic construction +// --------------------------------------------------------------------------- + +func TestNewLogger_DefaultOptions(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("test-component", LoggerOptions{Writer: &buf}) + + require.NotNil(t, logger) + require.Equal(t, "test-component", logger.Component()) +} + +func TestNewLogger_ZeroOptions(t *testing.T) { + // Zero-value opts should not panic (writes to stderr). + logger := NewLogger("safe") + require.NotNil(t, logger) + require.Equal(t, "safe", logger.Component()) +} + +func TestNewLogger_NoOpts(t *testing.T) { + // Calling without variadic opts should not panic. + logger := NewLogger("minimal") + require.NotNil(t, logger) +} + +// --------------------------------------------------------------------------- +// Log levels — Info +// --------------------------------------------------------------------------- + +func TestLogger_Info(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("mycomp", LoggerOptions{Writer: &buf}) + + logger.Info("hello world", "key", "val") + + output := buf.String() + require.Contains(t, output, "hello world") + require.Contains(t, output, "key=val") + require.Contains(t, output, "component=mycomp") +} + +// --------------------------------------------------------------------------- +// Log levels — Debug +// --------------------------------------------------------------------------- + +func TestLogger_Debug_Enabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: true, Writer: &buf}) + + logger.Debug("debug message", "detail", "x") + + require.Contains(t, buf.String(), "debug message") + require.Contains(t, buf.String(), "detail=x") +} + +func TestLogger_Debug_Disabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: false, Writer: &buf}) + + logger.Debug("should not appear") + + require.Empty(t, buf.String()) +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar(t *testing.T) { + t.Setenv("AZD_DEBUG", "true") + + var buf bytes.Buffer + logger := NewLogger("env-debug", LoggerOptions{Writer: &buf}) + + logger.Debug("from env var") + + require.Contains(t, buf.String(), "from env var") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_One(t *testing.T) { + t.Setenv("AZD_DEBUG", "1") + + var buf bytes.Buffer + logger := NewLogger("env-one", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via 1") + + require.Contains(t, buf.String(), "debug via 1") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_Yes(t *testing.T) { + t.Setenv("AZD_DEBUG", "yes") + + var buf bytes.Buffer + logger := NewLogger("env-yes", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via yes") + + require.Contains(t, buf.String(), "debug via yes") +} + +func TestLogger_Debug_AZD_DEBUG_Unset(t *testing.T) { + t.Setenv("AZD_DEBUG", "") + + var buf bytes.Buffer + logger := NewLogger("env-empty", LoggerOptions{Writer: &buf}) + + logger.Debug("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Log levels — Warn / Error +// --------------------------------------------------------------------------- + +func TestLogger_Warn(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn-test", LoggerOptions{Writer: &buf}) + + logger.Warn("something concerning", "retries", 3) + + require.Contains(t, buf.String(), "something concerning") + require.Contains(t, buf.String(), "retries=3") +} + +func TestLogger_Error(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("err-test", LoggerOptions{Writer: &buf}) + + logger.Error("bad thing happened", "code", 500) + + require.Contains(t, buf.String(), "bad thing happened") + require.Contains(t, buf.String(), "code=500") +} + +// --------------------------------------------------------------------------- +// Structured (JSON) output +// --------------------------------------------------------------------------- + +func TestLogger_StructuredJSON(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("json-comp", LoggerOptions{Structured: true, Writer: &buf}) + + logger.Info("structured entry", "env", "prod") + + // Each line should be valid JSON. + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "structured entry", parsed["msg"]) + require.Equal(t, "prod", parsed["env"]) + require.Equal(t, "json-comp", parsed["component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — With +// --------------------------------------------------------------------------- + +func TestLogger_With(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("base", LoggerOptions{Writer: &buf}) + + child := logger.With("request_id", "abc-123") + child.Info("processing") + + output := buf.String() + require.Contains(t, output, "request_id=abc-123") + require.Contains(t, output, "component=base") + require.Contains(t, output, "processing") + + // Child should have the same component. + require.Equal(t, "base", child.Component()) +} + +func TestLogger_With_ChainMultiple(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("chain", LoggerOptions{Writer: &buf}) + + child := logger.With("a", "1").With("b", "2") + child.Info("chained") + + output := buf.String() + require.Contains(t, output, "a=1") + require.Contains(t, output, "b=2") +} + +// --------------------------------------------------------------------------- +// Context chaining — WithComponent +// --------------------------------------------------------------------------- + +func TestLogger_WithComponent(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("parent", LoggerOptions{Structured: true, Writer: &buf}) + + child := logger.WithComponent("child-subsystem") + child.Info("from child") + + require.Equal(t, "child-subsystem", child.Component()) + + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "child-subsystem", parsed["component"]) + require.Equal(t, "parent", parsed["parent_component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — WithOperation +// --------------------------------------------------------------------------- + +func TestLogger_WithOperation(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("ops", LoggerOptions{Writer: &buf}) + + child := logger.WithOperation("deploy") + child.Info("starting deploy") + + output := buf.String() + require.Contains(t, output, "operation=deploy") + require.Equal(t, "ops", child.Component()) +} + +// --------------------------------------------------------------------------- +// Slogger accessor +// --------------------------------------------------------------------------- + +func TestLogger_Slogger(t *testing.T) { + logger := NewLogger("access", LoggerOptions{Writer: &bytes.Buffer{}}) + require.NotNil(t, logger.Slogger()) +} + +// --------------------------------------------------------------------------- +// SetupLogging — global logger configuration +// --------------------------------------------------------------------------- + +func TestSetupLogging_DoesNotPanic(t *testing.T) { + // SetupLogging modifies slog.Default which is global state. + // We only verify it does not panic here. + var buf bytes.Buffer + SetupLogging(LoggerOptions{Debug: true, Structured: true, Writer: &buf}) + + // Restore a sensible default after the test. + SetupLogging(LoggerOptions{Writer: &bytes.Buffer{}}) +} + +// --------------------------------------------------------------------------- +// isDebugEnv internal helper +// --------------------------------------------------------------------------- + +func TestIsDebugEnv_Truthy(t *testing.T) { + truthy := []string{"1", "true", "TRUE", "True", "yes", "YES", "Yes"} + for _, v := range truthy { + t.Run(v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.True(t, isDebugEnv()) + }) + } +} + +func TestIsDebugEnv_Falsy(t *testing.T) { + falsy := []string{"", "0", "false", "no", "maybe"} + for _, v := range falsy { + t.Run("value="+v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.False(t, isDebugEnv()) + }) + } +} + +// --------------------------------------------------------------------------- +// Text format verification +// --------------------------------------------------------------------------- + +func TestLogger_TextFormat_ContainsLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("lvl", LoggerOptions{Writer: &buf}) + + logger.Info("test level") + + output := buf.String() + require.True(t, + strings.Contains(output, "INFO") || strings.Contains(output, "level=INFO"), + "expected level indicator in text output: %s", output) +} diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index c9c2122eaeb..a0e715d25e0 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -6,6 +6,7 @@ package azdext import ( "fmt" "net" + "net/http" "net/url" "os" "path/filepath" @@ -25,6 +26,10 @@ type MCPSecurityPolicy struct { blockedHosts map[string]bool // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) + // onBlocked is an optional callback invoked when a URL or path is blocked. + // Parameters: action ("url_blocked", "path_blocked", "redirect_blocked"), + // detail (human-readable explanation). Safe for concurrent use. + onBlocked func(action, detail string) } // NewMCPSecurityPolicy creates an empty security policy. @@ -111,6 +116,20 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } +// OnBlocked registers a callback that is invoked whenever a URL, path, or +// redirect is blocked by the security policy. This enables security audit +// logging without coupling the policy to a specific logging framework. +// +// The callback receives an action tag ("url_blocked", "path_blocked", +// "redirect_blocked") and a human-readable detail string. It must be safe +// for concurrent invocation. +func (p *MCPSecurityPolicy) OnBlocked(fn func(action, detail string)) *MCPSecurityPolicy { + p.mu.Lock() + defer p.mu.Unlock() + p.onBlocked = fn + return p +} + // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -125,8 +144,16 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - defer p.mu.RUnlock() + onBlocked := p.onBlocked + err := p.checkURLCore(rawURL) + p.mu.RUnlock() + if err != nil && onBlocked != nil { + onBlocked("url_blocked", err.Error()) + } + return err +} +func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) @@ -153,6 +180,13 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { // If the host is an IP literal, check it directly against blocked CIDRs. if ip := net.ParseIP(host); ip != nil { + // Check normalized IP form against blocked hosts — catches IPv4-mapped + // IPv6 forms like ::ffff:169.254.169.254 that bypass string matching. + if normalizedIP := ip.String(); normalizedIP != host { + if p.blockedHosts[strings.ToLower(normalizedIP)] { + return fmt.Errorf("blocked host: %s (normalized: %s)", host, normalizedIP) + } + } if err := p.checkIP(ip, host); err != nil { return err } @@ -251,10 +285,33 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. +// +// Security note (TOCTOU): There is an inherent time-of-check to time-of-use +// gap between the symlink resolution performed here and the caller's +// subsequent file operation. An adversary with write access to the filesystem +// could create or modify a symlink between the check and the use. This is a +// fundamental limitation of path-based validation on POSIX systems. +// +// Mitigations callers should consider: +// - Use O_NOFOLLOW when opening files after validation (prevents symlink +// following at the final component). +// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on +// Linux 5.6+) where possible. +// - Avoid writing to directories that untrusted users can modify. +// - Consider validating the opened fd's path post-open via /proc/self/fd/N +// or fstat. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - defer p.mu.RUnlock() + onBlocked := p.onBlocked + err := p.checkPathCore(path) + p.mu.RUnlock() + if err != nil && onBlocked != nil { + onBlocked("path_blocked", err.Error()) + } + return err +} +func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil } @@ -348,3 +405,86 @@ func resolveExistingPrefix(p string) string { } } } + +// --------------------------------------------------------------------------- +// Redirect SSRF protection +// --------------------------------------------------------------------------- + +// redirectBlockedHosts lists cloud metadata service endpoints that must never +// be the target of an HTTP redirect. +var redirectBlockedHosts = map[string]bool{ + "169.254.169.254": true, + "fd00:ec2::254": true, + "metadata.google.internal": true, + "100.100.100.200": true, +} + +// SSRFSafeRedirect is an [http.Client] CheckRedirect function that blocks +// redirects to private networks and cloud metadata endpoints. It prevents +// redirect-based SSRF attacks where an attacker-controlled URL redirects to +// an internal service. +// +// Hostnames in redirect targets are resolved via DNS and all resulting IPs +// are checked. DNS resolution failures are treated as blocked (fail-closed) +// to prevent bypass via transient DNS errors or rebinding attacks. +// +// Usage: +// +// client := &http.Client{CheckRedirect: azdext.SSRFSafeRedirect} +func SSRFSafeRedirect(req *http.Request, via []*http.Request) error { + const maxRedirects = 10 + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + + host := req.URL.Hostname() + lowerHost := strings.ToLower(host) + + // Block redirects to known metadata endpoints (string match). + if redirectBlockedHosts[lowerHost] { + return fmt.Errorf("redirect to metadata endpoint %s blocked (SSRF protection)", host) + } + + if ip := net.ParseIP(host); ip != nil { + // Check normalized form against metadata list — catches IPv4-mapped IPv6 + // forms like [::ffff:169.254.169.254] that bypass string matching. + if normalizedIP := ip.String(); redirectBlockedHosts[strings.ToLower(normalizedIP)] { + return fmt.Errorf("redirect to metadata endpoint %s blocked (SSRF protection)", host) + } + if err := checkRedirectIP(ip, host); err != nil { + return err + } + } else { + // Hostname — resolve and check all IPs (fail-closed on DNS failure). + addrs, err := net.LookupHost(host) + if err != nil { + return fmt.Errorf( + "redirect to %s blocked: DNS resolution failed (fail-closed, SSRF protection)", host, + ) + } + for _, addr := range addrs { + if redirectBlockedHosts[strings.ToLower(addr)] { + return fmt.Errorf( + "redirect to %s blocked: resolves to metadata endpoint %s (SSRF protection)", + host, addr, + ) + } + if resolvedIP := net.ParseIP(addr); resolvedIP != nil { + if err := checkRedirectIP(resolvedIP, host); err != nil { + return err + } + } + } + } + + return nil +} + +// checkRedirectIP checks whether an IP is in a private, loopback, link-local, +// or unspecified range. Used by [SSRFSafeRedirect] to block redirect-based SSRF. +func checkRedirectIP(ip net.IP, host string) error { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return fmt.Errorf("redirect to private/loopback IP %s blocked (SSRF protection, host: %s)", ip, host) + } + return nil +} diff --git a/cli/azd/pkg/azdext/mcp_security_test.go b/cli/azd/pkg/azdext/mcp_security_test.go index 6cc0d1d44da..f962cf5f119 100644 --- a/cli/azd/pkg/azdext/mcp_security_test.go +++ b/cli/azd/pkg/azdext/mcp_security_test.go @@ -5,6 +5,8 @@ package azdext import ( "fmt" + "net/http" + "net/url" "os" "path/filepath" "strings" @@ -19,6 +21,9 @@ func TestMCPSecurityCheckURL_BlocksMetadataEndpoints(t *testing.T) { "http://fd00:ec2::254/latest/meta-data/", "http://metadata.google.internal/computeMetadata/v1/", "http://100.100.100.200/latest/meta-data/", + // IPv4-mapped forms of metadata IPs — must be caught by IP normalization. + "http://[::ffff:169.254.169.254]/latest/meta-data/", + "http://[::ffff:100.100.100.200]/latest/meta-data/", } for _, u := range blocked { if err := policy.CheckURL(u); err == nil { @@ -27,6 +32,29 @@ func TestMCPSecurityCheckURL_BlocksMetadataEndpoints(t *testing.T) { } } +func TestSSRFSafeRedirect_BlocksPrivateHostnames(t *testing.T) { + t.Parallel() + // A hostname redirect where the DNS lookup itself resolves to a private IP + // is not testable without a mock DNS resolver. This test verifies that + // well-known localhost names are blocked by the redirect handler (since + // net.LookupHost("localhost") returns 127.0.0.1 on most systems). + tests := []struct { + host string + blocked bool + }{ + {"169.254.169.254", true}, + {"127.0.0.1", true}, + {"10.0.0.1", true}, + } + for _, tc := range tests { + req := &http.Request{URL: mustParseURL(t, "http://"+tc.host+"/path")} + err := SSRFSafeRedirect(req, nil) + if tc.blocked && err == nil { + t.Errorf("SSRFSafeRedirect(%s) = nil, want error", tc.host) + } + } +} + func TestMCPSecurityCheckURL_BlocksPrivateIPs(t *testing.T) { policy := NewMCPSecurityPolicy().BlockPrivateNetworks() @@ -335,3 +363,12 @@ func TestMCPSecurityFluentBuilder(t *testing.T) { t.Errorf("expected 1 base path, got %d", len(policy.allowedBasePaths)) } } + +func mustParseURL(t *testing.T, rawURL string) *url.URL { + t.Helper() + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("mustParseURL(%q): %v", rawURL, err) + } + return u +} diff --git a/cli/azd/pkg/azdext/output.go b/cli/azd/pkg/azdext/output.go new file mode 100644 index 00000000000..80c305d4712 --- /dev/null +++ b/cli/azd/pkg/azdext/output.go @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/fatih/color" +) + +// OutputFormat represents the output format for extension commands. +type OutputFormat string + +const ( + // OutputFormatDefault is human-readable text with optional color. + OutputFormatDefault OutputFormat = "default" + // OutputFormatJSON outputs structured JSON for machine consumption. + OutputFormatJSON OutputFormat = "json" +) + +// ParseOutputFormat converts a string to an OutputFormat. +// Returns OutputFormatDefault for unrecognized values and a non-nil error. +func ParseOutputFormat(s string) (OutputFormat, error) { + switch strings.ToLower(s) { + case "default", "": + return OutputFormatDefault, nil + case "json": + return OutputFormatJSON, nil + default: + return OutputFormatDefault, fmt.Errorf("invalid output format %q (valid: default, json)", s) + } +} + +// OutputOptions configures an [Output] instance. +type OutputOptions struct { + // Format controls the output style. Defaults to OutputFormatDefault. + Format OutputFormat + // Writer is the destination for normal output. Defaults to os.Stdout. + Writer io.Writer + // ErrWriter is the destination for error/warning output. Defaults to os.Stderr. + ErrWriter io.Writer +} + +// Output provides formatted, format-aware output for extension commands. +// In default mode it writes human-readable text with ANSI color; in JSON mode +// it writes structured JSON objects to stdout and suppresses decorative output. +// +// Output is safe for use from a single goroutine. If concurrent use is needed +// callers should synchronize externally. +type Output struct { + writer io.Writer + errWriter io.Writer + format OutputFormat + + // Color printers — configured once at construction. + successColor *color.Color + warningColor *color.Color + errorColor *color.Color + infoColor *color.Color + headerColor *color.Color + dimColor *color.Color +} + +// NewOutput creates an Output configured by opts. +// If opts.Writer or opts.ErrWriter are nil they default to os.Stdout / os.Stderr. +func NewOutput(opts OutputOptions) *Output { + w := opts.Writer + if w == nil { + w = os.Stdout + } + ew := opts.ErrWriter + if ew == nil { + ew = os.Stderr + } + + return &Output{ + writer: w, + errWriter: ew, + format: opts.Format, + successColor: color.New(color.FgGreen), + warningColor: color.New(color.FgYellow), + errorColor: color.New(color.FgRed), + infoColor: color.New(color.FgCyan), + headerColor: color.New(color.Bold), + dimColor: color.New(color.Faint), + } +} + +// IsJSON returns true when the output format is JSON. +// Callers can use this to skip decorative output that is only relevant in +// human-readable mode. +func (o *Output) IsJSON() bool { + return o.format == OutputFormatJSON +} + +// Success prints a success message prefixed with a green check mark. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Success(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + o.successColor.Fprintf(o.writer, "(✓) Done: %s\n", msg) +} + +// Warning prints a warning message prefixed with a yellow exclamation mark. +// Warnings are written to ErrWriter in both default and JSON mode so they are +// visible even when stdout is piped through a JSON consumer. +func (o *Output) Warning(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + // In JSON mode emit a structured warning to stderr. + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "warning", + "message": msg, + }) + return + } + o.warningColor.Fprintf(o.errWriter, "(!) Warning: %s\n", sanitizeOutputText(msg)) +} + +// Error prints an error message prefixed with a red cross. +// Errors are always written to ErrWriter. +func (o *Output) Error(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "error", + "message": msg, + }) + return + } + o.errorColor.Fprintf(o.errWriter, "(✗) Error: %s\n", sanitizeOutputText(msg)) +} + +// Info prints an informational message prefixed with an info symbol. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Info(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + o.infoColor.Fprintf(o.writer, "(i) %s\n", msg) +} + +// Message prints an undecorated message to stdout. +// In JSON mode the call is a no-op. +func (o *Output) Message(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + fmt.Fprintln(o.writer, msg) +} + +// JSON writes data as a pretty-printed JSON object to stdout. +// It is active in all output modes so callers can unconditionally emit +// structured payloads (in default mode the JSON is still human-readable). +func (o *Output) JSON(data any) error { + enc := json.NewEncoder(o.writer) + enc.SetIndent("", " ") + if err := enc.Encode(data); err != nil { + return fmt.Errorf("output: failed to encode JSON: %w", err) + } + return nil +} + +// Table prints a formatted text table with headers and rows. +// In JSON mode the table is emitted as a JSON array of objects instead. +// +// headers defines the column names. Each row is a slice of cell values +// with the same length as headers. Rows with fewer cells are padded with +// empty strings; extra cells are silently ignored. +func (o *Output) Table(headers []string, rows [][]string) { + if len(headers) == 0 { + return + } + + if o.IsJSON() { + o.tableJSON(headers, rows) + return + } + + o.tableText(headers, rows) +} + +// tableJSON emits the table as a JSON array of objects keyed by header name. +func (o *Output) tableJSON(headers []string, rows [][]string) { + out := make([]map[string]string, 0, len(rows)) + for _, row := range rows { + obj := make(map[string]string, len(headers)) + for i, h := range headers { + if i < len(row) { + obj[h] = row[i] + } else { + obj[h] = "" + } + } + out = append(out, obj) + } + _ = o.JSON(out) +} + +// tableText renders an aligned text table with a header separator. +func (o *Output) tableText(headers []string, rows [][]string) { + // Calculate column widths. + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + for _, row := range rows { + for i := range headers { + if i < len(row) && len(row[i]) > widths[i] { + widths[i] = len(row[i]) + } + } + } + + // Print header row. + for i, h := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + o.headerColor.Fprintf(o.writer, "%-*s", widths[i], h) + } + fmt.Fprintln(o.writer) + + // Print separator. + for i, w := range widths { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + fmt.Fprint(o.writer, strings.Repeat("─", w)) + } + fmt.Fprintln(o.writer) + + // Print data rows. + for _, row := range rows { + for i := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + cell := "" + if i < len(row) { + cell = sanitizeOutputText(row[i]) + } + fmt.Fprintf(o.writer, "%-*s", widths[i], cell) + } + fmt.Fprintln(o.writer) + } +} + +// sanitizeOutputText replaces CR, LF, and other ASCII control characters +// (except tab) with a space to prevent log forging and terminal escape +// sequence injection in text-mode output. +// +// JSON-mode output is not sanitized here because JSON encoding handles +// escaping of control characters and the structured fields must remain +// machine-readable. +func sanitizeOutputText(s string) string { + // Fast path: if no control characters are present, return as-is. + clean := true + for _, r := range s { + if (r < 0x20 && r != '\t') || r == 0x7F { + clean = false + break + } + } + if clean { + return s + } + + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if (r < 0x20 && r != '\t') || r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} diff --git a/cli/azd/pkg/azdext/output_test.go b/cli/azd/pkg/azdext/output_test.go new file mode 100644 index 00000000000..4a31fe7ccae --- /dev/null +++ b/cli/azd/pkg/azdext/output_test.go @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ParseOutputFormat +// --------------------------------------------------------------------------- + +func TestParseOutputFormat(t *testing.T) { + tests := []struct { + name string + input string + expected OutputFormat + expectErr bool + }{ + {name: "default string", input: "default", expected: OutputFormatDefault}, + {name: "empty string", input: "", expected: OutputFormatDefault}, + {name: "json lowercase", input: "json", expected: OutputFormatJSON}, + {name: "JSON uppercase", input: "JSON", expected: OutputFormatJSON}, + {name: "Json mixed case", input: "Json", expected: OutputFormatJSON}, + {name: "DEFAULT uppercase", input: "DEFAULT", expected: OutputFormatDefault}, + {name: "invalid format", input: "xml", expected: OutputFormatDefault, expectErr: true}, + {name: "invalid format yaml", input: "yaml", expected: OutputFormatDefault, expectErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseOutputFormat(tt.input) + if tt.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid output format") + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// NewOutput defaults +// --------------------------------------------------------------------------- + +func TestNewOutput_DefaultWriters(t *testing.T) { + out := NewOutput(OutputOptions{}) + require.NotNil(t, out) + // Default format should be "default" (zero-value of OutputFormat). + require.False(t, out.IsJSON()) +} + +func TestNewOutput_JSONMode(t *testing.T) { + out := NewOutput(OutputOptions{Format: OutputFormatJSON}) + require.True(t, out.IsJSON()) +} + +// --------------------------------------------------------------------------- +// Success +// --------------------------------------------------------------------------- + +func TestOutput_Success_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Success("deployed %s", "myapp") + + // Should contain the message text (color codes may wrap it). + require.Contains(t, buf.String(), "Done: deployed myapp") +} + +func TestOutput_Success_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Success("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Warning +// --------------------------------------------------------------------------- + +func TestOutput_Warning_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Warning("deprecated %s", "v1") + + require.Contains(t, errBuf.String(), "Warning: deprecated v1") +} + +func TestOutput_Warning_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Warning("api deprecated") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "warning", parsed["level"]) + require.Equal(t, "api deprecated", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Error +// --------------------------------------------------------------------------- + +func TestOutput_Error_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Error("connection failed: %s", "timeout") + + require.Contains(t, errBuf.String(), "Error: connection failed: timeout") +} + +func TestOutput_Error_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Error("disk full") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "error", parsed["level"]) + require.Equal(t, "disk full", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Info +// --------------------------------------------------------------------------- + +func TestOutput_Info_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Info("fetching %d items", 5) + + require.Contains(t, buf.String(), "fetching 5 items") +} + +func TestOutput_Info_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Info("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Message +// --------------------------------------------------------------------------- + +func TestOutput_Message_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Message("plain text %d", 42) + + require.Equal(t, "plain text 42\n", buf.String()) +} + +func TestOutput_Message_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Message("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// JSON +// --------------------------------------------------------------------------- + +func TestOutput_JSON_Struct(t *testing.T) { + type result struct { + Name string `json:"name"` + Count int `json:"count"` + } + + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(result{Name: "test", Count: 7}) + require.NoError(t, err) + + var decoded result + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "test", decoded.Name) + require.Equal(t, 7, decoded.Count) +} + +func TestOutput_JSON_Map(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]string{"key": "value"}) + require.NoError(t, err) + + var decoded map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "value", decoded["key"]) +} + +func TestOutput_JSON_Unmarshalable(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(make(chan int)) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to encode JSON") +} + +func TestOutput_JSON_PrettyPrinted(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]int{"a": 1}) + require.NoError(t, err) + + // Verify indentation is present (pretty-printed). + require.Contains(t, buf.String(), " ") +} + +// --------------------------------------------------------------------------- +// Table — default format +// --------------------------------------------------------------------------- + +func TestOutput_Table_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"Name", "Status"} + rows := [][]string{ + {"api", "running"}, + {"web", "stopped"}, + } + + out.Table(headers, rows) + + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "Status") + require.Contains(t, text, "api") + require.Contains(t, text, "running") + require.Contains(t, text, "web") + require.Contains(t, text, "stopped") + + // Separator line should be present. + require.Contains(t, text, "─") +} + +func TestOutput_Table_EmptyHeaders(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table(nil, [][]string{{"a"}}) + + require.Empty(t, buf.String()) +} + +func TestOutput_Table_EmptyRows(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table([]string{"Name"}, nil) + + // Header + separator should still be printed. + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "─") +} + +func TestOutput_Table_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + // Row has fewer cells than headers — should pad with empty strings. + out.Table([]string{"A", "B", "C"}, [][]string{{"only-a"}}) + + text := buf.String() + require.Contains(t, text, "only-a") + // No panic from short row. + lines := strings.Split(strings.TrimSpace(text), "\n") + require.Len(t, lines, 3) // header + separator + 1 data row +} + +func TestOutput_Table_ColumnAlignment(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"ID", "LongerName"} + rows := [][]string{ + {"1", "short"}, + {"2", "a-much-longer-value"}, + } + + out.Table(headers, rows) + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.GreaterOrEqual(t, len(lines), 3) + + // All separator dashes should align with header width. + sepLine := lines[1] + require.NotEmpty(t, sepLine) +} + +// --------------------------------------------------------------------------- +// Table — JSON format +// --------------------------------------------------------------------------- + +func TestOutput_Table_JSONFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + headers := []string{"Service", "Port"} + rows := [][]string{ + {"api", "8080"}, + {"web", "3000"}, + } + + out.Table(headers, rows) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 2) + require.Equal(t, "api", decoded[0]["Service"]) + require.Equal(t, "8080", decoded[0]["Port"]) + require.Equal(t, "web", decoded[1]["Service"]) + require.Equal(t, "3000", decoded[1]["Port"]) +} + +func TestOutput_Table_JSONFormat_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Table([]string{"A", "B"}, [][]string{{"only-a"}}) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 1) + require.Equal(t, "only-a", decoded[0]["A"]) + require.Equal(t, "", decoded[0]["B"]) +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index 6f95645686d..2376dda9fea 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -14,10 +14,19 @@ import ( "strings" ) +const ( + // defaultMaxPages is the default upper bound on pages fetched by [Pager.Collect]. + // Individual callers can override this via [PagerOptions.MaxPages]. + // A value of 0 means unlimited (no cap), which is the default for manual + // NextPage iteration. Collect uses this default when MaxPages is unset. + defaultMaxPages = 500 +) + const ( // maxPageResponseSize limits the maximum size of a single page response // body to prevent excessive memory consumption from malicious or - // misconfigured servers. + // misconfigured servers. 10 MB is intentionally above typical Azure list + // payloads while still bounding memory use. maxPageResponseSize int64 = 10 << 20 // 10 MB // maxErrorBodySize limits the size of error response bodies captured @@ -44,6 +53,8 @@ type Pager[T any] struct { done bool opts PagerOptions originHost string // host of the initial URL for SSRF protection + pageCount int // number of pages fetched so far + truncated bool } // PageResponse is a single page returned by [Pager.NextPage]. @@ -59,6 +70,18 @@ type PageResponse[T any] struct { type PagerOptions struct { // Method overrides the HTTP method used for page requests. Defaults to GET. Method string + + // MaxPages limits the maximum number of pages that [Pager.Collect] will + // fetch. When set to a positive value, Collect stops after fetching that + // many pages. A value of 0 means unlimited (no cap) for manual NextPage + // iteration; Collect applies [defaultMaxPages] when this is 0. + MaxPages int + + // MaxItems limits the maximum total items that [Pager.Collect] will + // accumulate. When the collected items reach this count, Collect stops + // and returns the items gathered so far (truncated to MaxItems). + // A value of 0 means unlimited (no cap). + MaxItems int } // HTTPDoer abstracts the HTTP call so that [ResilientClient] or any @@ -117,6 +140,11 @@ func (p *Pager[T]) More() bool { return !p.done && p.nextURL != "" } +// Truncated reports whether [Collect] stopped due to MaxPages or MaxItems limits. +func (p *Pager[T]) Truncated() bool { + return p.truncated +} + // NextPage fetches the next page of results. Returns an error if the request // fails, the response is not 2xx, or the body cannot be decoded. // @@ -145,7 +173,7 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { return nil, &PaginationError{ StatusCode: resp.StatusCode, URL: p.nextURL, - Body: string(body), + Body: sanitizeErrorBody(string(body)), } } @@ -170,6 +198,9 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { p.nextURL = page.NextLink } + // Track page count for MaxPages enforcement in Collect. + p.pageCount++ + return &page, nil } @@ -199,12 +230,23 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { } // Collect is a convenience method that fetches all remaining pages and -// returns all items in a single slice. Use with caution on large result sets. +// returns all items in a single slice. +// +// To prevent unbounded memory growth from runaway pagination, Collect +// enforces [PagerOptions.MaxPages] (defaults to [defaultMaxPages] when +// unset) and [PagerOptions.MaxItems]. When either limit is reached, +// iteration stops and the items collected so far are returned. // // If NextPage returns both page data and an error (e.g. rejected nextLink), // the page data is included in the returned slice before returning the error. func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { var all []T + p.truncated = false + + maxPages := p.opts.MaxPages + if maxPages <= 0 { + maxPages = defaultMaxPages + } for p.More() { page, err := p.NextPage(ctx) @@ -214,11 +256,30 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { if err != nil { return all, err } + + // Enforce MaxItems: truncate and stop if exceeded. + if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + if len(all) > p.opts.MaxItems { + all = all[:p.opts.MaxItems] + } + p.truncated = true + break + } + + // Enforce MaxPages: stop after collecting the configured number of pages. + if p.pageCount >= maxPages { + p.truncated = true + break + } } return all, nil } +// maxPaginationErrorBodyLen limits the response body length stored in +// PaginationError to prevent sensitive data leakage through error messages. +const maxPaginationErrorBodyLen = 1024 + // PaginationError is returned when a page request receives a non-2xx response. type PaginationError struct { StatusCode int @@ -229,6 +290,40 @@ type PaginationError struct { func (e *PaginationError) Error() string { return fmt.Sprintf( "azdext.Pager: page request returned HTTP %d (url=%s)", - e.StatusCode, e.URL, + e.StatusCode, redactURL(e.URL), ) } + +func sanitizeErrorBody(body string) string { + if len(body) > maxPaginationErrorBodyLen { + body = body[:maxPaginationErrorBodyLen] + "...[truncated]" + } + return stripControlChars(body) +} + +func stripControlChars(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r < 0x20 && r != '\t' { + b.WriteRune(' ') + } else if r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// redactURL strips query parameters and fragments from a URL to avoid leaking +// tokens, SAS signatures, or other secrets in log/error messages. +func redactURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String() +} diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go index 1f8c3537279..58e80ef9f4c 100644 --- a/cli/azd/pkg/azdext/pagination_test.go +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -481,3 +482,106 @@ func TestPager_CollectWithSSRFError(t *testing.T) { t.Errorf("all = %v, want [a b] (partial results before SSRF error)", all) } } + +func TestPager_TruncatedByMaxPages(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 1; i <= 5; i++ { + nextLink := "" + if i < 5 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) + } + body := pageJSON([]int{i}, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 3 { + t.Errorf("len(all) = %d, want 3", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxPages)") + } +} + +func TestPager_TruncatedByMaxItems(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 0; i < 3; i++ { + items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} + nextLink := "" + if i < 2 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) + } + body := pageJSON(items, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Errorf("len(all) = %d, want 5", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxItems)") + } +} + +func TestPager_NotTruncatedOnNaturalEnd(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 2 { + t.Errorf("len(all) = %d, want 2", len(all)) + } + + if pager.Truncated() { + t.Error("Truncated() = true, want false (natural end)") + } +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index 2916b885334..3565e40ca93 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "math" + "math/rand/v2" "net/http" "strconv" "time" @@ -107,8 +108,9 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli return &ResilientClient{ httpClient: &http.Client{ - Transport: transport, - Timeout: opts.Timeout, + Transport: transport, + Timeout: opts.Timeout, + CheckRedirect: SSRFSafeRedirect, }, tokenProvider: tokenProvider, scopeDetector: sd, @@ -126,6 +128,13 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R if ctx == nil { return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") } + if body != nil && rc.opts.MaxRetries > 0 { + if _, ok := body.(io.ReadSeeker); !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + } var lastErr error var retryAfterOverride time.Duration @@ -240,7 +249,8 @@ func (rc *ResilientClient) backoff(attempt int) time.Duration { delay = rc.opts.MaxDelay } - return delay + jitter := 0.5 + rand.Float64()*0.5 + return time.Duration(float64(delay) * jitter) } // isRetryable returns true for status codes that indicate a transient failure. @@ -292,6 +302,12 @@ func retryAfterFromResponse(resp *http.Response) time.Duration { } if n, _ := strconv.Atoi(v); n > 0 { + // Cap parsed value before multiplication to prevent integer overflow + // (a crafted Retry-After header could wrap int64, bypassing maxRetryAfterDuration). + maxN := int(maxRetryAfterDuration / rh.units) + if n > maxN { + return maxRetryAfterDuration + } return time.Duration(n) * rh.units } diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go index 4ffd67af103..24ca2dffbf9 100644 --- a/cli/azd/pkg/azdext/resilient_http_client_test.go +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -530,9 +530,9 @@ func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should have made exactly 1 attempt (first gets 503 → retry → fail on body check). - if attempts != 1 { - t.Errorf("attempts = %d, want 1 (fail before second attempt)", attempts) + // Upfront check should fail before any request attempt. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) } } @@ -609,14 +609,73 @@ func TestResilientClient_RetryAfterCapped(t *testing.T) { t.Errorf("maxRetryAfterDuration = %v, should be <= 5m", maxRetryAfterDuration) } - // A large Retry-After value should be capped in Do(). + // A large Retry-After value should be capped at parse time to prevent + // integer overflow (crafted values could wrap int64 and bypass the cap in Do). h := http.Header{} h.Set("retry-after", "999999") resp := &http.Response{Header: h} got := retryAfterFromResponse(resp) - // retryAfterFromResponse itself doesn't cap (pure parser), but Do() caps it. - if got != 999999*time.Second { - t.Errorf("retryAfterFromResponse() = %v, want %v (capping happens in Do)", got, 999999*time.Second) + // retryAfterFromResponse now caps values to maxRetryAfterDuration to prevent overflow. + if got != maxRetryAfterDuration { + t.Errorf("retryAfterFromResponse() = %v, want %v (capped at parse time)", got, maxRetryAfterDuration) + } +} + +func TestResilientClient_BackoffJitter(t *testing.T) { + t.Parallel() + rc := NewResilientClient(nil, &ResilientClientOptions{InitialDelay: 100 * time.Millisecond, MaxDelay: 10 * time.Second}) + seen := make(map[time.Duration]bool) + for range 20 { + d := rc.backoff(1) + seen[d] = true + if d < 50*time.Millisecond || d >= 100*time.Millisecond { + t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + } + } + if len(seen) < 2 { + t.Error("backoff jitter produced identical values across 20 calls") + } +} + +func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { + t.Parallel() + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("ok")), Header: http.Header{}}, nil + }) + rc := NewResilientClient(nil, &ResilientClientOptions{Transport: transport, MaxRetries: 2, InitialDelay: time.Millisecond}) + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + } +} + +func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { + t.Parallel() + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{StatusCode: http.StatusTooManyRequests, Body: io.NopCloser(strings.NewReader("throttled")), Header: h}, nil + }) + rc := NewResilientClient(nil, &ResilientClientOptions{Transport: transport, MaxRetries: 1, InitialDelay: time.Millisecond}) + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded (proving cap was applied), got: %v", err) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1", attempts) } }