diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 9cffb243b..c7ab0702e 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1029,6 +1029,7 @@ enum DoctorCommands { Check, } +#[allow(clippy::large_enum_variant)] #[derive(Subcommand, Debug)] enum SandboxCommands { /// Create a sandbox. @@ -1087,9 +1088,13 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] + #[arg(long, requires = "gpu", conflicts_with = "gpu_count")] gpu_device: Option, + /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. + #[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")] + gpu_count: Option, + /// Provider names to attach to this sandbox. #[arg(long = "provider")] providers: Vec, @@ -2365,6 +2370,7 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + gpu_count, providers, policy, forward, @@ -2431,6 +2437,7 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + gpu_count, editor, &providers, policy.as_deref(), @@ -3641,6 +3648,52 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_count_parses_without_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"]) + .expect("sandbox create --gpu-count should parse"); + + if let Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, gpu_count, .. }), + .. + }) = cli.command + { + assert!(!gpu); + assert_eq!(gpu_count, Some(2)); + } else { + panic!("expected SandboxCommands::Create"); + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]); + + assert!( + result.is_err(), + "sandbox create --gpu-count 0 should be rejected" + ); + } + + #[test] + fn sandbox_create_gpu_count_conflicts_with_gpu_device() { + let result = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "--gpu-device", + "0", + "--gpu-count", + "2", + ]); + + assert!( + result.is_err(), + "sandbox create should reject --gpu-count with --gpu-device" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 3205b8f68..e89b258dc 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -32,11 +32,11 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, - GetSandboxRequest, GetServiceRequest, HealthRequest, ImportProviderProfilesRequest, + GetSandboxRequest, GetServiceRequest, GpuSpec, HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, - ListServicesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, + ListServicesRequest, PlacementRequirements, PolicySource, PolicyStatus, Provider, + ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, @@ -1468,6 +1468,7 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, editor: Option, providers: &[String], policy: Option<&str>, @@ -1518,7 +1519,8 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); + let requested_gpu = + gpu || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); let inferred_types: Vec = inferred_provider_type(command).into_iter().collect(); let configured_providers = ensure_required_providers( @@ -1538,8 +1540,7 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, - gpu_device: gpu_device.unwrap_or_default().to_string(), + placement: placement_requirements_from_cli(requested_gpu, gpu_device, gpu_count), policy, providers: configured_providers, template, @@ -1971,6 +1972,27 @@ pub async fn sandbox_create( } } +fn placement_requirements_from_cli( + requested_gpu: bool, + gpu_device: Option<&str>, + gpu_count: Option, +) -> Option { + let requested_gpu = requested_gpu || gpu_count.is_some(); + requested_gpu.then(|| PlacementRequirements { + gpu: Some(GpuSpec { + device_id: if gpu_count.is_none() { + gpu_device + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default() + } else { + Vec::new() + }, + count: gpu_count, + }), + }) +} + /// Resolved source for the `--from` flag on `sandbox create`. #[derive(Debug)] enum ResolvedSource { @@ -6017,9 +6039,9 @@ mod tests { gateway_env_override_warning, gateway_select_with, gateway_type_label, git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, - parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, - ready_false_condition_message, resolve_from, sandbox_should_persist, - service_expose_status_error, service_url_for_gateway, + parse_credential_pairs, placement_requirements_from_cli, plaintext_gateway_is_remote, + provisioning_timeout_message, ready_false_condition_message, resolve_from, + sandbox_should_persist, service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -6296,6 +6318,47 @@ mod tests { } } + #[test] + fn image_requests_gpu_detects_known_community_gpu_name() { + assert!(image_requests_gpu("nvidia-gpu")); + assert!(!image_requests_gpu("base")); + } + + #[test] + fn placement_requirements_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() { + let request = placement_requirements_from_cli(true, None, None) + .expect("placement requirements should be present"); + let gpu = request.gpu.expect("gpu request should be present"); + + assert!(gpu.device_id.is_empty()); + assert_eq!(gpu.count, None); + } + + #[test] + fn placement_requirements_from_cli_maps_gpu_device_to_one_device_id() { + let request = placement_requirements_from_cli(true, Some("0000:2d:00.0"), None) + .expect("placement requirements should be present"); + let gpu = request.gpu.expect("gpu request should be present"); + + assert_eq!(gpu.device_id, vec!["0000:2d:00.0"]); + assert_eq!(gpu.count, None); + } + + #[test] + fn placement_requirements_from_cli_maps_gpu_count() { + let request = placement_requirements_from_cli(false, None, Some(2)) + .expect("placement requirements should be present"); + let gpu = request.gpu.expect("gpu request should be present"); + + assert!(gpu.device_id.is_empty()); + assert_eq!(gpu.count, Some(2)); + } + + #[test] + fn placement_requirements_from_cli_omits_placement_when_not_requested() { + assert!(placement_requirements_from_cli(false, Some("0"), None).is_none()); + } + #[test] fn resolve_from_classifies_existing_dockerfile_path() { let temp = tempfile::tempdir().expect("failed to create tempdir"); diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 6e7d66d11..ab6aa1d38 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -671,6 +671,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { false, None, None, + None, &[], None, None, @@ -710,6 +711,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { false, None, None, + None, &[], None, None, @@ -752,6 +754,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { false, None, None, + None, &[], None, None, @@ -794,6 +797,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { false, None, None, + None, &[], None, None, @@ -836,6 +840,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { false, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 5df8702ed..85b6c9644 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,21 +4,19 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::compute::v1::GpuSpec; -/// Resolve the existing GPU request fields into CDI device identifiers. +/// Resolve a driver GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. A GPU request with no explicit device -/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes -/// through unchanged. +/// `None` means no GPU was requested. Presence with no explicit device IDs +/// uses the CDI all-GPU request; otherwise the driver-native IDs pass through. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { - gpu.then(|| { - if gpu_device.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - vec![gpu_device.to_string()] - } - }) +pub fn cdi_gpu_device_ids(gpu: Option<&GpuSpec>) -> Option> { + match gpu { + Some(gpu) if gpu.device_id.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(gpu) => Some(gpu.device_id.clone()), + None => None, + } } #[cfg(test)] @@ -27,22 +25,51 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, ""), None); + assert_eq!(cdi_gpu_device_ids(None), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let request = GpuSpec { + device_id: vec![], + count: None, + }; + assert_eq!( - cdi_gpu_device_ids(true, ""), + cdi_gpu_device_ids(Some(&request)), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] - fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + fn cdi_gpu_device_ids_passes_single_device_id_through() { + let request = GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }; + assert_eq!( - cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), + cdi_gpu_device_ids(Some(&request)), Some(vec!["nvidia.com/gpu=0".to_string()]) ); } + + #[test] + fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { + let request = GpuSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + count: None, + }; + + assert_eq!( + cdi_gpu_device_ids(Some(&request)), + Some(vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ]) + ); + } } diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 6059596ab..67182a8b7 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -24,7 +24,7 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, - ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, + GpuSpec, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, @@ -310,7 +310,12 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; + Self::validate_gpu_request( + spec.placement + .as_ref() + .and_then(|placement| placement.gpu.as_ref()), + config.supports_gpu, + )?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -330,8 +335,23 @@ impl DockerComputeDriver { Ok(()) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { - if gpu && !supports_gpu { + fn validate_gpu_request(gpu: Option<&GpuSpec>, supports_gpu: bool) -> Result<(), Status> { + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(Status::invalid_argument("gpu.count must be greater than 0")); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + if gpu.count.is_some() { + return Err(Status::invalid_argument( + "docker compute driver does not support GPU count requests", + )); + } + } + if gpu.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); @@ -945,8 +965,8 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { - cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { +fn docker_gpu_device_requests(gpu: Option<&GpuSpec>) -> Option> { + cdi_gpu_device_ids(gpu).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), device_ids: Some(device_ids), @@ -996,7 +1016,11 @@ fn build_container_create_body( host_config: Some(HostConfig { nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, - device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), + device_requests: docker_gpu_device_requests( + spec.placement + .as_ref() + .and_then(|placement| placement.gpu.as_ref()), + ), binds: Some(build_binds(config)), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index df68d39d6..ab4173f5f 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -4,7 +4,8 @@ use super::*; use openshell_core::config::{CDI_GPU_DEVICE_ALL, DEFAULT_SERVER_PORT}; use openshell_core::proto::compute::v1::{ - DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, + DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, GpuSpec, + PlacementRequirements, }; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -31,8 +32,7 @@ fn test_sandbox() -> DriverSandbox { resources: None, platform_config: None, }), - gpu: false, - gpu_device: String::new(), + placement: None, }), status: None, } @@ -487,7 +487,12 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -495,12 +500,35 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("does not support GPU count")); +} + #[test] fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -518,13 +546,19 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { } #[test] -fn build_container_create_body_passes_explicit_cdi_device_id_through() { +fn build_container_create_body_passes_explicit_cdi_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; - spec.gpu_device = "nvidia.com/gpu=0".to_string(); + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + count: None, + }), + }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -537,7 +571,10 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { assert_eq!(request.driver.as_deref(), Some("cdi")); assert_eq!( request.device_ids.as_ref().unwrap(), - &vec!["nvidia.com/gpu=0".to_string()] + &vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ] ); } diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 56b73447a..9ea55815d 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -15,7 +15,7 @@ use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + GetCapabilitiesResponse, GpuSpec, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; use std::collections::BTreeMap; @@ -59,7 +59,16 @@ const SANDBOX_ID_LABEL: &str = "openshell.ai/sandbox-id"; const SANDBOX_MANAGED_LABEL: &str = "openshell.ai/managed-by"; const SANDBOX_MANAGED_VALUE: &str = "openshell"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; -const GPU_RESOURCE_QUANTITY: &str = "1"; +const DEFAULT_GPU_COUNT: u32 = 1; + +fn gpu_from_spec(spec: Option<&SandboxSpec>) -> Option<&GpuSpec> { + spec.and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()) +} + +fn gpu_has_explicit_device_ids(gpu: Option<&GpuSpec>) -> bool { + gpu.is_some_and(|gpu| !gpu.device_id.is_empty()) +} // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) @@ -194,8 +203,29 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); - if gpu_requested + let gpu = gpu_from_spec(sandbox.spec.as_ref()); + self.validate_gpu_request(gpu).await + } + + async fn validate_gpu_request(&self, gpu: Option<&GpuSpec>) -> Result<(), tonic::Status> { + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(tonic::Status::invalid_argument( + "gpu.count must be greater than 0", + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(tonic::Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + } + if gpu_has_explicit_device_ids(gpu) { + return Err(tonic::Status::invalid_argument( + "kubernetes compute driver does not support explicit GPU device IDs", + )); + } + if gpu.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -291,6 +321,25 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { + if let Some(gpu) = gpu_from_spec(sandbox.spec.as_ref()) { + if gpu.count == Some(0) { + return Err(KubernetesDriverError::Precondition( + "gpu.count must be greater than 0".to_string(), + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(KubernetesDriverError::Precondition( + "gpu.count is mutually exclusive with gpu.device_id".to_string(), + )); + } + if gpu_has_explicit_device_ids(Some(gpu)) { + return Err(KubernetesDriverError::Precondition( + "kubernetes compute driver does not support explicit GPU device IDs" + .to_string(), + )); + } + } + let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -1011,7 +1060,13 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s( + template, + gpu_from_spec(Some(spec)), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1043,7 +1098,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + gpu_from_spec(spec), &pod_env, inject_workspace, params, @@ -1058,7 +1113,7 @@ fn sandbox_to_k8s_spec( fn sandbox_template_to_k8s( template: &SandboxTemplate, - gpu: bool, + gpu: Option<&GpuSpec>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1091,7 +1146,7 @@ fn sandbox_template_to_k8s( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1225,7 +1280,10 @@ fn sandbox_template_to_k8s( result } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: Option<&GpuSpec>, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1247,8 +1305,9 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option) -> GpuSpec { + GpuSpec { + device_id: vec![], + count, + } + } + #[test] fn apply_required_env_always_injects_ssh_handshake_secret() { let mut env = Vec::new(); @@ -1787,7 +1853,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - true, + Some(&gpu_spec(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -1800,10 +1866,44 @@ mod tests { ); assert_eq!( pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) + ); + } + + #[test] + fn gpu_sandbox_uses_requested_gpu_count() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + Some(&gpu_spec(Some(2))), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") ); } + #[test] + fn gpu_has_explicit_device_ids_only_when_ids_are_present() { + use openshell_core::proto::compute::v1::GpuSpec; + + assert!(!gpu_has_explicit_device_ids(None)); + assert!(!gpu_has_explicit_device_ids(Some(&GpuSpec { + device_id: vec![], + count: None, + }))); + assert!(gpu_has_explicit_device_ids(Some(&GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }))); + } + #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { @@ -1823,7 +1923,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_spec(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -1855,7 +1955,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -1883,7 +1983,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_spec(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -1894,7 +1994,7 @@ mod tests { assert_eq!(limits["cpu"], serde_json::json!("2")); assert_eq!( limits[GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) ); } @@ -1907,7 +2007,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -1932,7 +2032,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -1955,7 +2055,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2094,7 +2194,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), false, // user provided custom VCTs ¶ms, @@ -2132,7 +2232,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2197,7 +2297,7 @@ mod tests { let params = SandboxPodParams::default(); // cluster default is off let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2235,7 +2335,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2261,7 +2361,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2356,7 +2456,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, @@ -2417,7 +2517,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 5b9b0d735..dcad04a64 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -345,8 +345,12 @@ fn build_resource_limits(sandbox: &DriverSandbox) -> ResourceLimits { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let spec = sandbox.spec.as_ref()?; - cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { + let gpu = sandbox + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()); + cdi_gpu_device_ids(gpu).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -687,53 +691,6 @@ mod tests { assert_eq!(short_id("short"), "short"); } - #[test] - fn container_spec_omits_devices_without_gpu_request() { - let sandbox = test_sandbox("test-id", "test-name"); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - assert!(spec.get("devices").is_none()); - } - - #[test] - fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { - use openshell_core::config::CDI_GPU_DEVICE_ALL; - use openshell_core::proto::compute::v1::DriverSandboxSpec; - - let mut sandbox = test_sandbox("test-id", "test-name"); - sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - ..Default::default() - }); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - assert_eq!( - spec["devices"][0]["path"].as_str(), - Some(CDI_GPU_DEVICE_ALL) - ); - } - - #[test] - fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::DriverSandboxSpec; - - let mut sandbox = test_sandbox("test-id", "test-name"); - sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - gpu_device: "nvidia.com/gpu=0".to_string(), - ..Default::default() - }); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - assert_eq!( - spec["devices"][0]["path"].as_str(), - Some("nvidia.com/gpu=0") - ); - } - #[test] fn container_spec_includes_required_capabilities() { let sandbox = test_sandbox("test-id", "test-name"); @@ -782,6 +739,73 @@ mod tests { ); } + #[test] + fn container_spec_omits_devices_without_gpu_request() { + let sandbox = test_sandbox("test-id", "test-name"); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert!(spec.get("devices").is_none()); + } + + #[test] + fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { + use openshell_core::config::CDI_GPU_DEVICE_ALL; + use openshell_core::proto::compute::v1::{ + DriverSandboxSpec, GpuSpec, PlacementRequirements, + }; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some(CDI_GPU_DEVICE_ALL) + ); + } + + #[test] + fn container_spec_passes_explicit_cdi_device_ids_through() { + use openshell_core::proto::compute::v1::{ + DriverSandboxSpec, GpuSpec, PlacementRequirements, + }; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + count: None, + }), + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some("nvidia.com/gpu=0") + ); + assert_eq!( + spec["devices"][1]["path"].as_str(), + Some("nvidia.com/gpu=1") + ); + } + #[test] fn container_spec_uses_secret_env_not_plaintext() { let sandbox = test_sandbox("test-id", "test-name"); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index f78c5c730..7cd814529 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,7 +10,7 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse, GpuSpec}; use tracing::{info, warn}; impl From for ComputeDriverError { @@ -198,12 +198,33 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); - Self::validate_gpu_request(gpu_requested) - } - - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { - if gpu_requested && !Self::has_gpu_capacity() { + let gpu = sandbox + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()); + Self::validate_gpu_request(gpu) + } + + fn validate_gpu_request(gpu: Option<&GpuSpec>) -> Result<(), ComputeDriverError> { + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(ComputeDriverError::Precondition( + "gpu.count must be greater than 0".to_string(), + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(ComputeDriverError::Precondition( + "gpu.count is mutually exclusive with gpu.device_id".to_string(), + )); + } + if gpu.count.is_some() { + return Err(ComputeDriverError::Precondition( + "podman compute driver does not support GPU count requests".to_string(), + )); + } + } + if gpu.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); @@ -223,6 +244,7 @@ impl PodmanComputeDriver { "sandbox id is required".into(), )); } + self.validate_sandbox_create(sandbox)?; // Validate the composed container name early, before creating any // resources (secret, volume), so we don't leave orphans when the @@ -575,6 +597,19 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_count() { + let err = PodmanComputeDriver::validate_gpu_request(Some(&GpuSpec { + device_id: vec![], + count: Some(2), + })) + .expect_err("GPU count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("does not support GPU count")) + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index b797f4835..b5b579f81 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -25,7 +25,7 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, + GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, GpuSpec, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, @@ -358,9 +358,7 @@ impl VmDriver { return Err(Status::already_exists("sandbox already exists")); } - let spec = sandbox.spec.as_ref(); - let is_gpu = spec.is_some_and(|s| s.gpu); - let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); + let gpu_device = requested_gpu_device(sandbox_gpu(sandbox)); let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id)?; let rootfs = state_dir.join("rootfs"); @@ -437,7 +435,7 @@ impl VmDriver { ))); } - let gpu_bdf = if is_gpu { + let gpu_bdf = if let Some(gpu_device) = gpu_device { let inventory = self .gpu_inventory .as_ref() @@ -1461,15 +1459,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - if spec.gpu && !gpu_enabled { - return Err(Status::failed_precondition( - "GPU support is not enabled on this driver; start with --gpu", - )); - } - - if !spec.gpu && !spec.gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); - } + validate_gpu_request(sandbox_gpu(sandbox), gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -1491,7 +1481,6 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu Ok(()) } -#[allow(clippy::result_large_err)] fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { if sandbox_id.is_empty() { return Err(Status::invalid_argument("sandbox id is required")); @@ -1517,6 +1506,51 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +fn sandbox_gpu(sandbox: &Sandbox) -> Option<&GpuSpec> { + sandbox + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()) +} + +fn requested_gpu_device(gpu: Option<&GpuSpec>) -> Option<&str> { + let gpu = gpu?; + Some(gpu.device_id.first().map_or("", String::as_str)) +} + +#[allow(clippy::result_large_err)] +fn validate_gpu_request(gpu: Option<&GpuSpec>, gpu_enabled: bool) -> Result<(), Status> { + if gpu.is_some() && !gpu_enabled { + return Err(Status::failed_precondition( + "GPU support is not enabled on this driver; start with --gpu", + )); + } + + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(Status::invalid_argument("gpu.count must be greater than 0")); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + if gpu.count.is_some_and(|count| count > 1) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU", + )); + } + } + + if gpu.is_some_and(|gpu| gpu.device_id.len() > 1) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU device ID", + )); + } + Ok(()) +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { @@ -2519,7 +2553,8 @@ mod tests { use super::*; use crate::gpu::{SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name}; use openshell_core::proto::compute::v1::{ - DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, + DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, GpuSpec, + PlacementRequirements, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; @@ -2533,7 +2568,12 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }), ..Default::default() }), ..Default::default() @@ -2549,7 +2589,12 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }), ..Default::default() }), ..Default::default() @@ -2558,20 +2603,109 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_accepts_gpu_count_one_when_enabled() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(1), + }), + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("gpu count one should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_greater_than_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("gpu count > 1 should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_with_device_id() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, - gpu_device: "0000:2d:00.0".to_string(), + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["0000:2d:00.0".to_string()], + count: Some(1), + }), + }), ..Default::default() }), ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device without gpu should be rejected"); + .expect_err("gpu count with device ID should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device requires gpu=true")); + assert!(err.message().contains("mutually exclusive")); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], + count: None, + }), + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU device IDs should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU device ID")); + } + + #[test] + fn requested_gpu_device_returns_none_without_gpu_request() { + assert_eq!(requested_gpu_device(None), None); + } + + #[test] + fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { + let gpu = GpuSpec { + device_id: vec![], + count: None, + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("")); + } + + #[test] + fn requested_gpu_device_returns_first_explicit_device_id() { + let gpu = GpuSpec { + device_id: vec!["0000:2d:00.0".to_string()], + count: None, + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("0000:2d:00.0")); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index d2fd34011..b9cc31aaa 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -18,7 +18,8 @@ use futures::{Stream, StreamExt}; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, + DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, GpuSpec as DriverGpuSpec, + ListSandboxesRequest, PlacementRequirements as DriverPlacementRequirements, ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, watch_sandboxes_event, @@ -1130,8 +1131,15 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .template .as_ref() .map(driver_sandbox_template_from_public), - gpu: spec.gpu, - gpu_device: spec.gpu_device.clone(), + placement: spec + .placement + .as_ref() + .map(|placement| DriverPlacementRequirements { + gpu: placement.gpu.as_ref().map(|gpu| DriverGpuSpec { + device_id: gpu.device_id.clone(), + count: gpu.count, + }), + }), } } @@ -1491,7 +1499,9 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec + .and_then(|sandbox_spec| sandbox_spec.placement.as_ref()) + .is_some_and(|placement| placement.gpu.is_some()); if !gpu_requested { return; } @@ -1653,6 +1663,7 @@ mod tests { CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; + use openshell_core::proto::{GpuSpec, PlacementRequirements}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; @@ -1669,6 +1680,56 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_request_device_ids() { + let public = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .placement + .expect("driver placement requirements should be present") + .gpu + .expect("driver GPU request should be present") + .device_id, + vec!["nvidia.com/gpu=0".to_string()] + ); + } + + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_count() { + let public = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .placement + .expect("driver placement requirements should be present") + .gpu + .expect("driver GPU request should be present") + .count, + Some(2) + ); + } + fn struct_value( fields: impl IntoIterator, prost_types::Value)>, ) -> prost_types::Value { @@ -2117,7 +2178,12 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }), ..Default::default() }), ); @@ -2149,7 +2215,7 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: false, + placement: None, ..Default::default() }), ); @@ -2376,7 +2442,12 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2399,7 +2470,13 @@ mod tests { SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!( + stored + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .is_some_and(|placement| placement.gpu.is_some()) + ); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 160b7e031..0f891a7d0 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,6 +131,13 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } + // --- spec.placement --- + if let Some(placement) = spec.placement.as_ref() + && let Some(gpu) = placement.gpu.as_ref() + { + validate_gpu_spec(gpu)?; + } + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -144,6 +151,20 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_gpu_spec(gpu: &openshell_core::proto::GpuSpec) -> Result<(), Status> { + if gpu.count == Some(0) { + return Err(Status::invalid_argument( + "placement.gpu.count must be greater than 0", + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "placement.gpu.count is mutually exclusive with placement.gpu.device_id", + )); + } + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -642,7 +663,7 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::SandboxSpec; + use openshell_core::proto::{GpuSpec, PlacementRequirements, SandboxSpec}; use std::collections::HashMap; use tonic::Code; @@ -668,12 +689,63 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(0), + }), + }), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("count must be greater than 0")); + } + + #[test] + fn validate_sandbox_spec_rejects_gpu_count_with_device_ids() { + let spec = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: Some(1), + }), + }), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("mutually exclusive")); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 3eddecec5..52df28aa6 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -35,6 +35,19 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +To request a specific number of GPUs, use `--gpu-count`. GPU count requests are +mutually exclusive with explicit GPU device IDs, and the count must be greater +than zero. + +```shell +openshell sandbox create --gpu-count 2 -- claude +``` + +Kubernetes-backed sandboxes honor `--gpu-count` by setting the `nvidia.com/gpu` +resource limit to the requested count. VM-backed sandboxes accept only +`--gpu-count 1`. Docker-backed and Podman-backed sandboxes currently reject GPU +count requests. + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. diff --git a/e2e/python/conftest.py b/e2e/python/conftest.py index 712704929..5b3b1b882 100644 --- a/e2e/python/conftest.py +++ b/e2e/python/conftest.py @@ -101,6 +101,8 @@ def gpu_sandbox_spec() -> datamodel_pb2.SandboxSpec: # override (e.g. a locally-built or registry-mirrored image). image = os.environ.get("OPENSHELL_E2E_GPU_IMAGE", "") return datamodel_pb2.SandboxSpec( - gpu=True, + placement=datamodel_pb2.PlacementRequirements( + gpu=datamodel_pb2.GPUSpec(), + ), template=datamodel_pb2.SandboxTemplate(image=image), ) diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 3c4308f3f..4524591d2 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -78,18 +78,31 @@ message DriverSandbox { // Driver-owned provisioning inputs required to create a sandbox. message DriverSandboxSpec { + reserved 9, 10; + // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Optional placement requirements for the sandbox workload. + PlacementRequirements placement = 11; +} + +// Driver-owned placement requirements for selecting compatible compute resources. +message PlacementRequirements { + // Request GPU resources for this sandbox. Presence indicates a GPU request. + GPUSpec gpu = 1; +} + +// Driver-native GPU placement details. +message GPUSpec { + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. + repeated string device_id = 1; + // Optional number of GPUs requested. Mutually exclusive with device_id. + optional uint32 count = 2; } // Driver-owned runtime template consumed by the compute platform. diff --git a/proto/openshell.proto b/proto/openshell.proto index bb2ce6cec..9454febbb 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -240,6 +240,8 @@ message Sandbox { // Desired sandbox configuration provided through the public API. message SandboxSpec { + reserved 9, 10; + // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. @@ -250,12 +252,24 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Optional placement requirements for the sandbox workload. + PlacementRequirements placement = 11; +} + +// Public placement requirements for selecting compatible compute resources. +message PlacementRequirements { + // Request GPU resources for this sandbox. Presence indicates a GPU request. + GPUSpec gpu = 1; +} + +// Public GPU placement details. Device identifiers are interpreted by the +// selected compute driver. +message GPUSpec { + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. + repeated string device_id = 1; + // Optional number of GPUs requested. Mutually exclusive with device_id. + optional uint32 count = 2; } // Public sandbox template mapped onto compute-driver template inputs. diff --git a/python/openshell/_proto/__init__.py b/python/openshell/_proto/__init__.py index 3ace22421..0763b7b53 100644 --- a/python/openshell/_proto/__init__.py +++ b/python/openshell/_proto/__init__.py @@ -2,7 +2,13 @@ # Sandbox messages and phase enums moved into openshell.proto. Keep aliases on # datamodel_pb2 so existing Python callers and E2E tests continue to work. -for _name in ("Sandbox", "SandboxSpec", "SandboxTemplate"): +for _name in ( + "Sandbox", + "SandboxSpec", + "SandboxTemplate", + "PlacementRequirements", + "GPUSpec", +): if not hasattr(datamodel_pb2, _name): setattr(datamodel_pb2, _name, getattr(openshell_pb2, _name))