diff --git a/docs/installation/configuration.md b/docs/installation/configuration.md index f1d1498f..f802bc79 100644 --- a/docs/installation/configuration.md +++ b/docs/installation/configuration.md @@ -68,3 +68,4 @@ This is the current list of components that can be configured in the `logging.co - `migration_runner` - `scheduler_service` - `workflows_worker` +- `ports_service` diff --git a/server/cmd/root.go b/server/cmd/root.go index 88c6c2a1..c0945fcf 100644 --- a/server/cmd/root.go +++ b/server/cmd/root.go @@ -23,6 +23,7 @@ import ( "github.com/pgEdge/control-plane/server/internal/monitor" "github.com/pgEdge/control-plane/server/internal/orchestrator" "github.com/pgEdge/control-plane/server/internal/orchestrator/swarm" + "github.com/pgEdge/control-plane/server/internal/ports" "github.com/pgEdge/control-plane/server/internal/resource" "github.com/pgEdge/control-plane/server/internal/scheduler" "github.com/pgEdge/control-plane/server/internal/task" @@ -76,6 +77,7 @@ func newRootCmd(i *do.Injector) *cobra.Command { logging.Provide(i) migrate.Provide(i) monitor.Provide(i) + ports.Provide(i) resource.Provide(i) scheduler.Provide(i) workflows.Provide(i) diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 29d69143..beea4e9c 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -194,6 +194,32 @@ const ( EtcdModeClient EtcdMode = "client" ) +type RandomPorts struct { + Min int `koanf:"min" json:"min,omitempty"` + Max int `koanf:"max" json:"max,omitempty"` +} + +func (r RandomPorts) validate() []error { + var errs []error + if r.Min < 1 { + errs = append(errs, errors.New("min: cannot be less than 1")) + } + if r.Max > 65535 { + errs = append(errs, errors.New("max: cannot be greater than 65535")) + } + if r.Max <= r.Min { + errs = append(errs, errors.New("max: must be greater than min")) + } + return errs +} + +// We're intentionally using a range that's well below the ephemeral port range +// to reduce the risk of interference from the OS. +var defaultRandomPorts = RandomPorts{ + Min: 5432, + Max: 15432, +} + type Config struct { TenantID string `koanf:"tenant_id" json:"tenant_id,omitempty"` HostID string `koanf:"host_id" json:"host_id,omitempty"` @@ -216,6 +242,7 @@ type Config struct { DockerSwarm DockerSwarm `koanf:"docker_swarm" json:"docker_swarm,omitzero"` DatabaseOwnerUID int `koanf:"database_owner_uid" json:"database_owner_uid,omitempty"` ProfilingEnabled bool `koanf:"profiling_enabled" json:"profiling_enabled,omitempty"` + RandomPorts RandomPorts `koanf:"random_ports" json:"random_ports,omitzero"` } // ClientAddress is a convenience function to return the first client address. @@ -310,6 +337,9 @@ func (c Config) Validate() error { for _, err := range c.Logging.validate() { errs = append(errs, fmt.Errorf("logging.%w", err)) } + for _, err := range c.RandomPorts.validate() { + errs = append(errs, fmt.Errorf("random_ports.%w", err)) + } if c.Orchestrator != OrchestratorSwarm { errs = append(errs, fmt.Errorf("orchestrator: unsupported orchestrator %q", c.Orchestrator)) } @@ -361,6 +391,7 @@ func DefaultConfig() (Config, error) { EtcdClient: etcdClientDefault, DockerSwarm: defaultDockerSwarm, DatabaseOwnerUID: 26, + RandomPorts: defaultRandomPorts, }, nil } diff --git a/server/internal/database/instance_spec_store.go b/server/internal/database/instance_spec_store.go new file mode 100644 index 00000000..76f3e803 --- /dev/null +++ b/server/internal/database/instance_spec_store.go @@ -0,0 +1,66 @@ +package database + +import ( + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/pgEdge/control-plane/server/internal/storage" +) + +type StoredInstanceSpec struct { + storage.StoredValue + Spec *InstanceSpec `json:"spec"` +} + +type InstanceSpecStore struct { + client *clientv3.Client + root string +} + +func NewInstanceSpecStore(client *clientv3.Client, root string) *InstanceSpecStore { + return &InstanceSpecStore{ + client: client, + root: root, + } +} + +func (s *InstanceSpecStore) Prefix() string { + return storage.Prefix("/", s.root, "instance_specs") +} + +func (s *InstanceSpecStore) DatabasePrefix(databaseID string) string { + return storage.Prefix(s.Prefix(), databaseID) +} + +func (s *InstanceSpecStore) Key(databaseID, instanceID string) string { + return storage.Key(s.DatabasePrefix(databaseID), instanceID) +} + +func (s *InstanceSpecStore) GetByKey(databaseID, instanceID string) storage.GetOp[*StoredInstanceSpec] { + key := s.Key(databaseID, instanceID) + return storage.NewGetOp[*StoredInstanceSpec](s.client, key) +} + +func (s *InstanceSpecStore) GetByDatabaseID(databaseID string) storage.GetMultipleOp[*StoredInstanceSpec] { + prefix := s.DatabasePrefix(databaseID) + return storage.NewGetPrefixOp[*StoredInstanceSpec](s.client, prefix) +} + +func (s *InstanceSpecStore) GetAll() storage.GetMultipleOp[*StoredInstanceSpec] { + prefix := s.Prefix() + return storage.NewGetPrefixOp[*StoredInstanceSpec](s.client, prefix) +} + +func (s *InstanceSpecStore) Update(item *StoredInstanceSpec) storage.PutOp[*StoredInstanceSpec] { + key := s.Key(item.Spec.DatabaseID, item.Spec.InstanceID) + return storage.NewUpdateOp(s.client, key, item) +} + +func (s *InstanceSpecStore) DeleteByKey(databaseID, instanceID string) storage.DeleteOp { + key := s.Key(databaseID, instanceID) + return storage.NewDeleteKeyOp(s.client, key) +} + +func (s *InstanceSpecStore) DeleteByDatabaseID(databaseID string) storage.DeleteOp { + prefix := s.DatabasePrefix(databaseID) + return storage.NewDeletePrefixOp(s.client, prefix) +} diff --git a/server/internal/database/provide.go b/server/internal/database/provide.go index 9613dd53..ce7c6848 100644 --- a/server/internal/database/provide.go +++ b/server/internal/database/provide.go @@ -6,6 +6,7 @@ import ( "github.com/pgEdge/control-plane/server/internal/config" "github.com/pgEdge/control-plane/server/internal/host" + "github.com/pgEdge/control-plane/server/internal/ports" ) func Provide(i *do.Injector) { @@ -15,6 +16,10 @@ func Provide(i *do.Injector) { func provideService(i *do.Injector) { do.Provide(i, func(i *do.Injector) (*Service, error) { + cfg, err := do.Invoke[config.Config](i) + if err != nil { + return nil, err + } orch, err := do.Invoke[Orchestrator](i) if err != nil { return nil, err @@ -27,7 +32,11 @@ func provideService(i *do.Injector) { if err != nil { return nil, err } - return NewService(orch, store, hostSvc), nil + portsSvc, err := do.Invoke[*ports.Service](i) + if err != nil { + return nil, err + } + return NewService(cfg, orch, store, hostSvc, portsSvc), nil }) } diff --git a/server/internal/database/service.go b/server/internal/database/service.go index 29d536a8..458bc699 100644 --- a/server/internal/database/service.go +++ b/server/internal/database/service.go @@ -8,8 +8,11 @@ import ( "github.com/google/uuid" + "github.com/pgEdge/control-plane/server/internal/config" "github.com/pgEdge/control-plane/server/internal/host" + "github.com/pgEdge/control-plane/server/internal/ports" "github.com/pgEdge/control-plane/server/internal/storage" + "github.com/pgEdge/control-plane/server/internal/utils" ) var ( @@ -24,16 +27,26 @@ var ( ) type Service struct { + cfg config.Config orchestrator Orchestrator store *Store hostSvc *host.Service + portsSvc *ports.Service } -func NewService(orchestrator Orchestrator, store *Store, hostSvc *host.Service) *Service { +func NewService( + cfg config.Config, + orchestrator Orchestrator, + store *Store, + hostSvc *host.Service, + portsSvc *ports.Service, +) *Service { return &Service{ + cfg: cfg, orchestrator: orchestrator, store: store, hostSvc: hostSvc, + portsSvc: portsSvc, } } @@ -119,6 +132,18 @@ func (s *Service) UpdateDatabase(ctx context.Context, state DatabaseState, spec } func (s *Service) DeleteDatabase(ctx context.Context, databaseID string) error { + specs, err := s.store.InstanceSpec. + GetByDatabaseID(databaseID). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to get instance specs: %w", err) + } + for _, spec := range specs { + if err := s.releaseInstancePorts(ctx, spec.Spec); err != nil { + return err + } + } + var ops []storage.TxnOperation spec, err := s.store.Spec.GetByKey(databaseID).Exec(ctx) @@ -141,6 +166,7 @@ func (s *Service) DeleteDatabase(ctx context.Context, databaseID string) error { ops = append(ops, s.store.Instance.DeleteByDatabaseID(databaseID), + s.store.InstanceSpec.DeleteByDatabaseID(databaseID), s.store.InstanceStatus.DeleteByDatabaseID(databaseID), ) @@ -279,7 +305,7 @@ func (s *Service) DeleteInstance(ctx context.Context, databaseID, instanceID str return fmt.Errorf("failed to delete stored instance status: %w", err) } - return nil + return s.DeleteInstanceSpec(ctx, databaseID, instanceID) } func (s *Service) UpdateInstanceStatus( @@ -490,6 +516,134 @@ func (s *Service) PopulateSpecDefaults(ctx context.Context, spec *Spec) error { return nil } +func (s *Service) ReconcileInstanceSpec(ctx context.Context, spec *InstanceSpec) (*InstanceSpec, error) { + if s.cfg.HostID != spec.HostID { + return nil, fmt.Errorf("this instance belongs to another host - this host='%s', instance host='%s'", s.cfg.HostID, spec.HostID) + } + + var previous *InstanceSpec + stored, err := s.store.InstanceSpec. + GetByKey(spec.DatabaseID, spec.InstanceID). + Exec(ctx) + switch { + case err == nil: + previous = stored.Spec + spec.CopySettingsFrom(previous) + case errors.Is(err, storage.ErrNotFound): + stored = &StoredInstanceSpec{} + default: + return nil, fmt.Errorf("failed to get current spec for instance '%s': %w", spec.InstanceID, err) + } + + var allocated []int + rollback := func(cause error) error { + rollbackCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + return errors.Join(cause, s.portsSvc.ReleasePort(rollbackCtx, spec.HostID, allocated...)) + } + + if spec.Port != nil && *spec.Port == 0 { + port, err := s.portsSvc.AllocatePort(ctx, spec.HostID) + if err != nil { + return nil, fmt.Errorf("failed to allocate port: %w", err) + } + allocated = append(allocated, port) + spec.Port = utils.PointerTo(port) + } + + if spec.PatroniPort != nil && *spec.PatroniPort == 0 { + port, err := s.portsSvc.AllocatePort(ctx, spec.HostID) + if err != nil { + return nil, rollback(fmt.Errorf("failed to allocate patroni port: %w", err)) + } + allocated = append(allocated, port) + spec.PatroniPort = utils.PointerTo(port) + } + + stored.Spec = spec + err = s.store.InstanceSpec. + Update(stored). + Exec(ctx) + if err != nil { + return nil, rollback(fmt.Errorf("failed to persist updated instance spec: %w", err)) + } + + if err := s.releasePreviousSpecPorts(ctx, previous, spec); err != nil { + return nil, err + } + + return spec, nil +} + +func (s *Service) DeleteInstanceSpec(ctx context.Context, databaseID, instanceID string) error { + spec, err := s.store.InstanceSpec. + GetByKey(databaseID, instanceID). + Exec(ctx) + if errors.Is(err, storage.ErrNotFound) { + // Spec has already been deleted + return nil + } else if err != nil { + return fmt.Errorf("failed to check if instance spec exists: %w", err) + } + + if err := s.releaseInstancePorts(ctx, spec.Spec); err != nil { + return err + } + + _, err = s.store.InstanceSpec. + DeleteByKey(databaseID, instanceID). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to delete instance spec: %w", err) + } + + return nil +} + +func (s *Service) releaseInstancePorts(ctx context.Context, spec *InstanceSpec) error { + err := s.portsSvc.ReleasePortIfDefined(ctx, spec.HostID, spec.Port, spec.PatroniPort) + if err != nil { + return fmt.Errorf("failed to release ports for instance '%s': %w", spec.InstanceID, err) + } + + return nil +} + +func (s *Service) releasePreviousSpecPorts(ctx context.Context, previous, new *InstanceSpec) error { + if previous == nil { + return nil + } + if portShouldBeReleased(previous.Port, new.Port) { + err := s.portsSvc.ReleasePortIfDefined(ctx, previous.HostID, previous.Port) + if err != nil { + return fmt.Errorf("failed to release previous port: %w", err) + } + } + if portShouldBeReleased(previous.PatroniPort, new.PatroniPort) { + err := s.portsSvc.ReleasePortIfDefined(ctx, previous.HostID, previous.PatroniPort) + if err != nil { + return fmt.Errorf("failed to release previous patroni port: %w", err) + } + } + return nil +} + +func portShouldBeReleased(current *int, new *int) bool { + if current == nil || *current == 0 { + // we didn't previously have an assigned port + return false + } + if new == nil || *current != *new { + // we had a previously assigned port and now the port is either nil or + // different + return true + } + + // the current and new ports are equal, so it should not be released. + return false +} + func ValidateChangedSpec(current, updated *Spec) error { var errs []error diff --git a/server/internal/database/spec.go b/server/internal/database/spec.go index bc219b43..9df1de55 100644 --- a/server/internal/database/spec.go +++ b/server/internal/database/spec.go @@ -508,6 +508,25 @@ type InstanceSpec struct { InPlaceRestore bool `json:"in_place_restore,omitempty"` } +func (s *InstanceSpec) CopySettingsFrom(current *InstanceSpec) { + s.Port = reconcilePort(current.Port, s.Port) + s.PatroniPort = reconcilePort(current.PatroniPort, s.PatroniPort) +} + +func reconcilePort(current, new *int) *int { + if new == nil || *new != 0 { + // no action needed if the new port is unexposed or explicitly set + return new + } + if current != nil && *current != 0 { + // we've already assigned a stable random port here + return utils.PointerTo(*current) // create new pointer + } + + // return 0 to signal that we need to assign a new random port + return utils.PointerTo(0) +} + type InstanceSpecChange struct { Previous *InstanceSpec Current *InstanceSpec diff --git a/server/internal/database/store.go b/server/internal/database/store.go index c7aed96f..b280aefc 100644 --- a/server/internal/database/store.go +++ b/server/internal/database/store.go @@ -12,6 +12,7 @@ type Store struct { Database *DatabaseStore Instance *InstanceStore InstanceStatus *InstanceStatusStore + InstanceSpec *InstanceSpecStore ServiceInstance *ServiceInstanceStore ServiceInstanceStatus *ServiceInstanceStatusStore } @@ -23,6 +24,7 @@ func NewStore(client *clientv3.Client, root string) *Store { Database: NewDatabaseStore(client, root), Instance: NewInstanceStore(client, root), InstanceStatus: NewInstanceStatusStore(client, root), + InstanceSpec: NewInstanceSpecStore(client, root), ServiceInstance: NewServiceInstanceStore(client, root), ServiceInstanceStatus: NewServiceInstanceStatusStore(client, root), } diff --git a/server/internal/ports/ports.go b/server/internal/ports/ports.go new file mode 100644 index 00000000..e9b591d8 --- /dev/null +++ b/server/internal/ports/ports.go @@ -0,0 +1,151 @@ +package ports + +import ( + "errors" + "fmt" + "math/big" + "math/rand" +) + +const ( + MinValidPort = 1 + MaxValidPort = 65535 + // bitmapBytes is the fixed size of the serialized port bitmap. + // Bit N of the big.Int represents port N (1-indexed). + bitmapBytes = (MaxValidPort + 7) / 8 // 8192 bytes +) + +var ( + ErrFull = errors.New("range is full") + ErrAllocated = errors.New("provided port is already allocated") +) + +type ErrNotInRange struct { + Min int + Max int +} + +func (e *ErrNotInRange) Error() string { + return fmt.Sprintf("provided port is not in the valid range. The range of valid ports is [%d, %d]", e.Min, e.Max) +} + +// PortRange tracks allocated ports across the full valid range [1, 65535] +// using a big.Int bitmap, where bit N represents port N. Random allocation +// draws only from [min, max], but any port in [1, 65535] can be recorded via +// Allocate. This means min and max can be reconfigured without affecting +// previously stored state. +type PortRange struct { + min int + max int + bits big.Int +} + +// NewPortRange creates a PortRange for the given spec. +func NewPortRange(min, max int) (*PortRange, error) { + if min < MinValidPort || max > MaxValidPort { + return nil, fmt.Errorf("ports must be in the range [%d, %d]", MinValidPort, MaxValidPort) + } + if min > max { + return nil, fmt.Errorf("min port %d is greater than max port %d", min, max) + } + return &PortRange{min: min, max: max}, nil +} + +// Free returns the count of unallocated ports within [min, max]. +func (r *PortRange) Free() int { + return r.max - r.min + 1 - r.Used() +} + +// Used returns the count of allocated ports within [min, max]. +func (r *PortRange) Used() int { + count := 0 + for p := r.min; p <= r.max; p++ { + if r.bits.Bit(p) == 1 { + count++ + } + } + return count +} + +// Allocate reserves the given port. Any port in [1, 65535] may be recorded, +// including ports outside [min, max]. ErrAllocated is returned if the port is +// already reserved. +func (r *PortRange) Allocate(port int) error { + if port < MinValidPort || port > MaxValidPort { + return &ErrNotInRange{MinValidPort, MaxValidPort} + } + if r.bits.Bit(port) == 1 { + return ErrAllocated + } + r.bits.SetBit(&r.bits, port, 1) + return nil +} + +// AllocateNext reserves a random unallocated port from [min, max]. ErrFull is +// returned if all ports in the range are allocated. +func (r *PortRange) AllocateNext() (int, error) { + free := r.Free() + if free == 0 { + return 0, ErrFull + } + n := rand.Intn(free) + for p := r.min; p <= r.max; p++ { + if r.bits.Bit(p) == 0 { + if n == 0 { + r.bits.SetBit(&r.bits, p, 1) + return p, nil + } + n-- + } + } + return 0, ErrFull // unreachable +} + +// Release clears the port's allocated bit. Out-of-range or unallocated ports are +// silently ignored. +func (r *PortRange) Release(port int) error { + if port < MinValidPort || port > MaxValidPort { + return nil + } + r.bits.SetBit(&r.bits, port, 0) + return nil +} + +// Has returns true if the given port is currently allocated. +func (r *PortRange) Has(port int) bool { + if port < MinValidPort || port > MaxValidPort { + return false + } + return r.bits.Bit(port) == 1 +} + +// ForEach calls fn for every allocated port across the full valid range +// [1, 65535]. +func (r *PortRange) ForEach(fn func(int)) { + for p := MinValidPort; p <= MaxValidPort; p++ { + if r.bits.Bit(p) == 1 { + fn(p) + } + } +} + +// Snapshot saves the current allocation state. The spec string encodes +// the current min/max, and the data is a fixed-size big-endian bitmap of all +// 65535 ports. +func (r *PortRange) Snapshot() []byte { + data := make([]byte, bitmapBytes) + r.bits.FillBytes(data) + return data +} + +// Restore loads a previously saved bitmap. The spec in specStr must be valid +// JSON, but a min/max mismatch does not cause an error — the current range's +// min/max are preserved, allowing the configuration to be changed without +// losing allocation history. +func (r *PortRange) Restore(data []byte) error { + if len(data) != bitmapBytes { + return fmt.Errorf("snapshot data size mismatch: expected %d bytes, got %d", bitmapBytes, len(data)) + } + r.bits.SetBytes(data) + return nil +} diff --git a/server/internal/ports/ports_test.go b/server/internal/ports/ports_test.go new file mode 100644 index 00000000..0687e282 --- /dev/null +++ b/server/internal/ports/ports_test.go @@ -0,0 +1,190 @@ +package ports_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pgEdge/control-plane/server/internal/ports" +) + +func TestPortRange(t *testing.T) { + t.Run("basic usage", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + require.NotNil(t, r) + assert.Equal(t, 100, r.Free()) + assert.Equal(t, 0, r.Used()) + + // AllocateNext returns a port within the range + port, err := r.AllocateNext() + require.NoError(t, err) + assert.GreaterOrEqual(t, port, 5432) + assert.LessOrEqual(t, port, 5531) + assert.True(t, r.Has(port)) + assert.Equal(t, 99, r.Free()) + assert.Equal(t, 1, r.Used()) + + // ForEach visits allocated ports + var seen []int + r.ForEach(func(p int) { seen = append(seen, p) }) + assert.Equal(t, []int{port}, seen) + + // Release returns the port to the pool + require.NoError(t, r.Release(port)) + assert.False(t, r.Has(port)) + assert.Equal(t, 100, r.Free()) + + // Allocate reserves a specific port + require.NoError(t, r.Allocate(5500)) + assert.True(t, r.Has(5500)) + assert.Equal(t, 99, r.Free()) + + // Release the manually-allocated port + require.NoError(t, r.Release(5500)) + assert.False(t, r.Has(5500)) + assert.Equal(t, 100, r.Free()) + }) + + t.Run("single port range", func(t *testing.T) { + r, err := ports.NewPortRange(8080, 8080) + require.NoError(t, err) + assert.Equal(t, 1, r.Free()) + + port, err := r.AllocateNext() + require.NoError(t, err) + assert.Equal(t, 8080, port) + + _, err = r.AllocateNext() + assert.ErrorIs(t, err, ports.ErrFull) + }) + + t.Run("allocate specific port already in use returns ErrAllocated", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + + require.NoError(t, r.Allocate(5440)) + err = r.Allocate(5440) + assert.ErrorIs(t, err, ports.ErrAllocated) + }) + + t.Run("allocate port outside valid range [1, 65535] returns ErrNotInRange", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + + err = r.Allocate(0) + var notInRange *ports.ErrNotInRange + assert.ErrorAs(t, err, ¬InRange) + assert.Equal(t, ports.MinValidPort, notInRange.Min) + assert.Equal(t, ports.MaxValidPort, notInRange.Max) + + err = r.Allocate(65536) + assert.ErrorAs(t, err, ¬InRange) + }) + + t.Run("allocate port outside [min, max] but within [1, 65535] succeeds", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + + // Ports outside [min, max] can be recorded; they don't count toward Free/Used. + require.NoError(t, r.Allocate(5431)) + require.NoError(t, r.Allocate(5532)) + assert.True(t, r.Has(5431)) + assert.True(t, r.Has(5532)) + assert.Equal(t, 100, r.Free()) + }) + + t.Run("release out-of-range port is a no-op", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + assert.NoError(t, r.Release(9999)) + assert.Equal(t, 100, r.Free()) + }) + + t.Run("Has returns false for out-of-range port", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + assert.False(t, r.Has(0)) + assert.False(t, r.Has(65536)) + }) + + t.Run("allocates all ports in range before exhausting", func(t *testing.T) { + r, err := ports.NewPortRange(3000, 3004) + require.NoError(t, err) + + seen := make(map[int]bool) + for range 5 { + port, err := r.AllocateNext() + require.NoError(t, err) + assert.GreaterOrEqual(t, port, 3000) + assert.LessOrEqual(t, port, 3004) + assert.False(t, seen[port], "port %d allocated twice", port) + seen[port] = true + } + assert.Len(t, seen, 5) + + _, err = r.AllocateNext() + assert.ErrorIs(t, err, ports.ErrFull) + }) + + t.Run("snapshot restore", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + + port, err := r.AllocateNext() + require.NoError(t, err) + + snapshot := r.Snapshot() + assert.NotEmpty(t, snapshot) + + restored, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + require.NoError(t, restored.Restore(snapshot)) + + assert.True(t, restored.Has(port)) + assert.Equal(t, 99, restored.Free()) + }) + + t.Run("restore with different min/max succeeds and preserves allocations", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + + require.NoError(t, r.Allocate(5440)) + + snapshot := r.Snapshot() + + // Restore into a range with different min/max — should succeed. + other, err := ports.NewPortRange(6000, 6099) + require.NoError(t, err) + require.NoError(t, other.Restore(snapshot)) + + // The previously allocated port is still marked as allocated. + assert.True(t, other.Has(5440)) + }) + + t.Run("restore with wrong-size data returns error", func(t *testing.T) { + r, err := ports.NewPortRange(5432, 5531) + require.NoError(t, err) + + err = r.Restore([]byte("too short")) + assert.Error(t, err) + }) +} + +func TestNewPortRange_validation(t *testing.T) { + t.Run("min greater than max", func(t *testing.T) { + _, err := ports.NewPortRange(5531, 5432) + assert.Error(t, err) + }) + + t.Run("port zero is invalid", func(t *testing.T) { + _, err := ports.NewPortRange(0, 100) + assert.Error(t, err) + }) + + t.Run("port above 65535 is invalid", func(t *testing.T) { + _, err := ports.NewPortRange(1, 65536) + assert.Error(t, err) + }) +} diff --git a/server/internal/ports/provide.go b/server/internal/ports/provide.go new file mode 100644 index 00000000..62d6b882 --- /dev/null +++ b/server/internal/ports/provide.go @@ -0,0 +1,53 @@ +package ports + +import ( + "fmt" + + "github.com/samber/do" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/pgEdge/control-plane/server/internal/config" + "github.com/pgEdge/control-plane/server/internal/host" + "github.com/pgEdge/control-plane/server/internal/logging" +) + +func Provide(i *do.Injector) { + provideStore(i) + provideService(i) +} + +func provideService(i *do.Injector) { + do.Provide(i, func(i *do.Injector) (*Service, error) { + cfg, err := do.Invoke[config.Config](i) + if err != nil { + return nil, fmt.Errorf("failed to get config: %w", err) + } + store, err := do.Invoke[*Store](i) + if err != nil { + return nil, fmt.Errorf("failed to get store: %w", err) + } + loggerFactory, err := do.Invoke[*logging.Factory](i) + if err != nil { + return nil, fmt.Errorf("failed to get logger: %w", err) + } + hostSvc, err := do.Invoke[*host.Service](i) + if err != nil { + return nil, fmt.Errorf("failed to get host service: %w", err) + } + return NewService(cfg, loggerFactory, store, DefaultPortChecker, hostSvc), nil + }) +} + +func provideStore(i *do.Injector) { + do.Provide(i, func(i *do.Injector) (*Store, error) { + cfg, err := do.Invoke[config.Config](i) + if err != nil { + return nil, err + } + client, err := do.Invoke[*clientv3.Client](i) + if err != nil { + return nil, err + } + return NewStore(client, cfg.EtcdKeyRoot), nil + }) +} diff --git a/server/internal/ports/service.go b/server/internal/ports/service.go new file mode 100644 index 00000000..07a736af --- /dev/null +++ b/server/internal/ports/service.go @@ -0,0 +1,260 @@ +package ports + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "net" + "sync" + "time" + + "github.com/rs/zerolog" + + "github.com/pgEdge/control-plane/server/internal/config" + "github.com/pgEdge/control-plane/server/internal/host" + "github.com/pgEdge/control-plane/server/internal/logging" + "github.com/pgEdge/control-plane/server/internal/storage" + "github.com/pgEdge/control-plane/server/internal/utils" +) + +// Using a higher number of retries here because we expect high contention in +// some topologies. +const maxRetries = 6 +const sleepBase = 100 * time.Millisecond +const sleepJitterMS = 100 + +// PortChecker reports whether a port is available for binding. +type PortChecker func(port int) bool + +// DefaultPortChecker returns true if the given TCP port is not currently bound +// by any process on the local host. +func DefaultPortChecker(port int) bool { + l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return false + } + l.Close() + return true +} + +type Service struct { + mu sync.Mutex + cfg config.Config + logger zerolog.Logger + store *Store + portChecker PortChecker + hostSvc *host.Service +} + +func NewService( + cfg config.Config, + loggerFactory *logging.Factory, + store *Store, + portChecker PortChecker, + hostSvc *host.Service, +) *Service { + return &Service{ + cfg: cfg, + logger: loggerFactory.Logger("ports_service"), + store: store, + portChecker: portChecker, + hostSvc: hostSvc, + } +} + +// AllocatePort allocates the next available port in [min, max] that is not +// already recorded in the persistent range and is not currently bound on the +// local host. +func (s *Service) AllocatePort(ctx context.Context, hostID string) (int, error) { + if hostID != s.cfg.HostID { + return 0, fmt.Errorf("cannot allocate a new port for another host, this host='%s', requested host='%s'", s.cfg.HostID, hostID) + } + + s.mu.Lock() + defer s.mu.Unlock() + + return s.allocatePort(ctx, maxRetries) +} + +func (s *Service) allocatePort(ctx context.Context, retriesRemaining int) (int, error) { + if retriesRemaining < 1 { + // This can happen if there's too much contention for this port range + // across multiple hosts. + return 0, errors.New("failed to allocate port: exhausted retries") + } + + logger := s.logger.With(). + Int("retries_remaining", retriesRemaining). + Logger() + + logger.Debug(). + Int("range_min", s.cfg.RandomPorts.Min). + Int("range_max", s.cfg.RandomPorts.Max). + Msg("attempting to allocate a random port") + + name := s.cfg.ClientAddress() + min := s.cfg.RandomPorts.Min + max := s.cfg.RandomPorts.Max + + r, err := NewPortRange(min, max) + if err != nil { + return 0, fmt.Errorf("failed to create port allocator: %w", err) + } + + stored, err := s.restoreAllocator(ctx, r, name) + if err != nil { + return 0, fmt.Errorf("failed to restore port allocator from storage: %w", err) + } + + port, err := s.allocateAvailablePort(r) + if err != nil { + return 0, err + } + + stored.Snapshot = r.Snapshot() + + err = s.store.Update(stored).Exec(ctx) + if errors.Is(err, storage.ErrValueVersionMismatch) { + sleepDuration := addJitter(sleepBase, sleepJitterMS) + + logger.Debug(). + Int64("sleep_milliseconds", sleepDuration.Milliseconds()). + Msg("encountered conflict. sleeping before reattempting.") + + time.Sleep(sleepDuration) + + return s.allocatePort(ctx, retriesRemaining-1) + } else if err != nil { + return 0, fmt.Errorf("failed to store port allocator: %w", err) + } + + logger.Debug(). + Int("port", port). + Msg("successfully allocated random port") + + return port, nil +} + +func (s *Service) ReleasePortIfDefined(ctx context.Context, hostID string, ports ...*int) error { + defined := make([]int, 0, len(ports)) + for _, port := range ports { + if p := utils.FromPointer(port); p != 0 { + defined = append(defined, p) + } + } + + return s.ReleasePort(ctx, hostID, defined...) +} + +// ReleasePort releases the given port back to the pool, persisting the updated +// state to storage. +func (s *Service) ReleasePort(ctx context.Context, hostID string, ports ...int) error { + host, err := s.hostSvc.GetHost(ctx, hostID) + if err != nil { + return fmt.Errorf("failed to get host '%s': %w", hostID, err) + } + if len(host.ClientAddresses) == 0 { + return fmt.Errorf("host '%s' has no client addresses", hostID) + } + + name := host.ClientAddresses[0] + + s.mu.Lock() + defer s.mu.Unlock() + + return s.releasePort(ctx, name, ports, maxRetries) +} + +func (s *Service) releasePort(ctx context.Context, name string, ports []int, retriesRemaining int) error { + if retriesRemaining < 1 { + return errors.New("failed to release port: exhausted retries") + } + + logger := s.logger.With(). + Int("retries_remaining", retriesRemaining). + Ints("ports", ports). + Logger() + + logger.Debug().Msg("attempting to release port") + + r, err := NewPortRange(s.cfg.RandomPorts.Min, s.cfg.RandomPorts.Max) + if err != nil { + return fmt.Errorf("failed to create port allocator: %w", err) + } + + stored, err := s.restoreAllocator(ctx, r, name) + if err != nil { + return fmt.Errorf("failed to restore port allocator from storage: %w", err) + } + + for _, port := range ports { + if err := r.Release(port); err != nil { + return fmt.Errorf("failed to release port: %w", err) + } + } + + stored.Snapshot = r.Snapshot() + + err = s.store.Update(stored).Exec(ctx) + if errors.Is(err, storage.ErrValueVersionMismatch) { + sleepDuration := addJitter(sleepBase, sleepJitterMS) + + logger.Debug(). + Int64("sleep_milliseconds", sleepDuration.Milliseconds()). + Msg("encountered conflict. sleeping before reattempting.") + + time.Sleep(sleepDuration) + + return s.releasePort(ctx, name, ports, retriesRemaining-1) + } else if err != nil { + return fmt.Errorf("failed to store port allocator: %w", err) + } + + logger.Debug().Msg("successfully released port") + + return nil +} + +// allocateAvailablePort calls AllocateNext until it finds a port that passes +// the OS availability check. Ports that are occupied by external processes are +// kept as allocated in the range so they are not retried on future calls. +func (s *Service) allocateAvailablePort(r *PortRange) (int, error) { + for { + port, err := r.AllocateNext() + if err != nil { + return 0, fmt.Errorf("failed to allocate port: %w", err) + } + if s.portChecker(port) { + return port, nil + } + s.logger.Debug(). + Int("port", port). + Msg("port is in use by another process, skipping") + } +} + +func (s *Service) restoreAllocator(ctx context.Context, r *PortRange, name string) (*StoredPortRange, error) { + stored, err := s.store.GetByKey(name).Exec(ctx) + if errors.Is(err, storage.ErrNotFound) { + return &StoredPortRange{ + Name: name, + }, nil + } else if err != nil { + return nil, fmt.Errorf("failed to get port allocator spec from storage: %w", err) + } + if err := r.Restore(stored.Snapshot); err != nil { + // An error can happen here if the config has changed. In this case, + // continue without restoring and overwrite the old allocator on the + // next allocation. + s.logger.Warn(). + Err(err). + Msg("failed to restore port allocator") + } + return stored, nil +} + +func addJitter(base time.Duration, jitterMS uint) time.Duration { + jitter := time.Duration(rand.N(jitterMS)) * time.Millisecond + return base + jitter +} diff --git a/server/internal/ports/store.go b/server/internal/ports/store.go new file mode 100644 index 00000000..1006c9ff --- /dev/null +++ b/server/internal/ports/store.go @@ -0,0 +1,48 @@ +package ports + +import ( + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/pgEdge/control-plane/server/internal/storage" +) + +type StoredPortRange struct { + storage.StoredValue + Name string `json:"name"` + Snapshot []byte `json:"snapshot"` +} + +type Store struct { + client *clientv3.Client + root string +} + +func NewStore(client *clientv3.Client, root string) *Store { + return &Store{ + client: client, + root: root, + } +} + +func (s *Store) Prefix() string { + return storage.Prefix("/", s.root, "ports") +} + +func (s *Store) Key(allocatorName string) string { + return storage.Key(s.Prefix(), allocatorName) +} + +func (s *Store) ExistsByKey(allocatorName string) storage.ExistsOp { + key := s.Key(allocatorName) + return storage.NewExistsOp(s.client, key) +} + +func (s *Store) GetByKey(allocatorName string) storage.GetOp[*StoredPortRange] { + key := s.Key(allocatorName) + return storage.NewGetOp[*StoredPortRange](s.client, key) +} + +func (s *Store) Update(item *StoredPortRange) storage.PutOp[*StoredPortRange] { + key := s.Key(item.Name) + return storage.NewUpdateOp(s.client, key, item) +} diff --git a/server/internal/workflows/activities/get_instance_resources.go b/server/internal/workflows/activities/get_instance_resources.go index 7cc0ad46..672d3a7b 100644 --- a/server/internal/workflows/activities/get_instance_resources.go +++ b/server/internal/workflows/activities/get_instance_resources.go @@ -39,7 +39,12 @@ func (a *Activities) GetInstanceResources(ctx context.Context, input *GetInstanc ) logger.Info("getting instance resources") - resources, err := a.Orchestrator.GenerateInstanceResources(input.Spec) + spec, err := a.DatabaseService.ReconcileInstanceSpec(ctx, input.Spec) + if err != nil { + return nil, fmt.Errorf("failed to reconcile instance spec: %w", err) + } + + resources, err := a.Orchestrator.GenerateInstanceResources(spec) if err != nil { return nil, fmt.Errorf("failed to generate instance resources: %w", err) } diff --git a/server/internal/workflows/activities/get_restore_resources.go b/server/internal/workflows/activities/get_restore_resources.go index 85848db9..ceeb06d0 100644 --- a/server/internal/workflows/activities/get_restore_resources.go +++ b/server/internal/workflows/activities/get_restore_resources.go @@ -43,12 +43,17 @@ func (a *Activities) GetRestoreResources(ctx context.Context, input *GetRestoreR ) logger.Info("getting restore resources") - resources, err := a.Orchestrator.GenerateInstanceResources(input.Spec) + spec, err := a.DatabaseService.ReconcileInstanceSpec(ctx, input.Spec) + if err != nil { + return nil, fmt.Errorf("failed to reconcile instance spec: %w", err) + } + + resources, err := a.Orchestrator.GenerateInstanceResources(spec) if err != nil { return nil, fmt.Errorf("failed to generate instance resources: %w", err) } - restoreSpec := input.Spec.Clone() + restoreSpec := spec.Clone() restoreSpec.RestoreConfig = input.RestoreConfig restoreResources, err := a.Orchestrator.GenerateInstanceRestoreResources(restoreSpec, input.TaskID) if err != nil {