Skip to content
8 changes: 8 additions & 0 deletions docs/scheduling.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ For long-lived automation, define jobs in `ace.yaml` and let `./ace start` orche
- `schedule_jobs`: describes each job (what to run and with which arguments).
- `schedule_config`: pairs a job name with either a frequency or a cron expression and marks it enabled/disabled.

### Reloading the Configuration at Runtime

If you change the configuration and do not wish to restart the process, you can send SIGHUP (1) to reload it. This works for all long-running ACE modes:

- **`ace start` / `ace start --component=scheduler`** — waits for in-flight jobs to complete, then swaps in the new configuration.
- **`ace start --component=api` / `ace server`** — applies the new configuration immediately and reloads the mTLS security config (certificate revocation list and allowed CN list).
- **`ace start --component=all`** — both of the above.

### Sample Configuration

```yaml
Expand Down
92 changes: 92 additions & 0 deletions internal/api/http/config_reload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package server

import (
"testing"

"github.com/pgedge/ace/pkg/config"
)

// TestResolversPickUpReloadedConfig verifies that the handler resolver
// functions read from the config snapshot passed to them, so that a
// SIGHUP-triggered config.Set takes effect for subsequent API requests
// while keeping each request internally consistent.
func TestResolversPickUpReloadedConfig(t *testing.T) {
original := config.Get()
t.Cleanup(func() {
if original != nil {
config.Set(original)
} else {
config.Set(&config.Config{})
}
})

// Set initial config.
config.Set(&config.Config{
TableDiff: config.DiffConfig{
DiffBlockSize: 5000,
ConcurrencyFactor: 0.25,
CompareUnitSize: 2000,
MaxDiffRows: 100,
},
Server: config.ServerConfig{
TaskStorePath: "/tmp/old-tasks.db",
},
})

s := &APIServer{}
cfg := config.Get()

// Verify resolvers return initial values (0 = "use config default").
if got := s.resolveBlockSize(cfg, 0); got != 5000 {
t.Errorf("resolveBlockSize: got %d, want 5000", got)
}
if got := s.resolveConcurrency(cfg, 0); got != 0.25 {
t.Errorf("resolveConcurrency: got %f, want 0.25", got)
}
if got := s.resolveCompareUnitSize(cfg, 0); got != 2000 {
t.Errorf("resolveCompareUnitSize: got %d, want 2000", got)
}
if got := s.resolveMaxDiffRows(cfg, 0); got != 100 {
t.Errorf("resolveMaxDiffRows: got %d, want 100", got)
}
if got := cfg.Server.TaskStorePath; got != "/tmp/old-tasks.db" {
t.Errorf("TaskStorePath: got %q, want /tmp/old-tasks.db", got)
}

// Simulate SIGHUP: swap in new config.
config.Set(&config.Config{
TableDiff: config.DiffConfig{
DiffBlockSize: 9999,
ConcurrencyFactor: 0.75,
CompareUnitSize: 4000,
MaxDiffRows: 500,
},
Server: config.ServerConfig{
TaskStorePath: "/tmp/new-tasks.db",
},
})

newCfg := config.Get()

// Verify resolvers now return the reloaded values.
if got := s.resolveBlockSize(newCfg, 0); got != 9999 {
t.Errorf("after reload resolveBlockSize: got %d, want 9999", got)
}
if got := s.resolveConcurrency(newCfg, 0); got != 0.75 {
t.Errorf("after reload resolveConcurrency: got %f, want 0.75", got)
}
if got := s.resolveCompareUnitSize(newCfg, 0); got != 4000 {
t.Errorf("after reload resolveCompareUnitSize: got %d, want 4000", got)
}
if got := s.resolveMaxDiffRows(newCfg, 0); got != 500 {
t.Errorf("after reload resolveMaxDiffRows: got %d, want 500", got)
}
if got := newCfg.Server.TaskStorePath; got != "/tmp/new-tasks.db" {
t.Errorf("after reload TaskStorePath: got %q, want /tmp/new-tasks.db", got)
}

// Verify old snapshot still returns old values (per-request consistency).
if got := s.resolveBlockSize(cfg, 0); got != 5000 {
t.Errorf("old snapshot resolveBlockSize: got %d, want 5000", got)
}
}
66 changes: 36 additions & 30 deletions internal/api/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,16 @@ func (s *APIServer) handleTableDiff(w http.ResponseWriter, r *http.Request) {
return
}

cfg := config.Get()

task := diff.NewTableDiffTask()
task.ClusterName = cluster
task.QualifiedTableName = tableName
task.DBName = strings.TrimSpace(req.DBName)
task.BlockSize = s.resolveBlockSize(req.BlockSize)
task.ConcurrencyFactor = s.resolveConcurrency(req.Concurrency)
task.CompareUnitSize = s.resolveCompareUnitSize(req.CompareUnitSize)
task.MaxDiffRows = s.resolveMaxDiffRows(req.MaxDiffRows)
task.BlockSize = s.resolveBlockSize(cfg, req.BlockSize)
task.ConcurrencyFactor = s.resolveConcurrency(cfg, req.Concurrency)
task.CompareUnitSize = s.resolveCompareUnitSize(cfg, req.CompareUnitSize)
task.MaxDiffRows = s.resolveMaxDiffRows(cfg, req.MaxDiffRows)
task.Output = "json"
task.Nodes = s.resolveNodes(req.Nodes)
task.TableFilter = strings.TrimSpace(req.TableFilter)
Expand All @@ -220,7 +222,7 @@ func (s *APIServer) handleTableDiff(w http.ResponseWriter, r *http.Request) {
task.QuietMode = req.Quiet
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = cfg.Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -249,41 +251,41 @@ func (s *APIServer) handleTableDiff(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusAccepted, resp)
}

func (s *APIServer) resolveBlockSize(requested int) int {
func (s *APIServer) resolveBlockSize(cfg *config.Config, requested int) int {
if requested > 0 {
return requested
}
if cfg := s.cfg; cfg != nil && cfg.TableDiff.DiffBlockSize > 0 {
if cfg != nil && cfg.TableDiff.DiffBlockSize > 0 {
return cfg.TableDiff.DiffBlockSize
}
return 100000
}

func (s *APIServer) resolveConcurrency(requested float64) float64 {
func (s *APIServer) resolveConcurrency(cfg *config.Config, requested float64) float64 {
if requested > 0 {
return requested
}
if cfg := s.cfg; cfg != nil && cfg.TableDiff.ConcurrencyFactor > 0 {
if cfg != nil && cfg.TableDiff.ConcurrencyFactor > 0 {
return cfg.TableDiff.ConcurrencyFactor
}
return 0.5
}

func (s *APIServer) resolveCompareUnitSize(requested int) int {
func (s *APIServer) resolveCompareUnitSize(cfg *config.Config, requested int) int {
if requested > 0 {
return requested
}
if cfg := s.cfg; cfg != nil && cfg.TableDiff.CompareUnitSize > 0 {
if cfg != nil && cfg.TableDiff.CompareUnitSize > 0 {
return cfg.TableDiff.CompareUnitSize
}
return 10000
}

func (s *APIServer) resolveMaxDiffRows(requested int64) int64 {
func (s *APIServer) resolveMaxDiffRows(cfg *config.Config, requested int64) int64 {
if requested > 0 {
return requested
}
if cfg := s.cfg; cfg != nil && cfg.TableDiff.MaxDiffRows > 0 {
if cfg != nil && cfg.TableDiff.MaxDiffRows > 0 {
return cfg.TableDiff.MaxDiffRows
}
return 0
Expand Down Expand Up @@ -360,7 +362,7 @@ func (s *APIServer) handleTableRerun(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := s.enqueueTask(task.TaskID, func(ctx context.Context) error {
task.Ctx = ctx
Expand Down Expand Up @@ -441,7 +443,7 @@ func (s *APIServer) handleTableRepair(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := task.ValidateAndPrepare(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -506,7 +508,7 @@ func (s *APIServer) handleSpockDiff(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -561,6 +563,8 @@ func (s *APIServer) handleSchemaDiff(w http.ResponseWriter, r *http.Request) {
return
}

cfg := config.Get()

task := diff.NewSchemaDiffTask()
task.ClusterName = cluster
task.SchemaName = schema
Expand All @@ -570,9 +574,9 @@ func (s *APIServer) handleSchemaDiff(w http.ResponseWriter, r *http.Request) {
task.SkipFile = req.SkipFile
task.Quiet = req.Quiet
task.DDLOnly = req.DDLOnly
task.BlockSize = s.resolveBlockSize(req.BlockSize)
task.ConcurrencyFactor = s.resolveConcurrency(req.Concurrency)
task.CompareUnitSize = s.resolveCompareUnitSize(req.CompareUnitSize)
task.BlockSize = s.resolveBlockSize(cfg, req.BlockSize)
task.ConcurrencyFactor = s.resolveConcurrency(cfg, req.Concurrency)
task.CompareUnitSize = s.resolveCompareUnitSize(cfg, req.CompareUnitSize)
task.Output = strings.TrimSpace(req.Output)
if task.Output == "" {
task.Output = "json"
Expand All @@ -581,7 +585,7 @@ func (s *APIServer) handleSchemaDiff(w http.ResponseWriter, r *http.Request) {
task.Ctx = r.Context()
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = cfg.Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -642,6 +646,8 @@ func (s *APIServer) handleRepsetDiff(w http.ResponseWriter, r *http.Request) {
return
}

cfg := config.Get()

task := diff.NewRepsetDiffTask()
task.ClusterName = cluster
task.RepsetName = repset
Expand All @@ -650,9 +656,9 @@ func (s *APIServer) handleRepsetDiff(w http.ResponseWriter, r *http.Request) {
task.SkipTables = req.SkipTables
task.SkipFile = req.SkipFile
task.Quiet = req.Quiet
task.BlockSize = s.resolveBlockSize(req.BlockSize)
task.ConcurrencyFactor = s.resolveConcurrency(req.Concurrency)
task.CompareUnitSize = s.resolveCompareUnitSize(req.CompareUnitSize)
task.BlockSize = s.resolveBlockSize(cfg, req.BlockSize)
task.ConcurrencyFactor = s.resolveConcurrency(cfg, req.Concurrency)
task.CompareUnitSize = s.resolveCompareUnitSize(cfg, req.CompareUnitSize)
task.Output = strings.TrimSpace(req.Output)
if task.Output == "" {
task.Output = "json"
Expand All @@ -663,7 +669,7 @@ func (s *APIServer) handleRepsetDiff(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = cfg.Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -741,7 +747,7 @@ func (s *APIServer) handleMtreeInit(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := s.enqueueTask(task.TaskID, func(ctx context.Context) error {
task.Ctx = ctx
Expand Down Expand Up @@ -799,7 +805,7 @@ func (s *APIServer) handleMtreeTeardown(w http.ResponseWriter, r *http.Request)
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := s.enqueueTask(task.TaskID, func(ctx context.Context) error {
task.Ctx = ctx
Expand Down Expand Up @@ -863,7 +869,7 @@ func (s *APIServer) handleMtreeTeardownTable(w http.ResponseWriter, r *http.Requ
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := s.enqueueTask(task.TaskID, func(ctx context.Context) error {
task.Ctx = ctx
Expand Down Expand Up @@ -934,7 +940,7 @@ func (s *APIServer) handleMtreeBuild(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -1009,7 +1015,7 @@ func (s *APIServer) handleMtreeUpdate(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down Expand Up @@ -1088,7 +1094,7 @@ func (s *APIServer) handleMtreeDiff(w http.ResponseWriter, r *http.Request) {
task.InvokeMethod = "api"
task.SkipDBUpdate = false
task.TaskStore = s.taskStore
task.TaskStorePath = s.cfg.Server.TaskStorePath
task.TaskStorePath = config.Get().Server.TaskStorePath

if err := task.Validate(); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
Expand Down
22 changes: 17 additions & 5 deletions internal/api/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net"
"net/http"
"sync"
"sync/atomic"
"time"

"github.com/pgedge/ace/pkg/config"
Expand All @@ -18,9 +19,8 @@ import (
)

type APIServer struct {
cfg *config.Config
server *http.Server
validator *certValidator
validator atomic.Pointer[certValidator]
taskStore *taskstore.Store
listenAddr string
jobCtx context.Context
Expand Down Expand Up @@ -62,12 +62,11 @@ func New(cfg *config.Config) (*APIServer, error) {
mux := http.NewServeMux()

apiServer := &APIServer{
cfg: cfg,
validator: validator,
taskStore: taskStore,
listenAddr: fmt.Sprintf("%s:%d", srvCfg.ListenAddress, srvCfg.ListenPort),
jobCtx: context.Background(),
}
apiServer.validator.Store(validator)

mux.Handle("/api/v1/table-diff", apiServer.authenticated(http.HandlerFunc(apiServer.handleTableDiff)))
mux.Handle("/api/v1/table-rerun", apiServer.authenticated(http.HandlerFunc(apiServer.handleTableRerun)))
Expand Down Expand Up @@ -175,14 +174,27 @@ type clientInfo struct {

type clientContextKey struct{}

// ReloadSecurityConfig rebuilds the certValidator from cfg and atomically
// swaps it in so that subsequent requests use the updated allowedCNs and CRL.
// Note: the TLS CA pool used for handshake verification requires a restart to
// change; only allowedCNs and CRL changes take effect without a restart.
func (s *APIServer) ReloadSecurityConfig(cfg *config.Config) error {
v, err := newCertValidator(cfg)
if err != nil {
return err
}
s.validator.Store(v)
return nil
}

func (s *APIServer) authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
writeError(w, http.StatusUnauthorized, "client certificate required")
return
}
clientCert := r.TLS.PeerCertificates[0]
role, err := s.validator.Validate(clientCert)
role, err := s.validator.Load().Validate(clientCert)
if err != nil {
logger.Warn("client certificate validation failed: %v", err)
writeError(w, http.StatusUnauthorized, err.Error())
Expand Down
Loading
Loading