diff --git a/server/internal/orchestrator/swarm/orchestrator.go b/server/internal/orchestrator/swarm/orchestrator.go index ba2fa7d5..4e2105ff 100644 --- a/server/internal/orchestrator/swarm/orchestrator.go +++ b/server/internal/orchestrator/swarm/orchestrator.go @@ -704,7 +704,25 @@ func (o *Orchestrator) generateRAGInstanceResources(spec *database.ServiceInstan Keys: extractRAGAPIKeys(ragConfig), } - orchestratorResources = append(orchestratorResources, dataDir, keysResource) + // RAG config resource — generates pgedge-rag-server.yaml in the data directory. + var dbHost string + var dbPort int + if len(spec.DatabaseHosts) > 0 { + dbHost = spec.DatabaseHosts[0].Host + dbPort = spec.DatabaseHosts[0].Port + } + ragConfigRes := &RAGConfigResource{ + ServiceInstanceID: spec.ServiceInstanceID, + ServiceID: spec.ServiceSpec.ServiceID, + HostID: spec.HostID, + DirResourceID: dataDirID, + Config: ragConfig, + DatabaseName: spec.DatabaseName, + DatabaseHost: dbHost, + DatabasePort: dbPort, + } + + orchestratorResources = append(orchestratorResources, dataDir, keysResource, ragConfigRes) return o.buildServiceInstanceResources(spec, orchestratorResources) } diff --git a/server/internal/orchestrator/swarm/rag_config.go b/server/internal/orchestrator/swarm/rag_config.go new file mode 100644 index 00000000..cf5c9a96 --- /dev/null +++ b/server/internal/orchestrator/swarm/rag_config.go @@ -0,0 +1,249 @@ +package swarm + +import ( + "fmt" + "path" + + "github.com/goccy/go-yaml" + + "github.com/pgEdge/control-plane/server/internal/database" +) + +// ragYAMLConfig mirrors the pgedge-rag-server Config struct for YAML generation. +// Only the fields the control plane needs to set are included. +type ragYAMLConfig struct { + Server ragServerYAML `yaml:"server"` + Pipelines []ragPipelineYAML `yaml:"pipelines"` + Defaults *ragDefaultsYAML `yaml:"defaults,omitempty"` +} + +type ragServerYAML struct { + ListenAddress string `yaml:"listen_address"` + Port int `yaml:"port"` +} + +type ragPipelineYAML struct { + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + Database ragDatabaseYAML `yaml:"database"` + Tables []ragTableYAML `yaml:"tables"` + EmbeddingLLM ragLLMYAML `yaml:"embedding_llm"` + RAGLLM ragLLMYAML `yaml:"rag_llm"` + APIKeys *ragAPIKeysYAML `yaml:"api_keys,omitempty"` + TokenBudget *int `yaml:"token_budget,omitempty"` + TopN *int `yaml:"top_n,omitempty"` + SystemPrompt string `yaml:"system_prompt,omitempty"` + Search *ragSearchYAML `yaml:"search,omitempty"` +} + +type ragDatabaseYAML struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Database string `yaml:"database"` + Username string `yaml:"username"` + Password string `yaml:"password"` + SSLMode string `yaml:"ssl_mode"` +} + +type ragTableYAML struct { + Table string `yaml:"table"` + TextColumn string `yaml:"text_column"` + VectorColumn string `yaml:"vector_column"` + IDColumn string `yaml:"id_column,omitempty"` +} + +type ragLLMYAML struct { + Provider string `yaml:"provider"` + Model string `yaml:"model"` + BaseURL string `yaml:"base_url,omitempty"` +} + +// ragAPIKeysYAML holds container-side file paths for each provider's API key. +type ragAPIKeysYAML struct { + Anthropic string `yaml:"anthropic,omitempty"` + OpenAI string `yaml:"openai,omitempty"` + Voyage string `yaml:"voyage,omitempty"` +} + +type ragSearchYAML struct { + HybridEnabled *bool `yaml:"hybrid_enabled,omitempty"` + VectorWeight *float64 `yaml:"vector_weight,omitempty"` +} + +type ragDefaultsYAML struct { + TokenBudget *int `yaml:"token_budget,omitempty"` + TopN *int `yaml:"top_n,omitempty"` +} + +// RAGConfigParams holds all inputs needed to generate pgedge-rag-server.yaml. +type RAGConfigParams struct { + Config *database.RAGServiceConfig + DatabaseName string + DatabaseHost string + DatabasePort int + Username string + Password string + // KeysDir is the container-side directory where API key files are mounted, + // e.g. "/app/keys". Key filenames follow the {pipeline}_{embedding|rag}.key + // convention produced by extractRAGAPIKeys. + KeysDir string +} + +// GenerateRAGConfig generates the pgedge-rag-server.yaml content from the +// given parameters. API key paths in the generated YAML reference files under +// KeysDir so the RAG server reads them from the bind-mounted keys directory. +func GenerateRAGConfig(params *RAGConfigParams) ([]byte, error) { + pipelines := make([]ragPipelineYAML, 0, len(params.Config.Pipelines)) + for _, p := range params.Config.Pipelines { + pl, err := buildRAGPipelineYAML(p, params) + if err != nil { + return nil, err + } + pipelines = append(pipelines, pl) + } + + var defaults *ragDefaultsYAML + if params.Config.Defaults != nil { + src := params.Config.Defaults + if src.TokenBudget != nil || src.TopN != nil { + defaults = &ragDefaultsYAML{ + TokenBudget: src.TokenBudget, + TopN: src.TopN, + } + } + } + + cfg := &ragYAMLConfig{ + Server: ragServerYAML{ + ListenAddress: "0.0.0.0", + Port: 8080, + }, + Pipelines: pipelines, + Defaults: defaults, + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return nil, err + } + return data, nil +} + +func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) (ragPipelineYAML, error) { + tables := make([]ragTableYAML, 0, len(p.Tables)) + for _, t := range p.Tables { + tbl := ragTableYAML{ + Table: t.Table, + TextColumn: t.TextColumn, + VectorColumn: t.VectorColumn, + } + if t.IDColumn != nil { + tbl.IDColumn = *t.IDColumn + } + tables = append(tables, tbl) + } + + embLLM := ragLLMYAML{ + Provider: p.EmbeddingLLM.Provider, + Model: p.EmbeddingLLM.Model, + } + if p.EmbeddingLLM.BaseURL != nil { + embLLM.BaseURL = *p.EmbeddingLLM.BaseURL + } + + ragLLM := ragLLMYAML{ + Provider: p.RAGLLM.Provider, + Model: p.RAGLLM.Model, + } + if p.RAGLLM.BaseURL != nil { + ragLLM.BaseURL = *p.RAGLLM.BaseURL + } + + apiKeys, err := buildRAGAPIKeysYAML(p, params.KeysDir) + if err != nil { + return ragPipelineYAML{}, err + } + + pipeline := ragPipelineYAML{ + Name: p.Name, + Database: ragDatabaseYAML{ + Host: params.DatabaseHost, + Port: params.DatabasePort, + Database: params.DatabaseName, + Username: params.Username, + Password: params.Password, + SSLMode: "prefer", + }, + Tables: tables, + EmbeddingLLM: embLLM, + RAGLLM: ragLLM, + APIKeys: apiKeys, + } + + if p.Description != nil { + pipeline.Description = *p.Description + } + pipeline.TokenBudget = p.TokenBudget + pipeline.TopN = p.TopN + if p.SystemPrompt != nil { + pipeline.SystemPrompt = *p.SystemPrompt + } + if p.Search != nil { + pipeline.Search = &ragSearchYAML{ + HybridEnabled: p.Search.HybridEnabled, + VectorWeight: p.Search.VectorWeight, + } + } + + return pipeline, nil +} + +// buildRAGAPIKeysYAML maps each LLM provider that requires a key to the +// corresponding bind-mounted key file path inside the container. +// Embedding key: {keysDir}/{pipeline}_embedding.key +// RAG key: {keysDir}/{pipeline}_rag.key +// If embedding and RAG use the same provider, the RAG key path takes precedence +// (both files contain the same value). Returns an error if both LLMs share a +// provider but were configured with different API keys. +func buildRAGAPIKeysYAML(p database.RAGPipeline, keysDir string) (*ragAPIKeysYAML, error) { + // Reject mismatched keys for the same provider — the RAG server has a + // single key slot per provider and cannot reconcile two different values. + if p.EmbeddingLLM.Provider == p.RAGLLM.Provider && + p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" && + p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" && + *p.EmbeddingLLM.APIKey != *p.RAGLLM.APIKey { + return nil, fmt.Errorf("pipeline %q: embedding_llm and rag_llm share provider %q but have different API keys", + p.Name, p.EmbeddingLLM.Provider) + } + + keys := &ragAPIKeysYAML{} + + // Embedding provider key + if p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" { + keyPath := path.Join(keysDir, p.Name+"_embedding.key") + switch p.EmbeddingLLM.Provider { + case "anthropic": + keys.Anthropic = keyPath + case "openai": + keys.OpenAI = keyPath + case "voyage": + keys.Voyage = keyPath + } + } + + // RAG provider key (overwrites if same provider as embedding) + if p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" { + keyPath := path.Join(keysDir, p.Name+"_rag.key") + switch p.RAGLLM.Provider { + case "anthropic": + keys.Anthropic = keyPath + case "openai": + keys.OpenAI = keyPath + } + } + + if keys.Anthropic == "" && keys.OpenAI == "" && keys.Voyage == "" { + return nil, nil + } + return keys, nil +} diff --git a/server/internal/orchestrator/swarm/rag_config_resource.go b/server/internal/orchestrator/swarm/rag_config_resource.go new file mode 100644 index 00000000..25ca3cff --- /dev/null +++ b/server/internal/orchestrator/swarm/rag_config_resource.go @@ -0,0 +1,157 @@ +package swarm + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/samber/do" + "github.com/spf13/afero" + + "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/filesystem" + "github.com/pgEdge/control-plane/server/internal/resource" +) + +var _ resource.Resource = (*RAGConfigResource)(nil) + +const ResourceTypeRAGConfig resource.Type = "swarm.rag_config" + +// ragConfigFilename is the config file name expected by pgedge-rag-server. +const ragConfigFilename = "pgedge-rag-server.yaml" + +// ragKeysContainerDir is the container-side mount path for the keys directory. +const ragKeysContainerDir = "/app/keys" + +func RAGConfigResourceIdentifier(serviceInstanceID string) resource.Identifier { + return resource.Identifier{ + ID: serviceInstanceID, + Type: ResourceTypeRAGConfig, + } +} + +// RAGConfigResource manages the pgedge-rag-server.yaml config file on the +// host filesystem. The file is written to the service data directory +// (managed by a DirResource) which is bind-mounted into the container at +// /app/data. On every Create or Update the file is regenerated from the +// current RAGServiceConfig and database credentials. +type RAGConfigResource struct { + ServiceInstanceID string `json:"service_instance_id"` + ServiceID string `json:"service_id"` + HostID string `json:"host_id"` + DirResourceID string `json:"dir_resource_id"` + Config *database.RAGServiceConfig `json:"config"` + DatabaseName string `json:"database_name"` + DatabaseHost string `json:"database_host"` + DatabasePort int `json:"database_port"` +} + +func (r *RAGConfigResource) ResourceVersion() string { + return "1" +} + +func (r *RAGConfigResource) DiffIgnore() []string { + return nil +} + +func (r *RAGConfigResource) Identifier() resource.Identifier { + return RAGConfigResourceIdentifier(r.ServiceInstanceID) +} + +func (r *RAGConfigResource) Executor() resource.Executor { + return resource.HostExecutor(r.HostID) +} + +func (r *RAGConfigResource) Dependencies() []resource.Identifier { + return []resource.Identifier{ + filesystem.DirResourceIdentifier(r.DirResourceID), + ServiceUserRoleIdentifier(r.ServiceID, ServiceUserRoleRO), + RAGServiceKeysResourceIdentifier(r.ServiceInstanceID), + } +} + +func (r *RAGConfigResource) TypeDependencies() []resource.Type { + return nil +} + +func (r *RAGConfigResource) Refresh(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + dirPath, err := filesystem.DirResourceFullPath(rc, r.DirResourceID) + if err != nil { + return fmt.Errorf("failed to get service data dir path: %w", err) + } + + _, err = readResourceFile(fs, filepath.Join(dirPath, ragConfigFilename)) + if err != nil { + return fmt.Errorf("failed to read RAG config: %w", err) + } + + return nil +} + +func (r *RAGConfigResource) Create(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + dirPath, err := filesystem.DirResourceFullPath(rc, r.DirResourceID) + if err != nil { + return fmt.Errorf("failed to get service data dir path: %w", err) + } + + return r.writeConfigFile(fs, dirPath, rc) +} + +func (r *RAGConfigResource) Update(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + dirPath, err := filesystem.DirResourceFullPath(rc, r.DirResourceID) + if err != nil { + return fmt.Errorf("failed to get service data dir path: %w", err) + } + + return r.writeConfigFile(fs, dirPath, rc) +} + +func (r *RAGConfigResource) Delete(ctx context.Context, rc *resource.Context) error { + // Cleanup is handled by the parent DirResource deletion. + return nil +} + +func (r *RAGConfigResource) writeConfigFile(fs afero.Fs, dirPath string, rc *resource.Context) error { + userRole, err := resource.FromContext[*ServiceUserRole](rc, ServiceUserRoleIdentifier(r.ServiceID, ServiceUserRoleRO)) + if err != nil { + return fmt.Errorf("failed to get RAG service user role from state: %w", err) + } + + content, err := GenerateRAGConfig(&RAGConfigParams{ + Config: r.Config, + DatabaseName: r.DatabaseName, + DatabaseHost: r.DatabaseHost, + DatabasePort: r.DatabasePort, + Username: userRole.Username, + Password: userRole.Password, + KeysDir: ragKeysContainerDir, + }) + if err != nil { + return fmt.Errorf("failed to generate RAG config: %w", err) + } + + configPath := filepath.Join(dirPath, ragConfigFilename) + if err := afero.WriteFile(fs, configPath, content, 0o600); err != nil { + return fmt.Errorf("failed to write %s: %w", configPath, err) + } + if err := fs.Chown(configPath, ragContainerUID, ragContainerUID); err != nil { + return fmt.Errorf("failed to change ownership for %s: %w", configPath, err) + } + + return nil +} diff --git a/server/internal/orchestrator/swarm/rag_config_resource_test.go b/server/internal/orchestrator/swarm/rag_config_resource_test.go new file mode 100644 index 00000000..0a58f1b7 --- /dev/null +++ b/server/internal/orchestrator/swarm/rag_config_resource_test.go @@ -0,0 +1,76 @@ +package swarm + +import ( + "testing" + + "github.com/pgEdge/control-plane/server/internal/filesystem" + "github.com/pgEdge/control-plane/server/internal/resource" +) + +func TestRAGConfigResource_ResourceVersion(t *testing.T) { + r := &RAGConfigResource{} + if got := r.ResourceVersion(); got != "1" { + t.Errorf("ResourceVersion() = %q, want %q", got, "1") + } +} + +func TestRAGConfigResource_Identifier(t *testing.T) { + r := &RAGConfigResource{ServiceInstanceID: "storefront-rag-host1"} + id := r.Identifier() + if id.ID != "storefront-rag-host1" { + t.Errorf("Identifier().ID = %q, want %q", id.ID, "storefront-rag-host1") + } + if id.Type != ResourceTypeRAGConfig { + t.Errorf("Identifier().Type = %q, want %q", id.Type, ResourceTypeRAGConfig) + } +} + +func TestRAGConfigResourceIdentifier(t *testing.T) { + id := RAGConfigResourceIdentifier("my-instance") + if id.ID != "my-instance" { + t.Errorf("ID = %q, want %q", id.ID, "my-instance") + } + if id.Type != ResourceTypeRAGConfig { + t.Errorf("Type = %q, want %q", id.Type, ResourceTypeRAGConfig) + } +} + +func TestRAGConfigResource_Executor(t *testing.T) { + r := &RAGConfigResource{HostID: "host-1"} + exec := r.Executor() + if exec != resource.HostExecutor("host-1") { + t.Errorf("Executor() = %v, want HostExecutor(%q)", exec, "host-1") + } +} + +func TestRAGConfigResource_DiffIgnore(t *testing.T) { + r := &RAGConfigResource{} + ignored := r.DiffIgnore() + if len(ignored) != 0 { + t.Errorf("DiffIgnore() = %v, want empty", ignored) + } +} + +func TestRAGConfigResource_Dependencies(t *testing.T) { + r := &RAGConfigResource{ + ServiceInstanceID: "storefront-rag-host1", + ServiceID: "rag", + DirResourceID: "storefront-rag-host1-data", + } + deps := r.Dependencies() + + if len(deps) != 3 { + t.Fatalf("Dependencies() len = %d, want 3", len(deps)) + } + + wantDeps := []resource.Identifier{ + filesystem.DirResourceIdentifier("storefront-rag-host1-data"), + ServiceUserRoleIdentifier("rag", ServiceUserRoleRO), + RAGServiceKeysResourceIdentifier("storefront-rag-host1"), + } + for i, want := range wantDeps { + if deps[i] != want { + t.Errorf("Dependencies()[%d] = %v, want %v", i, deps[i], want) + } + } +} diff --git a/server/internal/orchestrator/swarm/rag_config_test.go b/server/internal/orchestrator/swarm/rag_config_test.go new file mode 100644 index 00000000..33e4f0dd --- /dev/null +++ b/server/internal/orchestrator/swarm/rag_config_test.go @@ -0,0 +1,551 @@ +package swarm + +import ( + "testing" + + "github.com/goccy/go-yaml" + + "github.com/pgEdge/control-plane/server/internal/database" +) + +// parseRAGYAML unmarshals GenerateRAGConfig output into ragYAMLConfig for assertion. +func parseRAGYAML(t *testing.T, data []byte) *ragYAMLConfig { + t.Helper() + var cfg ragYAMLConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + t.Fatalf("failed to unmarshal YAML: %v\nYAML:\n%s", err, string(data)) + } + return &cfg +} + +func minimalRAGParams() *RAGConfigParams { + apiKey := "sk-ant-test" + embedKey := "sk-openai-test" + return &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Tables: []database.RAGPipelineTable{ + {Table: "docs", TextColumn: "content", VectorColumn: "embedding"}, + }, + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", + Model: "text-embedding-3-small", + APIKey: &embedKey, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + APIKey: &apiKey, + }, + }, + }, + }, + DatabaseName: "mydb", + DatabaseHost: "pg-host", + DatabasePort: 5432, + Username: "svc_rag", + Password: "secret", + KeysDir: "/app/keys", + } +} + +func TestGenerateRAGConfig_ServerDefaults(t *testing.T) { + data, err := GenerateRAGConfig(minimalRAGParams()) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + + if cfg.Server.ListenAddress != "0.0.0.0" { + t.Errorf("server.listen_address = %q, want %q", cfg.Server.ListenAddress, "0.0.0.0") + } + if cfg.Server.Port != 8080 { + t.Errorf("server.port = %d, want 8080", cfg.Server.Port) + } +} + +func TestGenerateRAGConfig_DatabaseConnection(t *testing.T) { + params := minimalRAGParams() + params.DatabaseHost = "pg-primary.internal" + params.DatabasePort = 5433 + params.DatabaseName = "storefront" + params.Username = "svc_storefront_rag" + params.Password = "supersecret" + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + + if len(cfg.Pipelines) != 1 { + t.Fatalf("pipelines len = %d, want 1", len(cfg.Pipelines)) + } + db := cfg.Pipelines[0].Database + + if db.Host != "pg-primary.internal" { + t.Errorf("database.host = %q, want %q", db.Host, "pg-primary.internal") + } + if db.Port != 5433 { + t.Errorf("database.port = %d, want 5433", db.Port) + } + if db.Database != "storefront" { + t.Errorf("database.database = %q, want %q", db.Database, "storefront") + } + if db.Username != "svc_storefront_rag" { + t.Errorf("database.username = %q, want %q", db.Username, "svc_storefront_rag") + } + if db.Password != "supersecret" { + t.Errorf("database.password = %q, want %q", db.Password, "supersecret") + } + if db.SSLMode != "prefer" { + t.Errorf("database.ssl_mode = %q, want %q", db.SSLMode, "prefer") + } +} + +func TestGenerateRAGConfig_APIKeyPaths_DifferentProviders(t *testing.T) { + // embedding = openai, rag = anthropic → separate api_keys paths + data, err := GenerateRAGConfig(minimalRAGParams()) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + keys := cfg.Pipelines[0].APIKeys + + if keys == nil { + t.Fatal("api_keys should be present") + } + if keys.OpenAI != "/app/keys/default_embedding.key" { + t.Errorf("api_keys.openai = %q, want %q", keys.OpenAI, "/app/keys/default_embedding.key") + } + if keys.Anthropic != "/app/keys/default_rag.key" { + t.Errorf("api_keys.anthropic = %q, want %q", keys.Anthropic, "/app/keys/default_rag.key") + } + if keys.Voyage != "" { + t.Errorf("api_keys.voyage should be empty, got %q", keys.Voyage) + } +} + +func TestGenerateRAGConfig_APIKeyPaths_VoyageEmbedding(t *testing.T) { + voyageKey := "pa-voyage-key" + antKey := "sk-ant-test" + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "search", + Tables: []database.RAGPipelineTable{{Table: "docs", TextColumn: "body", VectorColumn: "vec"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "voyage", + Model: "voyage-3", + APIKey: &voyageKey, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + APIKey: &antKey, + }, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + keys := cfg.Pipelines[0].APIKeys + + if keys == nil { + t.Fatal("api_keys should be present") + } + if keys.Voyage != "/app/keys/search_embedding.key" { + t.Errorf("api_keys.voyage = %q, want %q", keys.Voyage, "/app/keys/search_embedding.key") + } + if keys.Anthropic != "/app/keys/search_rag.key" { + t.Errorf("api_keys.anthropic = %q, want %q", keys.Anthropic, "/app/keys/search_rag.key") + } +} + +func TestGenerateRAGConfig_APIKeyPaths_SameProvider_RAGTakesPrecedence(t *testing.T) { + // Both embedding and rag use openai → rag key path overwrites embedding key path. + key := "sk-openai-key" + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Tables: []database.RAGPipelineTable{{Table: "t", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", Model: "text-embedding-3-small", APIKey: &key, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", Model: "gpt-4o", APIKey: &key, + }, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + keys := cfg.Pipelines[0].APIKeys + + if keys == nil { + t.Fatal("api_keys should be present") + } + // rag key takes precedence when provider is the same + if keys.OpenAI != "/app/keys/default_rag.key" { + t.Errorf("api_keys.openai = %q, want rag key path %q", keys.OpenAI, "/app/keys/default_rag.key") + } +} + +func TestGenerateRAGConfig_OllamaNoAPIKey(t *testing.T) { + // ollama providers have no API key — api_keys section must be absent. + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Tables: []database.RAGPipelineTable{{Table: "t", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", Model: "nomic-embed-text", + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", Model: "llama3", + }, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + + if cfg.Pipelines[0].APIKeys != nil { + t.Errorf("api_keys should be absent for ollama providers, got %+v", cfg.Pipelines[0].APIKeys) + } +} + +func TestGenerateRAGConfig_LLMConfig(t *testing.T) { + baseURL := "http://ollama.internal:11434" + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Tables: []database.RAGPipelineTable{{Table: "t", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", Model: "nomic-embed-text", BaseURL: &baseURL, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "ollama", Model: "llama3", + }, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + p := cfg.Pipelines[0] + + if p.EmbeddingLLM.Provider != "ollama" { + t.Errorf("embedding_llm.provider = %q, want %q", p.EmbeddingLLM.Provider, "ollama") + } + if p.EmbeddingLLM.Model != "nomic-embed-text" { + t.Errorf("embedding_llm.model = %q, want %q", p.EmbeddingLLM.Model, "nomic-embed-text") + } + if p.EmbeddingLLM.BaseURL != baseURL { + t.Errorf("embedding_llm.base_url = %q, want %q", p.EmbeddingLLM.BaseURL, baseURL) + } + if p.RAGLLM.Provider != "ollama" { + t.Errorf("rag_llm.provider = %q, want %q", p.RAGLLM.Provider, "ollama") + } + if p.RAGLLM.Model != "llama3" { + t.Errorf("rag_llm.model = %q, want %q", p.RAGLLM.Model, "llama3") + } +} + +func TestGenerateRAGConfig_Tables(t *testing.T) { + idCol := "id" + antKey := "sk-ant" + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Tables: []database.RAGPipelineTable{ + {Table: "docs", TextColumn: "content", VectorColumn: "embedding", IDColumn: &idCol}, + {Table: "notes", TextColumn: "body", VectorColumn: "vec"}, + }, + EmbeddingLLM: database.RAGPipelineLLMConfig{Provider: "anthropic", Model: "m", APIKey: &antKey}, + RAGLLM: database.RAGPipelineLLMConfig{Provider: "anthropic", Model: "m", APIKey: &antKey}, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + tables := cfg.Pipelines[0].Tables + + if len(tables) != 2 { + t.Fatalf("tables len = %d, want 2", len(tables)) + } + if tables[0].Table != "docs" { + t.Errorf("tables[0].table = %q, want %q", tables[0].Table, "docs") + } + if tables[0].IDColumn != "id" { + t.Errorf("tables[0].id_column = %q, want %q", tables[0].IDColumn, "id") + } + if tables[1].Table != "notes" { + t.Errorf("tables[1].table = %q, want %q", tables[1].Table, "notes") + } + if tables[1].IDColumn != "" { + t.Errorf("tables[1].id_column should be empty (omitted), got %q", tables[1].IDColumn) + } +} + +func TestGenerateRAGConfig_OptionalPipelineFields(t *testing.T) { + antKey := "sk-ant" + desc := "My pipeline" + budget := 500 + topN := 5 + prompt := "You are a helpful assistant." + hybridEnabled := false + weight := 0.7 + + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Description: &desc, + Tables: []database.RAGPipelineTable{{Table: "t", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{Provider: "anthropic", Model: "m", APIKey: &antKey}, + RAGLLM: database.RAGPipelineLLMConfig{Provider: "anthropic", Model: "m", APIKey: &antKey}, + TokenBudget: &budget, + TopN: &topN, + SystemPrompt: &prompt, + Search: &database.RAGPipelineSearch{ + HybridEnabled: &hybridEnabled, + VectorWeight: &weight, + }, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + p := cfg.Pipelines[0] + + if p.Description != desc { + t.Errorf("description = %q, want %q", p.Description, desc) + } + if p.TokenBudget == nil || *p.TokenBudget != 500 { + t.Errorf("token_budget = %v, want 500", p.TokenBudget) + } + if p.TopN == nil || *p.TopN != 5 { + t.Errorf("top_n = %v, want 5", p.TopN) + } + if p.SystemPrompt != prompt { + t.Errorf("system_prompt = %q, want %q", p.SystemPrompt, prompt) + } + if p.Search == nil { + t.Fatal("search should be present") + } + if p.Search.HybridEnabled == nil || *p.Search.HybridEnabled != false { + t.Errorf("search.hybrid_enabled = %v, want false", p.Search.HybridEnabled) + } + if p.Search.VectorWeight == nil || *p.Search.VectorWeight != 0.7 { + t.Errorf("search.vector_weight = %v, want 0.7", p.Search.VectorWeight) + } +} + +func TestGenerateRAGConfig_OptionalPipelineFieldsOmitted(t *testing.T) { + // When optional fields are not set, they must be absent from the YAML. + data, err := GenerateRAGConfig(minimalRAGParams()) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + p := cfg.Pipelines[0] + + if p.Description != "" { + t.Errorf("description should be empty (omitted), got %q", p.Description) + } + if p.TokenBudget != nil { + t.Errorf("token_budget should be nil (omitted), got %v", *p.TokenBudget) + } + if p.TopN != nil { + t.Errorf("top_n should be nil (omitted), got %v", *p.TopN) + } + if p.SystemPrompt != "" { + t.Errorf("system_prompt should be empty (omitted), got %q", p.SystemPrompt) + } + if p.Search != nil { + t.Errorf("search should be nil (omitted), got %+v", p.Search) + } +} + +func TestGenerateRAGConfig_MultiplePipelines(t *testing.T) { + key1 := "sk-ant-1" + key2 := "sk-openai-2" + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "pipeline-a", + Tables: []database.RAGPipelineTable{{Table: "t1", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{Provider: "anthropic", Model: "m1", APIKey: &key1}, + RAGLLM: database.RAGPipelineLLMConfig{Provider: "anthropic", Model: "m1", APIKey: &key1}, + }, + { + Name: "pipeline-b", + Tables: []database.RAGPipelineTable{{Table: "t2", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{Provider: "openai", Model: "m2", APIKey: &key2}, + RAGLLM: database.RAGPipelineLLMConfig{Provider: "openai", Model: "m2", APIKey: &key2}, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + + if len(cfg.Pipelines) != 2 { + t.Fatalf("pipelines len = %d, want 2", len(cfg.Pipelines)) + } + if cfg.Pipelines[0].Name != "pipeline-a" { + t.Errorf("pipelines[0].name = %q, want %q", cfg.Pipelines[0].Name, "pipeline-a") + } + if cfg.Pipelines[1].Name != "pipeline-b" { + t.Errorf("pipelines[1].name = %q, want %q", cfg.Pipelines[1].Name, "pipeline-b") + } + // Each pipeline's api_keys paths use their own name prefix + if cfg.Pipelines[0].APIKeys.Anthropic != "/app/keys/pipeline-a_rag.key" { + t.Errorf("pipelines[0].api_keys.anthropic = %q, want %q", + cfg.Pipelines[0].APIKeys.Anthropic, "/app/keys/pipeline-a_rag.key") + } + if cfg.Pipelines[1].APIKeys.OpenAI != "/app/keys/pipeline-b_rag.key" { + t.Errorf("pipelines[1].api_keys.openai = %q, want %q", + cfg.Pipelines[1].APIKeys.OpenAI, "/app/keys/pipeline-b_rag.key") + } +} + +func TestGenerateRAGConfig_DefaultsSection(t *testing.T) { + budget := 2000 + topN := 20 + params := minimalRAGParams() + params.Config.Defaults = &database.RAGDefaults{ + TokenBudget: &budget, + TopN: &topN, + } + + data, err := GenerateRAGConfig(params) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + + if cfg.Defaults == nil { + t.Fatal("defaults section should be present when configured") + } + if cfg.Defaults.TokenBudget == nil || *cfg.Defaults.TokenBudget != 2000 { + t.Errorf("defaults.token_budget = %v, want 2000", cfg.Defaults.TokenBudget) + } + if cfg.Defaults.TopN == nil || *cfg.Defaults.TopN != 20 { + t.Errorf("defaults.top_n = %v, want 20", cfg.Defaults.TopN) + } +} + +func TestGenerateRAGConfig_DefaultsAbsent(t *testing.T) { + // No Defaults set — defaults section must be omitted. + data, err := GenerateRAGConfig(minimalRAGParams()) + if err != nil { + t.Fatalf("GenerateRAGConfig() error = %v", err) + } + + cfg := parseRAGYAML(t, data) + + if cfg.Defaults != nil { + t.Errorf("defaults section should be absent when not configured, got %+v", cfg.Defaults) + } +} + +func TestGenerateRAGConfig_SameProviderDifferentKeys_ReturnsError(t *testing.T) { + key1 := "sk-openai-embed" + key2 := "sk-openai-rag-different" + params := &RAGConfigParams{ + Config: &database.RAGServiceConfig{ + Pipelines: []database.RAGPipeline{ + { + Name: "default", + Tables: []database.RAGPipelineTable{{Table: "t", TextColumn: "c", VectorColumn: "v"}}, + EmbeddingLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", Model: "text-embedding-3-small", APIKey: &key1, + }, + RAGLLM: database.RAGPipelineLLMConfig{ + Provider: "openai", Model: "gpt-4o", APIKey: &key2, + }, + }, + }, + }, + DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432, + Username: "u", Password: "p", KeysDir: "/app/keys", + } + + _, err := GenerateRAGConfig(params) + if err == nil { + t.Fatal("expected error for same-provider mismatched API keys, got nil") + } +} diff --git a/server/internal/orchestrator/swarm/rag_service_keys_resource.go b/server/internal/orchestrator/swarm/rag_service_keys_resource.go index 15108323..fea28f62 100644 --- a/server/internal/orchestrator/swarm/rag_service_keys_resource.go +++ b/server/internal/orchestrator/swarm/rag_service_keys_resource.go @@ -156,6 +156,7 @@ func (r *RAGServiceKeysResource) Update(ctx context.Context, rc *resource.Contex } func (r *RAGServiceKeysResource) Delete(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) if err != nil { return err diff --git a/server/internal/orchestrator/swarm/rag_service_user_role_test.go b/server/internal/orchestrator/swarm/rag_service_user_role_test.go index 681cadf6..e7cd5a41 100644 --- a/server/internal/orchestrator/swarm/rag_service_user_role_test.go +++ b/server/internal/orchestrator/swarm/rag_service_user_role_test.go @@ -73,9 +73,9 @@ func TestGenerateRAGInstanceResources_ResourceList(t *testing.T) { result.ServiceInstance.State, database.ServiceInstanceStateCreating) } - // Single node: canonical RO ServiceUserRole + data DirResource + RAGServiceKeysResource. - if len(result.Resources) != 3 { - t.Fatalf("len(Resources) = %d, want 3", len(result.Resources)) + // Single node: canonical RO ServiceUserRole + DirResource + RAGServiceKeysResource + RAGConfigResource. + if len(result.Resources) != 4 { + t.Fatalf("len(Resources) = %d, want 4", len(result.Resources)) } if result.Resources[0].Identifier.Type != ResourceTypeServiceUserRole { t.Errorf("Resources[0].Identifier.Type = %q, want %q", @@ -93,6 +93,10 @@ func TestGenerateRAGInstanceResources_ResourceList(t *testing.T) { t.Errorf("Resources[2].Identifier.Type = %q, want %q", result.Resources[2].Identifier.Type, ResourceTypeRAGServiceKeys) } + if result.Resources[3].Identifier.Type != ResourceTypeRAGConfig { + t.Errorf("Resources[3].Identifier.Type = %q, want %q", + result.Resources[3].Identifier.Type, ResourceTypeRAGConfig) + } } func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { @@ -121,9 +125,9 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { t.Fatalf("generateRAGInstanceResources() error = %v", err) } - // 3 nodes → canonical(n1) + per-node(n2) + per-node(n3) + data dir + keys = 5 resources. - if len(result.Resources) != 5 { - t.Fatalf("len(Resources) = %d, want 5", len(result.Resources)) + // 3 nodes → canonical(n1) + per-node(n2) + per-node(n3) + dir + keys + config = 6 resources. + if len(result.Resources) != 6 { + t.Fatalf("len(Resources) = %d, want 6", len(result.Resources)) } // First three must be ServiceUserRole resources. for i := 0; i < 3; i++ { @@ -159,7 +163,7 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { } } - // Data dir and keys resource are appended last. + // Data dir, keys, and config resource are appended last. if result.Resources[3].Identifier.Type != filesystem.ResourceTypeDir { t.Errorf("Resources[3].Identifier.Type = %q, want %q", result.Resources[3].Identifier.Type, filesystem.ResourceTypeDir) @@ -168,6 +172,10 @@ func TestGenerateRAGInstanceResources_MultiNode(t *testing.T) { t.Errorf("Resources[4].Identifier.Type = %q, want %q", result.Resources[4].Identifier.Type, ResourceTypeRAGServiceKeys) } + if result.Resources[5].Identifier.Type != ResourceTypeRAGConfig { + t.Errorf("Resources[5].Identifier.Type = %q, want %q", + result.Resources[5].Identifier.Type, ResourceTypeRAGConfig) + } } func TestGenerateRAGInstanceResources_MultiNode_CanonicalNotFirst(t *testing.T) { @@ -196,9 +204,9 @@ func TestGenerateRAGInstanceResources_MultiNode_CanonicalNotFirst(t *testing.T) t.Fatalf("generateRAGInstanceResources() error = %v", err) } - // 3 nodes → canonical(n2) + per-node(n1) + per-node(n3) + data dir + keys = 5 resources. - if len(result.Resources) != 5 { - t.Fatalf("len(Resources) = %d, want 5", len(result.Resources)) + // 3 nodes → canonical(n2) + per-node(n1) + per-node(n3) + dir + keys + config = 6 resources. + if len(result.Resources) != 6 { + t.Fatalf("len(Resources) = %d, want 6", len(result.Resources)) } // Canonical (index 0) must be n2 with no CredentialSource diff --git a/server/internal/orchestrator/swarm/resources.go b/server/internal/orchestrator/swarm/resources.go index a049ff10..3bbeccbc 100644 --- a/server/internal/orchestrator/swarm/resources.go +++ b/server/internal/orchestrator/swarm/resources.go @@ -25,4 +25,5 @@ func RegisterResourceTypes(registry *resource.Registry) { resource.RegisterResourceType[*PostgRESTConfigResource](registry, ResourceTypePostgRESTConfig) resource.RegisterResourceType[*PostgRESTAuthenticatorResource](registry, ResourceTypePostgRESTAuthenticator) resource.RegisterResourceType[*RAGServiceKeysResource](registry, ResourceTypeRAGServiceKeys) + resource.RegisterResourceType[*RAGConfigResource](registry, ResourceTypeRAGConfig) }