Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type App struct {
wg sync.WaitGroup
log *slog.Logger
statusRegistry *StatusRegistry
authToken string // if set, Bearer token required for protected endpoints
}

func NewApp(redisAddr, gpuType string, log *slog.Logger) *App {
Expand All @@ -36,6 +37,8 @@ func NewApp(redisAddr, gpuType string, log *slog.Logger) *App {
consumerID := fmt.Sprintf("worker_%d", os.Getpid())
supervisor := NewSupervisor(redisAddr, consumerID, gpuType, log)

authToken := os.Getenv("AUTH_TOKEN")

mux := http.NewServeMux()
a := &App{
redisClient: client,
Expand All @@ -44,12 +47,14 @@ func NewApp(redisAddr, gpuType string, log *slog.Logger) *App {
httpServer: &http.Server{Addr: ":3000", Handler: mux},
log: log,
statusRegistry: statusRegistry,
authToken: authToken,
}

mux.HandleFunc("/auth/login", a.login)
mux.HandleFunc("/auth/refresh", a.refresh)
mux.HandleFunc("/jobs", a.handleJobs)
mux.HandleFunc("/jobs/status", a.getJobStatus)
mux.HandleFunc("/jobs/logs/", a.requireAuth(a.getJobLogs))
mux.HandleFunc("/supervisors/status", a.getSupervisorStatus)
mux.HandleFunc("/supervisors/status/", a.getSupervisorStatusByID)
mux.HandleFunc("/supervisors", a.getAllSupervisors)
Expand Down Expand Up @@ -297,6 +302,62 @@ func (a *App) getSupervisorStatusByID(w http.ResponseWriter, r *http.Request) {
}
}

// requireAuth wraps a handler and enforces Bearer token authentication.
// If AUTH_TOKEN is not set, returns 503 (logs feature not configured).
// If Authorization header is missing or invalid, returns 401.
func (a *App) requireAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if a.authToken == "" {
a.log.Warn("job logs requested but AUTH_TOKEN not configured")
http.Error(w, "Logs require authentication to be configured", http.StatusServiceUnavailable)
return
}
auth := r.Header.Get("Authorization")
if auth == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
const prefix = "Bearer "
if !strings.HasPrefix(auth, prefix) || strings.TrimSpace(strings.TrimPrefix(auth, prefix)) != a.authToken {
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
next(w, r)
}
}

func (a *App) getJobLogs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

path := strings.TrimPrefix(r.URL.Path, "/jobs/logs/")
jobID := strings.Trim(path, "/")
if jobID == "" {
jobID = r.URL.Query().Get("id")
}
if jobID == "" {
http.Error(w, "Job ID is required", http.StatusBadRequest)
return
}

a.log.Info("getJobLogs handler accessed", "job_id", jobID, "remote_address", r.RemoteAddr)

logs, err := a.supervisor.GetContainerLogsForJob(jobID)
if err != nil {
a.log.Error("failed to get job logs", "job_id", jobID, "error", err)
http.Error(w, fmt.Sprintf("Logs not available for job: %s (container must be running)", jobID), http.StatusNotFound)
return
}

w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(logs); err != nil {
a.log.Error("failed to write job logs response", "job_id", jobID, "error", err)
}
}

func (a *App) getAllSupervisors(w http.ResponseWriter, r *http.Request) {
activeOnly := r.URL.Query().Get("active") == "true"

Expand Down
216 changes: 216 additions & 0 deletions src/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package main

import (
"context"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"mist/docker"

"github.com/docker/docker/client"
"github.com/redis/go-redis/v9"
)

func TestGetJobLogs_RequiresAuth(t *testing.T) {
redisAddr := "localhost:6379"
client := redis.NewClient(&redis.Options{Addr: redisAddr})
if err := client.Ping(context.Background()).Err(); err != nil {
t.Skipf("Redis not running, skipping: %v", err)
}
client.FlushDB(context.Background())

log := slog.New(slog.NewJSONHandler(io.Discard, nil))
app := NewApp(redisAddr, "AMD", log)
app.authToken = "secret-token"

req := httptest.NewRequest(http.MethodGet, "/jobs/logs/job_123", nil)
rr := httptest.NewRecorder()

app.requireAuth(app.getJobLogs)(rr, req)

if rr.Code != http.StatusUnauthorized {
t.Errorf("expected 401 without auth, got %d", rr.Code)
}
}

func TestGetJobLogs_ValidAuth(t *testing.T) {
redisAddr := "localhost:6379"
redisClient := redis.NewClient(&redis.Options{Addr: redisAddr})
if err := redisClient.Ping(context.Background()).Err(); err != nil {
t.Skipf("Redis not running, skipping: %v", err)
}
redisClient.FlushDB(context.Background())

dockerCli, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
t.Skipf("Docker not available: %v", err)
}
defer dockerCli.Close()
if _, err := dockerCli.Ping(context.Background()); err != nil {
t.Skipf("Docker daemon not reachable: %v", err)
}
_, _, err = dockerCli.ImageInspectWithRaw(context.Background(), "pytorch-cpu")
if err != nil {
t.Skipf("pytorch-cpu image not found: %v", err)
}

// Start a running container named with job ID (simulating supervisor)
mgr := docker.NewDockerMgr(dockerCli, 10, 100)
volName := "test_logs_vol"
_, _ = mgr.CreateVolume(volName)
defer mgr.RemoveVolume(volName, true)

containerID, err := mgr.RunContainer("pytorch-cpu", "runc", volName, "job_123")
if err != nil {
t.Fatalf("failed to run container: %v", err)
}
defer func() {
_ = mgr.StopContainer(containerID)
_ = mgr.RemoveContainer(containerID)
}()

time.Sleep(500 * time.Millisecond) // let container produce output

log := slog.New(slog.NewJSONHandler(io.Discard, nil))
app := NewApp(redisAddr, "AMD", log)
app.authToken = "secret-token"

req := httptest.NewRequest(http.MethodGet, "/jobs/logs/job_123", nil)
req.Header.Set("Authorization", "Bearer secret-token")
rr := httptest.NewRecorder()

app.requireAuth(app.getJobLogs)(rr, req)

if rr.Code != http.StatusOK {
t.Errorf("expected 200 with valid auth, got %d: %s", rr.Code, rr.Body.String())
}
if !strings.Contains(rr.Body.String(), "hello-from-container") {
t.Errorf("expected logs to contain 'hello-from-container', got %q", rr.Body.String())
}
}

