From 6b0c62ef92d607ef210ea6465484f5ec07b4147a Mon Sep 17 00:00:00 2001 From: Ian Hodge Date: Fri, 1 May 2026 19:02:15 -0400 Subject: [PATCH 1/3] self review --- internal/metrics/metrics.go | 16 +--------------- internal/worker/worker.go | 13 +++---------- 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index bca056c..2e99931 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -70,7 +70,6 @@ type instruments struct { connected metric.Int64Gauge tasksActive metric.Int64UpDownCounter tasksMaxConcurrent metric.Int64Gauge - tasksClaimed metric.Int64Counter tasksRejected metric.Int64Counter tasksCompleted metric.Int64Counter taskDuration metric.Float64Histogram @@ -120,6 +119,7 @@ const ( TaskFailureReasonInvalidImage = "invalid_image" TaskFailureReasonActiveDeadline = "active_deadline" TaskFailureReasonCleanup = "cleanup" + ) var ( @@ -275,7 +275,6 @@ func shouldInitTraces() bool { func primeInstruments(ctx context.Context, set *instruments) { set.connected.Record(ctx, 0) set.tasksActive.Add(ctx, 0) - set.tasksClaimed.Add(ctx, 0) set.tasksRejected.Add(ctx, 0, metric.WithAttributes(attribute.String("reason", RejectReasonAtCapacity)), ) @@ -343,13 +342,6 @@ func buildInstruments(m metric.Meter) (*instruments, error) { if err != nil { return nil, err } - tasksClaimed, err := m.Int64Counter( - "oz_worker_tasks_claimed_total", - metric.WithDescription("Total tasks the worker has claimed since process start."), - ) - if err != nil { - return nil, err - } tasksRejected, err := m.Int64Counter( "oz_worker_tasks_rejected_total", metric.WithDescription("Total tasks the worker has rejected since process start."), @@ -397,7 +389,6 @@ func buildInstruments(m metric.Meter) (*instruments, error) { connected: connected, tasksActive: tasksActive, tasksMaxConcurrent: tasksMaxConcurrent, - tasksClaimed: tasksClaimed, tasksRejected: tasksRejected, tasksCompleted: tasksCompleted, taskDuration: taskDuration, @@ -437,11 +428,6 @@ func SetMaxConcurrent(n int) { current().tasksMaxConcurrent.Record(context.Background(), int64(n)) } -// RecordTaskClaim records a successful task claim (the worker has accepted a task). -func RecordTaskClaim() { - current().tasksClaimed.Add(context.Background(), 1) -} - // RecordTaskRejected records a task that the worker rejected, e.g. because // it was at the configured concurrency limit. The reason label is intended // to be a small bounded enum. diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 891f206..c2e64a9 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -358,7 +358,6 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) { if err := w.sendTaskClaimed(assignment.TaskID); err != nil { log.Errorf(w.ctx, "Failed to send task claimed message: %v", err) } - metrics.RecordTaskClaim() metrics.AddTaskEvent(taskCtx, "task.claimed") metrics.IncTasksActive() taskCtx, taskCancel := context.WithCancel(taskCtx) @@ -392,20 +391,14 @@ func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *Tas baseArgs := []string{ "agent", "run", - } - // Only share with the team when the task is team-owned. User-owned tasks - // (created with "Team visible" unchecked) use user-scoped API keys that - // cannot set up team-level session sharing. - if task.Owner.IsTeamOwned() { - baseArgs = append(baseArgs, "--share", "team:edit") - } - baseArgs = append(baseArgs, + "--share", + "team:edit", "--task-id", task.ID, "--sandboxed", "--server-root-url", w.config.ServerRootURL, - ) + } baseArgs = common.AugmentArgsForTask(task, baseArgs, common.TaskAugmentOptions{ IdleOnComplete: w.config.IdleOnComplete, }) From f536fc6d4a072fdbf1171da97eee1316b216bf30 Mon Sep 17 00:00:00 2001 From: Ian Hodge Date: Fri, 1 May 2026 18:36:29 -0400 Subject: [PATCH 2/3] Add worker cancellation controls Co-Authored-By: Oz --- internal/metrics/metrics.go | 16 +++++- internal/types/messages.go | 35 ++++++++---- internal/worker/direct.go | 6 +++ internal/worker/docker.go | 4 +- internal/worker/worker.go | 97 +++++++++++++++++++++++++++++++--- internal/worker/worker_test.go | 74 ++++++++++++++++++++++++-- 6 files changed, 209 insertions(+), 23 deletions(-) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 2e99931..bca056c 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -70,6 +70,7 @@ type instruments struct { connected metric.Int64Gauge tasksActive metric.Int64UpDownCounter tasksMaxConcurrent metric.Int64Gauge + tasksClaimed metric.Int64Counter tasksRejected metric.Int64Counter tasksCompleted metric.Int64Counter taskDuration metric.Float64Histogram @@ -119,7 +120,6 @@ const ( TaskFailureReasonInvalidImage = "invalid_image" TaskFailureReasonActiveDeadline = "active_deadline" TaskFailureReasonCleanup = "cleanup" - ) var ( @@ -275,6 +275,7 @@ func shouldInitTraces() bool { func primeInstruments(ctx context.Context, set *instruments) { set.connected.Record(ctx, 0) set.tasksActive.Add(ctx, 0) + set.tasksClaimed.Add(ctx, 0) set.tasksRejected.Add(ctx, 0, metric.WithAttributes(attribute.String("reason", RejectReasonAtCapacity)), ) @@ -342,6 +343,13 @@ func buildInstruments(m metric.Meter) (*instruments, error) { if err != nil { return nil, err } + tasksClaimed, err := m.Int64Counter( + "oz_worker_tasks_claimed_total", + metric.WithDescription("Total tasks the worker has claimed since process start."), + ) + if err != nil { + return nil, err + } tasksRejected, err := m.Int64Counter( "oz_worker_tasks_rejected_total", metric.WithDescription("Total tasks the worker has rejected since process start."), @@ -389,6 +397,7 @@ func buildInstruments(m metric.Meter) (*instruments, error) { connected: connected, tasksActive: tasksActive, tasksMaxConcurrent: tasksMaxConcurrent, + tasksClaimed: tasksClaimed, tasksRejected: tasksRejected, tasksCompleted: tasksCompleted, taskDuration: taskDuration, @@ -428,6 +437,11 @@ func SetMaxConcurrent(n int) { current().tasksMaxConcurrent.Record(context.Background(), int64(n)) } +// RecordTaskClaim records a successful task claim (the worker has accepted a task). +func RecordTaskClaim() { + current().tasksClaimed.Add(context.Background(), 1) +} + // RecordTaskRejected records a task that the worker rejected, e.g. because // it was at the configured concurrency limit. The reason label is intended // to be a small bounded enum. diff --git a/internal/types/messages.go b/internal/types/messages.go index c72ee68..34c67e7 100644 --- a/internal/types/messages.go +++ b/internal/types/messages.go @@ -9,12 +9,13 @@ import ( type MessageType string const ( - MessageTypeTaskAssignment MessageType = "task_assignment" - MessageTypeTaskClaimed MessageType = "task_claimed" - MessageTypeTaskCompleted MessageType = "task_completed" - MessageTypeTaskFailed MessageType = "task_failed" - MessageTypeTaskRejected MessageType = "task_rejected" - MessageTypeHeartbeat MessageType = "heartbeat" + MessageTypeTaskAssignment MessageType = "task_assignment" + MessageTypeTaskClaimed MessageType = "task_claimed" + MessageTypeTaskCompleted MessageType = "task_completed" + MessageTypeTaskFailed MessageType = "task_failed" + MessageTypeTaskRejected MessageType = "task_rejected" + MessageTypeTaskCancellation MessageType = "task_cancellation" + MessageTypeHeartbeat MessageType = "heartbeat" ) // WebSocketMessage is the base structure for all WebSocket messages @@ -51,14 +52,16 @@ type TaskClaimedMessage struct { // TaskCompletedMessage tells the server to end the active run execution after a successful agent process exit. type TaskCompletedMessage struct { - TaskID string `json:"task_id"` - Message string `json:"message"` + TaskID string `json:"task_id"` + Message string `json:"message"` + TaskState *TaskState `json:"task_state,omitempty"` } // TaskFailedMessage is sent from worker to server if task launch fails type TaskFailedMessage struct { - TaskID string `json:"task_id"` - Message string `json:"message"` + TaskID string `json:"task_id"` + Message string `json:"message"` + TaskState *TaskState `json:"task_state,omitempty"` } // TaskRejectedMessage is sent from worker to server when the worker cannot accept the task @@ -68,6 +71,18 @@ type TaskRejectedMessage struct { Reason string `json:"reason"` } +// TaskCancellationMessage is sent from server to worker to cancel an active task. +type TaskCancellationMessage struct { + TaskID string `json:"task_id"` +} + +// TaskState is the serialized terminal task state accepted by warp-server. +type TaskState string + +const ( + TaskStateCancelled TaskState = "CANCELLED" +) + type TaskDefinition struct { Prompt string `json:"prompt"` } diff --git a/internal/worker/direct.go b/internal/worker/direct.go index 1848d77..ad3a80d 100644 --- a/internal/worker/direct.go +++ b/internal/worker/direct.go @@ -154,6 +154,9 @@ func (b *DirectBackend) ExecuteTask(ctx context.Context, params *TaskParams) err log.Infof(ctx, "Running setup command: %s", b.config.SetupCommand) if err := b.runCommand(ctx, b.config.SetupCommand, workspaceDir, setupEnv); err != nil { + if ctx.Err() != nil { + return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonTaskCancelled, ctx.Err()) + } return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonSetupCommand, fmt.Errorf("setup command failed: %w", err)) } } @@ -183,6 +186,9 @@ func (b *DirectBackend) ExecuteTask(ctx context.Context, params *TaskParams) err log.Debugf(ctx, "Command: %s %s", b.ozPath, strings.Join(params.BaseArgs, " ")) if err := cmd.Run(); err != nil { + if ctx.Err() != nil { + return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonTaskCancelled, ctx.Err()) + } return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonAgentInvocation, fmt.Errorf("oz agent exited with error: %w", err)) } diff --git a/internal/worker/docker.go b/internal/worker/docker.go index 935dac5..a2ebd3b 100644 --- a/internal/worker/docker.go +++ b/internal/worker/docker.go @@ -145,7 +145,9 @@ func (b *DockerBackend) ExecuteTask(ctx context.Context, params *TaskParams) err defer func() { if containerID != "" && !b.config.NoCleanup { - if removeErr := dockerClient.ContainerRemove(ctx, containerID, container.RemoveOptions{Force: true}); removeErr != nil { + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), BackendShutdownTimeout) + defer cleanupCancel() + if removeErr := dockerClient.ContainerRemove(cleanupCtx, containerID, container.RemoveOptions{Force: true}); removeErr != nil { log.Debugf(ctx, "Container %s already removed or removal failed: %v", containerID, removeErr) } } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index c2e64a9..2ec5659 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -3,6 +3,7 @@ package worker import ( "context" "encoding/json" + "errors" "fmt" "net/url" "sync" @@ -59,12 +60,17 @@ type Worker struct { reconnectDelay time.Duration lastHeartbeat time.Time sendChan chan []byte - activeTasks map[string]context.CancelFunc + activeTasks map[string]activeTask tasksMutex sync.Mutex backend Backend taskSemaphore *semaphore.Weighted // nil when unlimited } +type activeTask struct { + ctx context.Context + cancel context.CancelFunc +} + func New(ctx context.Context, config Config) (*Worker, error) { workerCtx, cancel := context.WithCancel(ctx) @@ -109,7 +115,7 @@ func New(ctx context.Context, config Config) (*Worker, error) { cancel: cancel, reconnectDelay: InitialReconnectDelay, sendChan: make(chan []byte, 256), - activeTasks: make(map[string]context.CancelFunc), + activeTasks: make(map[string]activeTask), backend: backend, taskSemaphore: taskSemaphore, }, nil @@ -322,11 +328,36 @@ func (w *Worker) handleMessage(message []byte) { } w.handleTaskAssignment(&assignment) + case types.MessageTypeTaskCancellation: + var cancellation types.TaskCancellationMessage + if err := json.Unmarshal(msg.Data, &cancellation); err != nil { + log.Errorf(w.ctx, "Failed to unmarshal task cancellation: %v", err) + return + } + w.handleTaskCancellation(&cancellation) + default: log.Warnf(w.ctx, "Unknown message type: %s", msg.Type) } } +func (w *Worker) handleTaskCancellation(cancellation *types.TaskCancellationMessage) { + w.tasksMutex.Lock() + task, ok := w.activeTasks[cancellation.TaskID] + w.tasksMutex.Unlock() + if !ok { + log.Warnf(w.ctx, "Received cancellation for inactive task: taskID=%s", cancellation.TaskID) + return + } + + log.Infof(w.ctx, "Cancelling task from server request: taskID=%s", cancellation.TaskID) + metrics.AddTaskEvent(task.ctx, "task.cancellation_requested", + attribute.String("source", "server"), + attribute.String("task.id", cancellation.TaskID), + ) + task.cancel() +} + func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) { receivedAt := time.Now() log.Infof(w.ctx, "Received task assignment: taskID=%s, title=%s", assignment.TaskID, assignment.Task.Title) @@ -363,7 +394,10 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) { taskCtx, taskCancel := context.WithCancel(taskCtx) w.tasksMutex.Lock() - w.activeTasks[assignment.TaskID] = taskCancel + w.activeTasks[assignment.TaskID] = activeTask{ + ctx: taskCtx, + cancel: taskCancel, + } w.tasksMutex.Unlock() go w.executeTask(taskCtx, span, assignment, receivedAt) } @@ -391,14 +425,20 @@ func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *Tas baseArgs := []string{ "agent", "run", - "--share", - "team:edit", + } + // Only share with the team when the task is team-owned. User-owned tasks + // (created with "Team visible" unchecked) use user-scoped API keys that + // cannot set up team-level session sharing. + if task.Owner.IsTeamOwned() { + baseArgs = append(baseArgs, "--share", "team:edit") + } + baseArgs = append(baseArgs, "--task-id", task.ID, "--sandboxed", "--server-root-url", w.config.ServerRootURL, - } + ) baseArgs = common.AugmentArgsForTask(task, baseArgs, common.TaskAugmentOptions{ IdleOnComplete: w.config.IdleOnComplete, }) @@ -482,6 +522,17 @@ func (w *Worker) executeTask(ctx context.Context, span trace.Span, assignment *t err := w.backend.ExecuteTask(ctx, params) if err != nil { + if errors.Is(err, context.Canceled) { + result = metrics.TaskResultCancelled + metrics.AddTaskEvent(ctx, "task.cancelled") + span.SetStatus(codes.Ok, "task cancelled") + log.Infof(ctx, "Task execution cancelled: taskID=%s", taskID) + if statusErr := w.sendTaskCancelled(taskID, "Task cancelled."); statusErr != nil { + log.Errorf(ctx, "Failed to send task cancelled message: %v", statusErr) + } + return + } + result = metrics.TaskResultFailed phase, reason := taskFailureLabels(err) metrics.RecordTaskFailure(phase, reason) @@ -531,6 +582,32 @@ func (w *Worker) sendTaskClaimed(taskID string) error { return w.sendMessage(msgBytes) } +func (w *Worker) sendTaskCancelled(taskID, message string) error { + taskState := types.TaskStateCancelled + completedMsg := types.TaskCompletedMessage{ + TaskID: taskID, + Message: message, + TaskState: &taskState, + } + + data, err := json.Marshal(completedMsg) + if err != nil { + return fmt.Errorf("failed to marshal task cancelled message: %w", err) + } + + msg := types.WebSocketMessage{ + Type: types.MessageTypeTaskCompleted, + Data: data, + } + + msgBytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal websocket message: %w", err) + } + + return w.sendMessage(msgBytes) +} + func (w *Worker) sendTaskRejected(taskID, reason string) error { rejectedMsg := types.TaskRejectedMessage{ TaskID: taskID, @@ -621,9 +698,13 @@ func (w *Worker) Shutdown() { activeTaskCount := len(w.activeTasks) if activeTaskCount > 0 { log.Infof(w.ctx, "Cancelling %d active tasks", activeTaskCount) - for taskID, cancel := range w.activeTasks { + for taskID, task := range w.activeTasks { log.Debugf(w.ctx, "Cancelling task: %s", taskID) - cancel() + metrics.AddTaskEvent(task.ctx, "task.cancellation_requested", + attribute.String("source", "signal"), + attribute.String("task.id", taskID), + ) + task.cancel() } } w.tasksMutex.Unlock() diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index 9faa96c..9f33ceb 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -84,12 +84,80 @@ func TestTaskFailureLabels(t *testing.T) { } } +func TestExecuteTaskReportsTaskCancelledOnContextCancellation(t *testing.T) { + w := &Worker{ + ctx: context.Background(), + config: Config{}, + sendChan: make(chan []byte, 1), + activeTasks: map[string]activeTask{"task-1": {cancel: func() {}}}, + backend: &recordingBackend{err: context.Canceled}, + } + + w.executeTask(context.Background(), trace.SpanFromContext(context.Background()), &types.TaskAssignmentMessage{ + TaskID: "task-1", + Task: &types.Task{ID: "task-1", Title: "test task"}, + }, time.Now()) + + msg := readWebSocketMessage(t, w.sendChan) + if msg.Type != types.MessageTypeTaskCompleted { + t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskCompleted) + } + + var completed types.TaskCompletedMessage + if err := json.Unmarshal(msg.Data, &completed); err != nil { + t.Fatalf("failed to unmarshal task completed message: %v", err) + } + if completed.TaskID != "task-1" { + t.Errorf("task ID = %q, want %q", completed.TaskID, "task-1") + } + if completed.TaskState == nil || *completed.TaskState != types.TaskStateCancelled { + t.Fatalf("task state = %v, want %q", completed.TaskState, types.TaskStateCancelled) + } + if _, ok := w.activeTasks["task-1"]; ok { + t.Fatal("task should be removed from active tasks") + } +} + +func TestHandleMessageCancelsActiveTask(t *testing.T) { + taskCtx, taskCancel := context.WithCancel(context.Background()) + defer taskCancel() + + w := &Worker{ + ctx: context.Background(), + sendChan: make(chan []byte, 1), + activeTasks: map[string]activeTask{ + "task-1": { + ctx: taskCtx, + cancel: taskCancel, + }, + }, + } + + data, err := json.Marshal(types.TaskCancellationMessage{TaskID: "task-1"}) + if err != nil { + t.Fatalf("failed to marshal cancellation message: %v", err) + } + message, err := json.Marshal(types.WebSocketMessage{ + Type: types.MessageTypeTaskCancellation, + Data: data, + }) + if err != nil { + t.Fatalf("failed to marshal websocket message: %v", err) + } + + w.handleMessage(message) + + if taskCtx.Err() != context.Canceled { + t.Fatalf("task context error = %v, want %v", taskCtx.Err(), context.Canceled) + } +} + func TestExecuteTaskReportsTaskCompletedOnSuccess(t *testing.T) { w := &Worker{ ctx: context.Background(), config: Config{}, sendChan: make(chan []byte, 1), - activeTasks: map[string]context.CancelFunc{"task-1": func() {}}, + activeTasks: map[string]activeTask{"task-1": {cancel: func() {}}}, backend: &recordingBackend{}, } @@ -123,7 +191,7 @@ func TestExecuteTaskReportsTaskFailedOnBackendError(t *testing.T) { ctx: context.Background(), config: Config{}, sendChan: make(chan []byte, 1), - activeTasks: map[string]context.CancelFunc{"task-1": func() {}}, + activeTasks: map[string]activeTask{"task-1": {cancel: func() {}}}, backend: &recordingBackend{err: errors.New("boom")}, } @@ -363,7 +431,7 @@ func TestWorkerShutdownUsesFreshContextForBackendCleanup(t *testing.T) { w := &Worker{ ctx: workerCtx, cancel: cancel, - activeTasks: make(map[string]context.CancelFunc), + activeTasks: make(map[string]activeTask), backend: backend, } From cf855b21907c5f66884472fbd61f3e3756893c19 Mon Sep 17 00:00:00 2001 From: Oz Date: Thu, 7 May 2026 15:46:32 +0000 Subject: [PATCH 3/3] Restore metrics.RecordTaskClaim() accidentally removed during rebase Co-Authored-By: Oz --- internal/worker/worker.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 2ec5659..ff74946 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -389,6 +389,7 @@ func (w *Worker) handleTaskAssignment(assignment *types.TaskAssignmentMessage) { if err := w.sendTaskClaimed(assignment.TaskID); err != nil { log.Errorf(w.ctx, "Failed to send task claimed message: %v", err) } + metrics.RecordTaskClaim() metrics.AddTaskEvent(taskCtx, "task.claimed") metrics.IncTasksActive() taskCtx, taskCancel := context.WithCancel(taskCtx)