diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 00000000..2ab9bca9 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,131 @@ +package modular + +import ( + "context" + "fmt" + "testing" + "time" +) + +// --- Benchmark helpers --- + +// benchModule is a minimal Module for bootstrap benchmarks. +type benchModule struct{ name string } + +func (m *benchModule) Name() string { return m.name } +func (m *benchModule) Init(_ Application) error { return nil } + +// benchReloadable is a fast Reloadable for reload benchmarks. +type benchReloadable struct{ name string } + +func (m *benchReloadable) Name() string { return m.name } +func (m *benchReloadable) Init(_ Application) error { return nil } +func (m *benchReloadable) Reload(_ context.Context, _ []ConfigChange) error { + return nil +} +func (m *benchReloadable) CanReload() bool { return true } +func (m *benchReloadable) ReloadTimeout() time.Duration { return 5 * time.Second } + +// benchLogger is a no-op logger for benchmarks. +type benchLogger struct{} + +func (l *benchLogger) Info(_ string, _ ...any) {} +func (l *benchLogger) Error(_ string, _ ...any) {} +func (l *benchLogger) Warn(_ string, _ ...any) {} +func (l *benchLogger) Debug(_ string, _ ...any) {} + +// BenchmarkBootstrap measures Init time with 10 modules. Target: <150ms. +func BenchmarkBootstrap(b *testing.B) { + modules := make([]Module, 10) + for i := range modules { + modules[i] = &benchModule{name: fmt.Sprintf("bench-mod-%d", i)} + } + + b.ResetTimer() + for b.Loop() { + app, err := NewApplication( + WithLogger(&benchLogger{}), + WithConfigProvider(NewStdConfigProvider(&struct{}{})), + WithModules(modules...), + ) + if err != nil { + b.Fatalf("NewApplication failed: %v", err) + } + + if err := app.Init(); err != nil { + b.Fatalf("Init failed: %v", err) + } + } +} + +// BenchmarkServiceLookup measures service registry lookup. Target: <2us. +func BenchmarkServiceLookup(b *testing.B) { + registry := NewEnhancedServiceRegistry() + _, _ = registry.RegisterService("bench-service", &struct{ Value int }{42}) + svcReg := registry.AsServiceRegistry() + + b.ResetTimer() + for b.Loop() { + _ = svcReg["bench-service"] + } +} + +// BenchmarkReload measures a single reload cycle with 5 modules. Target: <80ms. +func BenchmarkReload(b *testing.B) { + log := &benchLogger{} + orchestrator := NewReloadOrchestrator(log, nil) + + for i := 0; i < 5; i++ { + mod := &benchReloadable{name: fmt.Sprintf("reload-mod-%d", i)} + orchestrator.RegisterReloadable(mod.name, mod) + } + + diff := ConfigDiff{ + Changed: map[string]FieldChange{ + "key1": {OldValue: "a", NewValue: "b", FieldPath: "key1", ChangeType: ChangeModified}, + }, + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + } + + ctx := context.Background() + + b.ResetTimer() + for b.Loop() { + req := ReloadRequest{ + Trigger: ReloadManual, + Diff: diff, + Ctx: ctx, + } + // Call processReload directly to measure the actual reload cycle + // without channel/goroutine overhead. + if err := orchestrator.processReload(ctx, req); err != nil { + b.Fatalf("processReload failed: %v", err) + } + } +} + +// BenchmarkHealthAggregation measures health check aggregation with 10 providers. +// Target: <5ms. +func BenchmarkHealthAggregation(b *testing.B) { + svc := NewAggregateHealthService(WithCacheTTL(0)) + + for i := 0; i < 10; i++ { + name := fmt.Sprintf("provider-%d", i) + provider := NewSimpleHealthProvider(name, "main", func(_ context.Context) (HealthStatus, string, error) { + return StatusHealthy, "ok", nil + }) + svc.AddProvider(name, provider) + } + + // Force refresh on every call by using ForceHealthRefreshKey. + ctx := context.WithValue(context.Background(), ForceHealthRefreshKey, true) + + b.ResetTimer() + for b.Loop() { + _, err := svc.Check(ctx) + if err != nil { + b.Fatalf("Check failed: %v", err) + } + } +} diff --git a/builder.go b/builder.go index b03b1e88..baf77b19 100644 --- a/builder.go +++ b/builder.go @@ -2,6 +2,7 @@ package modular import ( "context" + "fmt" cloudevents "github.com/cloudevents/sdk-go/v2" ) @@ -21,6 +22,8 @@ type ApplicationBuilder struct { enableObserver bool enableTenant bool configLoadedHooks []func(Application) error // Hooks to run after config loading + tenantGuard *StandardTenantGuard + tenantGuardConfig *TenantGuardConfig } // ObserverFunc is a functional observer that can be registered with the application @@ -97,6 +100,16 @@ func (b *ApplicationBuilder) Build() (Application, error) { app = NewObservableDecorator(app, b.observers...) } + // Create and register tenant guard if configured. + // Use RegisterService so that the EnhancedServiceRegistry (if enabled) tracks + // the entry and subsequent RegisterService calls don't overwrite it. + if b.tenantGuardConfig != nil { + b.tenantGuard = NewStandardTenantGuard(*b.tenantGuardConfig) + if err := app.RegisterService("tenant.guard", b.tenantGuard); err != nil { + return nil, fmt.Errorf("failed to register tenant guard service: %w", err) + } + } + // Register modules for _, module := range b.modules { app.RegisterModule(module) @@ -194,6 +207,26 @@ func WithOnConfigLoaded(hooks ...func(Application) error) Option { } } +// WithTenantGuardMode enables the tenant guard with the specified mode using default config. +func WithTenantGuardMode(mode TenantGuardMode) Option { + return func(b *ApplicationBuilder) error { + if b.tenantGuardConfig == nil { + cfg := DefaultTenantGuardConfig() + b.tenantGuardConfig = &cfg + } + b.tenantGuardConfig.Mode = mode + return nil + } +} + +// WithTenantGuardConfig enables the tenant guard with a full configuration. +func WithTenantGuardConfig(config TenantGuardConfig) Option { + return func(b *ApplicationBuilder) error { + b.tenantGuardConfig = &config + return nil + } +} + // Convenience functions for creating common decorators // InstanceAwareConfig creates an instance-aware configuration decorator diff --git a/contract_verifier.go b/contract_verifier.go new file mode 100644 index 00000000..298ab942 --- /dev/null +++ b/contract_verifier.go @@ -0,0 +1,227 @@ +package modular + +import ( + "context" + "fmt" + "sync" + "time" +) + +// ContractViolation describes a single violation found during contract verification. +type ContractViolation struct { + Contract string // "reload" or "health" + Rule string // e.g., "must-return-positive-timeout" + Description string + Severity string // "error" or "warning" +} + +// ContractVerifier verifies that implementations of Reloadable and HealthProvider +// satisfy their behavioral contracts beyond what the type system enforces. +type ContractVerifier interface { + VerifyReloadContract(module Reloadable) []ContractViolation + VerifyHealthContract(provider HealthProvider) []ContractViolation +} + +// StandardContractVerifier is the default implementation of ContractVerifier. +type StandardContractVerifier struct{} + +// NewStandardContractVerifier creates a new StandardContractVerifier. +func NewStandardContractVerifier() *StandardContractVerifier { + return &StandardContractVerifier{} +} + +// VerifyReloadContract checks that a Reloadable module satisfies its behavioral contract: +// 1. ReloadTimeout() returns a positive duration +// 2. CanReload() is safe to call concurrently (no panics) +// 3. Reload() with empty changes is idempotent +// 4. Reload() respects context cancellation +func (v *StandardContractVerifier) VerifyReloadContract(module Reloadable) []ContractViolation { + var violations []ContractViolation + + // 1. ReloadTimeout must return a positive duration. + if timeout := module.ReloadTimeout(); timeout <= 0 { + violations = append(violations, ContractViolation{ + Contract: "reload", + Rule: "must-return-positive-timeout", + Description: fmt.Sprintf("ReloadTimeout() returned %v, must be > 0", timeout), + Severity: "error", + }) + } + + // 2. CanReload must be safe to call concurrently (no panics). + if panicked := v.checkCanReloadConcurrency(module); panicked { + violations = append(violations, ContractViolation{ + Contract: "reload", + Rule: "can-reload-must-not-panic", + Description: "CanReload() panicked during concurrent invocation", + Severity: "warning", + }) + } + + // 3. Reload with empty changes should be idempotent. + if err := v.checkReloadIdempotent(module); err != nil { + violations = append(violations, ContractViolation{ + Contract: "reload", + Rule: "empty-reload-must-be-idempotent", + Description: fmt.Sprintf("Reload() with empty changes failed: %v", err), + Severity: "warning", + }) + } + + // 4. Reload must respect context cancellation. + if !v.checkReloadRespectsCancel(module) { + violations = append(violations, ContractViolation{ + Contract: "reload", + Rule: "must-respect-context-cancellation", + Description: "Reload() with cancelled context did not return an error", + Severity: "warning", + }) + } + + return violations +} + +// checkCanReloadConcurrency calls CanReload 100 times concurrently and reports +// whether any invocation panicked. +func (v *StandardContractVerifier) checkCanReloadConcurrency(module Reloadable) bool { + var ( + wg sync.WaitGroup + panicked int32 + mu sync.Mutex + ) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + mu.Lock() + panicked = 1 + mu.Unlock() + } + }() + module.CanReload() + }() + } + wg.Wait() + return panicked != 0 +} + +// checkReloadIdempotent calls Reload with empty changes twice and returns an error +// if either call fails or hangs beyond the timeout. Each call is guarded by a +// goroutine so a misbehaving module cannot block the verifier indefinitely. +func (v *StandardContractVerifier) checkReloadIdempotent(module Reloadable) error { + for i, label := range []string{"first", "second"} { + _ = i + if err := v.runReloadWithGuard(module, label); err != nil { + return err + } + } + return nil +} + +// runReloadWithGuard runs module.Reload in a goroutine and returns an error if +// it fails or exceeds the 5-second timeout. +func (v *StandardContractVerifier) runReloadWithGuard(module Reloadable, label string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + type result struct{ err error } + ch := make(chan result, 1) + go func() { + ch <- result{err: module.Reload(ctx, nil)} + }() + + select { + case r := <-ch: + if r.err != nil { + return fmt.Errorf("%s call: %w", label, r.err) + } + return nil + case <-ctx.Done(): + return fmt.Errorf("%s call: %w", label, ErrReloadTimeout) + } +} + +// checkReloadRespectsCancel calls Reload with an already-cancelled context and +// returns true if Reload returned an error (i.e., it respected the cancellation). +func (v *StandardContractVerifier) checkReloadRespectsCancel(module Reloadable) bool { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + err := module.Reload(ctx, nil) + return err != nil +} + +// VerifyHealthContract checks that a HealthProvider satisfies its behavioral contract: +// 1. HealthCheck returns within 5 seconds +// 2. Reports have non-empty Module field +// 3. Reports have non-empty Component field +// 4. HealthCheck with cancelled context returns an error +func (v *StandardContractVerifier) VerifyHealthContract(provider HealthProvider) []ContractViolation { + var violations []ContractViolation + + // 1 + 2 + 3: Check that HealthCheck returns in time and reports have required fields. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + type result struct { + reports []HealthReport + err error + } + ch := make(chan result, 1) + go func() { + reports, err := provider.HealthCheck(ctx) + ch <- result{reports, err} + }() + + select { + case <-ctx.Done(): + violations = append(violations, ContractViolation{ + Contract: "health", + Rule: "must-return-within-timeout", + Description: "HealthCheck() did not return within 5 seconds", + Severity: "error", + }) + // Can't check fields if we timed out. + return violations + case res := <-ch: + if res.err == nil { + for _, report := range res.reports { + if report.Module == "" { + violations = append(violations, ContractViolation{ + Contract: "health", + Rule: "must-have-module-field", + Description: "HealthReport has empty Module field", + Severity: "error", + }) + } + if report.Component == "" { + violations = append(violations, ContractViolation{ + Contract: "health", + Rule: "must-have-component-field", + Description: "HealthReport has empty Component field", + Severity: "error", + }) + } + } + } + } + + // 4. HealthCheck with cancelled context should return an error. + cancelCtx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + + _, err := provider.HealthCheck(cancelCtx) + if err == nil { + violations = append(violations, ContractViolation{ + Contract: "health", + Rule: "must-respect-context-cancellation", + Description: "HealthCheck() with cancelled context did not return an error", + Severity: "warning", + }) + } + + return violations +} diff --git a/contract_verifier_test.go b/contract_verifier_test.go new file mode 100644 index 00000000..a4237365 --- /dev/null +++ b/contract_verifier_test.go @@ -0,0 +1,164 @@ +package modular + +import ( + "context" + "testing" + "time" +) + +// --- Mock Reloadable modules for contract tests --- + +// wellBehavedReloadable satisfies all reload contract rules. +type wellBehavedReloadable struct{} + +func (w *wellBehavedReloadable) Reload(ctx context.Context, _ []ConfigChange) error { + if err := ctx.Err(); err != nil { + return err + } + return nil +} +func (w *wellBehavedReloadable) CanReload() bool { return true } +func (w *wellBehavedReloadable) ReloadTimeout() time.Duration { return 5 * time.Second } + +// zeroTimeoutReloadable returns a zero timeout. +type zeroTimeoutReloadable struct{ wellBehavedReloadable } + +func (z *zeroTimeoutReloadable) ReloadTimeout() time.Duration { return 0 } + +// panickyReloadable panics when CanReload is called. +type panickyReloadable struct{ wellBehavedReloadable } + +func (p *panickyReloadable) CanReload() bool { panic("boom") } + +// --- Mock HealthProviders for contract tests --- + +// wellBehavedHealthProvider returns a proper report and respects cancellation. +type wellBehavedHealthProvider struct{} + +func (w *wellBehavedHealthProvider) HealthCheck(ctx context.Context) ([]HealthReport, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + return []HealthReport{ + { + Module: "test-module", + Component: "test-component", + Status: StatusHealthy, + Message: "ok", + CheckedAt: time.Now(), + }, + }, nil +} + +// emptyModuleHealthProvider returns a report with empty Module field. +type emptyModuleHealthProvider struct{} + +func (e *emptyModuleHealthProvider) HealthCheck(ctx context.Context) ([]HealthReport, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + return []HealthReport{ + { + Module: "", + Component: "comp", + Status: StatusHealthy, + CheckedAt: time.Now(), + }, + }, nil +} + +// cancelIgnoringHealthProvider ignores context cancellation. +type cancelIgnoringHealthProvider struct{} + +func (c *cancelIgnoringHealthProvider) HealthCheck(_ context.Context) ([]HealthReport, error) { + return []HealthReport{ + { + Module: "mod", + Component: "comp", + Status: StatusHealthy, + CheckedAt: time.Now(), + }, + }, nil +} + +// --- Tests --- + +func TestContractVerifier_ReloadWellBehaved(t *testing.T) { + verifier := NewStandardContractVerifier() + violations := verifier.VerifyReloadContract(&wellBehavedReloadable{}) + if len(violations) != 0 { + t.Fatalf("expected 0 violations for well-behaved reloadable, got %d: %+v", len(violations), violations) + } +} + +func TestContractVerifier_ReloadZeroTimeout(t *testing.T) { + verifier := NewStandardContractVerifier() + violations := verifier.VerifyReloadContract(&zeroTimeoutReloadable{}) + + found := false + for _, v := range violations { + if v.Rule == "must-return-positive-timeout" && v.Severity == "error" { + found = true + break + } + } + if !found { + t.Fatalf("expected violation for zero timeout, got: %+v", violations) + } +} + +func TestContractVerifier_ReloadPanicsOnCanReload(t *testing.T) { + verifier := NewStandardContractVerifier() + violations := verifier.VerifyReloadContract(&panickyReloadable{}) + + found := false + for _, v := range violations { + if v.Rule == "can-reload-must-not-panic" && v.Severity == "warning" { + found = true + break + } + } + if !found { + t.Fatalf("expected violation for panicky CanReload, got: %+v", violations) + } +} + +func TestContractVerifier_HealthWellBehaved(t *testing.T) { + verifier := NewStandardContractVerifier() + violations := verifier.VerifyHealthContract(&wellBehavedHealthProvider{}) + if len(violations) != 0 { + t.Fatalf("expected 0 violations for well-behaved health provider, got %d: %+v", len(violations), violations) + } +} + +func TestContractVerifier_HealthEmptyModule(t *testing.T) { + verifier := NewStandardContractVerifier() + violations := verifier.VerifyHealthContract(&emptyModuleHealthProvider{}) + + found := false + for _, v := range violations { + if v.Rule == "must-have-module-field" && v.Severity == "error" { + found = true + break + } + } + if !found { + t.Fatalf("expected violation for empty Module field, got: %+v", violations) + } +} + +func TestContractVerifier_HealthIgnoresCancellation(t *testing.T) { + verifier := NewStandardContractVerifier() + violations := verifier.VerifyHealthContract(&cancelIgnoringHealthProvider{}) + + found := false + for _, v := range violations { + if v.Rule == "must-respect-context-cancellation" && v.Severity == "warning" { + found = true + break + } + } + if !found { + t.Fatalf("expected violation for ignoring cancellation, got: %+v", violations) + } +} diff --git a/docs/plans/aggregate-health.md b/docs/plans/aggregate-health.md index 9077e023..7f3ae47b 100644 --- a/docs/plans/aggregate-health.md +++ b/docs/plans/aggregate-health.md @@ -1,17 +1,32 @@ -# Aggregate Health Service — Reimplementation Plan - -> Previously implemented in GoCodeAlone/modular (v1.4.3). Dropped during reset to GoCodeAlone/modular upstream. -> This document captures the design for future reimplementation. - -## Overview - -The Aggregate Health Service collects health reports from registered providers, aggregates them into readiness and overall health statuses, and caches results with a configurable TTL. It supports concurrent health checks with panic recovery, emits status change events, and provides adapter patterns for simple, static, and composite health providers. +# Aggregate Health Service — Revised Implementation Plan + +> Reset from CrisisTextLine/modular upstream (2026-03-09). This revision reflects what already exists. + +## Gap Analysis + +**Already exists (~15%):** +- ReverseProxy `HealthChecker` with concurrent backend checks, events, debug endpoints (`modules/reverseproxy/health_checker.go`) +- Backend health events: `EventTypeBackendHealthy`, `EventTypeBackendUnhealthy` (`modules/reverseproxy/events.go`) +- Observer pattern with CloudEvents (`observer.go`) — event emission infrastructure +- ReverseProxy circuit breaker (`modules/reverseproxy/circuit_breaker.go`) +- Database BDD health check stubs (`modules/database/bdd_connections_test.go`) +- HTTP server health monitoring BDD stubs (`modules/httpserver/bdd_health_monitoring_test.go`) + +**Must implement (entire core service is new):** +- `HealthStatus` enum (Unknown/Healthy/Degraded/Unhealthy) +- `HealthProvider` interface +- `HealthReport` and `AggregatedHealth` structs +- `AggregateHealthService` with provider registry, concurrent fan-out, caching +- Per-provider panic recovery +- Temporary error detection (→ Degraded) +- Provider adapters: Simple, Static, Composite +- Health events: `HealthEvaluatedEvent`, `HealthStatusChangedEvent` +- Cache with TTL + force refresh context key ## Key Interfaces ```go type HealthStatus int - const ( StatusUnknown HealthStatus = iota StatusHealthy @@ -19,9 +34,6 @@ const ( StatusUnhealthy ) -func (s HealthStatus) String() string { /* "unknown", "healthy", "degraded", "unhealthy" */ } -func (s HealthStatus) IsHealthy() bool { return s == StatusHealthy } - type HealthProvider interface { HealthCheck(ctx context.Context) ([]HealthReport, error) } @@ -38,8 +50,8 @@ type HealthReport struct { } type AggregatedHealth struct { - Readiness HealthStatus // Worst of non-optional providers only - Health HealthStatus // Worst of all providers + Readiness HealthStatus + Health HealthStatus Reports []HealthReport GeneratedAt time.Time } @@ -47,77 +59,49 @@ type AggregatedHealth struct { ## Architecture -**AggregateHealthService** is the central coordinator: - Provider registry: `map[string]HealthProvider` behind `sync.RWMutex` -- Cache: single `AggregatedHealth` with timestamp, TTL default 250ms -- Force refresh: context value key to bypass cache - -**Aggregation rules**: -- **Readiness**: worst status among non-optional providers only. Used for load balancer probes. -- **Health**: worst status among all providers. Used for monitoring/alerting. -- Ordering: Healthy < Degraded < Unhealthy (higher = worse, worst wins). -- Unknown treated as Unhealthy for aggregation purposes. - -**Concurrent collection**: -- Fan-out goroutines to all providers simultaneously -- Per-provider panic recovery (panic -> Unhealthy report with panic details) -- Results collected via channel, aggregated after all complete or context cancels -- Temporary errors (implementing `interface{ Temporary() bool }`) produce Degraded; other errors produce Unhealthy - -**Caching**: -- Enabled by default, TTL 250ms -- Invalidated when providers are added or removed -- Force refresh via `context.WithValue(ctx, ForceHealthRefreshKey, true)` - -**Provider adapters**: -```go -// Wrap a function as a provider -func NewSimpleHealthProvider(name string, fn func(ctx context.Context) (HealthStatus, string, error)) HealthProvider - -// Fixed status, useful for testing or static components -func NewStaticHealthProvider(reports ...HealthReport) HealthProvider - -// Combine multiple providers into one -func NewCompositeHealthProvider(providers ...HealthProvider) HealthProvider -``` +- Cache: single `AggregatedHealth` with TTL (default 250ms), invalidated on provider add/remove +- Force refresh: `context.WithValue(ctx, ForceHealthRefreshKey, true)` +- Concurrent collection: fan-out goroutines, per-provider panic recovery, channel-based results +- Aggregation: Readiness = worst non-optional, Health = worst all. Unknown → Unhealthy for aggregation +- Temporary errors (`interface{ Temporary() bool }`) → Degraded; other errors → Unhealthy -**Events**: -- `HealthEvaluatedEvent{Metrics}` — emitted after each aggregation with `HealthEvaluationMetrics` (components evaluated, failed, avg response time, bottleneck component name + duration) -- `HealthStatusChangedEvent{Previous, Current, ChangedAt}` — emitted only when aggregated status transitions +## Files -**Module-specific implementations** (examples for built-in modules): -- **Cache**: connectivity check (Set/Get/Delete cycle), capacity reporting -- **Database**: connection pool stats, ping latency -- **EventBus**: publish test event, worker count vs expected -- **ReverseProxy**: backend reachability with per-backend circuit breaker +| Action | File | What | +|--------|------|------| +| Create | `health.go` | HealthStatus enum, HealthProvider, HealthReport, AggregatedHealth, provider adapters | +| Create | `health_service.go` | AggregateHealthService implementation | +| Modify | `observer.go` | Add EventTypeHealthEvaluated, EventTypeHealthStatusChanged | +| Create | `health_test.go` | Unit + concurrency + panic recovery tests | ## Implementation Checklist -- [ ] Define `HealthStatus` enum with `String()` and `IsHealthy()` -- [ ] Define `HealthProvider` interface -- [ ] Define `HealthReport` and `AggregatedHealth` structs -- [ ] Implement `AggregateHealthService` with provider registry and RWMutex -- [ ] Implement concurrent fan-out health collection with goroutines and channel -- [ ] Implement per-provider panic recovery +- [ ] Define HealthStatus enum with String() and IsHealthy() +- [ ] Define HealthProvider interface +- [ ] Define HealthReport and AggregatedHealth structs +- [ ] Add health event constants to observer.go +- [ ] Implement AggregateHealthService with provider registry + RWMutex +- [ ] Implement concurrent fan-out collection with goroutines + channel +- [ ] Implement per-provider panic recovery (panic → Unhealthy with details) - [ ] Implement aggregation logic (readiness = worst non-optional, health = worst all) -- [ ] Implement cache with TTL (default 250ms) and force-refresh context key +- [ ] Implement cache with TTL (250ms default) and force-refresh context key - [ ] Implement cache invalidation on provider add/remove -- [ ] Implement `NewSimpleHealthProvider` adapter -- [ ] Implement `NewStaticHealthProvider` adapter -- [ ] Implement `NewCompositeHealthProvider` adapter -- [ ] Define and emit `HealthEvaluatedEvent` with metrics -- [ ] Define and emit `HealthStatusChangedEvent` on transitions +- [ ] Implement NewSimpleHealthProvider adapter +- [ ] Implement NewStaticHealthProvider adapter +- [ ] Implement NewCompositeHealthProvider adapter - [ ] Implement temporary error detection (Degraded vs Unhealthy) -- [ ] Write unit tests: single provider, multiple providers, optional vs required aggregation -- [ ] Write unit tests: cache hit/miss/invalidation, force refresh -- [ ] Write concurrency tests: parallel health checks, provider registration during check +- [ ] Emit HealthEvaluatedEvent after each aggregation +- [ ] Emit HealthStatusChangedEvent on status transitions only +- [ ] Write unit tests: single provider, multiple providers, optional vs required +- [ ] Write cache tests: hit, miss, invalidation, force refresh +- [ ] Write concurrency tests: parallel checks, registration during check - [ ] Write panic recovery tests -- [ ] Implement module-specific health providers (cache, database, eventbus, reverseproxy) as examples ## Notes -- The 250ms cache TTL prevents health check storms under high request rates while keeping results fresh. -- Panic recovery ensures one misbehaving provider cannot crash the entire health system. -- `ObservedSince` in `HealthReport` tracks when the current status was first seen, enabling duration-based alerting. -- Optional providers affect `Health` but not `Readiness`, allowing non-critical components to degrade without failing readiness probes. -- The bottleneck detection in `HealthEvaluationMetrics` identifies the slowest provider to aid performance tuning. +- 250ms cache TTL prevents health check storms while keeping results fresh. +- Panic recovery ensures one misbehaving provider cannot crash the health system. +- `ObservedSince` tracks when current status was first seen, enabling duration-based alerting. +- Optional providers affect Health but not Readiness. +- Module-specific providers (cache, database, eventbus, reverseproxy) are examples, not required for core. diff --git a/docs/plans/bdd-contract-testing.md b/docs/plans/bdd-contract-testing.md index 14f5b8d9..469e12bc 100644 --- a/docs/plans/bdd-contract-testing.md +++ b/docs/plans/bdd-contract-testing.md @@ -1,132 +1,80 @@ -# BDD/Contract Testing Framework — Reimplementation Plan - -> Previously implemented in GoCodeAlone/modular (v1.4.3). Dropped during reset to GoCodeAlone/modular upstream. -> This document captures the design for future reimplementation. - -## Overview - -The BDD/Contract Testing framework uses Cucumber/Godog for behavior-driven development with Gherkin feature files and Go step definitions. It defines formal contracts for the reload and health subsystems, establishes performance baselines, and enforces a TDD discipline (RED-GREEN-REFACTOR) across a 58-task, 6-phase implementation structure. It also includes API contract management tooling for breaking change detection. - -## Key Interfaces - -```go -// Contract verification — modules assert compliance with behavioral contracts -type ContractVerifier interface { - VerifyReloadContract(module Reloadable) []ContractViolation - VerifyHealthContract(provider HealthProvider) []ContractViolation -} - -type ContractViolation struct { - Contract string // e.g., "reload", "health" - Rule string // e.g., "must-emit-started-event" - Description string - Severity string // "error", "warning" -} - -// Contract extraction for API versioning -type ContractExtractor interface { - Extract(version string) ContractSnapshot - Compare(old, new ContractSnapshot) []BreakingChange -} - -type ContractSnapshot struct { - Version string - Interfaces map[string]InterfaceContract - Events []string - Timestamp time.Time -} - -type BreakingChange struct { - Type string // "interface-widened", "method-removed", "signature-changed" - Interface string - Method string - Description string -} -``` - -## Architecture - -**Gherkin feature files** cover core framework behaviors: -- `application_lifecycle.feature` — startup, shutdown, signal handling -- `configuration_management.feature` — config loading, validation, env overrides -- `cycle_detection.feature` — module dependency cycle detection and reporting -- `logger_decorator.feature` — structured logging decoration -- `service_registry.feature` — service registration, lookup, type safety -- `base_config.feature` — default config, merging, precedence - -**Contract specifications** define formal behavioral requirements: - -*Reload contract*: -- Modules implementing `Reloadable` must handle `Reload()` idempotently -- `CanReload()` must be safe to call concurrently and return deterministically -- `ReloadTimeout()` must return a positive duration -- Events must fire in order: Started -> (Completed | Failed) -- On failure, previously applied modules must be rolled back -- Constraint: reload must not block longer than the sum of all module timeouts - -*Health contract*: -- `HealthCheck()` must return within a reasonable timeout (default 5s) -- Reports must have non-empty Module and Component fields -- JSON schema validation for health response format -- Aggregation: worst-of for readiness (non-optional), worst-of for health (all) -- Events: `HealthEvaluatedEvent` after every check, `HealthStatusChangedEvent` on transitions only - -**Design briefs** (FR-045 and FR-048) provide detailed functional requirements: -- FR-045 (Dynamic Reload): atomic semantics, circuit breaker, event lifecycle, rollback behavior -- FR-048 (Aggregate Health): provider pattern, caching, concurrent collection, panic recovery - -**Task structure** — 58 tasks across 6 phases: -1. **Setup** (tasks 1-8): project scaffolding, Godog integration, build tags for pending tests -2. **Tests First** (tasks 9-20): write failing Gherkin scenarios and step definitions -3. **Core Implementation** (tasks 21-35): implement to make tests pass (RED -> GREEN) -4. **Integration** (tasks 36-44): cross-module integration tests, event flow verification -5. **Hardening** (tasks 45-52): performance benchmarks, concurrency stress tests, edge cases -6. **Finalization** (tasks 53-58): documentation, contract extraction tooling, CI integration - -**Performance targets**: -- Bootstrap: <150ms P50 with 10 modules -- Service lookup: <2us -- Reload: <80ms P50 -- Health aggregation: <5ms P50 - -**Constitution rules** (non-negotiable design constraints): -- No interface widening — existing interfaces are frozen after v1.0 -- Additive only — new functionality via new interfaces or builder options -- Builder options preferred over config struct changes - -**API contract management** via `modcli`: -- `modcli contract extract` — snapshot current interfaces, events, types -- `modcli contract compare v1 v2` — detect breaking changes between versions -- CI integration: fail build on breaking changes in non-major version bumps +# BDD/Contract Testing Framework — Revised Implementation Plan + +> Reset from CrisisTextLine/modular upstream (2026-03-09). This revision reflects what already exists. + +## Gap Analysis + +**Already exists (~65%):** +- Godog dependency (`go.mod`: `github.com/cucumber/godog v0.15.1`) +- 21 Gherkin feature files across core + modules +- 121 BDD test files across the codebase +- Core framework BDD: lifecycle, config, cycle detection, service registry, logger decorator +- Module BDD: auth, cache, database, eventbus, httpserver, httpclient, scheduler, reverseproxy, etc. +- Contract CLI: `modcli contract extract|compare|git-diff|tags` (`cmd/modcli/cmd/contract.go`, 636 lines) +- Contract types: `Contract`, `InterfaceContract`, `BreakingChange`, `ContractDiff` (`cmd/modcli/internal/contract/`) +- Contract extractor + differ with tests (1715 lines across 6 files) +- CI: `contract-check.yml` (241 lines) — extracts, compares, comments on PRs +- CI: `bdd-matrix.yml` (215 lines) — parallel module BDD, coverage merging +- BDD scripts: `run-module-bdd-parallel.sh`, `verify-bdd-tests.sh` + +**Must implement (depends on Dynamic Reload + Aggregate Health):** +- Reload contract feature file + step definitions (depends on Reloadable interface) +- Health contract feature file + step definitions (depends on HealthProvider interface) +- `ContractVerifier` interface for reload + health contracts +- Performance benchmark BDD (4 targets: bootstrap, lookup, reload, health) +- Concurrency stress test BDD scenarios + +## What to Build + +Since the BDD infrastructure and contract tooling are fully operational, the remaining work is: + +1. **Reload contract BDD** — write after Dynamic Reload is implemented +2. **Health contract BDD** — write after Aggregate Health is implemented +3. **ContractVerifier** — programmatic verification of reload/health behavioral contracts +4. **Performance benchmarks** — formalize the 4 targets as Go benchmarks + +## Files + +| Action | File | What | +|--------|------|------| +| Create | `features/reload_contract.feature` | Gherkin scenarios for Reloadable contract | +| Create | `features/health_contract.feature` | Gherkin scenarios for HealthProvider contract | +| Create | `reload_contract_bdd_test.go` | Step definitions for reload scenarios | +| Create | `health_contract_bdd_test.go` | Step definitions for health scenarios | +| Create | `contract_verifier.go` | ContractVerifier interface + implementations | +| Create | `contract_verifier_test.go` | Verifier tests | +| Create | `benchmark_test.go` | Performance benchmarks for 4 targets | ## Implementation Checklist -- [ ] Add `github.com/cucumber/godog` dependency -- [ ] Create `features/` directory with Gherkin feature files (6 files listed above) -- [ ] Write Go step definitions for application lifecycle scenarios -- [ ] Write Go step definitions for configuration management scenarios -- [ ] Write Go step definitions for cycle detection scenarios -- [ ] Write Go step definitions for service registry scenarios -- [ ] Define reload contract spec as testable assertions -- [ ] Define health contract spec as testable assertions -- [ ] Implement `ContractVerifier` for reload and health contracts -- [ ] Write FR-045 (dynamic reload) Gherkin scenarios before implementation -- [ ] Write FR-048 (aggregate health) Gherkin scenarios before implementation -- [ ] Set up build tags (`//go:build pending`) for tests written before implementation exists -- [ ] Implement core features to pass tests (GREEN phase) -- [ ] Refactor for clarity and performance (REFACTOR phase) -- [ ] Write performance benchmarks for all 4 targets (bootstrap, lookup, reload, health) -- [ ] Write concurrency stress tests (parallel reloads, concurrent health checks, registration races) -- [ ] Implement `ContractExtractor` and `ContractSnapshot` types -- [ ] Implement `modcli contract extract` command -- [ ] Implement `modcli contract compare` command with breaking change detection -- [ ] Add CI step: contract comparison on PRs targeting main +- [x] ~~Add godog dependency~~ (exists) +- [x] ~~Create features/ directory with core Gherkin files~~ (6 files exist) +- [x] ~~Write step definitions for lifecycle, config, cycle detection, service registry~~ (121 BDD tests) +- [x] ~~Implement ContractExtractor and ContractSnapshot~~ (contract package complete) +- [x] ~~Implement modcli contract extract/compare~~ (636-line CLI) +- [x] ~~Add CI contract comparison on PRs~~ (contract-check.yml) +- [ ] Create reload_contract.feature (after Dynamic Reload is implemented) +- [ ] Write reload contract step definitions +- [ ] Create health_contract.feature (after Aggregate Health is implemented) +- [ ] Write health contract step definitions +- [ ] Implement ContractVerifier for reload contracts +- [ ] Implement ContractVerifier for health contracts +- [ ] Write performance benchmarks (bootstrap <150ms, lookup <2us, reload <80ms, health <5ms) +- [ ] Write concurrency stress test scenarios + +## Performance Targets + +| Metric | Target (P50) | +|--------|-------------| +| Bootstrap (10 modules) | <150ms | +| Service lookup | <2us | +| Reload | <80ms | +| Health aggregation | <5ms | ## Notes -- Use `//go:build pending` to keep failing tests compiling but excluded from default `go test` runs until implementation catches up. -- The 58-task structure is a guide, not rigid. Tasks can be parallelized within phases but phases should be sequential. -- Performance targets are P50 values measured on commodity hardware. CI benchmarks should track regressions, not enforce absolute thresholds. -- Constitution rules exist to maintain backward compatibility. Breaking changes require a major version bump and must be flagged by contract tooling. -- Godog integrates with `testing.T` via `godog.TestSuite` — no separate test runner needed. -- Feature files should be human-readable enough for non-engineers to review behavioral expectations. +- Reload/health contract BDD depends on those features being implemented first. +- Performance targets are P50 on commodity hardware; CI tracks regressions, not absolutes. +- Constitution rules (no interface widening, additive only) are already enforced by contract-check.yml. +- Godog integrates with `testing.T` via `godog.TestSuite`. +- Feature files should be readable by non-engineers. diff --git a/docs/plans/dynamic-reload.md b/docs/plans/dynamic-reload.md index cd882008..1bee18a9 100644 --- a/docs/plans/dynamic-reload.md +++ b/docs/plans/dynamic-reload.md @@ -1,11 +1,28 @@ -# Dynamic Reload Manager — Reimplementation Plan - -> Previously implemented in GoCodeAlone/modular (v1.4.3). Dropped during reset to GoCodeAlone/modular upstream. -> This document captures the design for future reimplementation. - -## Overview - -The Dynamic Reload Manager enables live configuration reloading for modules that implement the `Reloadable` interface. It uses a channel-based request queue, atomic processing guards, an exponential backoff circuit breaker for failure resilience, and emits lifecycle events via the observer pattern. Reloads have atomic semantics: all modules apply or all roll back. +# Dynamic Reload Manager — Revised Implementation Plan + +> Reset from CrisisTextLine/modular upstream (2026-03-09). This revision reflects what already exists. + +## Gap Analysis + +**Already exists:** +- Observer pattern with CloudEvents (`observer.go`) — foundation for reload events +- Config field tracking (`config_field_tracking.go`) — `FieldPopulation`, `StructStateDiffer` +- Config providers with thread-safe variants (`config_provider.go`) — `ImmutableConfigProvider` (atomic.Value) +- Circuit breaker pattern in reverseproxy (`modules/reverseproxy/circuit_breaker.go`) — reference implementation +- `EventTypeConfigChanged` event constant +- Module interfaces: `Module`, `Configurable`, `Startable`, `Stoppable`, `DependencyAware` +- Builder pattern with `WithOnConfigLoaded()` option + +**Must implement:** +- `Reloadable` interface (add to `module.go`) +- `ConfigChange`, `ConfigDiff`, `FieldChange` types +- `ReloadTrigger` enum +- `ReloadOrchestrator` with request queue, CAS guard, circuit breaker +- Atomic reload with rollback semantics +- Reload lifecycle events (4 new event types) +- `RequestReload()` on Application interface +- `WithDynamicReload()` builder option +- Tests ## Key Interfaces @@ -60,58 +77,58 @@ const ( - Background goroutine drains the request queue **Circuit breaker** with exponential backoff: -- Base delay: 2 seconds -- Max delay cap: 2 minutes +- Base delay: 2 seconds, max delay cap: 2 minutes - Formula: `min(base * 2^(failures-1), cap)` -- Resets to zero on successful reload -- Rejects requests while circuit is open (returns error immediately) +- Resets on successful reload, rejects while open **Atomic reload semantics**: 1. Compute `ConfigDiff` between old and new config 2. Filter modules by affected sections -3. Check `CanReload()` on each; abort if any critical module refuses -4. Apply changes to each module with per-module timeout from `ReloadTimeout()` -5. On first failure: roll back already-applied modules with reverse changes +3. Check `CanReload()` on each; skip those returning false +4. Apply changes with per-module timeout from `ReloadTimeout()` +5. On failure: roll back already-applied modules with reverse changes 6. Emit completion or failure event -**Events** (via existing observer/event bus): -- `ConfigReloadStarted{ReloadID, Trigger, Sections}` -- `ConfigReloadCompleted{ReloadID, Duration, ModulesReloaded}` -- `ConfigReloadFailed{ReloadID, Error, ModulesFailed}` -- `ConfigReloadNoop{ReloadID, Reason}` — emitted when diff has no changes +**Events** (add to observer.go): +- `EventTypeConfigReloadStarted` +- `EventTypeConfigReloadCompleted` +- `EventTypeConfigReloadFailed` +- `EventTypeConfigReloadNoop` -**ConfigDiff methods**: -- `HasChanges() bool` — true if any Changed/Added/Removed entries -- `FilterByPrefix(prefix) ConfigDiff` — returns subset matching field path prefix -- `RedactSensitiveFields() ConfigDiff` — replaces sensitive values with `"[REDACTED]"` -- `ChangeSummary() string` — human-readable summary of changes +## Files -**HealthEvaluationMetrics** tracks per-reload stats: components evaluated, failed, skipped, timed out, and identifies the slowest component. +| Action | File | What | +|--------|------|------| +| Create | `reload.go` | ConfigChange, ConfigDiff, FieldChange, ReloadTrigger types + ConfigDiff methods | +| Modify | `module.go` | Add Reloadable interface | +| Create | `reload_orchestrator.go` | ReloadOrchestrator implementation | +| Modify | `observer.go` | Add 4 reload event type constants | +| Modify | `application.go` | Add RequestReload() method | +| Modify | `builder.go` | Add WithDynamicReload() option | +| Create | `reload_test.go` | Unit + concurrency tests | ## Implementation Checklist -- [ ] Define `Reloadable` interface -- [ ] Define `ConfigChange`, `ConfigDiff`, `FieldChange` structs -- [ ] Implement `ConfigDiff` methods (HasChanges, FilterByPrefix, RedactSensitiveFields, ChangeSummary) -- [ ] Define `ReloadTrigger` enum -- [ ] Implement `ReloadOrchestrator` with module registry and RWMutex +- [ ] Define `Reloadable` interface in module.go +- [ ] Create reload.go with ConfigChange, ConfigDiff, FieldChange, ChangeType, ReloadTrigger +- [ ] Implement ConfigDiff methods: HasChanges, FilterByPrefix, RedactSensitiveFields, ChangeSummary +- [ ] Add 4 reload event constants to observer.go +- [ ] Implement ReloadOrchestrator with module registry + RWMutex - [ ] Implement channel-based request queue (buffered, size 100) - [ ] Implement atomic CAS processing guard -- [ ] Implement exponential backoff circuit breaker (base 2s, cap 2m, factor 2^(n-1)) +- [ ] Implement exponential backoff circuit breaker - [ ] Implement atomic reload with rollback on failure -- [ ] Implement per-module timeout via `ReloadTimeout()` and context cancellation -- [ ] Define and emit reload lifecycle events -- [ ] Implement `HealthEvaluationMetrics` tracking -- [ ] Add `RequestReload(sections ...string)` to application interface -- [ ] Add `WithDynamicReload()` builder option -- [ ] Write unit tests: successful reload, partial failure + rollback, circuit breaker backoff -- [ ] Write concurrency tests: concurrent reload requests, CAS contention -- [ ] Write example: HTTP server with reloadable timeouts (read/write/idle) and non-reloadable address/port +- [ ] Implement per-module timeout via context cancellation +- [ ] Emit reload lifecycle events via observer +- [ ] Add RequestReload() to Application interface + StdApplication +- [ ] Add WithDynamicReload() builder option +- [ ] Write unit tests: successful reload, partial failure + rollback, circuit breaker +- [ ] Write concurrency tests: concurrent requests, CAS contention ## Notes -- Modules that return `CanReload() == false` are skipped, not treated as errors. -- Rollback applies reverse `ConfigChange` entries (swap Old/New) in reverse module order. -- The request queue drops requests when full (capacity 100) and returns an error to the caller. -- Circuit breaker state is internal to the orchestrator; not exposed to modules. -- Sensitive field detection can use a configurable list of field path patterns (e.g., `*password*`, `*secret*`). +- Modules returning `CanReload() == false` are skipped, not errors. +- Rollback applies reverse ConfigChange entries in reverse module order. +- Queue drops requests when full (capacity 100) and returns error. +- Circuit breaker state is internal to orchestrator; not exposed to modules. +- Sensitive field detection uses configurable field path patterns (e.g., `*password*`, `*secret*`). diff --git a/docs/plans/tenant-guard.md b/docs/plans/tenant-guard.md index 9983abec..7f7fc915 100644 --- a/docs/plans/tenant-guard.md +++ b/docs/plans/tenant-guard.md @@ -1,99 +1,118 @@ -# TenantGuard Framework — Reimplementation Plan - -> Previously implemented in GoCodeAlone/modular (v1.4.3). Dropped during reset to GoCodeAlone/modular upstream. -> This document captures the design for future reimplementation. - -## Overview - -TenantGuard provides multi-tenant isolation enforcement for the modular framework. It validates cross-tenant access at runtime with configurable strictness (strict/lenient/disabled), tracks violations with severity levels, and integrates with the application builder via decorator and builder option patterns. All tenant state is RWMutex-protected for concurrent access. - -## Key Interfaces +# TenantGuard Framework — Revised Implementation Plan + +> Reset from CrisisTextLine/modular upstream (2026-03-09). This revision reflects what already exists. + +## Gap Analysis + +**Already exists (~50% complete):** +- `TenantContext` with context propagation (`tenant.go:51-94`) +- `TenantService` interface + `StandardTenantService` implementation (`tenant.go`, `tenant_service.go`) +- `TenantAwareModule` interface with lifecycle hooks (`tenant.go:211-230`) +- `TenantConfigProvider` with RWMutex, isolation, immutability variants (`tenant_config_provider.go`) +- `TenantConfigLoader` + file-based implementation (`tenant_config_loader.go`, `tenant_config_file_loader.go`) +- `TenantAwareConfig` context-aware resolution (`tenant_aware_config.go`) +- `TenantAwareDecorator` application decorator (`decorator_tenant.go`) +- `TenantAffixedEnvFeeder` for tenant-specific env vars (`feeders/tenant_affixed_env.go`) +- `WithTenantAware()` builder option (`builder.go:163-169`) +- 8 tenant sentinel errors in `errors.go` +- ~28 test files covering tenant basics + +**Must implement:** +- `TenantGuard` interface + `StandardTenantGuard` implementation +- `TenantGuardMode` enum (Strict/Lenient/Disabled) +- `ViolationType` + `Severity` enums +- `TenantViolation` struct +- `TenantGuardConfig` with defaults +- Ring buffer for bounded violation history +- Whitelist support +- `WithTenantGuardMode()` + `WithTenantGuardModeConfig()` builder options +- 2 missing sentinel errors +- Violation event emission via observer +- Mode-specific tests + concurrency tests + +## Key Types (new) ```go type TenantGuardMode int - const ( - TenantGuardStrict TenantGuardMode = iota // Block cross-tenant access - TenantGuardLenient // Allow but log violations - TenantGuardDisabled // No enforcement + TenantGuardStrict TenantGuardMode = iota + TenantGuardLenient + TenantGuardDisabled ) -type TenantGuard interface { - GetMode() TenantGuardMode - ValidateAccess(ctx context.Context, violation TenantViolation) error - GetRecentViolations() []TenantViolation -} - -type TenantService interface { - GetTenantConfig(tenantID string) (TenantConfig, error) - GetTenants() []string - RegisterTenant(tenantID string, config TenantConfig) error - RegisterTenantAwareModule(module TenantAwareModule) -} +type ViolationType int +const ( + CrossTenant ViolationType = iota + InvalidContext + MissingContext + Unauthorized +) -type TenantAwareModule interface { - OnTenantRegistered(tenantID string, config TenantConfig) - OnTenantRemoved(tenantID string) -} -``` +type Severity int +const ( + SeverityLow Severity = iota + SeverityMedium + SeverityHigh + SeverityCritical +) -```go type TenantViolation struct { - Type ViolationType // CrossTenant, InvalidContext, MissingContext, Unauthorized - Severity Severity // Low, Medium, High, Critical + Type ViolationType + Severity Severity TenantID string TargetID string Timestamp time.Time Details string } +type TenantGuard interface { + GetMode() TenantGuardMode + ValidateAccess(ctx context.Context, violation TenantViolation) error + GetRecentViolations() []TenantViolation +} + type TenantGuardConfig struct { Mode TenantGuardMode EnforceIsolation bool AllowCrossTenant bool ValidationTimeout time.Duration - CacheSize int - CacheTTL time.Duration - Whitelist map[string][]string // tenantID -> allowed target tenant IDs + Whitelist map[string][]string + MaxViolations int LogViolations bool - BlockViolations bool } ``` -## Architecture - -**Context propagation**: `TenantContext` wraps `context.Context` with a tenant ID value. `GetTenantIDFromContext(ctx)` extracts it. All tenant-scoped operations must carry tenant context. - -**Config isolation**: `TenantConfigProvider` stores per-tenant config sections behind an `RWMutex`. Config reads return deep copies to prevent mutation. `TenantAffixedEnvFeeder` loads environment variables with tenant-specific prefixes/suffixes (e.g., `TENANT_ACME_DB_HOST`). - -**Decorator pattern**: `TenantAwareDecorator` wraps the application to inject tenant context into request processing. It intercepts module lifecycle calls and routes them through the tenant service. - -**Concurrency model**: All mutable state (`violations` slice, `config` maps, `whitelist`) protected by `sync.RWMutex`. `GetRecentViolations()` returns a deep copy to prevent data races. Violation tracking uses a bounded ring buffer to cap memory. +## Files -**Error types**: Sentinel errors (`ErrTenantNotFound`, `ErrTenantConfigNotFound`, `ErrTenantIsolationViolation`, `ErrTenantContextMissing`) for typed error handling. +| Action | File | What | +|--------|------|------| +| Create | `tenant_guard.go` | TenantGuardMode, ViolationType, Severity enums, TenantViolation, TenantGuardConfig, TenantGuard interface, StandardTenantGuard with ring buffer | +| Modify | `errors.go` | Add ErrTenantContextMissing, ErrTenantIsolationViolation | +| Modify | `builder.go` | Add WithTenantGuardMode(), WithTenantGuardModeConfig() | +| Modify | `observer.go` | Add EventTypeTenantViolation constant | +| Create | `tenant_guard_test.go` | Unit + concurrency tests | ## Implementation Checklist -- [ ] Define `TenantGuardMode` enum with String() method -- [ ] Define `ViolationType` and `Severity` enums -- [ ] Implement `TenantViolation` struct with timestamp tracking -- [ ] Implement `TenantGuardConfig` with sane defaults -- [ ] Implement `TenantGuard` interface and default implementation with RWMutex-protected violation ring buffer -- [ ] Implement `TenantContext` with `context.WithValue` / `GetTenantIDFromContext()` -- [ ] Implement `TenantService` interface and default implementation -- [ ] Implement `TenantAwareModule` lifecycle hook dispatch (fan-out on register/remove) -- [ ] Implement `TenantConfigProvider` with per-tenant config sections and deep copy reads -- [ ] Implement `TenantAffixedEnvFeeder` for tenant-specific env var loading -- [ ] Implement `TenantAwareDecorator` application decorator -- [ ] Add builder options: `WithTenantGuardMode()`, `WithTenantGuardModeConfig()`, `WithTenantAware()` -- [ ] Define sentinel error types -- [ ] Write unit tests for all modes (strict blocks, lenient logs, disabled skips) -- [ ] Write concurrency tests (parallel ValidateAccess, concurrent tenant registration) +- [ ] Create tenant_guard.go with TenantGuardMode enum + String() +- [ ] Add ViolationType and Severity enums with String() methods +- [ ] Implement TenantViolation struct +- [ ] Implement TenantGuardConfig with defaults (MaxViolations: 1000, LogViolations: true) +- [ ] Implement StandardTenantGuard with RWMutex-protected ring buffer +- [ ] Implement ValidateAccess: strict returns error, lenient logs, disabled no-op +- [ ] Implement whitelist checking in ValidateAccess +- [ ] Implement GetRecentViolations with deep copy +- [ ] Add ErrTenantContextMissing and ErrTenantIsolationViolation to errors.go +- [ ] Add EventTypeTenantViolation to observer.go +- [ ] Add WithTenantGuardMode() and WithTenantGuardModeConfig() to builder.go +- [ ] Write tests: strict blocks, lenient logs, disabled skips +- [ ] Write tests: whitelist bypass, ring buffer FIFO eviction +- [ ] Write concurrency tests: parallel ValidateAccess, concurrent violations ## Notes -- Whitelist map allows explicit cross-tenant access for service accounts or admin tenants. -- Violation buffer should be bounded (e.g., 1000 entries) to prevent unbounded memory growth. -- Strict mode returns an error from `ValidateAccess`; lenient mode logs and returns nil. -- `GetRecentViolations()` must deep-copy to avoid callers mutating internal state. -- Consider emitting events via the observer pattern for violation tracking integration. +- Ring buffer bounded at MaxViolations (default 1000) entries; FIFO eviction when full. +- Strict mode returns ErrTenantIsolationViolation; lenient logs + returns nil. +- GetRecentViolations() deep-copies to prevent caller mutation. +- Whitelist allows explicit cross-tenant access for service accounts. +- Emit EventTypeTenantViolation via observer for external monitoring integration. diff --git a/errors.go b/errors.go index 8693c401..fb008a59 100644 --- a/errors.go +++ b/errors.go @@ -82,6 +82,17 @@ var ( ErrMockTenantConfigsNotInitialized = errors.New("mock tenant configs not initialized") ErrConfigSectionNotFoundForTenant = errors.New("config section not found for tenant") + // Tenant guard errors + ErrTenantContextMissing = errors.New("tenant context is missing") + ErrTenantIsolationViolation = errors.New("tenant isolation violation") + + // Reload errors + ErrReloadCircuitBreakerOpen = errors.New("reload circuit breaker is open; backing off") + ErrReloadChannelFull = errors.New("reload request channel is full") + ErrReloadInProgress = errors.New("reload already in progress") + ErrReloadStopped = errors.New("reload orchestrator is stopped") + ErrReloadTimeout = errors.New("reload timed out waiting for module") + // Observer/Event emission errors ErrNoSubjectForEventEmission = errors.New("no subject available for event emission") diff --git a/features/health_contract.feature b/features/health_contract.feature new file mode 100644 index 00000000..354d0b32 --- /dev/null +++ b/features/health_contract.feature @@ -0,0 +1,47 @@ +Feature: Aggregate Health Contract + The health service must aggregate provider reports correctly. + + Scenario: Single healthy provider produces healthy status + Given a health service with one healthy provider + When health is checked + Then the overall status should be "healthy" + And readiness should be "healthy" + + Scenario: One unhealthy provider degrades overall health + Given a health service with one healthy and one unhealthy provider + When health is checked + Then the overall health should be "unhealthy" + And readiness should be "unhealthy" + + Scenario: Optional unhealthy provider does not affect readiness + Given a health service with one healthy required and one unhealthy optional provider + When health is checked + Then the overall health should be "unhealthy" + But readiness should be "healthy" + + Scenario: Provider panic is recovered gracefully + Given a health service with a provider that panics + When health is checked + Then the panicking provider should report "unhealthy" + And other providers should still be checked + + Scenario: Temporary error produces degraded status + Given a health service with a provider returning a temporary error + When health is checked + Then the provider status should be "degraded" + + Scenario: Cache returns previous result within TTL + Given a health service with a 100ms cache TTL + And a healthy provider + When health is checked twice within 50ms + Then the provider should only be called once + + Scenario: Force refresh bypasses cache + Given a health service with cached results + When health is checked with force refresh + Then the provider should be called again + + Scenario: Status change emits event + Given a health service with a provider that transitions from healthy to unhealthy + When health is checked after the transition + Then a health status changed event should be emitted diff --git a/features/reload_contract.feature b/features/reload_contract.feature new file mode 100644 index 00000000..4c3388ec --- /dev/null +++ b/features/reload_contract.feature @@ -0,0 +1,38 @@ +Feature: Dynamic Reload Contract + Modules implementing Reloadable must follow these behavioral contracts. + + Scenario: Successful reload applies changes to all reloadable modules + Given a reload orchestrator with 3 reloadable modules + When a reload is requested with configuration changes + Then all 3 modules should receive the changes + And a reload completed event should be emitted + + Scenario: Module refusing reload is skipped + Given a reload orchestrator with a module that cannot reload + When a reload is requested + Then the non-reloadable module should be skipped + And other modules should still be reloaded + + Scenario: Partial failure triggers rollback + Given a reload orchestrator with 3 modules where the second fails + When a reload is requested + Then the first module should be rolled back + And a reload failed event should be emitted + + Scenario: Circuit breaker activates after repeated failures + Given a reload orchestrator with a failing module + When 3 consecutive reloads fail + Then subsequent reload requests should be rejected + And the circuit breaker should eventually reset + + Scenario: Empty diff produces noop event + Given a reload orchestrator with reloadable modules + When a reload is requested with no changes + Then a reload noop event should be emitted + And no modules should be called + + Scenario: Concurrent reload requests are serialized + Given a reload orchestrator with reloadable modules + When 10 reload requests are submitted concurrently + Then all requests should be processed + And no race conditions should occur diff --git a/health.go b/health.go new file mode 100644 index 00000000..7344e98f --- /dev/null +++ b/health.go @@ -0,0 +1,141 @@ +package modular + +import ( + "context" + "fmt" + "time" +) + +// HealthStatus represents the health state of a component. +type HealthStatus int + +const ( + // StatusUnknown indicates the health state has not been determined. + StatusUnknown HealthStatus = iota + // StatusHealthy indicates the component is functioning normally. + StatusHealthy + // StatusDegraded indicates the component is functioning with reduced capability. + StatusDegraded + // StatusUnhealthy indicates the component is not functioning. + StatusUnhealthy +) + +// String returns the string representation of a HealthStatus. +func (s HealthStatus) String() string { + switch s { + case StatusUnknown: + return "unknown" + case StatusHealthy: + return "healthy" + case StatusDegraded: + return "degraded" + case StatusUnhealthy: + return "unhealthy" + default: + return "unknown" + } +} + +// IsHealthy returns true if the status is StatusHealthy. +func (s HealthStatus) IsHealthy() bool { + return s == StatusHealthy +} + +// HealthProvider is an interface for components that can report their health. +type HealthProvider interface { + HealthCheck(ctx context.Context) ([]HealthReport, error) +} + +// HealthReport represents the health status of a single component. +type HealthReport struct { + Module string + Component string + Status HealthStatus + Message string + CheckedAt time.Time + ObservedSince time.Time + Optional bool + Details map[string]any +} + +// AggregatedHealth represents the combined health of all providers. +type AggregatedHealth struct { + Readiness HealthStatus + Health HealthStatus + Reports []HealthReport + GeneratedAt time.Time +} + +// forceHealthRefreshKeyType is an unexported type for context key safety. +type forceHealthRefreshKeyType struct{} + +// ForceHealthRefreshKey is the context key used to force a health refresh, +// bypassing the cache. Usage: context.WithValue(ctx, modular.ForceHealthRefreshKey, true) +var ForceHealthRefreshKey = forceHealthRefreshKeyType{} + +// simpleHealthProvider adapts a function into a HealthProvider. +type simpleHealthProvider struct { + module string + component string + fn func(ctx context.Context) (HealthStatus, string, error) +} + +// NewSimpleHealthProvider creates a HealthProvider from a function that returns +// a status, message, and error. +func NewSimpleHealthProvider(module, component string, fn func(ctx context.Context) (HealthStatus, string, error)) HealthProvider { + return &simpleHealthProvider{module: module, component: component, fn: fn} +} + +func (p *simpleHealthProvider) HealthCheck(ctx context.Context) ([]HealthReport, error) { + status, msg, err := p.fn(ctx) + report := HealthReport{ + Module: p.module, + Component: p.component, + Status: status, + Message: msg, + CheckedAt: time.Now(), + } + return []HealthReport{report}, err +} + +// staticHealthProvider returns fixed reports. +type staticHealthProvider struct { + reports []HealthReport +} + +// NewStaticHealthProvider creates a HealthProvider that always returns the given reports. +func NewStaticHealthProvider(reports ...HealthReport) HealthProvider { + return &staticHealthProvider{reports: reports} +} + +func (p *staticHealthProvider) HealthCheck(_ context.Context) ([]HealthReport, error) { + now := time.Now() + result := make([]HealthReport, len(p.reports)) + copy(result, p.reports) + for i := range result { + result[i].CheckedAt = now + } + return result, nil +} + +// compositeHealthProvider aggregates multiple providers into one. +type compositeHealthProvider struct { + providers []HealthProvider +} + +// NewCompositeHealthProvider creates a HealthProvider that delegates to multiple providers. +func NewCompositeHealthProvider(providers ...HealthProvider) HealthProvider { + return &compositeHealthProvider{providers: providers} +} + +func (p *compositeHealthProvider) HealthCheck(ctx context.Context) ([]HealthReport, error) { + var all []HealthReport + for _, provider := range p.providers { + reports, err := provider.HealthCheck(ctx) + if err != nil { + return all, fmt.Errorf("composite health check: %w", err) + } + all = append(all, reports...) + } + return all, nil +} diff --git a/health_contract_bdd_test.go b/health_contract_bdd_test.go new file mode 100644 index 00000000..517bae85 --- /dev/null +++ b/health_contract_bdd_test.go @@ -0,0 +1,456 @@ +package modular + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + cloudevents "github.com/cloudevents/sdk-go/v2" + "github.com/cucumber/godog" +) + +// Static errors for health contract BDD tests. +var ( + errExpectedOverallHealthy = errors.New("expected overall status to be healthy") + errExpectedOverallUnhealthy = errors.New("expected overall status to be unhealthy") + errExpectedReadinessHealthy = errors.New("expected readiness to be healthy") + errExpectedReadinessUnhealthy = errors.New("expected readiness to be unhealthy") + errExpectedPanicUnhealthy = errors.New("expected panicking provider to report unhealthy") + errExpectedOtherProvidersChecked = errors.New("expected other providers to still be checked") + errExpectedDegradedStatus = errors.New("expected provider status to be degraded") + errExpectedSingleCall = errors.New("expected provider to be called only once") + errExpectedRefreshCall = errors.New("expected provider to be called again on refresh") + errExpectedStatusChangedEvent = errors.New("expected health status changed event") +) + +// healthBDDProvider is a configurable mock HealthProvider for BDD tests. +type healthBDDProvider struct { + reports []HealthReport + err error + callCount atomic.Int32 + panicMsg string + mu sync.Mutex +} + +func (p *healthBDDProvider) HealthCheck(_ context.Context) ([]HealthReport, error) { + p.callCount.Add(1) + if p.panicMsg != "" { + panic(p.panicMsg) + } + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil { + return nil, p.err + } + reports := make([]HealthReport, len(p.reports)) + copy(reports, p.reports) + for i := range reports { + reports[i].CheckedAt = time.Now() + } + return reports, nil +} + +func (p *healthBDDProvider) setReports(reports []HealthReport) { + p.mu.Lock() + defer p.mu.Unlock() + p.reports = reports +} + +// bddTemporaryError implements the Temporary() bool interface for degraded status. +type bddTemporaryError struct { + msg string +} + +func (e *bddTemporaryError) Error() string { return e.msg } +func (e *bddTemporaryError) Temporary() bool { return true } + +// healthBDDSubject captures events for BDD health contract tests. +type healthBDDSubject struct { + mu sync.Mutex + events []cloudevents.Event +} + +func (s *healthBDDSubject) RegisterObserver(_ Observer, _ ...string) error { return nil } +func (s *healthBDDSubject) UnregisterObserver(_ Observer) error { return nil } +func (s *healthBDDSubject) GetObservers() []ObserverInfo { return nil } +func (s *healthBDDSubject) NotifyObservers(_ context.Context, event cloudevents.Event) error { + s.mu.Lock() + s.events = append(s.events, event) + s.mu.Unlock() + return nil +} + +func (s *healthBDDSubject) hasEventType(eventType string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, e := range s.events { + if e.Type() == eventType { + return true + } + } + return false +} + +func (s *healthBDDSubject) reset() { + s.mu.Lock() + s.events = nil + s.mu.Unlock() +} + +// HealthBDDContext holds state for health contract BDD scenarios. +type HealthBDDContext struct { + service *AggregateHealthService + subject *healthBDDSubject + providers map[string]*healthBDDProvider + result *AggregatedHealth + checkErr error +} + +func (hc *HealthBDDContext) reset() { + hc.subject = &healthBDDSubject{} + hc.providers = make(map[string]*healthBDDProvider) + hc.service = nil + hc.result = nil + hc.checkErr = nil +} + +func (hc *HealthBDDContext) ensureService() { + if hc.service == nil { + hc.service = NewAggregateHealthService( + WithSubject(hc.subject), + WithCacheTTL(250*time.Millisecond), + ) + } +} + +// Step definitions + +func (hc *HealthBDDContext) aHealthServiceWithOneHealthyProvider() error { + hc.ensureService() + p := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "healthy-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + }}, + } + hc.providers["healthy"] = p + hc.service.AddProvider("healthy", p) + return nil +} + +func (hc *HealthBDDContext) healthIsChecked() error { + hc.result, hc.checkErr = hc.service.Check(context.Background()) + return nil +} + +func (hc *HealthBDDContext) theOverallStatusShouldBe(expected string) error { + if hc.result.Health.String() != expected { + if expected == "healthy" { + return errExpectedOverallHealthy + } + return errExpectedOverallUnhealthy + } + return nil +} + +func (hc *HealthBDDContext) readinessShouldBe(expected string) error { + if hc.result.Readiness.String() != expected { + if expected == "healthy" { + return errExpectedReadinessHealthy + } + return errExpectedReadinessUnhealthy + } + return nil +} + +func (hc *HealthBDDContext) aHealthServiceWithOneHealthyAndOneUnhealthyProvider() error { + hc.ensureService() + healthy := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "healthy-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + }}, + } + unhealthy := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "unhealthy-mod", + Component: "main", + Status: StatusUnhealthy, + Message: "down", + }}, + } + hc.providers["healthy"] = healthy + hc.providers["unhealthy"] = unhealthy + hc.service.AddProvider("healthy", healthy) + hc.service.AddProvider("unhealthy", unhealthy) + return nil +} + +func (hc *HealthBDDContext) theOverallHealthShouldBe(expected string) error { + return hc.theOverallStatusShouldBe(expected) +} + +func (hc *HealthBDDContext) aHealthServiceWithOneHealthyRequiredAndOneUnhealthyOptionalProvider() error { + hc.ensureService() + required := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "required-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + Optional: false, + }}, + } + optional := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "optional-mod", + Component: "aux", + Status: StatusUnhealthy, + Message: "not critical", + Optional: true, + }}, + } + hc.providers["required"] = required + hc.providers["optional"] = optional + hc.service.AddProvider("required", required) + hc.service.AddProvider("optional", optional) + return nil +} + +func (hc *HealthBDDContext) aHealthServiceWithAProviderThatPanics() error { + hc.ensureService() + panicker := &healthBDDProvider{ + panicMsg: "something went terribly wrong", + } + stable := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "stable-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + }}, + } + hc.providers["panicker"] = panicker + hc.providers["stable"] = stable + hc.service.AddProvider("panicker", panicker) + hc.service.AddProvider("stable", stable) + return nil +} + +func (hc *HealthBDDContext) thePanickingProviderShouldReport(expected string) error { + for _, r := range hc.result.Reports { + if r.Component == "panic-recovery" { + if r.Status.String() != expected { + return errExpectedPanicUnhealthy + } + return nil + } + } + return errExpectedPanicUnhealthy +} + +func (hc *HealthBDDContext) otherProvidersShouldStillBeChecked() error { + for _, r := range hc.result.Reports { + if r.Module == "stable-mod" { + return nil + } + } + return errExpectedOtherProvidersChecked +} + +func (hc *HealthBDDContext) aHealthServiceWithAProviderReturningATemporaryError() error { + hc.ensureService() + p := &healthBDDProvider{ + err: &bddTemporaryError{msg: "transient issue"}, + } + hc.providers["temp-err"] = p + hc.service.AddProvider("temp-err", p) + return nil +} + +func (hc *HealthBDDContext) theProviderStatusShouldBe(expected string) error { + for _, r := range hc.result.Reports { + if r.Status.String() == expected { + return nil + } + } + return errExpectedDegradedStatus +} + +func (hc *HealthBDDContext) aHealthServiceWithA100msCacheTTL() error { + hc.service = NewAggregateHealthService( + WithSubject(hc.subject), + WithCacheTTL(100*time.Millisecond), + ) + return nil +} + +func (hc *HealthBDDContext) aHealthyProvider() error { + p := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "cached-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + }}, + } + hc.providers["cached"] = p + hc.service.AddProvider("cached", p) + return nil +} + +func (hc *HealthBDDContext) healthIsCheckedTwiceWithin50ms() error { + hc.result, hc.checkErr = hc.service.Check(context.Background()) + if hc.checkErr != nil { + return hc.checkErr + } + // Second check within cache TTL + time.Sleep(10 * time.Millisecond) + hc.result, hc.checkErr = hc.service.Check(context.Background()) + return nil +} + +func (hc *HealthBDDContext) theProviderShouldOnlyBeCalledOnce() error { + p := hc.providers["cached"] + if p.callCount.Load() != 1 { + return errExpectedSingleCall + } + return nil +} + +func (hc *HealthBDDContext) aHealthServiceWithCachedResults() error { + hc.service = NewAggregateHealthService( + WithSubject(hc.subject), + WithCacheTTL(10*time.Second), + ) + p := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "refresh-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + }}, + } + hc.providers["refresh"] = p + hc.service.AddProvider("refresh", p) + // Prime the cache + _, _ = hc.service.Check(context.Background()) + return nil +} + +func (hc *HealthBDDContext) healthIsCheckedWithForceRefresh() error { + ctx := context.WithValue(context.Background(), ForceHealthRefreshKey, true) + hc.result, hc.checkErr = hc.service.Check(ctx) + return nil +} + +func (hc *HealthBDDContext) theProviderShouldBeCalledAgain() error { + p := hc.providers["refresh"] + if p.callCount.Load() < 2 { + return errExpectedRefreshCall + } + return nil +} + +func (hc *HealthBDDContext) aHealthServiceWithAProviderThatTransitionsFromHealthyToUnhealthy() error { + hc.ensureService() + p := &healthBDDProvider{ + reports: []HealthReport{{ + Module: "transitioning-mod", + Component: "main", + Status: StatusHealthy, + Message: "ok", + }}, + } + hc.providers["transitioning"] = p + hc.service.AddProvider("transitioning", p) + + // Do initial check to establish healthy baseline, then invalidate cache. + _, _ = hc.service.Check(context.Background()) + hc.service.invalidateCache() + + // Transition to unhealthy. + p.setReports([]HealthReport{{ + Module: "transitioning-mod", + Component: "main", + Status: StatusUnhealthy, + Message: "went down", + }}) + return nil +} + +func (hc *HealthBDDContext) healthIsCheckedAfterTheTransition() error { + hc.result, hc.checkErr = hc.service.Check(context.Background()) + return nil +} + +func (hc *HealthBDDContext) aHealthStatusChangedEventShouldBeEmitted() error { + if hc.subject.hasEventType(EventTypeHealthStatusChanged) { + return nil + } + return errExpectedStatusChangedEvent +} + +// InitializeHealthContractScenario wires up all health contract BDD steps. +func InitializeHealthContractScenario(ctx *godog.ScenarioContext) { + hc := &HealthBDDContext{} + + ctx.Before(func(ctx context.Context, _ *godog.Scenario) (context.Context, error) { + hc.reset() + return ctx, nil + }) + + ctx.Step(`^a health service with one healthy provider$`, hc.aHealthServiceWithOneHealthyProvider) + ctx.Step(`^health is checked$`, hc.healthIsChecked) + ctx.Step(`^the overall status should be "([^"]*)"$`, hc.theOverallStatusShouldBe) + ctx.Step(`^readiness should be "([^"]*)"$`, hc.readinessShouldBe) + + ctx.Step(`^a health service with one healthy and one unhealthy provider$`, hc.aHealthServiceWithOneHealthyAndOneUnhealthyProvider) + ctx.Step(`^the overall health should be "([^"]*)"$`, hc.theOverallHealthShouldBe) + + ctx.Step(`^a health service with one healthy required and one unhealthy optional provider$`, hc.aHealthServiceWithOneHealthyRequiredAndOneUnhealthyOptionalProvider) + + ctx.Step(`^a health service with a provider that panics$`, hc.aHealthServiceWithAProviderThatPanics) + ctx.Step(`^the panicking provider should report "([^"]*)"$`, hc.thePanickingProviderShouldReport) + ctx.Step(`^other providers should still be checked$`, hc.otherProvidersShouldStillBeChecked) + + ctx.Step(`^a health service with a provider returning a temporary error$`, hc.aHealthServiceWithAProviderReturningATemporaryError) + ctx.Step(`^the provider status should be "([^"]*)"$`, hc.theProviderStatusShouldBe) + + ctx.Step(`^a health service with a 100ms cache TTL$`, hc.aHealthServiceWithA100msCacheTTL) + ctx.Step(`^a healthy provider$`, hc.aHealthyProvider) + ctx.Step(`^health is checked twice within 50ms$`, hc.healthIsCheckedTwiceWithin50ms) + ctx.Step(`^the provider should only be called once$`, hc.theProviderShouldOnlyBeCalledOnce) + + ctx.Step(`^a health service with cached results$`, hc.aHealthServiceWithCachedResults) + ctx.Step(`^health is checked with force refresh$`, hc.healthIsCheckedWithForceRefresh) + ctx.Step(`^the provider should be called again$`, hc.theProviderShouldBeCalledAgain) + + ctx.Step(`^a health service with a provider that transitions from healthy to unhealthy$`, hc.aHealthServiceWithAProviderThatTransitionsFromHealthyToUnhealthy) + ctx.Step(`^health is checked after the transition$`, hc.healthIsCheckedAfterTheTransition) + ctx.Step(`^a health status changed event should be emitted$`, hc.aHealthStatusChangedEventShouldBeEmitted) +} + +// TestHealthContractBDD runs the BDD tests for the health contract. +func TestHealthContractBDD(t *testing.T) { + suite := godog.TestSuite{ + ScenarioInitializer: InitializeHealthContractScenario, + Options: &godog.Options{ + Format: "pretty", + Paths: []string{"features/health_contract.feature"}, + TestingT: t, + Strict: true, + }, + } + + if suite.Run() != 0 { + t.Fatal("non-zero status returned, failed to run health contract feature tests") + } +} diff --git a/health_service.go b/health_service.go new file mode 100644 index 00000000..e23ef08a --- /dev/null +++ b/health_service.go @@ -0,0 +1,294 @@ +package modular + +import ( + "context" + "fmt" + "sync" + "time" +) + +// AggregateHealthService collects health reports from registered providers +// and produces an aggregated health result with caching and event emission. +type AggregateHealthService struct { + providers map[string]HealthProvider + mu sync.RWMutex + cache *AggregatedHealth + cacheMu sync.RWMutex + cacheExpiry time.Time + cacheTTL time.Duration + lastStatus HealthStatus + subject Subject + logger Logger +} + +// HealthServiceOption configures an AggregateHealthService. +type HealthServiceOption func(*AggregateHealthService) + +// WithCacheTTL sets the cache time-to-live for health check results. +func WithCacheTTL(d time.Duration) HealthServiceOption { + return func(s *AggregateHealthService) { + s.cacheTTL = d + } +} + +// WithSubject sets the event subject for health event emission. +func WithSubject(sub Subject) HealthServiceOption { + return func(s *AggregateHealthService) { + s.subject = sub + } +} + +// WithHealthLogger sets the structured logger for the health service. +func WithHealthLogger(l Logger) HealthServiceOption { + return func(s *AggregateHealthService) { + s.logger = l + } +} + +// NewAggregateHealthService creates a new AggregateHealthService with the given options. +func NewAggregateHealthService(opts ...HealthServiceOption) *AggregateHealthService { + svc := &AggregateHealthService{ + providers: make(map[string]HealthProvider), + cacheTTL: 250 * time.Millisecond, + lastStatus: StatusUnknown, + } + for _, opt := range opts { + opt(svc) + } + return svc +} + +// AddProvider registers a named health provider and invalidates the cache. +func (s *AggregateHealthService) AddProvider(name string, provider HealthProvider) { + s.mu.Lock() + s.providers[name] = provider + s.mu.Unlock() + s.invalidateCache() +} + +// RemoveProvider removes a named health provider and invalidates the cache. +func (s *AggregateHealthService) RemoveProvider(name string) { + s.mu.Lock() + delete(s.providers, name) + s.mu.Unlock() + s.invalidateCache() +} + +func (s *AggregateHealthService) invalidateCache() { + s.cacheMu.Lock() + s.cache = nil + s.cacheExpiry = time.Time{} + s.cacheMu.Unlock() +} + +// providerResult is used to collect results from concurrent provider checks. +type providerResult struct { + reports []HealthReport + err error + name string +} + +// Check evaluates all registered providers and returns an aggregated health result. +// Results are cached for the configured TTL unless ForceHealthRefreshKey is set in the context. +// The returned AggregatedHealth is a deep copy and safe to mutate. +func (s *AggregateHealthService) Check(ctx context.Context) (*AggregatedHealth, error) { + // Check cache validity + forceRefresh, _ := ctx.Value(ForceHealthRefreshKey).(bool) + if !forceRefresh { + s.cacheMu.RLock() + if s.cache != nil && time.Now().Before(s.cacheExpiry) { + copied := s.deepCopyAggregated(s.cache) + s.cacheMu.RUnlock() + return copied, nil + } + s.cacheMu.RUnlock() + } + + // Snapshot providers under read lock + s.mu.RLock() + providers := make(map[string]HealthProvider, len(s.providers)) + for k, v := range s.providers { + providers[k] = v + } + s.mu.RUnlock() + + // Fan-out to all providers + ch := make(chan providerResult, len(providers)) + for name, provider := range providers { + go func(name string, provider HealthProvider) { + result := providerResult{name: name} + defer func() { + if r := recover(); r != nil { + result.reports = []HealthReport{{ + Module: name, + Component: "panic-recovery", + Status: StatusUnhealthy, + Message: fmt.Sprintf("provider panicked: %v", r), + CheckedAt: time.Now(), + }} + result.err = nil + ch <- result + } + }() + reports, err := provider.HealthCheck(ctx) + result.reports = reports + result.err = err + ch <- result + }(name, provider) + } + + // Collect results + var allReports []HealthReport + readiness := StatusHealthy + health := StatusHealthy + + for range len(providers) { + var result providerResult + select { + case result = <-ch: + case <-ctx.Done(): + return nil, fmt.Errorf("health check interrupted: %w", ctx.Err()) + } + + if result.err != nil { + // Check if error is temporary + status := StatusUnhealthy + if te, ok := result.err.(interface{ Temporary() bool }); ok && te.Temporary() { + status = StatusDegraded + } + // Add error report + allReports = append(allReports, HealthReport{ + Module: result.name, + Component: "error", + Status: status, + Message: result.err.Error(), + CheckedAt: time.Now(), + }) + readiness = worstStatus(readiness, status) + health = worstStatus(health, status) + continue + } + + for _, report := range result.reports { + allReports = append(allReports, report) + health = worstStatus(health, report.Status) + if !report.Optional { + readiness = worstStatus(readiness, report.Status) + } + } + } + + aggregated := &AggregatedHealth{ + Readiness: readiness, + Health: health, + Reports: allReports, + GeneratedAt: time.Now(), + } + + // Cache result + s.cacheMu.Lock() + s.cache = aggregated + s.cacheExpiry = time.Now().Add(s.cacheTTL) + s.cacheMu.Unlock() + + // Emit events + s.emitHealthEvaluated(ctx, aggregated) + + s.cacheMu.Lock() + previousStatus := s.lastStatus + s.lastStatus = aggregated.Health + s.cacheMu.Unlock() + + if previousStatus != aggregated.Health { + s.emitHealthStatusChanged(ctx, previousStatus, aggregated.Health) + } + + return s.deepCopyAggregated(aggregated), nil +} + +// deepCopyAggregated returns a deep copy of an AggregatedHealth, including +// reports and their Details maps, so callers cannot mutate cached state. +func (s *AggregateHealthService) deepCopyAggregated(src *AggregatedHealth) *AggregatedHealth { + if src == nil { + return nil + } + dst := &AggregatedHealth{ + Readiness: src.Readiness, + Health: src.Health, + GeneratedAt: src.GeneratedAt, + Reports: make([]HealthReport, len(src.Reports)), + } + for i, r := range src.Reports { + dst.Reports[i] = r + if r.Details != nil { + dst.Reports[i].Details = make(map[string]any, len(r.Details)) + for k, v := range r.Details { + dst.Reports[i].Details[k] = v + } + } + } + return dst +} + +func (s *AggregateHealthService) emitHealthEvaluated(ctx context.Context, agg *AggregatedHealth) { + if s.subject == nil { + return + } + event := NewCloudEvent(EventTypeHealthEvaluated, "modular/health-service", map[string]any{ + "readiness": agg.Readiness.String(), + "health": agg.Health.String(), + "report_count": len(agg.Reports), + }, nil) + if err := s.subject.NotifyObservers(ctx, event); err != nil && s.logger != nil { + s.logger.Debug("Failed to emit health evaluated event", "error", err) + } +} + +func (s *AggregateHealthService) emitHealthStatusChanged(ctx context.Context, from, to HealthStatus) { + if s.subject == nil { + return + } + event := NewCloudEvent(EventTypeHealthStatusChanged, "modular/health-service", map[string]any{ + "previous_status": from.String(), + "current_status": to.String(), + }, nil) + if err := s.subject.NotifyObservers(ctx, event); err != nil && s.logger != nil { + s.logger.Debug("Failed to emit health status changed event", "error", err) + } +} + +// worstStatus returns the worse of two health statuses. +// StatusUnknown is treated as StatusUnhealthy for aggregation purposes: +// if either status is Unknown, it is mapped to Unhealthy in the result +// so that the aggregated output consistently reflects unhealthy severity. +func worstStatus(a, b HealthStatus) HealthStatus { + ar := normalizeForAggregation(a) + br := normalizeForAggregation(b) + var winner HealthStatus + if ar >= br { + winner = a + } else { + winner = b + } + // Map Unknown → Unhealthy so aggregated health never reports "unknown". + if winner == StatusUnknown { + return StatusUnhealthy + } + return winner +} + +// normalizeForAggregation maps StatusUnknown to StatusUnhealthy severity for comparison. +func normalizeForAggregation(s HealthStatus) int { + switch s { + case StatusHealthy: + return 0 + case StatusDegraded: + return 1 + case StatusUnhealthy: + return 2 + case StatusUnknown: + return 2 // Unknown treated as Unhealthy + default: + return 2 + } +} diff --git a/health_test.go b/health_test.go new file mode 100644 index 00000000..5dec4945 --- /dev/null +++ b/health_test.go @@ -0,0 +1,437 @@ +package modular + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + cloudevents "github.com/cloudevents/sdk-go/v2" +) + +func TestHealthStatus_String(t *testing.T) { + tests := []struct { + status HealthStatus + want string + }{ + {StatusUnknown, "unknown"}, + {StatusHealthy, "healthy"}, + {StatusDegraded, "degraded"}, + {StatusUnhealthy, "unhealthy"}, + {HealthStatus(99), "unknown"}, + } + for _, tt := range tests { + if got := tt.status.String(); got != tt.want { + t.Errorf("HealthStatus(%d).String() = %q, want %q", tt.status, got, tt.want) + } + } +} + +func TestHealthStatus_IsHealthy(t *testing.T) { + if !StatusHealthy.IsHealthy() { + t.Error("StatusHealthy.IsHealthy() should be true") + } + for _, s := range []HealthStatus{StatusUnknown, StatusDegraded, StatusUnhealthy} { + if s.IsHealthy() { + t.Errorf("%v.IsHealthy() should be false", s) + } + } +} + +func TestSimpleHealthProvider(t *testing.T) { + provider := NewSimpleHealthProvider("mymod", "db", func(_ context.Context) (HealthStatus, string, error) { + return StatusHealthy, "all good", nil + }) + reports, err := provider.HealthCheck(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(reports) != 1 { + t.Fatalf("expected 1 report, got %d", len(reports)) + } + r := reports[0] + if r.Module != "mymod" || r.Component != "db" || r.Status != StatusHealthy || r.Message != "all good" { + t.Errorf("unexpected report: %+v", r) + } + if r.CheckedAt.IsZero() { + t.Error("CheckedAt should be set") + } +} + +func TestStaticHealthProvider(t *testing.T) { + report := HealthReport{ + Module: "static", + Component: "cache", + Status: StatusDegraded, + Message: "warming up", + } + provider := NewStaticHealthProvider(report) + reports, err := provider.HealthCheck(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(reports) != 1 { + t.Fatalf("expected 1 report, got %d", len(reports)) + } + if reports[0].Status != StatusDegraded { + t.Errorf("expected degraded, got %v", reports[0].Status) + } + if reports[0].CheckedAt.IsZero() { + t.Error("CheckedAt should be set by static provider") + } +} + +func TestCompositeHealthProvider(t *testing.T) { + p1 := NewStaticHealthProvider(HealthReport{Module: "a", Component: "1", Status: StatusHealthy}) + p2 := NewStaticHealthProvider(HealthReport{Module: "b", Component: "2", Status: StatusDegraded}) + composite := NewCompositeHealthProvider(p1, p2) + + reports, err := composite.HealthCheck(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(reports) != 2 { + t.Fatalf("expected 2 reports, got %d", len(reports)) + } +} + +// testSubject is a minimal Subject implementation for testing event emission. +type testSubject struct { + mu sync.Mutex + events []cloudevents.Event +} + +func (s *testSubject) RegisterObserver(_ Observer, _ ...string) error { return nil } +func (s *testSubject) UnregisterObserver(_ Observer) error { return nil } +func (s *testSubject) GetObservers() []ObserverInfo { return nil } +func (s *testSubject) NotifyObservers(_ context.Context, event cloudevents.Event) error { + s.mu.Lock() + s.events = append(s.events, event) + s.mu.Unlock() + return nil +} +func (s *testSubject) getEvents() []cloudevents.Event { + s.mu.Lock() + defer s.mu.Unlock() + result := make([]cloudevents.Event, len(s.events)) + copy(result, s.events) + return result +} + +func TestAggregateHealthService_SingleProvider(t *testing.T) { + svc := NewAggregateHealthService() + svc.AddProvider("db", NewStaticHealthProvider(HealthReport{ + Module: "db", Component: "conn", Status: StatusHealthy, Message: "ok", + })) + + result, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Health != StatusHealthy { + t.Errorf("expected healthy, got %v", result.Health) + } + if result.Readiness != StatusHealthy { + t.Errorf("expected readiness healthy, got %v", result.Readiness) + } + if len(result.Reports) != 1 { + t.Errorf("expected 1 report, got %d", len(result.Reports)) + } +} + +func TestAggregateHealthService_MultipleProviders(t *testing.T) { + svc := NewAggregateHealthService() + svc.AddProvider("db", NewStaticHealthProvider(HealthReport{ + Module: "db", Component: "conn", Status: StatusHealthy, + })) + svc.AddProvider("cache", NewStaticHealthProvider(HealthReport{ + Module: "cache", Component: "redis", Status: StatusDegraded, + })) + + result, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Health != StatusDegraded { + t.Errorf("expected degraded health, got %v", result.Health) + } + if result.Readiness != StatusDegraded { + t.Errorf("expected degraded readiness, got %v", result.Readiness) + } + if len(result.Reports) != 2 { + t.Errorf("expected 2 reports, got %d", len(result.Reports)) + } +} + +func TestAggregateHealthService_OptionalVsRequired(t *testing.T) { + svc := NewAggregateHealthService() + svc.AddProvider("db", NewStaticHealthProvider(HealthReport{ + Module: "db", Component: "conn", Status: StatusHealthy, + })) + svc.AddProvider("metrics", NewStaticHealthProvider(HealthReport{ + Module: "metrics", Component: "export", Status: StatusUnhealthy, Optional: true, + })) + + result, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Health reflects all components (worst = unhealthy) + if result.Health != StatusUnhealthy { + t.Errorf("expected unhealthy health (includes optional), got %v", result.Health) + } + // Readiness only reflects required components (should be healthy) + if result.Readiness != StatusHealthy { + t.Errorf("expected healthy readiness (optional excluded), got %v", result.Readiness) + } +} + +func TestAggregateHealthService_CacheHit(t *testing.T) { + callCount := 0 + provider := NewSimpleHealthProvider("mod", "comp", func(_ context.Context) (HealthStatus, string, error) { + callCount++ + return StatusHealthy, "ok", nil + }) + + svc := NewAggregateHealthService(WithCacheTTL(1 * time.Second)) + svc.AddProvider("test", provider) + + // First call + _, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 1 { + t.Fatalf("expected 1 call, got %d", callCount) + } + + // Second call within TTL should be cached + _, err = svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 1 { + t.Errorf("expected 1 call (cached), got %d", callCount) + } +} + +func TestAggregateHealthService_CacheMiss(t *testing.T) { + callCount := 0 + provider := NewSimpleHealthProvider("mod", "comp", func(_ context.Context) (HealthStatus, string, error) { + callCount++ + return StatusHealthy, "ok", nil + }) + + svc := NewAggregateHealthService(WithCacheTTL(1 * time.Millisecond)) + svc.AddProvider("test", provider) + + _, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Wait for cache to expire + time.Sleep(5 * time.Millisecond) + + _, err = svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 2 { + t.Errorf("expected 2 calls after cache expiry, got %d", callCount) + } +} + +func TestAggregateHealthService_CacheInvalidation(t *testing.T) { + callCount := 0 + provider := NewSimpleHealthProvider("mod", "comp", func(_ context.Context) (HealthStatus, string, error) { + callCount++ + return StatusHealthy, "ok", nil + }) + + svc := NewAggregateHealthService(WithCacheTTL(10 * time.Second)) + svc.AddProvider("test", provider) + + _, _ = svc.Check(context.Background()) + if callCount != 1 { + t.Fatalf("expected 1 call, got %d", callCount) + } + + // AddProvider should invalidate cache + svc.AddProvider("another", NewStaticHealthProvider(HealthReport{ + Module: "x", Component: "y", Status: StatusHealthy, + })) + + _, _ = svc.Check(context.Background()) + if callCount != 2 { + t.Errorf("expected 2 calls after AddProvider invalidation, got %d", callCount) + } + + // RemoveProvider should also invalidate + svc.RemoveProvider("another") + _, _ = svc.Check(context.Background()) + if callCount != 3 { + t.Errorf("expected 3 calls after RemoveProvider invalidation, got %d", callCount) + } +} + +func TestAggregateHealthService_ForceRefresh(t *testing.T) { + callCount := 0 + provider := NewSimpleHealthProvider("mod", "comp", func(_ context.Context) (HealthStatus, string, error) { + callCount++ + return StatusHealthy, "ok", nil + }) + + svc := NewAggregateHealthService(WithCacheTTL(10 * time.Second)) + svc.AddProvider("test", provider) + + _, _ = svc.Check(context.Background()) + if callCount != 1 { + t.Fatalf("expected 1 call, got %d", callCount) + } + + // Force refresh bypasses cache + ctx := context.WithValue(context.Background(), ForceHealthRefreshKey, true) + _, _ = svc.Check(ctx) + if callCount != 2 { + t.Errorf("expected 2 calls after force refresh, got %d", callCount) + } +} + +func TestAggregateHealthService_PanicRecovery(t *testing.T) { + panicProvider := NewSimpleHealthProvider("panicky", "boom", func(_ context.Context) (HealthStatus, string, error) { + panic("something went wrong") + }) + + svc := NewAggregateHealthService() + svc.AddProvider("panicky", panicProvider) + + result, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Health != StatusUnhealthy { + t.Errorf("expected unhealthy after panic, got %v", result.Health) + } + // Check that the panic report is present + found := false + for _, r := range result.Reports { + if r.Status == StatusUnhealthy && r.Component == "panic-recovery" { + found = true + break + } + } + if !found { + t.Error("expected panic recovery report in results") + } +} + +// temporaryError implements the Temporary() interface. +type temporaryError struct { + msg string +} + +func (e *temporaryError) Error() string { return e.msg } +func (e *temporaryError) Temporary() bool { return true } + +func TestAggregateHealthService_TemporaryError(t *testing.T) { + provider := NewSimpleHealthProvider("net", "conn", func(_ context.Context) (HealthStatus, string, error) { + return StatusUnknown, "", &temporaryError{msg: "connection timeout"} + }) + + svc := NewAggregateHealthService() + svc.AddProvider("net", provider) + + result, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Health != StatusDegraded { + t.Errorf("expected degraded for temporary error, got %v", result.Health) + } +} + +func TestAggregateHealthService_PermanentError(t *testing.T) { + provider := NewSimpleHealthProvider("db", "conn", func(_ context.Context) (HealthStatus, string, error) { + return StatusUnknown, "", errors.New("connection refused") + }) + + svc := NewAggregateHealthService() + svc.AddProvider("db", provider) + + result, err := svc.Check(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Health != StatusUnhealthy { + t.Errorf("expected unhealthy for permanent error, got %v", result.Health) + } +} + +func TestAggregateHealthService_EventEmission(t *testing.T) { + sub := &testSubject{} + svc := NewAggregateHealthService(WithSubject(sub)) + svc.AddProvider("db", NewStaticHealthProvider(HealthReport{ + Module: "db", Component: "conn", Status: StatusHealthy, + })) + + _, _ = svc.Check(context.Background()) + + events := sub.getEvents() + // First check: should emit evaluated + status changed (unknown -> healthy) + if len(events) < 1 { + t.Fatal("expected at least 1 event") + } + + hasEvaluated := false + hasChanged := false + for _, e := range events { + switch e.Type() { + case EventTypeHealthEvaluated: + hasEvaluated = true + case EventTypeHealthStatusChanged: + hasChanged = true + } + } + if !hasEvaluated { + t.Error("expected health evaluated event") + } + if !hasChanged { + t.Error("expected health status changed event (unknown -> healthy)") + } +} + +func TestAggregateHealthService_ConcurrentChecks(t *testing.T) { + svc := NewAggregateHealthService(WithCacheTTL(1 * time.Millisecond)) + svc.AddProvider("db", NewStaticHealthProvider(HealthReport{ + Module: "db", Component: "conn", Status: StatusHealthy, + })) + + const goroutines = 20 + var wg sync.WaitGroup + errs := make(chan error, goroutines) + + for range goroutines { + wg.Add(1) + go func() { + defer wg.Done() + result, err := svc.Check(context.Background()) + if err != nil { + errs <- err + return + } + if result == nil { + errs <- errors.New("nil result") + return + } + }() + } + + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent check error: %v", err) + } +} diff --git a/module.go b/module.go index 506f4cd9..eed64041 100644 --- a/module.go +++ b/module.go @@ -16,7 +16,10 @@ // } package modular -import "context" +import ( + "context" + "time" +) // Module represents a registrable component in the application. // All modules must implement this interface to be managed by the application. @@ -248,6 +251,28 @@ type ModuleWithConstructor interface { Constructable } +// Reloadable is an optional interface for modules that support dynamic configuration reloading. +// Modules implementing this interface can have their configuration updated at runtime +// without requiring a full application restart. +// +// The reload process is coordinated by the ReloadOrchestrator, which detects configuration +// changes, computes diffs, and calls Reload on each module that supports it. +type Reloadable interface { + // Reload applies configuration changes to the module. + // The changes slice contains only the changes relevant to this module. + // Implementations should apply changes atomically where possible. + Reload(ctx context.Context, changes []ConfigChange) error + + // CanReload reports whether the module can currently accept a reload. + // Modules may return false if they are in a state where reloading is unsafe + // (e.g., mid-transaction, shutting down). + CanReload() bool + + // ReloadTimeout returns the maximum duration allowed for a reload operation. + // The orchestrator will cancel the reload context if this timeout is exceeded. + ReloadTimeout() time.Duration +} + // ModuleRegistry represents a registry of modules keyed by their names. // This is used internally by the application to manage registered modules // and resolve dependencies between them. diff --git a/modules/letsencrypt/go.sum b/modules/letsencrypt/go.sum index b337dbdd..0330822d 100644 --- a/modules/letsencrypt/go.sum +++ b/modules/letsencrypt/go.sum @@ -31,8 +31,8 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/GoCodeAlone/modular v1.12.0 h1:C4tLfJe65rrUQsbtndiVfldtT8IRKZcHczNRNbBK4wo= github.com/GoCodeAlone/modular v1.12.0/go.mod h1:ET7mlekRjkRq9mwJdWmaC2KDUWvjla2IqKVFrYO2JnY= -github.com/GoCodeAlone/modular/modules/httpserver v1.12.0 h1:KxH4WgdEMSzSw9xY1yNwHbQ4/pGxRM9ml5psNujR6F4= -github.com/GoCodeAlone/modular/modules/httpserver v1.12.0/go.mod h1:CTV3eBq7st01TDw+sE0CjUhkr4vmG0e1j7j4EhxM6v8= +github.com/GoCodeAlone/modular/modules/httpserver v1.12.0 h1:nVaeiC59OEqMj0jcDZwIUHrba4CdPT9ntcGBAw81iKs= +github.com/GoCodeAlone/modular/modules/httpserver v1.12.0/go.mod h1:sVklMEsxKxKihMDz5Zh2RFqnwpgXd/IT9lbAVGlkWEE= github.com/aws/aws-sdk-go-v2 v1.39.0 h1:xm5WV/2L4emMRmMjHFykqiA4M/ra0DJVSWUkDyBjbg4= github.com/aws/aws-sdk-go-v2 v1.39.0/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= github.com/aws/aws-sdk-go-v2/config v1.31.8 h1:kQjtOLlTU4m4A64TsRcqwNChhGCwaPBt+zCQt/oWsHU= diff --git a/observer.go b/observer.go index 5077919b..9bb202a9 100644 --- a/observer.go +++ b/observer.go @@ -90,6 +90,19 @@ const ( EventTypeApplicationStarted = "com.modular.application.started" EventTypeApplicationStopped = "com.modular.application.stopped" EventTypeApplicationFailed = "com.modular.application.failed" + + // Tenant guard events + EventTypeTenantViolation = "com.modular.tenant.violation" + + // Configuration reload events + EventTypeConfigReloadStarted = "com.modular.config.reload.started" + EventTypeConfigReloadCompleted = "com.modular.config.reload.completed" + EventTypeConfigReloadFailed = "com.modular.config.reload.failed" + EventTypeConfigReloadNoop = "com.modular.config.reload.noop" + + // Health events + EventTypeHealthEvaluated = "com.modular.health.evaluated" + EventTypeHealthStatusChanged = "com.modular.health.status.changed" ) // ObservableModule is an optional interface that modules can implement diff --git a/reload.go b/reload.go new file mode 100644 index 00000000..bcfb1878 --- /dev/null +++ b/reload.go @@ -0,0 +1,167 @@ +package modular + +import ( + "fmt" + "strings" + "time" +) + +// ChangeType represents the type of configuration change. +type ChangeType int + +const ( + // ChangeAdded indicates a new configuration field was added. + ChangeAdded ChangeType = iota + // ChangeModified indicates an existing configuration field was modified. + ChangeModified + // ChangeRemoved indicates a configuration field was removed. + ChangeRemoved +) + +// String returns the string representation of a ChangeType. +func (ct ChangeType) String() string { + switch ct { + case ChangeAdded: + return "added" + case ChangeModified: + return "modified" + case ChangeRemoved: + return "removed" + default: + return "unknown" + } +} + +// ConfigChange represents a single configuration change detected during reload. +type ConfigChange struct { + Section string + FieldPath string + OldValue string + NewValue string + Source string +} + +// FieldChange represents a detailed field-level change with validation metadata. +type FieldChange struct { + OldValue any + NewValue any + FieldPath string + ChangeType ChangeType + IsSensitive bool + ValidationResult error +} + +// ConfigDiff represents the complete set of configuration changes between two states. +type ConfigDiff struct { + Changed map[string]FieldChange + Added map[string]FieldChange + Removed map[string]FieldChange + Timestamp time.Time + DiffID string +} + +// HasChanges reports whether the diff contains any changes. +func (d ConfigDiff) HasChanges() bool { + return len(d.Changed) > 0 || len(d.Added) > 0 || len(d.Removed) > 0 +} + +// FilterByPrefix returns a new ConfigDiff containing only changes whose field paths +// start with the given prefix. +func (d ConfigDiff) FilterByPrefix(prefix string) ConfigDiff { + filtered := ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + Timestamp: d.Timestamp, + DiffID: d.DiffID, + } + for k, v := range d.Changed { + if strings.HasPrefix(k, prefix) { + filtered.Changed[k] = v + } + } + for k, v := range d.Added { + if strings.HasPrefix(k, prefix) { + filtered.Added[k] = v + } + } + for k, v := range d.Removed { + if strings.HasPrefix(k, prefix) { + filtered.Removed[k] = v + } + } + return filtered +} + +// RedactSensitiveFields returns a copy of the diff with sensitive field values replaced +// by a redaction placeholder. +func (d ConfigDiff) RedactSensitiveFields() ConfigDiff { + redacted := ConfigDiff{ + Changed: make(map[string]FieldChange, len(d.Changed)), + Added: make(map[string]FieldChange, len(d.Added)), + Removed: make(map[string]FieldChange, len(d.Removed)), + Timestamp: d.Timestamp, + DiffID: d.DiffID, + } + redactMap := func(src map[string]FieldChange, dst map[string]FieldChange) { + for k, v := range src { + if v.IsSensitive { + v.OldValue = "[REDACTED]" + v.NewValue = "[REDACTED]" + } + dst[k] = v + } + } + redactMap(d.Changed, redacted.Changed) + redactMap(d.Added, redacted.Added) + redactMap(d.Removed, redacted.Removed) + return redacted +} + +// ChangeSummary returns a human-readable summary of all changes in the diff. +func (d ConfigDiff) ChangeSummary() string { + if !d.HasChanges() { + return "no changes" + } + var parts []string + if n := len(d.Added); n > 0 { + parts = append(parts, fmt.Sprintf("%d added", n)) + } + if n := len(d.Changed); n > 0 { + parts = append(parts, fmt.Sprintf("%d modified", n)) + } + if n := len(d.Removed); n > 0 { + parts = append(parts, fmt.Sprintf("%d removed", n)) + } + return strings.Join(parts, ", ") +} + +// ReloadTrigger indicates what initiated a configuration reload. +type ReloadTrigger int + +const ( + // ReloadManual indicates a reload triggered by an explicit API or CLI call. + ReloadManual ReloadTrigger = iota + // ReloadFileChange indicates a reload triggered by a file system change. + ReloadFileChange + // ReloadAPIRequest indicates a reload triggered by an API request. + ReloadAPIRequest + // ReloadScheduled indicates a reload triggered by a periodic schedule. + ReloadScheduled +) + +// String returns the string representation of a ReloadTrigger. +func (rt ReloadTrigger) String() string { + switch rt { + case ReloadManual: + return "manual" + case ReloadFileChange: + return "file_change" + case ReloadAPIRequest: + return "api_request" + case ReloadScheduled: + return "scheduled" + default: + return "unknown" + } +} diff --git a/reload_contract_bdd_test.go b/reload_contract_bdd_test.go new file mode 100644 index 00000000..5bdc3615 --- /dev/null +++ b/reload_contract_bdd_test.go @@ -0,0 +1,488 @@ +package modular + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + cloudevents "github.com/cloudevents/sdk-go/v2" + "github.com/cucumber/godog" +) + +// Static errors for reload contract BDD tests. +var ( + errExpectedModuleReceiveChanges = errors.New("expected module to receive changes") + errExpectedCompletedEvent = errors.New("expected reload completed event") + errExpectedFailedEvent = errors.New("expected reload failed event") + errExpectedNoopEvent = errors.New("expected reload noop event") + errExpectedModuleSkipped = errors.New("expected non-reloadable module to be skipped") + errExpectedOtherModulesReloaded = errors.New("expected other modules to still be reloaded") + errExpectedRollback = errors.New("expected first module to be rolled back") + errExpectedCircuitBreakerReject = errors.New("expected circuit breaker to reject request") + errExpectedCircuitBreakerReset = errors.New("expected circuit breaker to eventually reset") + errExpectedNoModuleCalls = errors.New("expected no modules to be called") + errExpectedRequestsProcessed = errors.New("expected all requests to be processed") +) + +// reloadBDDMockReloadable is a mock Reloadable for BDD reload contract tests. +type reloadBDDMockReloadable struct { + name string + canReload bool + timeout time.Duration + reloadErr error + reloadCalls atomic.Int32 + lastChanges []ConfigChange + mu sync.Mutex +} + +func (m *reloadBDDMockReloadable) Reload(_ context.Context, changes []ConfigChange) error { + m.reloadCalls.Add(1) + m.mu.Lock() + m.lastChanges = changes + m.mu.Unlock() + return m.reloadErr +} + +func (m *reloadBDDMockReloadable) CanReload() bool { return m.canReload } +func (m *reloadBDDMockReloadable) ReloadTimeout() time.Duration { return m.timeout } + +// reloadBDDSubject captures events for BDD reload contract tests. +type reloadBDDSubject struct { + mu sync.Mutex + events []cloudevents.Event +} + +func (s *reloadBDDSubject) RegisterObserver(_ Observer, _ ...string) error { return nil } +func (s *reloadBDDSubject) UnregisterObserver(_ Observer) error { return nil } +func (s *reloadBDDSubject) GetObservers() []ObserverInfo { return nil } +func (s *reloadBDDSubject) NotifyObservers(_ context.Context, event cloudevents.Event) error { + s.mu.Lock() + s.events = append(s.events, event) + s.mu.Unlock() + return nil +} + +func (s *reloadBDDSubject) eventTypes() []string { + s.mu.Lock() + defer s.mu.Unlock() + var types []string + for _, e := range s.events { + types = append(types, e.Type()) + } + return types +} + +func (s *reloadBDDSubject) reset() { + s.mu.Lock() + s.events = nil + s.mu.Unlock() +} + +// reloadBDDLogger implements Logger for BDD reload contract tests. +type reloadBDDLogger struct{} + +func (l *reloadBDDLogger) Info(_ string, _ ...any) {} +func (l *reloadBDDLogger) Error(_ string, _ ...any) {} +func (l *reloadBDDLogger) Warn(_ string, _ ...any) {} +func (l *reloadBDDLogger) Debug(_ string, _ ...any) {} + +// bddWaitForEvent polls until the subject has recorded an event of the given type, +// or the timeout elapses. Returns true if the event was observed. +func bddWaitForEvent(subject *reloadBDDSubject, eventType string, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + for _, et := range subject.eventTypes() { + if et == eventType { + return true + } + } + time.Sleep(5 * time.Millisecond) + } + return false +} + +// bddWaitForCalls polls until the total reload calls across modules reaches +// at least n, or the timeout elapses. +func bddWaitForCalls(modules []*reloadBDDMockReloadable, n int32, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + var total int32 + for _, m := range modules { + total += m.reloadCalls.Load() + } + if total >= n { + return true + } + time.Sleep(5 * time.Millisecond) + } + return false +} + +// ReloadBDDContext holds state for reload contract BDD scenarios. +type ReloadBDDContext struct { + orchestrator *ReloadOrchestrator + modules []*reloadBDDMockReloadable + subject *reloadBDDSubject + logger *reloadBDDLogger + ctx context.Context + cancel context.CancelFunc + reloadErr error + raceDetected atomic.Bool +} + +func (rc *ReloadBDDContext) reset() { + if rc.cancel != nil { + rc.cancel() + } + rc.subject = &reloadBDDSubject{} + rc.logger = &reloadBDDLogger{} + rc.modules = nil + rc.reloadErr = nil + rc.raceDetected.Store(false) + rc.ctx, rc.cancel = context.WithCancel(context.Background()) +} + +func (rc *ReloadBDDContext) newDiff() ConfigDiff { + return ConfigDiff{ + Changed: map[string]FieldChange{ + "db.host": {OldValue: "localhost", NewValue: "remotehost", ChangeType: ChangeModified}, + }, + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + Timestamp: time.Now(), + DiffID: "bdd-test-diff", + } +} + +func (rc *ReloadBDDContext) emptyDiff() ConfigDiff { + return ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + DiffID: "bdd-empty-diff", + } +} + +// Step definitions + +func (rc *ReloadBDDContext) aReloadOrchestratorWithNReloadableModules(n int) error { + rc.orchestrator = NewReloadOrchestrator(rc.logger, rc.subject) + for i := range n { + mod := &reloadBDDMockReloadable{ + name: string(rune('a'+i)) + "_mod", + canReload: true, + timeout: 5 * time.Second, + } + rc.modules = append(rc.modules, mod) + rc.orchestrator.RegisterReloadable(mod.name, mod) + } + rc.orchestrator.Start(rc.ctx) + return nil +} + +func (rc *ReloadBDDContext) aReloadIsRequestedWithConfigurationChanges() error { + diff := rc.newDiff() + rc.reloadErr = rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + bddWaitForEvent(rc.subject, EventTypeConfigReloadCompleted, 2*time.Second) + return nil +} + +func (rc *ReloadBDDContext) allNModulesShouldReceiveTheChanges(n int) error { + received := 0 + for _, mod := range rc.modules { + if mod.reloadCalls.Load() > 0 { + received++ + } + } + if received != n { + return errExpectedModuleReceiveChanges + } + return nil +} + +func (rc *ReloadBDDContext) aReloadCompletedEventShouldBeEmitted() error { + for _, et := range rc.subject.eventTypes() { + if et == EventTypeConfigReloadCompleted { + return nil + } + } + return errExpectedCompletedEvent +} + +func (rc *ReloadBDDContext) aReloadOrchestratorWithAModuleThatCannotReload() error { + rc.orchestrator = NewReloadOrchestrator(rc.logger, rc.subject) + + disabledMod := &reloadBDDMockReloadable{ + name: "disabled_mod", + canReload: false, + timeout: 5 * time.Second, + } + rc.modules = append(rc.modules, disabledMod) + rc.orchestrator.RegisterReloadable(disabledMod.name, disabledMod) + + enabledMod := &reloadBDDMockReloadable{ + name: "enabled_mod", + canReload: true, + timeout: 5 * time.Second, + } + rc.modules = append(rc.modules, enabledMod) + rc.orchestrator.RegisterReloadable(enabledMod.name, enabledMod) + + rc.orchestrator.Start(rc.ctx) + return nil +} + +func (rc *ReloadBDDContext) aReloadIsRequested() error { + diff := rc.newDiff() + rc.reloadErr = rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + // Wait for either completed or failed event (covers both success and failure scenarios). + bddWaitForEvent(rc.subject, EventTypeConfigReloadCompleted, 2*time.Second) + bddWaitForEvent(rc.subject, EventTypeConfigReloadFailed, 100*time.Millisecond) + return nil +} + +func (rc *ReloadBDDContext) theNonReloadableModuleShouldBeSkipped() error { + for _, mod := range rc.modules { + if !mod.canReload && mod.reloadCalls.Load() != 0 { + return errExpectedModuleSkipped + } + } + return nil +} + +func (rc *ReloadBDDContext) otherModulesShouldStillBeReloaded() error { + for _, mod := range rc.modules { + if mod.canReload && mod.reloadCalls.Load() == 0 { + return errExpectedOtherModulesReloaded + } + } + return nil +} + +func (rc *ReloadBDDContext) aReloadOrchestratorWith3ModulesWhereTheSecondFails() error { + rc.orchestrator = NewReloadOrchestrator(rc.logger, rc.subject) + + // Use names that sort deterministically to control ordering. + mod1 := &reloadBDDMockReloadable{ + name: "aaa_first", + canReload: true, + timeout: 5 * time.Second, + } + mod2 := &reloadBDDMockReloadable{ + name: "bbb_second", + canReload: true, + timeout: 5 * time.Second, + reloadErr: errors.New("reload failure"), + } + mod3 := &reloadBDDMockReloadable{ + name: "ccc_third", + canReload: true, + timeout: 5 * time.Second, + } + rc.modules = append(rc.modules, mod1, mod2, mod3) + rc.orchestrator.RegisterReloadable(mod1.name, mod1) + rc.orchestrator.RegisterReloadable(mod2.name, mod2) + rc.orchestrator.RegisterReloadable(mod3.name, mod3) + + rc.orchestrator.Start(rc.ctx) + return nil +} + +func (rc *ReloadBDDContext) theFirstModuleShouldBeRolledBack() error { + // Reload targets are sorted by name. aaa_first runs before bbb_second (which + // fails), so aaa_first is always applied and then rolled back (2 calls total). + mod1 := rc.modules[0] + calls := mod1.reloadCalls.Load() + if calls != 2 { + return fmt.Errorf("%w: expected aaa_first to be called 2 times (apply + rollback), got %d", errExpectedRollback, calls) + } + return nil +} + +func (rc *ReloadBDDContext) aReloadFailedEventShouldBeEmitted() error { + for _, et := range rc.subject.eventTypes() { + if et == EventTypeConfigReloadFailed { + return nil + } + } + return errExpectedFailedEvent +} + +func (rc *ReloadBDDContext) aReloadOrchestratorWithAFailingModule() error { + rc.orchestrator = NewReloadOrchestrator(rc.logger, rc.subject) + + mod := &reloadBDDMockReloadable{ + name: "failing_mod", + canReload: true, + timeout: 5 * time.Second, + reloadErr: errors.New("always fails"), + } + rc.modules = append(rc.modules, mod) + rc.orchestrator.RegisterReloadable(mod.name, mod) + + rc.orchestrator.Start(rc.ctx) + return nil +} + +func (rc *ReloadBDDContext) nConsecutiveReloadsFail(n int) error { + diff := rc.newDiff() + for i := range n { + _ = rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + expected := int32(i + 1) + bddWaitForCalls(rc.modules, expected, 2*time.Second) + } + return nil +} + +func (rc *ReloadBDDContext) subsequentReloadRequestsShouldBeRejected() error { + diff := rc.newDiff() + err := rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + if err == nil || !strings.Contains(err.Error(), "circuit breaker") { + return errExpectedCircuitBreakerReject + } + return nil +} + +func (rc *ReloadBDDContext) theCircuitBreakerShouldEventuallyReset() error { + // Simulate that the backoff period has elapsed by moving lastFailure + // sufficiently into the past. This validates isCircuitOpen()/backoffDuration() + // rather than bypassing them. + rc.orchestrator.cbMu.Lock() + rc.orchestrator.lastFailure = time.Now().Add(-circuitBreakerMaxDelay - time.Second) + rc.orchestrator.cbMu.Unlock() + + diff := rc.newDiff() + err := rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + if err != nil && strings.Contains(err.Error(), "circuit breaker") { + return errExpectedCircuitBreakerReset + } + return nil +} + +func (rc *ReloadBDDContext) aReloadOrchestratorWithReloadableModules() error { + return rc.aReloadOrchestratorWithNReloadableModules(2) +} + +func (rc *ReloadBDDContext) aReloadIsRequestedWithNoChanges() error { + diff := rc.emptyDiff() + rc.reloadErr = rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + bddWaitForEvent(rc.subject, EventTypeConfigReloadNoop, 2*time.Second) + return nil +} + +func (rc *ReloadBDDContext) aReloadNoopEventShouldBeEmitted() error { + for _, et := range rc.subject.eventTypes() { + if et == EventTypeConfigReloadNoop { + return nil + } + } + return errExpectedNoopEvent +} + +func (rc *ReloadBDDContext) noModulesShouldBeCalled() error { + for _, mod := range rc.modules { + if mod.reloadCalls.Load() != 0 { + return errExpectedNoModuleCalls + } + } + return nil +} + +func (rc *ReloadBDDContext) tenReloadRequestsAreSubmittedConcurrently() error { + diff := rc.newDiff() + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + _ = rc.orchestrator.RequestReload(rc.ctx, ReloadManual, diff) + }() + } + wg.Wait() + bddWaitForCalls(rc.modules, 1, 2*time.Second) + return nil +} + +func (rc *ReloadBDDContext) allRequestsShouldBeProcessed() error { + totalCalls := int32(0) + for _, mod := range rc.modules { + totalCalls += mod.reloadCalls.Load() + } + if totalCalls < 1 { + return errExpectedRequestsProcessed + } + return nil +} + +func (rc *ReloadBDDContext) noRaceConditionsShouldOccur() error { + // The race detector (go test -race) validates this at runtime. + // If we got here without a panic, there are no races. + return nil +} + +// InitializeReloadContractScenario wires up all reload contract BDD steps. +func InitializeReloadContractScenario(ctx *godog.ScenarioContext) { + rc := &ReloadBDDContext{} + + ctx.Before(func(ctx context.Context, _ *godog.Scenario) (context.Context, error) { + rc.reset() + return ctx, nil + }) + + ctx.After(func(ctx context.Context, _ *godog.Scenario, _ error) (context.Context, error) { + if rc.cancel != nil { + rc.cancel() + } + return ctx, nil + }) + + ctx.Step(`^a reload orchestrator with (\d+) reloadable modules$`, rc.aReloadOrchestratorWithNReloadableModules) + ctx.Step(`^a reload is requested with configuration changes$`, rc.aReloadIsRequestedWithConfigurationChanges) + ctx.Step(`^all (\d+) modules should receive the changes$`, rc.allNModulesShouldReceiveTheChanges) + ctx.Step(`^a reload completed event should be emitted$`, rc.aReloadCompletedEventShouldBeEmitted) + + ctx.Step(`^a reload orchestrator with a module that cannot reload$`, rc.aReloadOrchestratorWithAModuleThatCannotReload) + ctx.Step(`^a reload is requested$`, rc.aReloadIsRequested) + ctx.Step(`^the non-reloadable module should be skipped$`, rc.theNonReloadableModuleShouldBeSkipped) + ctx.Step(`^other modules should still be reloaded$`, rc.otherModulesShouldStillBeReloaded) + + ctx.Step(`^a reload orchestrator with 3 modules where the second fails$`, rc.aReloadOrchestratorWith3ModulesWhereTheSecondFails) + ctx.Step(`^the first module should be rolled back$`, rc.theFirstModuleShouldBeRolledBack) + ctx.Step(`^a reload failed event should be emitted$`, rc.aReloadFailedEventShouldBeEmitted) + + ctx.Step(`^a reload orchestrator with a failing module$`, rc.aReloadOrchestratorWithAFailingModule) + ctx.Step(`^(\d+) consecutive reloads fail$`, rc.nConsecutiveReloadsFail) + ctx.Step(`^subsequent reload requests should be rejected$`, rc.subsequentReloadRequestsShouldBeRejected) + ctx.Step(`^the circuit breaker should eventually reset$`, rc.theCircuitBreakerShouldEventuallyReset) + + ctx.Step(`^a reload orchestrator with reloadable modules$`, rc.aReloadOrchestratorWithReloadableModules) + ctx.Step(`^a reload is requested with no changes$`, rc.aReloadIsRequestedWithNoChanges) + ctx.Step(`^a reload noop event should be emitted$`, rc.aReloadNoopEventShouldBeEmitted) + ctx.Step(`^no modules should be called$`, rc.noModulesShouldBeCalled) + + ctx.Step(`^10 reload requests are submitted concurrently$`, rc.tenReloadRequestsAreSubmittedConcurrently) + ctx.Step(`^all requests should be processed$`, rc.allRequestsShouldBeProcessed) + ctx.Step(`^no race conditions should occur$`, rc.noRaceConditionsShouldOccur) +} + +// TestReloadContractBDD runs the BDD tests for the reload contract. +func TestReloadContractBDD(t *testing.T) { + suite := godog.TestSuite{ + ScenarioInitializer: InitializeReloadContractScenario, + Options: &godog.Options{ + Format: "pretty", + Paths: []string{"features/reload_contract.feature"}, + TestingT: t, + Strict: true, + }, + } + + if suite.Run() != 0 { + t.Fatal("non-zero status returned, failed to run reload contract feature tests") + } +} diff --git a/reload_orchestrator.go b/reload_orchestrator.go new file mode 100644 index 00000000..ea34005d --- /dev/null +++ b/reload_orchestrator.go @@ -0,0 +1,393 @@ +package modular + +import ( + "context" + "fmt" + "sort" + "sync" + "sync/atomic" + "time" +) + +// ReloadRequest represents a pending configuration reload request. +type ReloadRequest struct { + Trigger ReloadTrigger + Diff ConfigDiff + Ctx context.Context +} + +// reloadEntry pairs a module name with its Reloadable implementation. +type reloadEntry struct { + name string + module Reloadable +} + +// defaultReloadTimeout is used when a module returns a non-positive ReloadTimeout. +const defaultReloadTimeout = 30 * time.Second + +// ReloadOrchestrator coordinates configuration reloading across all registered +// Reloadable modules. It provides single-flight execution, circuit breaking, +// rollback on partial failure, and event emission via the observer pattern. +// +// Note: Application-level integration (Application.RequestReload(), WithDynamicReload() +// builder option) will be added when the Application interface is extended in a follow-up. +type ReloadOrchestrator struct { + mu sync.RWMutex + reloadables map[string]Reloadable + + requestCh chan ReloadRequest + stopped atomic.Bool + stopOnce sync.Once + + processing atomic.Bool + + // Circuit breaker state + cbMu sync.Mutex + failures int + lastFailure time.Time + circuitOpen bool + + logger Logger + subject Subject +} + +// nopLogger is a no-op Logger used when nil is passed. +type nopLogger struct{} + +func (nopLogger) Info(_ string, _ ...any) {} +func (nopLogger) Error(_ string, _ ...any) {} +func (nopLogger) Warn(_ string, _ ...any) {} +func (nopLogger) Debug(_ string, _ ...any) {} + +// NewReloadOrchestrator creates a new ReloadOrchestrator with the given logger and event subject. +// If logger is nil, a no-op logger is used. +func NewReloadOrchestrator(logger Logger, subject Subject) *ReloadOrchestrator { + if logger == nil { + logger = nopLogger{} + } + return &ReloadOrchestrator{ + reloadables: make(map[string]Reloadable), + requestCh: make(chan ReloadRequest, 100), + logger: logger, + subject: subject, + } +} + +// RegisterReloadable registers a named module as reloadable. +func (o *ReloadOrchestrator) RegisterReloadable(name string, module Reloadable) { + o.mu.Lock() + defer o.mu.Unlock() + o.reloadables[name] = module +} + +// RequestReload enqueues a reload request. It returns an error if the orchestrator +// is stopped, the request channel is full, or the circuit breaker is open. +// +// The method is safe to call concurrently with Stop(). A recover guard protects +// against the send-on-closed-channel panic that can occur when Stop() closes +// requestCh between the stopped check and the channel send. +func (o *ReloadOrchestrator) RequestReload(ctx context.Context, trigger ReloadTrigger, diff ConfigDiff) (retErr error) { + if o.stopped.Load() { + return ErrReloadStopped + } + if o.isCircuitOpen() { + return ErrReloadCircuitBreakerOpen + } + + // Recover from a send on closed channel if Stop() races between the + // stopped check above and the channel send below. + defer func() { + if r := recover(); r != nil { + retErr = ErrReloadStopped + } + }() + + select { + case o.requestCh <- ReloadRequest{Trigger: trigger, Diff: diff, Ctx: ctx}: + return nil + default: + return ErrReloadChannelFull + } +} + +// Start begins the background goroutine that drains the reload request queue. +func (o *ReloadOrchestrator) Start(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case req, ok := <-o.requestCh: + if !ok { + return + } + o.handleReload(ctx, req) + } + } + }() +} + +// handleReload derives a properly scoped context for a single reload request and +// processes it. The context is cancelled immediately after processReload returns +// to avoid resource leaks from accumulated timers in the processing loop. +// +// The reload context is rooted in parentCtx (the Start context) so that stopping +// the orchestrator always cancels in-flight work. When the request carries its +// own context, both its deadline and cancellation are wired in: deadline via +// context.WithDeadline, and cancellation via a background goroutine that watches +// req.Ctx.Done(). This ensures callers who cancel req.Ctx abort the reload. +func (o *ReloadOrchestrator) handleReload(parentCtx context.Context, req ReloadRequest) { + rctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + if req.Ctx != nil { + // Apply deadline if present. + if deadline, ok := req.Ctx.Deadline(); ok { + rctx, cancel = context.WithDeadline(rctx, deadline) //nolint:contextcheck // deadline from request + defer cancel() + } + + // Propagate cancellation from the request context. When req.Ctx is + // cancelled, cancel rctx so module Reload calls see cancellation. + go func() { + select { + case <-req.Ctx.Done(): + cancel() + case <-rctx.Done(): + // rctx already done (parent cancelled or reload finished); stop goroutine. + } + }() + } + + if err := o.processReload(rctx, req); err != nil { + o.logger.Error("Reload failed", "trigger", req.Trigger.String(), "error", err) + } +} + +// Stop signals the background goroutine to exit. It is safe to call multiple times. +func (o *ReloadOrchestrator) Stop() { + o.stopOnce.Do(func() { + o.stopped.Store(true) + close(o.requestCh) + }) +} + +// processReload executes a single reload request with atomic single-flight semantics, +// rollback on partial failure, and event emission. +func (o *ReloadOrchestrator) processReload(ctx context.Context, req ReloadRequest) error { + // Single-flight: only one reload at a time. + if !o.processing.CompareAndSwap(false, true) { + o.logger.Warn("Reload already in progress, skipping request") + return ErrReloadInProgress + } + defer o.processing.Store(false) + + // Noop if no changes — emit noop without a misleading "started" event. + if !req.Diff.HasChanges() { + o.emitEvent(ctx, EventTypeConfigReloadNoop, map[string]interface{}{ + "trigger": req.Trigger.String(), + "diffId": req.Diff.DiffID, + }) + return nil + } + + // Emit started event only when there are actual changes to apply. + o.emitEvent(ctx, EventTypeConfigReloadStarted, map[string]interface{}{ + "trigger": req.Trigger.String(), + "diffId": req.Diff.DiffID, + "summary": req.Diff.ChangeSummary(), + }) + + // Build the list of changes for the Reloadable interface. + changes := o.buildChanges(req.Diff) + + // Snapshot current reloadables under read lock, sorted by name for + // deterministic reload/rollback ordering. + o.mu.RLock() + targets := make([]reloadEntry, 0, len(o.reloadables)) + for name, mod := range o.reloadables { + targets = append(targets, reloadEntry{name: name, module: mod}) + } + o.mu.RUnlock() + + sort.Slice(targets, func(i, j int) bool { + return targets[i].name < targets[j].name + }) + + // Track which modules have been successfully reloaded (for rollback). + var applied []reloadEntry + + for _, t := range targets { + if !t.module.CanReload() { + o.logger.Info("Module cannot reload, skipping", "module", t.name) + continue + } + + timeout := t.module.ReloadTimeout() + if timeout <= 0 { + timeout = defaultReloadTimeout + } + rctx, cancel := context.WithTimeout(ctx, timeout) + + err := t.module.Reload(rctx, changes) + cancel() + + if err != nil { + o.logger.Error("Module reload failed, initiating rollback", + "module", t.name, "error", err) + + // Rollback already-applied modules in reverse order. + o.rollback(ctx, applied, changes) + + o.recordFailure() + o.emitEvent(ctx, EventTypeConfigReloadFailed, map[string]interface{}{ + "trigger": req.Trigger.String(), + "diffId": req.Diff.DiffID, + "failedModule": t.name, + "error": err.Error(), + }) + return fmt.Errorf("reload failed at module %s: %w", t.name, err) + } + + applied = append(applied, t) + } + + o.recordSuccess() + o.emitEvent(ctx, EventTypeConfigReloadCompleted, map[string]interface{}{ + "trigger": req.Trigger.String(), + "diffId": req.Diff.DiffID, + "modulesLoaded": len(applied), + }) + return nil +} + +// buildChanges converts a ConfigDiff into a flat slice of ConfigChange entries. +func (o *ReloadOrchestrator) buildChanges(diff ConfigDiff) []ConfigChange { + var changes []ConfigChange + for path, fc := range diff.Added { + changes = append(changes, ConfigChange{ + FieldPath: path, + OldValue: fmt.Sprintf("%v", fc.OldValue), + NewValue: fmt.Sprintf("%v", fc.NewValue), + Source: "diff", + }) + } + for path, fc := range diff.Changed { + changes = append(changes, ConfigChange{ + FieldPath: path, + OldValue: fmt.Sprintf("%v", fc.OldValue), + NewValue: fmt.Sprintf("%v", fc.NewValue), + Source: "diff", + }) + } + for path, fc := range diff.Removed { + changes = append(changes, ConfigChange{ + FieldPath: path, + OldValue: fmt.Sprintf("%v", fc.OldValue), + NewValue: fmt.Sprintf("%v", fc.NewValue), + Source: "diff", + }) + } + return changes +} + +// rollback attempts to reverse already-applied changes on modules in reverse order. +// This is best-effort: errors are logged but not propagated. +func (o *ReloadOrchestrator) rollback(ctx context.Context, applied []reloadEntry, originalChanges []ConfigChange) { + // Build reverse changes (swap old and new values). + reverseChanges := make([]ConfigChange, len(originalChanges)) + for i, c := range originalChanges { + reverseChanges[i] = ConfigChange{ + Section: c.Section, + FieldPath: c.FieldPath, + OldValue: c.NewValue, + NewValue: c.OldValue, + Source: "rollback", + } + } + + // Apply in reverse order. + for i := len(applied) - 1; i >= 0; i-- { + t := applied[i] + timeout := t.module.ReloadTimeout() + if timeout <= 0 { + timeout = defaultReloadTimeout + } + rctx, cancel := context.WithTimeout(ctx, timeout) + + if err := t.module.Reload(rctx, reverseChanges); err != nil { + o.logger.Error("Rollback failed for module", "module", t.name, "error", err) + } else { + o.logger.Info("Rollback succeeded for module", "module", t.name) + } + cancel() + } +} + +// emitEvent sends a CloudEvent via the configured subject. +func (o *ReloadOrchestrator) emitEvent(ctx context.Context, eventType string, data map[string]interface{}) { + if o.subject == nil { + return + } + event := NewCloudEvent(eventType, "modular.reload.orchestrator", data, nil) + if err := o.subject.NotifyObservers(ctx, event); err != nil { + o.logger.Debug("Failed to emit reload event", "eventType", eventType, "error", err) + } +} + +// Circuit breaker methods. + +const ( + circuitBreakerThreshold = 3 + circuitBreakerBaseDelay = 2 * time.Second + circuitBreakerMaxDelay = 2 * time.Minute +) + +func (o *ReloadOrchestrator) isCircuitOpen() bool { + o.cbMu.Lock() + defer o.cbMu.Unlock() + if !o.circuitOpen { + return false + } + // Check if the backoff period has elapsed. + if time.Since(o.lastFailure) > o.backoffDuration() { + o.circuitOpen = false + o.logger.Info("Reload circuit breaker reset after backoff") + return false + } + return true +} + +func (o *ReloadOrchestrator) recordSuccess() { + o.cbMu.Lock() + defer o.cbMu.Unlock() + o.failures = 0 + o.circuitOpen = false +} + +func (o *ReloadOrchestrator) recordFailure() { + o.cbMu.Lock() + defer o.cbMu.Unlock() + o.failures++ + o.lastFailure = time.Now() + if o.failures >= circuitBreakerThreshold { + o.circuitOpen = true + o.logger.Warn("Reload circuit breaker opened", + "failures", o.failures, + "backoff", o.backoffDuration().String()) + } +} + +func (o *ReloadOrchestrator) backoffDuration() time.Duration { + if o.failures <= 0 { + return circuitBreakerBaseDelay + } + d := circuitBreakerBaseDelay + for i := 1; i < o.failures; i++ { + d *= 2 + if d > circuitBreakerMaxDelay { + return circuitBreakerMaxDelay + } + } + return d +} diff --git a/reload_test.go b/reload_test.go new file mode 100644 index 00000000..a5f77d71 --- /dev/null +++ b/reload_test.go @@ -0,0 +1,486 @@ +package modular + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + cloudevents "github.com/cloudevents/sdk-go/v2" +) + +// mockReloadable is a test double for the Reloadable interface. +type mockReloadable struct { + canReload bool + timeout time.Duration + reloadErr error + reloadCalls atomic.Int32 + lastChanges []ConfigChange + mu sync.Mutex +} + +func (m *mockReloadable) Reload(_ context.Context, changes []ConfigChange) error { + m.reloadCalls.Add(1) + m.mu.Lock() + m.lastChanges = changes + m.mu.Unlock() + return m.reloadErr +} + +func (m *mockReloadable) CanReload() bool { return m.canReload } +func (m *mockReloadable) ReloadTimeout() time.Duration { return m.timeout } + +func (m *mockReloadable) getLastChanges() []ConfigChange { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]ConfigChange, len(m.lastChanges)) + copy(result, m.lastChanges) + return result +} + +// reloadTestLogger implements Logger for testing. +type reloadTestLogger struct { + mu sync.Mutex + messages []string +} + +func (l *reloadTestLogger) Info(msg string, args ...any) { l.record("INFO", msg, args...) } +func (l *reloadTestLogger) Error(msg string, args ...any) { l.record("ERROR", msg, args...) } +func (l *reloadTestLogger) Warn(msg string, args ...any) { l.record("WARN", msg, args...) } +func (l *reloadTestLogger) Debug(msg string, args ...any) { l.record("DEBUG", msg, args...) } + +func (l *reloadTestLogger) record(level, msg string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.messages = append(l.messages, fmt.Sprintf("[%s] %s %v", level, msg, args)) +} + +// reloadTestSubject is a minimal Subject for capturing events in reload tests. +type reloadTestSubject struct { + mu sync.Mutex + events []cloudevents.Event +} + +func (s *reloadTestSubject) RegisterObserver(_ Observer, _ ...string) error { return nil } +func (s *reloadTestSubject) UnregisterObserver(_ Observer) error { return nil } +func (s *reloadTestSubject) GetObservers() []ObserverInfo { return nil } +func (s *reloadTestSubject) NotifyObservers(_ context.Context, event cloudevents.Event) error { + s.mu.Lock() + s.events = append(s.events, event) + s.mu.Unlock() + return nil +} + +func (s *reloadTestSubject) getEvents() []cloudevents.Event { + s.mu.Lock() + defer s.mu.Unlock() + result := make([]cloudevents.Event, len(s.events)) + copy(result, s.events) + return result +} + +func (s *reloadTestSubject) eventTypes() []string { + s.mu.Lock() + defer s.mu.Unlock() + var types []string + for _, e := range s.events { + types = append(types, e.Type()) + } + return types +} + +// --- ConfigDiff tests --- + +func TestConfigDiff_HasChanges(t *testing.T) { + t.Run("empty diff", func(t *testing.T) { + d := ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + } + if d.HasChanges() { + t.Error("expected no changes") + } + }) + t.Run("with changed", func(t *testing.T) { + d := ConfigDiff{ + Changed: map[string]FieldChange{"a": {OldValue: 1, NewValue: 2}}, + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + } + if !d.HasChanges() { + t.Error("expected changes") + } + }) + t.Run("with added", func(t *testing.T) { + d := ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: map[string]FieldChange{"b": {NewValue: "x"}}, + Removed: make(map[string]FieldChange), + } + if !d.HasChanges() { + t.Error("expected changes") + } + }) + t.Run("with removed", func(t *testing.T) { + d := ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: make(map[string]FieldChange), + Removed: map[string]FieldChange{"c": {OldValue: "y"}}, + } + if !d.HasChanges() { + t.Error("expected changes") + } + }) +} + +func TestConfigDiff_FilterByPrefix(t *testing.T) { + d := ConfigDiff{ + Changed: map[string]FieldChange{ + "db.host": {OldValue: "old", NewValue: "new"}, + "db.port": {OldValue: 3306, NewValue: 5432}, + "cache.ttl": {OldValue: 30, NewValue: 60}, + }, + Added: map[string]FieldChange{ + "db.ssl": {NewValue: true}, + }, + Removed: map[string]FieldChange{ + "cache.max": {OldValue: 100}, + }, + } + + filtered := d.FilterByPrefix("db.") + if len(filtered.Changed) != 2 { + t.Errorf("expected 2 changed, got %d", len(filtered.Changed)) + } + if len(filtered.Added) != 1 { + t.Errorf("expected 1 added, got %d", len(filtered.Added)) + } + if len(filtered.Removed) != 0 { + t.Errorf("expected 0 removed, got %d", len(filtered.Removed)) + } + + cacheFiltered := d.FilterByPrefix("cache.") + if len(cacheFiltered.Changed) != 1 { + t.Errorf("expected 1 changed for cache prefix, got %d", len(cacheFiltered.Changed)) + } + if len(cacheFiltered.Removed) != 1 { + t.Errorf("expected 1 removed for cache prefix, got %d", len(cacheFiltered.Removed)) + } +} + +func TestConfigDiff_RedactSensitiveFields(t *testing.T) { + d := ConfigDiff{ + Changed: map[string]FieldChange{ + "db.password": {OldValue: "secret1", NewValue: "secret2", IsSensitive: true}, + "db.host": {OldValue: "old", NewValue: "new", IsSensitive: false}, + }, + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + } + + redacted := d.RedactSensitiveFields() + + pw := redacted.Changed["db.password"] + if pw.OldValue != "[REDACTED]" || pw.NewValue != "[REDACTED]" { + t.Errorf("sensitive field not redacted: old=%v new=%v", pw.OldValue, pw.NewValue) + } + + host := redacted.Changed["db.host"] + if host.OldValue != "old" || host.NewValue != "new" { + t.Errorf("non-sensitive field should not be redacted: old=%v new=%v", host.OldValue, host.NewValue) + } + + // Verify original is not mutated. + origPw := d.Changed["db.password"] + if origPw.OldValue != "secret1" { + t.Error("original diff should not be mutated") + } +} + +func TestConfigDiff_ChangeSummary(t *testing.T) { + t.Run("no changes", func(t *testing.T) { + d := ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + } + s := d.ChangeSummary() + if s != "no changes" { + t.Errorf("expected 'no changes', got %q", s) + } + }) + t.Run("mixed changes", func(t *testing.T) { + d := ConfigDiff{ + Changed: map[string]FieldChange{"a": {}}, + Added: map[string]FieldChange{"b": {}, "c": {}}, + Removed: map[string]FieldChange{"d": {}}, + } + s := d.ChangeSummary() + if !strings.Contains(s, "2 added") { + t.Errorf("summary missing added count: %q", s) + } + if !strings.Contains(s, "1 modified") { + t.Errorf("summary missing modified count: %q", s) + } + if !strings.Contains(s, "1 removed") { + t.Errorf("summary missing removed count: %q", s) + } + }) +} + +// waitFor polls cond every 5ms until it returns true or timeout elapses. +// Returns true if cond was satisfied, false on timeout. +func waitFor(t *testing.T, timeout time.Duration, cond func() bool) bool { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return true + } + time.Sleep(5 * time.Millisecond) + } + return false +} + +// --- ReloadOrchestrator tests --- + +func newTestDiff() ConfigDiff { + return ConfigDiff{ + Changed: map[string]FieldChange{ + "db.host": {OldValue: "localhost", NewValue: "remotehost", ChangeType: ChangeModified}, + }, + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + Timestamp: time.Now(), + DiffID: "test-diff-1", + } +} + +func TestReloadOrchestrator_SuccessfulReload(t *testing.T) { + logger := &reloadTestLogger{} + subject := &reloadTestSubject{} + orch := NewReloadOrchestrator(logger, subject) + + mod := &mockReloadable{canReload: true, timeout: 5 * time.Second} + orch.RegisterReloadable("testmod", mod) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + orch.Start(ctx) + + diff := newTestDiff() + if err := orch.RequestReload(ctx, ReloadManual, diff); err != nil { + t.Fatalf("RequestReload failed: %v", err) + } + + if !waitFor(t, 2*time.Second, func() bool { return mod.reloadCalls.Load() >= 1 }) { + t.Fatalf("timed out waiting for reload call, got %d", mod.reloadCalls.Load()) + } + + if !waitFor(t, 2*time.Second, func() bool { return len(subject.eventTypes()) >= 2 }) { + t.Fatalf("timed out waiting for events, got %d", len(subject.eventTypes())) + } + + events := subject.eventTypes() + if events[0] != EventTypeConfigReloadStarted { + t.Errorf("expected started event, got %s", events[0]) + } + if events[len(events)-1] != EventTypeConfigReloadCompleted { + t.Errorf("expected completed event, got %s", events[len(events)-1]) + } +} + +func TestReloadOrchestrator_PartialFailure_Rollback(t *testing.T) { + logger := &reloadTestLogger{} + subject := &reloadTestSubject{} + orch := NewReloadOrchestrator(logger, subject) + + mod1 := &mockReloadable{canReload: true, timeout: 5 * time.Second} + mod2 := &mockReloadable{canReload: true, timeout: 5 * time.Second, reloadErr: errors.New("boom")} + orch.RegisterReloadable("aaa_first", mod1) + orch.RegisterReloadable("zzz_second", mod2) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + orch.Start(ctx) + + diff := newTestDiff() + if err := orch.RequestReload(ctx, ReloadManual, diff); err != nil { + t.Fatalf("RequestReload failed: %v", err) + } + + if !waitFor(t, 2*time.Second, func() bool { + return len(subject.eventTypes()) > 0 && subject.eventTypes()[len(subject.eventTypes())-1] == EventTypeConfigReloadFailed + }) { + t.Fatal("timed out waiting for reload failure event") + } + + // Targets are sorted by name: aaa_first runs before zzz_second. + // aaa_first succeeds, then zzz_second fails, triggering rollback of aaa_first. + // So aaa_first gets 2 calls (apply + rollback) and zzz_second gets 1 call (the failure). + calls1 := mod1.reloadCalls.Load() + calls2 := mod2.reloadCalls.Load() + + if calls1 != 2 { + t.Errorf("expected aaa_first to be called 2 times (apply+rollback), got %d", calls1) + } + + if calls2 != 1 { + t.Errorf("expected zzz_second to be called 1 time (the failure), got %d", calls2) + } + + // Verify a failed event was emitted. + hasFailedEvent := false + for _, et := range subject.eventTypes() { + if et == EventTypeConfigReloadFailed { + hasFailedEvent = true + } + } + if !hasFailedEvent { + t.Error("expected ConfigReloadFailed event") + } +} + +func TestReloadOrchestrator_CircuitBreaker(t *testing.T) { + logger := &reloadTestLogger{} + subject := &reloadTestSubject{} + orch := NewReloadOrchestrator(logger, subject) + + failMod := &mockReloadable{canReload: true, timeout: 5 * time.Second, reloadErr: errors.New("fail")} + orch.RegisterReloadable("failing", failMod) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + orch.Start(ctx) + + diff := newTestDiff() + + // Trigger enough failures to open the circuit breaker. + for i := 0; i < circuitBreakerThreshold; i++ { + if err := orch.RequestReload(ctx, ReloadManual, diff); err != nil { + t.Fatalf("RequestReload %d failed: %v", i, err) + } + expected := int32(i + 1) + if !waitFor(t, 2*time.Second, func() bool { return failMod.reloadCalls.Load() >= expected }) { + t.Fatalf("timed out waiting for reload call %d", i+1) + } + } + + // Next request should be rejected by the circuit breaker. + err := orch.RequestReload(ctx, ReloadManual, diff) + if err == nil { + t.Error("expected circuit breaker error, got nil") + } + if err != nil && !strings.Contains(err.Error(), "circuit breaker") { + t.Errorf("expected circuit breaker error, got: %v", err) + } +} + +func TestReloadOrchestrator_CanReloadFalse_Skipped(t *testing.T) { + logger := &reloadTestLogger{} + subject := &reloadTestSubject{} + orch := NewReloadOrchestrator(logger, subject) + + mod := &mockReloadable{canReload: false, timeout: 5 * time.Second} + orch.RegisterReloadable("disabled", mod) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + orch.Start(ctx) + + diff := newTestDiff() + if err := orch.RequestReload(ctx, ReloadManual, diff); err != nil { + t.Fatalf("RequestReload failed: %v", err) + } + + if !waitFor(t, 2*time.Second, func() bool { + for _, et := range subject.eventTypes() { + if et == EventTypeConfigReloadCompleted { + return true + } + } + return false + }) { + t.Fatal("timed out waiting for ConfigReloadCompleted event") + } + + if mod.reloadCalls.Load() != 0 { + t.Errorf("expected 0 reload calls for disabled module, got %d", mod.reloadCalls.Load()) + } +} + +func TestReloadOrchestrator_ConcurrentRequests(t *testing.T) { + logger := &reloadTestLogger{} + subject := &reloadTestSubject{} + orch := NewReloadOrchestrator(logger, subject) + + mod := &mockReloadable{canReload: true, timeout: 5 * time.Second} + orch.RegisterReloadable("concurrent", mod) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + orch.Start(ctx) + + diff := newTestDiff() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = orch.RequestReload(ctx, ReloadManual, diff) + }() + } + wg.Wait() + + if !waitFor(t, 2*time.Second, func() bool { return mod.reloadCalls.Load() >= 1 }) { + t.Fatalf("timed out waiting for at least 1 reload call, got %d", mod.reloadCalls.Load()) + } + + calls := mod.reloadCalls.Load() + // Due to single-flight, some may be skipped — that's expected. + t.Logf("concurrent test: %d reload calls processed out of 10 requests", calls) +} + +func TestReloadOrchestrator_NoopOnEmptyDiff(t *testing.T) { + logger := &reloadTestLogger{} + subject := &reloadTestSubject{} + orch := NewReloadOrchestrator(logger, subject) + + mod := &mockReloadable{canReload: true, timeout: 5 * time.Second} + orch.RegisterReloadable("mod", mod) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + orch.Start(ctx) + + emptyDiff := ConfigDiff{ + Changed: make(map[string]FieldChange), + Added: make(map[string]FieldChange), + Removed: make(map[string]FieldChange), + DiffID: "empty", + } + if err := orch.RequestReload(ctx, ReloadManual, emptyDiff); err != nil { + t.Fatalf("RequestReload failed: %v", err) + } + + if !waitFor(t, 2*time.Second, func() bool { + for _, et := range subject.eventTypes() { + if et == EventTypeConfigReloadNoop { + return true + } + } + return false + }) { + t.Fatal("timed out waiting for ConfigReloadNoop event") + } + + if mod.reloadCalls.Load() != 0 { + t.Errorf("expected 0 reload calls for empty diff, got %d", mod.reloadCalls.Load()) + } +} diff --git a/tenant_guard.go b/tenant_guard.go new file mode 100644 index 00000000..625ab801 --- /dev/null +++ b/tenant_guard.go @@ -0,0 +1,294 @@ +package modular + +import ( + "context" + "fmt" + "sync" + "time" +) + +// TenantGuardMode controls how the tenant guard responds to violations. +type TenantGuardMode int + +const ( + // TenantGuardStrict blocks the operation and returns an error on violation. + TenantGuardStrict TenantGuardMode = iota + // TenantGuardLenient records the violation and allows the operation to proceed. + // Violations are logged when LogViolations is true and a logger is configured. + TenantGuardLenient + // TenantGuardDisabled performs no validation at all. + TenantGuardDisabled +) + +// String returns the string representation of a TenantGuardMode. +func (m TenantGuardMode) String() string { + switch m { + case TenantGuardStrict: + return "strict" + case TenantGuardLenient: + return "lenient" + case TenantGuardDisabled: + return "disabled" + default: + return fmt.Sprintf("unknown(%d)", int(m)) + } +} + +// ViolationType categorizes the kind of tenant boundary violation. +type ViolationType int + +const ( + // CrossTenant indicates an attempt to access another tenant's resources. + CrossTenant ViolationType = iota + // InvalidContext indicates the tenant context is malformed or invalid. + InvalidContext + // MissingContext indicates no tenant context was provided. + MissingContext + // Unauthorized indicates the caller lacks permission for the tenant operation. + Unauthorized +) + +// String returns the string representation of a ViolationType. +func (v ViolationType) String() string { + switch v { + case CrossTenant: + return "cross_tenant" + case InvalidContext: + return "invalid_context" + case MissingContext: + return "missing_context" + case Unauthorized: + return "unauthorized" + default: + return fmt.Sprintf("unknown(%d)", int(v)) + } +} + +// Severity indicates the severity level of a tenant violation. +type Severity int + +const ( + // SeverityLow indicates a minor violation. + SeverityLow Severity = iota + // SeverityMedium indicates a moderate violation. + SeverityMedium + // SeverityHigh indicates a serious violation. + SeverityHigh + // SeverityCritical indicates a critical violation requiring immediate attention. + SeverityCritical +) + +// String returns the string representation of a Severity. +func (s Severity) String() string { + switch s { + case SeverityLow: + return "low" + case SeverityMedium: + return "medium" + case SeverityHigh: + return "high" + case SeverityCritical: + return "critical" + default: + return fmt.Sprintf("unknown(%d)", int(s)) + } +} + +// TenantViolation represents a detected tenant boundary violation. +type TenantViolation struct { + Type ViolationType + Severity Severity + TenantID string + TargetID string + Timestamp time.Time + Details string +} + +// TenantGuard validates tenant access and tracks violations. +type TenantGuard interface { + // GetMode returns the current guard mode. + GetMode() TenantGuardMode + + // ValidateAccess checks whether the given violation should be blocked. + // In Strict mode, it returns an error. In Lenient mode, it records the + // violation but returns nil. In Disabled mode, it is a no-op. + ValidateAccess(ctx context.Context, violation TenantViolation) error + + // GetRecentViolations returns a deep copy of recent violations, ordered oldest-first. + GetRecentViolations() []TenantViolation +} + +// TenantGuardConfig holds configuration for a StandardTenantGuard. +type TenantGuardConfig struct { + Mode TenantGuardMode + Whitelist map[string][]string // tenantID -> allowed target IDs + MaxViolations int // ring buffer capacity, default 1000 + LogViolations bool // whether to log violations, default true +} + +// DefaultTenantGuardConfig returns a TenantGuardConfig with sensible defaults. +func DefaultTenantGuardConfig() TenantGuardConfig { + return TenantGuardConfig{ + Mode: TenantGuardStrict, + Whitelist: make(map[string][]string), + MaxViolations: 1000, + LogViolations: true, + } +} + +// TenantGuardOption is a functional option for configuring a StandardTenantGuard. +type TenantGuardOption func(*StandardTenantGuard) + +// WithTenantGuardLogger sets a structured logger on the guard. +func WithTenantGuardLogger(l Logger) TenantGuardOption { + return func(g *StandardTenantGuard) { + g.logger = l + } +} + +// WithTenantGuardSubject sets a Subject for event emission on the guard. +func WithTenantGuardSubject(s Subject) TenantGuardOption { + return func(g *StandardTenantGuard) { + g.subject = s + } +} + +// StandardTenantGuard is the default TenantGuard implementation. +// It uses a ring buffer to store recent violations and optionally emits +// CloudEvents when violations are detected. +type StandardTenantGuard struct { + config TenantGuardConfig + whitelist map[string]map[string]struct{} // deep-copied set for fast lookups + violations []TenantViolation + head int + count int + mu sync.RWMutex + logger Logger + subject Subject +} + +// NewStandardTenantGuard creates a new StandardTenantGuard with the given config and options. +// The whitelist is deep-copied and converted to a set for safe, fast lookups. +func NewStandardTenantGuard(config TenantGuardConfig, opts ...TenantGuardOption) *StandardTenantGuard { + if config.MaxViolations <= 0 { + config.MaxViolations = 1000 + } + + // Deep-copy and convert whitelist to set + wl := make(map[string]map[string]struct{}, len(config.Whitelist)) + for tenant, targets := range config.Whitelist { + set := make(map[string]struct{}, len(targets)) + for _, t := range targets { + set[t] = struct{}{} + } + wl[tenant] = set + } + + g := &StandardTenantGuard{ + config: config, + whitelist: wl, + violations: make([]TenantViolation, config.MaxViolations), + } + + for _, opt := range opts { + opt(g) + } + + return g +} + +// GetMode returns the current guard mode. +func (g *StandardTenantGuard) GetMode() TenantGuardMode { + return g.config.Mode +} + +// ValidateAccess checks the violation against the guard's policy. +func (g *StandardTenantGuard) ValidateAccess(ctx context.Context, violation TenantViolation) error { + if g.config.Mode == TenantGuardDisabled { + return nil + } + + // Set timestamp if not provided + if violation.Timestamp.IsZero() { + violation.Timestamp = time.Now() + } + + // Check whitelist (set-based O(1) lookup) + if targets, ok := g.whitelist[violation.TenantID]; ok { + if _, allowed := targets[violation.TargetID]; allowed { + return nil + } + } + + // Record violation + g.mu.Lock() + g.addViolation(violation) + g.mu.Unlock() + + // Log if configured + if g.config.LogViolations && g.logger != nil { + g.logger.Warn("Tenant violation detected", + "type", violation.Type.String(), + "severity", violation.Severity.String(), + "tenant", violation.TenantID, + "target", violation.TargetID, + "details", violation.Details, + ) + } + + // Emit event using NewCloudEvent helper (sets ID, specversion, time) + if g.subject != nil { + event := NewCloudEvent(EventTypeTenantViolation, "com.modular.tenant.guard", violation, nil) + if err := g.subject.NotifyObservers(ctx, event); err != nil && g.logger != nil { + g.logger.Warn("Failed to emit tenant violation event", + "error", err, + "tenant", violation.TenantID, + "type", violation.Type.String(), + ) + } + } + + // In strict mode, return error + if g.config.Mode == TenantGuardStrict { + return ErrTenantIsolationViolation + } + + // Lenient mode: violation recorded, but allow the operation + return nil +} + +// GetRecentViolations returns a deep copy of recent violations ordered oldest-first. +func (g *StandardTenantGuard) GetRecentViolations() []TenantViolation { + g.mu.RLock() + defer g.mu.RUnlock() + + if g.count == 0 { + return nil + } + + result := make([]TenantViolation, g.count) + max := g.config.MaxViolations + + if g.count < max { + // Buffer not yet full — entries are at indices 0..count-1 + copy(result, g.violations[:g.count]) + } else { + // Buffer full — oldest is at head, wrap around + oldest := g.head % max + n := copy(result, g.violations[oldest:]) + copy(result[n:], g.violations[:oldest]) + } + + return result +} + +// addViolation writes a violation into the ring buffer. +// Caller must hold the write lock. +func (g *StandardTenantGuard) addViolation(v TenantViolation) { + max := g.config.MaxViolations + g.violations[g.head%max] = v + g.head++ + if g.count < max { + g.count++ + } +} diff --git a/tenant_guard_test.go b/tenant_guard_test.go new file mode 100644 index 00000000..60d8b3fd --- /dev/null +++ b/tenant_guard_test.go @@ -0,0 +1,335 @@ +package modular + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" +) + +// tenantGuardTestLogger is a Logger implementation that counts Warn calls for testing. +type tenantGuardTestLogger struct { + warnCalls atomic.Int32 +} + +func (l *tenantGuardTestLogger) Info(_ string, _ ...any) {} +func (l *tenantGuardTestLogger) Error(_ string, _ ...any) {} +func (l *tenantGuardTestLogger) Debug(_ string, _ ...any) {} +func (l *tenantGuardTestLogger) Warn(_ string, _ ...any) { + l.warnCalls.Add(1) +} + +func TestTenantGuardMode_String(t *testing.T) { + tests := []struct { + mode TenantGuardMode + want string + }{ + {TenantGuardStrict, "strict"}, + {TenantGuardLenient, "lenient"}, + {TenantGuardDisabled, "disabled"}, + {TenantGuardMode(99), "unknown(99)"}, + } + + for _, tt := range tests { + got := tt.mode.String() + if got != tt.want { + t.Errorf("TenantGuardMode(%d).String() = %q, want %q", int(tt.mode), got, tt.want) + } + } +} + +func TestViolationType_String(t *testing.T) { + tests := []struct { + vt ViolationType + want string + }{ + {CrossTenant, "cross_tenant"}, + {InvalidContext, "invalid_context"}, + {MissingContext, "missing_context"}, + {Unauthorized, "unauthorized"}, + {ViolationType(99), "unknown(99)"}, + } + + for _, tt := range tests { + got := tt.vt.String() + if got != tt.want { + t.Errorf("ViolationType(%d).String() = %q, want %q", int(tt.vt), got, tt.want) + } + } +} + +func TestSeverity_String(t *testing.T) { + tests := []struct { + sev Severity + want string + }{ + {SeverityLow, "low"}, + {SeverityMedium, "medium"}, + {SeverityHigh, "high"}, + {SeverityCritical, "critical"}, + {Severity(99), "unknown(99)"}, + } + + for _, tt := range tests { + got := tt.sev.String() + if got != tt.want { + t.Errorf("Severity(%d).String() = %q, want %q", int(tt.sev), got, tt.want) + } + } +} + +func TestStandardTenantGuard_StrictMode(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardStrict + guard := NewStandardTenantGuard(config) + + err := guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityHigh, + TenantID: "tenant-1", + TargetID: "tenant-2", + Details: "cross-tenant data access", + }) + + if err == nil { + t.Fatal("expected error in strict mode, got nil") + } + if !errors.Is(err, ErrTenantIsolationViolation) { + t.Errorf("expected ErrTenantIsolationViolation, got %v", err) + } + + violations := guard.GetRecentViolations() + if len(violations) != 1 { + t.Fatalf("expected 1 violation recorded, got %d", len(violations)) + } + if violations[0].TenantID != "tenant-1" { + t.Errorf("expected tenant-1, got %s", violations[0].TenantID) + } +} + +func TestStandardTenantGuard_LenientMode(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardLenient + + logger := &tenantGuardTestLogger{} + guard := NewStandardTenantGuard(config, WithTenantGuardLogger(logger)) + + err := guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityMedium, + TenantID: "tenant-1", + TargetID: "tenant-2", + Details: "lenient test", + }) + + if err != nil { + t.Fatalf("expected nil error in lenient mode, got %v", err) + } + + violations := guard.GetRecentViolations() + if len(violations) != 1 { + t.Fatalf("expected 1 violation recorded, got %d", len(violations)) + } + + if logger.warnCalls.Load() == 0 { + t.Error("expected log output for violation, got none") + } +} + +func TestStandardTenantGuard_DisabledMode(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardDisabled + guard := NewStandardTenantGuard(config) + + err := guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityCritical, + TenantID: "tenant-1", + TargetID: "tenant-2", + }) + + if err != nil { + t.Fatalf("expected nil error in disabled mode, got %v", err) + } + + violations := guard.GetRecentViolations() + if len(violations) != 0 { + t.Errorf("expected 0 violations in disabled mode, got %d", len(violations)) + } +} + +func TestStandardTenantGuard_Whitelist(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardStrict + config.Whitelist = map[string][]string{ + "tenant-1": {"tenant-2", "tenant-3"}, + } + guard := NewStandardTenantGuard(config) + + // Whitelisted access should succeed + err := guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityHigh, + TenantID: "tenant-1", + TargetID: "tenant-2", + }) + if err != nil { + t.Fatalf("expected nil for whitelisted access, got %v", err) + } + + // Non-whitelisted access should fail in strict mode + err = guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityHigh, + TenantID: "tenant-1", + TargetID: "tenant-99", + }) + if !errors.Is(err, ErrTenantIsolationViolation) { + t.Errorf("expected ErrTenantIsolationViolation for non-whitelisted access, got %v", err) + } + + // Only the non-whitelisted violation should be recorded + violations := guard.GetRecentViolations() + if len(violations) != 1 { + t.Fatalf("expected 1 violation, got %d", len(violations)) + } +} + +func TestStandardTenantGuard_RingBuffer(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardLenient + config.MaxViolations = 5 + config.LogViolations = false + guard := NewStandardTenantGuard(config) + + // Add 8 violations to a buffer of size 5 + for i := 0; i < 8; i++ { + _ = guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityLow, + TenantID: "tenant-1", + TargetID: "target-" + string(rune('A'+i)), + Details: "violation", + }) + } + + violations := guard.GetRecentViolations() + if len(violations) != 5 { + t.Fatalf("expected 5 violations (buffer size), got %d", len(violations)) + } + + // Oldest should be violation index 3 (target-D), newest should be index 7 (target-H) + expectedTargets := []string{"target-D", "target-E", "target-F", "target-G", "target-H"} + for i, v := range violations { + if v.TargetID != expectedTargets[i] { + t.Errorf("violation[%d].TargetID = %q, want %q", i, v.TargetID, expectedTargets[i]) + } + } +} + +func TestStandardTenantGuard_GetRecentViolations_DeepCopy(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardLenient + config.LogViolations = false + guard := NewStandardTenantGuard(config) + + _ = guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityHigh, + TenantID: "tenant-1", + TargetID: "tenant-2", + Details: "original", + }) + + // Get a copy and modify it + copy1 := guard.GetRecentViolations() + copy1[0].Details = "modified" + + // Get another copy — it should still have the original value + copy2 := guard.GetRecentViolations() + if copy2[0].Details != "original" { + t.Errorf("internal state was mutated: expected 'original', got %q", copy2[0].Details) + } +} + +func TestStandardTenantGuard_ConcurrentAccess(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardLenient + config.MaxViolations = 100 + config.LogViolations = false + guard := NewStandardTenantGuard(config) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _ = guard.ValidateAccess(context.Background(), TenantViolation{ + Type: CrossTenant, + Severity: SeverityLow, + TenantID: "tenant-1", + TargetID: "tenant-2", + }) + }(i) + } + wg.Wait() + + violations := guard.GetRecentViolations() + if len(violations) != 100 { + t.Errorf("expected 100 violations from concurrent access, got %d", len(violations)) + } +} + +func TestStandardTenantGuard_TimestampAutoSet(t *testing.T) { + config := DefaultTenantGuardConfig() + config.Mode = TenantGuardLenient + config.LogViolations = false + guard := NewStandardTenantGuard(config) + + before := time.Now() + _ = guard.ValidateAccess(context.Background(), TenantViolation{ + Type: MissingContext, + Severity: SeverityMedium, + TenantID: "tenant-1", + }) + after := time.Now() + + violations := guard.GetRecentViolations() + if len(violations) != 1 { + t.Fatalf("expected 1 violation, got %d", len(violations)) + } + + ts := violations[0].Timestamp + if ts.Before(before) || ts.After(after) { + t.Errorf("timestamp %v not between %v and %v", ts, before, after) + } +} + +func TestStandardTenantGuard_GetMode(t *testing.T) { + for _, mode := range []TenantGuardMode{TenantGuardStrict, TenantGuardLenient, TenantGuardDisabled} { + config := DefaultTenantGuardConfig() + config.Mode = mode + guard := NewStandardTenantGuard(config) + if guard.GetMode() != mode { + t.Errorf("GetMode() = %v, want %v", guard.GetMode(), mode) + } + } +} + +func TestStandardTenantGuard_DefaultMaxViolations(t *testing.T) { + config := DefaultTenantGuardConfig() + if config.MaxViolations != 1000 { + t.Errorf("DefaultTenantGuardConfig().MaxViolations = %d, want 1000", config.MaxViolations) + } + if !config.LogViolations { + t.Error("DefaultTenantGuardConfig().LogViolations should be true") + } + if config.Mode != TenantGuardStrict { + t.Errorf("DefaultTenantGuardConfig().Mode = %v, want strict", config.Mode) + } +} + +// Verify StandardTenantGuard satisfies the TenantGuard interface +var _ TenantGuard = (*StandardTenantGuard)(nil)