diff --git a/server/internal/orchestrator/swarm/orchestrator.go b/server/internal/orchestrator/swarm/orchestrator.go index 940b9c97..012670f1 100644 --- a/server/internal/orchestrator/swarm/orchestrator.go +++ b/server/internal/orchestrator/swarm/orchestrator.go @@ -578,41 +578,42 @@ func (o *Orchestrator) generateMCPInstanceResources(spec *database.ServiceInstan orchestratorResources = append(orchestratorResources, serviceInstanceSpec, serviceInstance) // Append per-node ServiceUserRole resources for each additional database node. - // The canonical resources (above) cover the first node; nodes [1:] each get + // The canonical resources (above) cover spec.NodeName; all other nodes get // their own RO and RW role that sources credentials from the canonical. - if len(spec.DatabaseNodes) > 1 { - for _, nodeInst := range spec.DatabaseNodes[1:] { - perNodeRWID := ServiceUserRolePerNodeIdentifier(spec.ServiceSpec.ServiceID, ServiceUserRoleRW, nodeInst.NodeName) + for _, nodeInst := range spec.DatabaseNodes { + if nodeInst.NodeName == spec.NodeName { + continue + } + perNodeRWID := ServiceUserRolePerNodeIdentifier(spec.ServiceSpec.ServiceID, ServiceUserRoleRW, nodeInst.NodeName) + orchestratorResources = append(orchestratorResources, + &ServiceUserRole{ + ServiceID: spec.ServiceSpec.ServiceID, + DatabaseID: spec.DatabaseID, + DatabaseName: spec.DatabaseName, + NodeName: nodeInst.NodeName, + Mode: ServiceUserRoleRO, + CredentialSource: &canonicalROID, + }, + &ServiceUserRole{ + ServiceID: spec.ServiceSpec.ServiceID, + DatabaseID: spec.DatabaseID, + DatabaseName: spec.DatabaseName, + NodeName: nodeInst.NodeName, + Mode: ServiceUserRoleRW, + CredentialSource: &canonicalRWID, + }, + ) + if spec.ServiceSpec.ServiceType == "postgrest" { orchestratorResources = append(orchestratorResources, - &ServiceUserRole{ - ServiceID: spec.ServiceSpec.ServiceID, - DatabaseID: spec.DatabaseID, - DatabaseName: spec.DatabaseName, - NodeName: nodeInst.NodeName, - Mode: ServiceUserRoleRO, - CredentialSource: &canonicalROID, - }, - &ServiceUserRole{ - ServiceID: spec.ServiceSpec.ServiceID, - DatabaseID: spec.DatabaseID, - DatabaseName: spec.DatabaseName, - NodeName: nodeInst.NodeName, - Mode: ServiceUserRoleRW, - CredentialSource: &canonicalRWID, + &PostgRESTAuthenticatorResource{ + ServiceID: spec.ServiceSpec.ServiceID, + DatabaseID: spec.DatabaseID, + DatabaseName: spec.DatabaseName, + NodeName: nodeInst.NodeName, + DBAnonRole: parsedPostgRESTConfig.DBAnonRole, + UserRoleID: perNodeRWID, }, ) - if spec.ServiceSpec.ServiceType == "postgrest" { - orchestratorResources = append(orchestratorResources, - &PostgRESTAuthenticatorResource{ - ServiceID: spec.ServiceSpec.ServiceID, - DatabaseID: spec.DatabaseID, - DatabaseName: spec.DatabaseName, - NodeName: nodeInst.NodeName, - DBAnonRole: parsedPostgRESTConfig.DBAnonRole, - UserRoleID: perNodeRWID, - }, - ) - } } } @@ -647,6 +648,12 @@ func (o *Orchestrator) buildServiceInstanceResources(spec *database.ServiceInsta // instance. RAG only requires read access, so a single ServiceUserRoleRO is // created per database node using the same canonical+per-node pattern as MCP. func (o *Orchestrator) generateRAGInstanceResources(spec *database.ServiceInstanceSpec) (*database.ServiceInstanceResources, error) { + // Parse the RAG service config to extract API keys. + ragConfig, errs := database.ParseRAGServiceConfig(spec.ServiceSpec.Config, false) + if len(errs) > 0 { + return nil, fmt.Errorf("failed to parse RAG service config: %w", errors.Join(errs...)) + } + canonicalROID := ServiceUserRoleIdentifier(spec.ServiceSpec.ServiceID, ServiceUserRoleRO) // Canonical read-only role — runs on the node co-located with this instance. @@ -676,6 +683,24 @@ func (o *Orchestrator) generateRAGInstanceResources(spec *database.ServiceInstan }) } + // Service data directory resource (host-side bind mount directory). + dataDirID := spec.ServiceInstanceID + "-data" + dataDir := &filesystem.DirResource{ + ID: dataDirID, + HostID: spec.HostID, + Path: filepath.Join(o.cfg.DataDir, "services", spec.ServiceInstanceID), + } + + // API key files resource — writes provider keys into a "keys" subdirectory. + keysResource := &RAGServiceKeysResource{ + ServiceInstanceID: spec.ServiceInstanceID, + HostID: spec.HostID, + ParentID: dataDirID, + Keys: extractRAGAPIKeys(ragConfig), + } + + orchestratorResources = append(orchestratorResources, dataDir, keysResource) + return o.buildServiceInstanceResources(spec, orchestratorResources) } diff --git a/server/internal/orchestrator/swarm/rag_service_keys_resource.go b/server/internal/orchestrator/swarm/rag_service_keys_resource.go new file mode 100644 index 00000000..15108323 --- /dev/null +++ b/server/internal/orchestrator/swarm/rag_service_keys_resource.go @@ -0,0 +1,239 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "strings" + + "github.com/samber/do" + "github.com/spf13/afero" + + "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/filesystem" + "github.com/pgEdge/control-plane/server/internal/resource" +) + +var _ resource.Resource = (*RAGServiceKeysResource)(nil) + +const ResourceTypeRAGServiceKeys resource.Type = "swarm.rag_service_keys" + +func RAGServiceKeysResourceIdentifier(serviceInstanceID string) resource.Identifier { + return resource.Identifier{ + ID: serviceInstanceID, + Type: ResourceTypeRAGServiceKeys, + } +} + +// RAGServiceKeysResource manages provider API key files on the host filesystem. +// Keys are written to a "keys" subdirectory under the service data directory +// and bind-mounted read-only into the RAG container. +// The directory and all files are removed when the service is deleted. +type RAGServiceKeysResource struct { + ServiceInstanceID string `json:"service_instance_id"` + HostID string `json:"host_id"` + ParentID string `json:"parent_id"` // DirResource ID for the service data directory + Keys map[string]string `json:"keys"` // filename → key value +} + +func (r *RAGServiceKeysResource) ResourceVersion() string { + return "1" +} + +func (r *RAGServiceKeysResource) DiffIgnore() []string { + return nil +} + +func (r *RAGServiceKeysResource) Identifier() resource.Identifier { + return RAGServiceKeysResourceIdentifier(r.ServiceInstanceID) +} + +func (r *RAGServiceKeysResource) Executor() resource.Executor { + return resource.HostExecutor(r.HostID) +} + +func (r *RAGServiceKeysResource) Dependencies() []resource.Identifier { + return []resource.Identifier{ + filesystem.DirResourceIdentifier(r.ParentID), + } +} + +func (r *RAGServiceKeysResource) TypeDependencies() []resource.Type { + return nil +} + +func (r *RAGServiceKeysResource) keysDir(rc *resource.Context) (string, error) { + parentPath, err := filesystem.DirResourceFullPath(rc, r.ParentID) + if err != nil { + return "", fmt.Errorf("failed to get service data dir path: %w", err) + } + return filepath.Join(parentPath, "keys"), nil +} + +func (r *RAGServiceKeysResource) Refresh(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + keysDir, err := r.keysDir(rc) + if err != nil { + return err + } + + info, err := fs.Stat(keysDir) + if errors.Is(err, afero.ErrFileNotFound) { + return resource.ErrNotFound + } + if err != nil { + return fmt.Errorf("failed to stat keys directory: %w", err) + } + if !info.IsDir() { + return fmt.Errorf("expected %q to be a directory", keysDir) + } + + for name := range r.Keys { + if err := validateKeyFilename(name); err != nil { + return fmt.Errorf("invalid key filename in state: %w", err) + } + if _, err := fs.Stat(filepath.Join(keysDir, name)); err != nil { + if errors.Is(err, afero.ErrFileNotFound) { + return resource.ErrNotFound + } + return fmt.Errorf("failed to stat key file %q: %w", name, err) + } + } + + return nil +} + +func (r *RAGServiceKeysResource) Create(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + keysDir, err := r.keysDir(rc) + if err != nil { + return err + } + if err := fs.MkdirAll(keysDir, 0o700); err != nil { + return fmt.Errorf("failed to create keys directory: %w", err) + } + if err := fs.Chown(keysDir, ragContainerUID, ragContainerUID); err != nil { + return fmt.Errorf("failed to set keys directory ownership: %w", err) + } + return r.writeKeyFiles(fs, keysDir) +} + +func (r *RAGServiceKeysResource) Update(ctx context.Context, rc *resource.Context) error { + // Validate all desired filenames before any filesystem mutation so that an + // invalid name never leaves the directory in a partially-deleted state. + for name := range r.Keys { + if err := validateKeyFilename(name); err != nil { + return err + } + } + + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + keysDir, err := r.keysDir(rc) + if err != nil { + return err + } + if err := fs.MkdirAll(keysDir, 0o700); err != nil { + return fmt.Errorf("failed to create keys directory: %w", err) + } + if err := fs.Chown(keysDir, ragContainerUID, ragContainerUID); err != nil { + return fmt.Errorf("failed to set keys directory ownership: %w", err) + } + if err := r.removeStaleKeyFiles(fs, keysDir); err != nil { + return err + } + return r.writeKeyFiles(fs, keysDir) +} + +func (r *RAGServiceKeysResource) Delete(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + keysDir, err := r.keysDir(rc) + if err != nil { + // Parent dir is gone or unresolvable; nothing to clean up. + return nil + } + if err := fs.RemoveAll(keysDir); err != nil { + return fmt.Errorf("failed to remove keys directory: %w", err) + } + return nil +} + +func (r *RAGServiceKeysResource) writeKeyFiles(fs afero.Fs, keysDir string) error { + for name, key := range r.Keys { + if err := validateKeyFilename(name); err != nil { + return err + } + path := filepath.Join(keysDir, name) + if err := afero.WriteFile(fs, path, []byte(key), 0o600); err != nil { + return fmt.Errorf("failed to write key file %q: %w", name, err) + } + if err := fs.Chown(path, ragContainerUID, ragContainerUID); err != nil { + return fmt.Errorf("failed to set key file %q ownership: %w", name, err) + } + } + return nil +} + +// removeStaleKeyFiles deletes key files in keysDir that are no longer in r.Keys. +// This handles the case where a pipeline (and its key files) has been removed. +func (r *RAGServiceKeysResource) removeStaleKeyFiles(fs afero.Fs, keysDir string) error { + entries, err := afero.ReadDir(fs, keysDir) + if errors.Is(err, afero.ErrFileNotFound) { + return nil + } + if err != nil { + return fmt.Errorf("failed to read keys directory: %w", err) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + if _, ok := r.Keys[entry.Name()]; !ok { + path := filepath.Join(keysDir, entry.Name()) + if err := fs.Remove(path); err != nil && !errors.Is(err, afero.ErrFileNotFound) { + return fmt.Errorf("failed to remove stale key file %q: %w", entry.Name(), err) + } + } + } + return nil +} + +// validateKeyFilename rejects filenames that could escape the keys directory via path traversal. +func validateKeyFilename(name string) error { + if name == "." || name == ".." { + return fmt.Errorf("invalid key filename %q", name) + } + if filepath.Clean(name) != name || filepath.IsAbs(name) || strings.ContainsAny(name, `/\`) { + return fmt.Errorf("invalid key filename %q", name) + } + return nil +} + +// extractRAGAPIKeys builds the filename→value map from a parsed RAGServiceConfig. +// Filenames follow the convention: {pipeline_name}_embedding.key and {pipeline_name}_rag.key. +// Providers that do not require an API key (e.g. ollama) produce no entry. +func extractRAGAPIKeys(cfg *database.RAGServiceConfig) map[string]string { + keys := make(map[string]string) + for _, p := range cfg.Pipelines { + if p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" { + keys[p.Name+"_embedding.key"] = *p.EmbeddingLLM.APIKey + } + if p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" { + keys[p.Name+"_rag.key"] = *p.RAGLLM.APIKey + } + } + return keys +} diff --git a/server/internal/orchestrator/swarm/rag_service_keys_resource_test.go b/server/internal/orchestrator/swarm/rag_service_keys_resource_test.go new file mode 100644 index 00000000..704fb32c --- /dev/null +++ b/server/internal/orchestrator/swarm/rag_service_keys_resource_test.go @@ -0,0 +1,487 @@ +package swarm + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/samber/do" + "github.com/spf13/afero" + + "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/filesystem" + "github.com/pgEdge/control-plane/server/internal/resource" +) + +// ragKeysRCAndFs returns a resource.Context backed by an in-memory afero.Fs +// with the given parent directory pre-created, alongside the Fs itself so +// callers can inspect written files without touching the real filesystem. +func ragKeysRCAndFs(t *testing.T, parentID, parentFullPath string) (*resource.Context, afero.Fs) { + t.Helper() + fs := afero.NewMemMapFs() + _ = fs.MkdirAll(parentFullPath, 0o700) + + injector := do.New() + do.Provide(injector, func(i *do.Injector) (afero.Fs, error) { + return fs, nil + }) + + parentDir := &filesystem.DirResource{ + ID: parentID, + HostID: "host-1", + Path: parentFullPath, + FullPath: parentFullPath, + } + data, err := resource.ToResourceData(parentDir) + if err != nil { + t.Fatalf("ToResourceData() error = %v", err) + } + state := resource.NewState() + state.Add(data) + return &resource.Context{State: state, Injector: injector}, fs +} + +// ragKeysRC returns a resource.Context for tests that only need rc (no file assertions). +func ragKeysRC(t *testing.T, parentID, parentFullPath string) *resource.Context { + t.Helper() + rc, _ := ragKeysRCAndFs(t, parentID, parentFullPath) + return rc +} + +// ragKeysRCWithTempDir returns a resource.Context, its backing afero.Fs, and +// a stable parent path for use in tests that create and inspect key files. +func ragKeysRCWithTempDir(t *testing.T, parentID string) (*resource.Context, afero.Fs, string) { + t.Helper() + parentPath := "/tmp/rag-test-" + t.Name() + "-" + parentID + rc, fs := ragKeysRCAndFs(t, parentID, parentPath) + return rc, fs, parentPath +} + +func TestRAGServiceKeysResource_ResourceVersion(t *testing.T) { + r := &RAGServiceKeysResource{} + if got := r.ResourceVersion(); got != "1" { + t.Errorf("ResourceVersion() = %q, want %q", got, "1") + } +} + +func TestRAGServiceKeysResource_Identifier(t *testing.T) { + r := &RAGServiceKeysResource{ServiceInstanceID: "storefront-rag-host1"} + id := r.Identifier() + if id.ID != "storefront-rag-host1" { + t.Errorf("Identifier().ID = %q, want %q", id.ID, "storefront-rag-host1") + } + if id.Type != ResourceTypeRAGServiceKeys { + t.Errorf("Identifier().Type = %q, want %q", id.Type, ResourceTypeRAGServiceKeys) + } +} + +func TestRAGServiceKeysResource_Executor(t *testing.T) { + r := &RAGServiceKeysResource{HostID: "host-1"} + exec := r.Executor() + if exec != resource.HostExecutor("host-1") { + t.Errorf("Executor() = %v, want HostExecutor(%q)", exec, "host-1") + } +} + +func TestRAGServiceKeysResource_DiffIgnore(t *testing.T) { + r := &RAGServiceKeysResource{} + if got := r.DiffIgnore(); len(got) != 0 { + t.Errorf("DiffIgnore() = %v, want empty", got) + } +} + +func TestRAGServiceKeysResource_Dependencies(t *testing.T) { + r := &RAGServiceKeysResource{ParentID: "storefront-rag-host1-data"} + deps := r.Dependencies() + if len(deps) != 1 { + t.Fatalf("Dependencies() len = %d, want 1", len(deps)) + } + want := filesystem.DirResourceIdentifier("storefront-rag-host1-data") + if deps[0] != want { + t.Errorf("Dependencies()[0] = %v, want %v", deps[0], want) + } +} + +func TestRAGServiceKeysResource_RefreshMissingDir(t *testing.T) { + parentID := "inst1-data" + rc := ragKeysRC(t, parentID, "/nonexistent/path/that/does/not/exist") + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{"default_rag.key": "sk-test"}, + } + err := r.Refresh(context.Background(), rc) + if err != resource.ErrNotFound { + t.Errorf("Refresh() = %v, want ErrNotFound", err) + } +} + +func TestRAGServiceKeysResourceIdentifier(t *testing.T) { + id := RAGServiceKeysResourceIdentifier("my-instance") + if id.ID != "my-instance" { + t.Errorf("ID = %q, want %q", id.ID, "my-instance") + } + if id.Type != ResourceTypeRAGServiceKeys { + t.Errorf("Type = %q, want %q", id.Type, ResourceTypeRAGServiceKeys) + } +} + +func TestExtractRAGAPIKeys_AllProviders(t *testing.T) { + embKey := "sk-embed-key" + ragKey := "sk-ant-key" + cfg := &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", + Model: "text-embedding-3-small", + APIKey: &embKey, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + APIKey: &ragKey, + }, + }, + }, + } + + keys := extractRAGAPIKeys(cfg) + if keys["default_embedding.key"] != embKey { + t.Errorf("default_embedding.key = %q, want %q", keys["default_embedding.key"], embKey) + } + if keys["default_rag.key"] != ragKey { + t.Errorf("default_rag.key = %q, want %q", keys["default_rag.key"], ragKey) + } + if len(keys) != 2 { + t.Errorf("len(keys) = %d, want 2", len(keys)) + } +} + +func TestExtractRAGAPIKeys_OllamaSkipped(t *testing.T) { + cfg := &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "local", + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", + Model: "nomic-embed-text", + // APIKey is nil + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", + Model: "llama3", + // APIKey is nil + }, + }, + }, + } + + keys := extractRAGAPIKeys(cfg) + if len(keys) != 0 { + t.Errorf("len(keys) = %d, want 0 (ollama has no api_key)", len(keys)) + } +} + +func TestExtractRAGAPIKeys_MultiPipeline(t *testing.T) { + k1 := "sk-openai-1" + k2 := "sk-ant-2" + cfg := &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "pipeline-a", + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", + Model: "text-embedding-3-small", + APIKey: &k1, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + APIKey: &k2, + }, + }, + { + Name: "pipeline-b", + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", + Model: "nomic-embed-text", + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", + Model: "llama3", + }, + }, + }, + } + + keys := extractRAGAPIKeys(cfg) + if _, ok := keys["pipeline-a_embedding.key"]; !ok { + t.Error("missing pipeline-a_embedding.key") + } + if _, ok := keys["pipeline-a_rag.key"]; !ok { + t.Error("missing pipeline-a_rag.key") + } + if _, ok := keys["pipeline-b_embedding.key"]; ok { + t.Error("unexpected pipeline-b_embedding.key (ollama has no api_key)") + } + if len(keys) != 2 { + t.Errorf("len(keys) = %d, want 2", len(keys)) + } +} + +func TestGenerateRAGInstanceResources_IncludesKeysResource(t *testing.T) { + o := &Orchestrator{} + spec := &database.ServiceInstanceSpec{ + ServiceInstanceID: "storefront-rag-host1", + ServiceSpec: &database.ServiceSpec{ + ServiceID: "rag", + ServiceType: "rag", + Version: "latest", + Config: map[string]any{ + "pipelines": []any{ + map[string]any{ + "name": "default", + "tables": []any{ + map[string]any{ + "table": "docs", + "text_column": "content", + "vector_column": "embedding", + }, + }, + "embedding_llm": map[string]any{ + "provider": "openai", + "model": "text-embedding-3-small", + "api_key": "sk-embed", + }, + "rag_llm": map[string]any{ + "provider": "anthropic", + "model": "claude-sonnet-4-5", + "api_key": "sk-ant", + }, + }, + }, + }, + }, + DatabaseID: "storefront", + DatabaseName: "storefront", + HostID: "host-1", + NodeName: "n1", + } + + result, err := o.generateRAGInstanceResources(spec) + if err != nil { + t.Fatalf("generateRAGInstanceResources() error = %v", err) + } + + var foundKeys bool + for _, rd := range result.Resources { + if rd.Identifier.Type == ResourceTypeRAGServiceKeys { + foundKeys = true + break + } + } + if !foundKeys { + t.Errorf("expected ResourceTypeRAGServiceKeys in resources, not found") + } +} + +func TestRAGServiceKeysResource_Create(t *testing.T) { + parentID := "inst1-data" + rc, fs, parentPath := ragKeysRCWithTempDir(t, parentID) + + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{ + "default_embedding.key": "sk-embed", + "default_rag.key": "sk-rag", + }, + } + + if err := r.Create(context.Background(), rc); err != nil { + t.Fatalf("Create() error = %v", err) + } + + keysDir := filepath.Join(parentPath, "keys") + for name, want := range r.Keys { + got, err := afero.ReadFile(fs, filepath.Join(keysDir, name)) + if err != nil { + t.Errorf("ReadFile(%q) error = %v", name, err) + continue + } + if string(got) != want { + t.Errorf("key file %q = %q, want %q", name, string(got), want) + } + info, err := fs.Stat(filepath.Join(keysDir, name)) + if err != nil { + t.Errorf("Stat(%q) error = %v", name, err) + continue + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("key file %q perm = %04o, want 0600", name, perm) + } + } + + // Refresh must succeed now that the directory and files exist. + if err := r.Refresh(context.Background(), rc); err != nil { + t.Errorf("Refresh() after Create = %v, want nil", err) + } +} + +func TestRAGServiceKeysResource_Update_WritesNewKeys(t *testing.T) { + parentID := "inst1-data" + rc, fs, parentPath := ragKeysRCWithTempDir(t, parentID) + + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{"old_rag.key": "sk-old"}, + } + if err := r.Create(context.Background(), rc); err != nil { + t.Fatalf("Create() error = %v", err) + } + + r.Keys = map[string]string{"new_rag.key": "sk-new"} + if err := r.Update(context.Background(), rc); err != nil { + t.Fatalf("Update() error = %v", err) + } + + keysDir := filepath.Join(parentPath, "keys") + + if _, err := fs.Stat(filepath.Join(keysDir, "old_rag.key")); !errors.Is(err, afero.ErrFileNotFound) { + t.Errorf("old_rag.key should be removed after Update, got err = %v", err) + } + + got, err := afero.ReadFile(fs, filepath.Join(keysDir, "new_rag.key")) + if err != nil { + t.Fatalf("ReadFile(new_rag.key) error = %v", err) + } + if string(got) != "sk-new" { + t.Errorf("new_rag.key = %q, want %q", string(got), "sk-new") + } +} + +func TestRAGServiceKeysResource_Delete(t *testing.T) { + parentID := "inst1-data" + rc, fs, parentPath := ragKeysRCWithTempDir(t, parentID) + + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{"default_rag.key": "sk-test"}, + } + if err := r.Create(context.Background(), rc); err != nil { + t.Fatalf("Create() error = %v", err) + } + + if err := r.Delete(context.Background(), rc); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + keysDir := filepath.Join(parentPath, "keys") + if _, err := fs.Stat(keysDir); !errors.Is(err, afero.ErrFileNotFound) { + t.Errorf("keys directory should not exist after Delete, got err = %v", err) + } +} + +func TestValidateKeyFilename(t *testing.T) { + valid := []string{ + "default_rag.key", + "pipeline-a_embedding.key", + "foo.key", + } + for _, name := range valid { + if err := validateKeyFilename(name); err != nil { + t.Errorf("validateKeyFilename(%q) = %v, want nil", name, err) + } + } + + invalid := []string{ + "../escape.key", + "/absolute/path.key", + "sub/dir.key", + `sub\dir.key`, + "./relative.key", + ".", + "..", + } + for _, name := range invalid { + if err := validateKeyFilename(name); err == nil { + t.Errorf("validateKeyFilename(%q) = nil, want error", name) + } + } +} + +func TestRAGServiceKeysResource_Create_DirPermissions(t *testing.T) { + parentID := "inst1-data" + rc, fs, parentPath := ragKeysRCWithTempDir(t, parentID) + + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{"default_rag.key": "sk-test"}, + } + if err := r.Create(context.Background(), rc); err != nil { + t.Fatalf("Create() error = %v", err) + } + + keysDir := filepath.Join(parentPath, "keys") + info, err := fs.Stat(keysDir) + if err != nil { + t.Fatalf("Stat(keysDir) error = %v", err) + } + if perm := info.Mode().Perm(); perm != 0o700 { + t.Errorf("keys dir perm = %04o, want 0700", perm) + } +} + +func TestRAGServiceKeysResource_Refresh_InvalidFilenameInState(t *testing.T) { + parentID := "inst1-data" + rc, _, _ := ragKeysRCWithTempDir(t, parentID) + + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{"../escape.key": "sk-bad"}, + } + err := r.Refresh(context.Background(), rc) + if err == nil { + t.Error("Refresh() = nil, want error for invalid key filename in state") + } +} + +func TestRAGServiceKeysResource_Update_InvalidFilenameIsNonDestructive(t *testing.T) { + parentID := "inst1-data" + rc, fs, parentPath := ragKeysRCWithTempDir(t, parentID) + + r := &RAGServiceKeysResource{ + ServiceInstanceID: "inst1", + HostID: "host-1", + ParentID: parentID, + Keys: map[string]string{"default_rag.key": "sk-good"}, + } + if err := r.Create(context.Background(), rc); err != nil { + t.Fatalf("Create() error = %v", err) + } + + // Attempt Update with an invalid filename — must fail before any deletion. + r.Keys = map[string]string{"../escape.key": "sk-bad"} + if err := r.Update(context.Background(), rc); err == nil { + t.Fatal("Update() = nil, want error for invalid key filename") + } + + // The original file must still be present — Update must not have deleted it. + existing := filepath.Join(parentPath, "keys", "default_rag.key") + if _, err := fs.Stat(existing); err != nil { + t.Errorf("default_rag.key should still exist after failed Update, got err = %v", err) + } +} diff --git a/server/internal/orchestrator/swarm/rag_service_user_role_test.go b/server/internal/orchestrator/swarm/rag_service_user_role_test.go index fed61b40..681cadf6 100644 --- a/server/internal/orchestrator/swarm/rag_service_user_role_test.go +++ b/server/internal/orchestrator/swarm/rag_service_user_role_test.go @@ -4,9 +4,38 @@ import ( "testing" "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/filesystem" "github.com/pgEdge/control-plane/server/internal/resource" ) +// minimalRAGConfig returns a minimal valid RAG service config suitable for unit tests. +func minimalRAGConfig() map[string]any { + return map[string]any{ + "pipelines": []any{ + map[string]any{ + "name": "default", + "tables": []any{ + map[string]any{ + "table": "docs", + "text_column": "content", + "vector_column": "embedding", + }, + }, + "embedding_llm": map[string]any{ + "provider": "openai", + "model": "text-embedding-3-small", + "api_key": "sk-embed", + }, + "rag_llm": map[string]any{ + "provider": "anthropic", + "model": "claude-sonnet-4-5", + "api_key": "sk-ant", + }, + }, + }, + } +} + func TestGenerateRAGInstanceResources_ResourceList(t *testing.T) { o := &Orchestrator{} spec := &database.ServiceInstanceSpec{ @@ -15,6 +44,7 @@ func TestGenerateRAGInstanceResources_ResourceList(t *testing.T) { ServiceID: "rag", ServiceType: "rag", Version: "latest", + Config: minimalRAGConfig(), }, DatabaseID: "storefront", DatabaseName: "storefront", @@ -43,9 +73,9 @@ func TestGenerateRAGInstanceResources_ResourceList(t *testing.T) { result.ServiceInstance.State, database.ServiceInstanceStateCreating) } - // Single node: one canonical RO ServiceUserRole. - if len(result.Resources) != 1 { - t.Fatalf("len(Resources) = %d, want 1", len(result.Resources)) + // Single node: canonical RO ServiceUserRole + data DirResource + RAGServiceKeysResource. + if len(result.Resources) != 3 { + t.Fatalf("len(Resources) = %d, want 3", len(result.Resources)) } if result.Resources[0].Identifier.Type != ResourceTypeServiceUserRole { t.Errorf("Resources[0].Identifier.Type = %q, want %q", @@ -55,6 +85,14 @@ func TestGenerateRAGInstanceResources_ResourceList(t *testing.T) { if result.Resources[0].Identifier != wantID { t.Errorf("Resources[0].Identifier = %v, want %v", result.Resources[0].Identifier, wantID) } + if result.Resources[1].Identifier.Type != filesystem.ResourceTypeDir { + t.Errorf("Resources[1].Identifier.Type = %q, want %q", + result.Resources[1].Identifier.Type, filesystem.ResourceTypeDir) + } + if result.Resources[2].Identifier.Type != ResourceTypeRAGServiceKeys { + t.Errorf("Resources[2].Identifier.Type = %q, want %q", + result.Resources[2].Identifier.Type, ResourceTypeRAGServiceKeys) + } } func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { @@ -65,6 +103,7 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { ServiceID: "rag", ServiceType: "rag", Version: "latest", + Config: minimalRAGConfig(), }, DatabaseID: "storefront", DatabaseName: "storefront", @@ -82,13 +121,14 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { t.Fatalf("generateRAGInstanceResources() error = %v", err) } - // 3 nodes → canonical(n1) + per-node(n2) + per-node(n3) = 3 RO resources - if len(result.Resources) != 3 { - t.Fatalf("len(Resources) = %d, want 3", len(result.Resources)) + // 3 nodes → canonical(n1) + per-node(n2) + per-node(n3) + data dir + keys = 5 resources. + if len(result.Resources) != 5 { + t.Fatalf("len(Resources) = %d, want 5", len(result.Resources)) } - for _, rd := range result.Resources { - if rd.Identifier.Type != ResourceTypeServiceUserRole { - t.Errorf("resource type = %q, want %q", rd.Identifier.Type, ResourceTypeServiceUserRole) + // First three must be ServiceUserRole resources. + for i := 0; i < 3; i++ { + if result.Resources[i].Identifier.Type != ResourceTypeServiceUserRole { + t.Errorf("resource[%d] type = %q, want %q", i, result.Resources[i].Identifier.Type, ResourceTypeServiceUserRole) } } @@ -106,7 +146,7 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { // Per-node resources point back to canonical canonicalID := ServiceUserRoleIdentifier("rag", ServiceUserRoleRO) - for i, rd := range result.Resources[1:] { + for i, rd := range result.Resources[1:3] { perNode, err := resource.ToResource[*ServiceUserRole](rd) if err != nil { t.Fatalf("ToResource per-node[%d]: %v", i, err) @@ -118,6 +158,16 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { t.Errorf("per-node[%d].Mode = %q, want %q", i, perNode.Mode, ServiceUserRoleRO) } } + + // Data dir and keys resource are appended last. + if result.Resources[3].Identifier.Type != filesystem.ResourceTypeDir { + t.Errorf("Resources[3].Identifier.Type = %q, want %q", + result.Resources[3].Identifier.Type, filesystem.ResourceTypeDir) + } + if result.Resources[4].Identifier.Type != ResourceTypeRAGServiceKeys { + t.Errorf("Resources[4].Identifier.Type = %q, want %q", + result.Resources[4].Identifier.Type, ResourceTypeRAGServiceKeys) + } } func TestGenerateRAGInstanceResources_MultiNode_CanonicalNotFirst(t *testing.T) { @@ -128,6 +178,7 @@ func TestGenerateRAGInstanceResources_MultiNode_CanonicalNotFirst(t *testing.T) ServiceID: "rag", ServiceType: "rag", Version: "latest", + Config: minimalRAGConfig(), }, DatabaseID: "storefront", DatabaseName: "storefront", @@ -145,9 +196,9 @@ func TestGenerateRAGInstanceResources_MultiNode_CanonicalNotFirst(t *testing.T) t.Fatalf("generateRAGInstanceResources() error = %v", err) } - // 3 nodes → canonical(n2) + per-node(n1) + per-node(n3) = 3 RO resources - if len(result.Resources) != 3 { - t.Fatalf("len(Resources) = %d, want 3", len(result.Resources)) + // 3 nodes → canonical(n2) + per-node(n1) + per-node(n3) + data dir + keys = 5 resources. + if len(result.Resources) != 5 { + t.Fatalf("len(Resources) = %d, want 5", len(result.Resources)) } // Canonical (index 0) must be n2 with no CredentialSource @@ -165,7 +216,7 @@ func TestGenerateRAGInstanceResources_MultiNode_CanonicalNotFirst(t *testing.T) // Per-node resources must cover n1 and n3, not n2 canonicalID := ServiceUserRoleIdentifier("rag", ServiceUserRoleRO) perNodeNames := make(map[string]bool) - for i, rd := range result.Resources[1:] { + for i, rd := range result.Resources[1:3] { perNode, err := resource.ToResource[*ServiceUserRole](rd) if err != nil { t.Fatalf("ToResource per-node[%d]: %v", i, err) @@ -191,6 +242,7 @@ func TestGenerateServiceInstanceResources_RAGDispatch(t *testing.T) { ServiceID: "rag", ServiceType: "rag", Version: "latest", + Config: minimalRAGConfig(), }, DatabaseID: "db1", DatabaseName: "db1", diff --git a/server/internal/orchestrator/swarm/resources.go b/server/internal/orchestrator/swarm/resources.go index 3ad755e8..a049ff10 100644 --- a/server/internal/orchestrator/swarm/resources.go +++ b/server/internal/orchestrator/swarm/resources.go @@ -24,4 +24,5 @@ func RegisterResourceTypes(registry *resource.Registry) { resource.RegisterResourceType[*PostgRESTPreflightResource](registry, ResourceTypePostgRESTPreflightResource) resource.RegisterResourceType[*PostgRESTConfigResource](registry, ResourceTypePostgRESTConfig) resource.RegisterResourceType[*PostgRESTAuthenticatorResource](registry, ResourceTypePostgRESTAuthenticator) + resource.RegisterResourceType[*RAGServiceKeysResource](registry, ResourceTypeRAGServiceKeys) } diff --git a/server/internal/orchestrator/swarm/service_spec.go b/server/internal/orchestrator/swarm/service_spec.go index 9901620d..56b17ebe 100644 --- a/server/internal/orchestrator/swarm/service_spec.go +++ b/server/internal/orchestrator/swarm/service_spec.go @@ -15,6 +15,9 @@ import ( // mcpContainerUID is the UID of the MCP container user. const mcpContainerUID = 1001 +// ragContainerUID is the UID of the RAG server container user. +const ragContainerUID = 1001 + // postgrestContainerUID is the UID of the PostgREST container user. // See: https://github.com/PostgREST/postgrest/blob/main/Dockerfile (USER 1000) const postgrestContainerUID = 1000 @@ -54,6 +57,9 @@ type ServiceContainerSpecOptions struct { Port *int // DataPath is the host-side directory path for the bind mount DataPath string + // KeysPath is the host-side directory containing API key files. + // When non-empty, it is bind-mounted read-only into the container at /app/keys. + KeysPath string } // ServiceContainerSpec builds a Docker Swarm service spec for a service instance.