func TestGetJobLogs_NotFound(t *testing.T) {
redisAddr := "localhost:6379"
client := redis.NewClient(&redis.Options{Addr: redisAddr})
if err := client.Ping(context.Background()).Err(); err != nil {
t.Skipf("Redis not running, skipping: %v", err)
}
client.FlushDB(context.Background())

log := slog.New(slog.NewJSONHandler(io.Discard, nil))
app := NewApp(redisAddr, "AMD", log)
app.authToken = "secret-token"

req := httptest.NewRequest(http.MethodGet, "/jobs/logs/nonexistent_job", nil)
req.Header.Set("Authorization", "Bearer secret-token")
rr := httptest.NewRecorder()

app.requireAuth(app.getJobLogs)(rr, req)

if rr.Code != http.StatusNotFound {
t.Errorf("expected 404 for missing logs, got %d", rr.Code)
}
}

func TestGetJobLogs_NoAuthConfigured(t *testing.T) {
redisAddr := "localhost:6379"
client := redis.NewClient(&redis.Options{Addr: redisAddr})
if err := client.Ping(context.Background()).Err(); err != nil {
t.Skipf("Redis not running, skipping: %v", err)
}

log := slog.New(slog.NewJSONHandler(io.Discard, nil))
app := NewApp(redisAddr, "AMD", log)
app.authToken = "" // no auth configured

req := httptest.NewRequest(http.MethodGet, "/jobs/logs/job_123", nil)
rr := httptest.NewRecorder()

app.requireAuth(app.getJobLogs)(rr, req)

if rr.Code != http.StatusServiceUnavailable {
t.Errorf("expected 503 when auth not configured, got %d", rr.Code)
}
}

func TestGetJobLogs_InvalidToken(t *testing.T) {
redisAddr := "localhost:6379"
client := redis.NewClient(&redis.Options{Addr: redisAddr})
if err := client.Ping(context.Background()).Err(); err != nil {
t.Skipf("Redis not running, skipping: %v", err)
}

log := slog.New(slog.NewJSONHandler(io.Discard, nil))
app := NewApp(redisAddr, "AMD", log)
app.authToken = "correct-token"

req := httptest.NewRequest(http.MethodGet, "/jobs/logs/job_123", nil)
req.Header.Set("Authorization", "Bearer wrong-token")
rr := httptest.NewRecorder()

app.requireAuth(app.getJobLogs)(rr, req)

if rr.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for invalid token, got %d", rr.Code)
}
}

func TestGetJobLogs_QueryParam(t *testing.T) {
redisAddr := "localhost:6379"
redisClient := redis.NewClient(&redis.Options{Addr: redisAddr})
if err := redisClient.Ping(context.Background()).Err(); err != nil {
t.Skipf("Redis not running, skipping: %v", err)
}
redisClient.FlushDB(context.Background())

dockerCli, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
t.Skipf("Docker not available: %v", err)
}
defer dockerCli.Close()
if _, err := dockerCli.Ping(context.Background()); err != nil {
t.Skipf("Docker daemon not reachable: %v", err)
}
_, _, err = dockerCli.ImageInspectWithRaw(context.Background(), "pytorch-cpu")
if err != nil {
t.Skipf("pytorch-cpu image not found: %v", err)
}

mgr := docker.NewDockerMgr(dockerCli, 10, 100)
volName := "test_logs_query_vol"
_, _ = mgr.CreateVolume(volName)
defer mgr.RemoveVolume(volName, true)

containerID, err := mgr.RunContainer("pytorch-cpu", "runc", volName, "job_456")
if err != nil {
t.Fatalf("failed to run container: %v", err)
}
defer func() {
_ = mgr.StopContainer(containerID)
_ = mgr.RemoveContainer(containerID)
}()

time.Sleep(500 * time.Millisecond)

log := slog.New(slog.NewJSONHandler(io.Discard, nil))
app := NewApp(redisAddr, "AMD", log)
app.authToken = "token"

req := httptest.NewRequest(http.MethodGet, "/jobs/logs/?id=job_456", nil)
req.Header.Set("Authorization", "Bearer token")
rr := httptest.NewRecorder()

app.requireAuth(app.getJobLogs)(rr, req)

if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "hello-from-container") {
t.Errorf("expected logs to contain 'hello-from-container', got %q", rr.Body.String())
}
}
2 changes: 1 addition & 1 deletion src/container_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestRunContainerCPUIntegration(t *testing.T) {
}
defer mgr.RemoveVolume(volName, true)

containerID, err := mgr.RunContainer("pytorch-cpu", "runc", volName)
containerID, err := mgr.RunContainer("pytorch-cpu", "runc", volName, "test_run_cpu_integration")
if err != nil {
t.Fatalf("run container: %v", err)
}
Expand Down
Loading