Skip to content

Commit f3f3da3

Browse files
committed
refactor(server): normalize compute driver config acquisition
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent d74119a commit f3f3da3

4 files changed

Lines changed: 467 additions & 323 deletions

File tree

crates/openshell-server/src/cli.rs

Lines changed: 92 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ use tracing::{info, warn};
1414
use tracing_subscriber::EnvFilter;
1515

1616
use crate::certgen;
17-
use crate::compute::{DockerComputeConfig, VmComputeConfig};
17+
use crate::compute::driver_config::GuestTlsPaths;
1818
use crate::config_file::{self, ConfigFile, GatewayFileSection};
1919
use crate::defaults::{self, LocalTlsPaths};
20-
use crate::{run_server, tracing_bus::TracingLogBus};
20+
use crate::{ServerStartupConfig, run_server, tracing_bus::TracingLogBus};
2121

2222
/// `OpenShell` gateway process - gRPC and HTTP server with protocol multiplexing.
2323
///
@@ -222,33 +222,29 @@ pub async fn run_cli() -> Result<()> {
222222
}
223223
}
224224

225-
async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
225+
fn prepare_server_config(args: &mut RunArgs, matches: &ArgMatches) -> Result<ServerStartupConfig> {
226226
// Load TOML when explicitly requested, or from the default XDG location
227227
// when that file exists. Missing default config is not an error: runtime
228228
// defaults and OPENSHELL_* env vars are enough for package-managed starts.
229-
let config_path = resolve_config_path(&args)?;
229+
let config_path = resolve_config_path(args)?;
230230
let file: Option<ConfigFile> = if let Some(path) = config_path {
231231
Some(config_file::load(&path).map_err(|e| miette::miette!("{e}"))?)
232232
} else {
233233
None
234234
};
235235
if let Some(file) = file.as_ref() {
236-
merge_file_into_args(&mut args, &file.openshell.gateway, &matches);
236+
merge_file_into_args(args, &file.openshell.gateway, matches);
237237
}
238238

239-
let local_tls = apply_runtime_defaults(&mut args)?;
239+
let local_tls = apply_runtime_defaults(args)?;
240+
let guest_tls = local_tls.as_ref().map(GuestTlsPaths::from);
240241
let local_jwt = defaults::complete_local_jwt_config()?;
241242

242-
let tracing_log_bus = TracingLogBus::new();
243-
tracing_log_bus.install_subscriber(
244-
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)),
245-
);
246-
247243
let bind = SocketAddr::new(args.bind_address, args.port);
248244

249245
let has_client_ca = args.tls_client_ca.is_some();
250246
let has_oidc = args.oidc_issuer.is_some();
251-
let mtls_auth_enabled = resolve_mtls_auth_enabled(&args, &matches, file.as_ref());
247+
let mtls_auth_enabled = resolve_mtls_auth_enabled(args, matches, file.as_ref());
252248

253249
if args.disable_tls && has_client_ca {
254250
return Err(miette::miette!(
@@ -267,7 +263,7 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
267263
}
268264
if mtls_auth_enabled
269265
&& matches!(
270-
effective_single_driver(&args),
266+
effective_single_driver(args),
271267
Some(ComputeDriverKind::Kubernetes)
272268
)
273269
{
@@ -318,14 +314,14 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
318314
let health_bind = resolve_aux_listener(
319315
args.bind_address,
320316
args.health_port,
321-
&matches,
317+
matches,
322318
"health_port",
323319
|| file_gateway.and_then(|g| g.health_bind_address),
324320
);
325321
let metrics_bind = resolve_aux_listener(
326322
args.bind_address,
327323
args.metrics_port,
328-
&matches,
324+
matches,
329325
"metrics_port",
330326
|| file_gateway.and_then(|g| g.metrics_bind_address),
331327
);
@@ -404,15 +400,31 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
404400
config.gateway_jwt = Some(jwt);
405401
}
406402

407-
let vm_config = build_vm_config(
408-
file.as_ref(),
409-
local_tls.as_ref(),
410-
args.disable_tls,
411-
args.port,
412-
)?;
413-
let docker_config = build_docker_config(file.as_ref(), local_tls.as_ref())?;
403+
Ok(ServerStartupConfig {
404+
config,
405+
config_file: file,
406+
guest_tls,
407+
})
408+
}
409+
410+
async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
411+
let prepared = prepare_server_config(&mut args, &matches)?;
412+
413+
let tracing_log_bus = TracingLogBus::new();
414+
tracing_log_bus.install_subscriber(
415+
EnvFilter::try_from_default_env()
416+
.unwrap_or_else(|_| EnvFilter::new(&prepared.config.log_level)),
417+
);
414418

415-
if args.disable_tls {
419+
let has_client_ca = prepared
420+
.config
421+
.tls
422+
.as_ref()
423+
.and_then(|tls| tls.client_ca_path.as_ref())
424+
.is_some();
425+
let has_oidc = prepared.config.oidc.is_some();
426+
427+
if prepared.config.tls.is_none() {
416428
warn!("TLS disabled — listening on plaintext HTTP");
417429
} else {
418430
info!("TLS enabled — listening on encrypted HTTPS");
@@ -421,40 +433,34 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
421433
if has_client_ca {
422434
info!("TLS client certificate verification enabled");
423435
}
424-
if config.mtls_auth.enabled {
436+
if prepared.config.mtls_auth.enabled {
425437
info!("mTLS user authentication enabled");
426438
}
427439
if has_oidc {
428440
info!("OIDC authentication enabled");
429441
}
430-
if config.auth.allow_unauthenticated_users {
442+
if prepared.config.auth.allow_unauthenticated_users {
431443
warn!(
432444
"Unauthenticated user access enabled — only use this for trusted local development or a fully trusted fronting proxy"
433445
);
434446
}
435447

436-
if !config.auth.allow_unauthenticated_users
437-
&& !config.mtls_auth.enabled
448+
if !prepared.config.auth.allow_unauthenticated_users
449+
&& !prepared.config.mtls_auth.enabled
438450
&& !has_oidc
439-
&& config.gateway_jwt.is_none()
451+
&& prepared.config.gateway_jwt.is_none()
440452
{
441453
warn!(
442454
"Neither mTLS user auth nor OIDC nor sandbox JWT auth is configured — \
443455
the gateway has no authentication mechanism"
444456
);
445457
}
446458

447-
info!(bind = %config.bind_address, "Starting OpenShell server");
459+
info!(bind = %prepared.config.bind_address, "Starting OpenShell server");
448460

449-
Box::pin(run_server(
450-
config,
451-
vm_config,
452-
docker_config,
453-
file,
454-
tracing_log_bus,
455-
))
456-
.await
457-
.into_diagnostic()
461+
Box::pin(run_server(prepared, tracing_log_bus))
462+
.await
463+
.into_diagnostic()
458464
}
459465

460466
fn parse_compute_driver(value: &str) -> std::result::Result<ComputeDriverKind, String> {
@@ -691,87 +697,6 @@ fn resolve_mtls_auth_enabled(
691697
is_singleplayer_driver(args)
692698
}
693699

694-
/// Build [`VmComputeConfig`] from the `[openshell.drivers.vm]` table
695-
/// inherited from `[openshell.gateway]`.
696-
fn build_vm_config(
697-
file: Option<&ConfigFile>,
698-
local_tls: Option<&LocalTlsPaths>,
699-
disable_tls: bool,
700-
gateway_port: u16,
701-
) -> Result<VmComputeConfig> {
702-
let mut cfg = if let Some(file) = file {
703-
let merged = config_file::driver_table(
704-
ComputeDriverKind::Vm,
705-
&file.openshell.gateway,
706-
file.openshell.drivers.get("vm"),
707-
);
708-
merged
709-
.try_into::<VmComputeConfig>()
710-
.map_err(|e| miette::miette!("invalid [openshell.drivers.vm] table: {e}"))?
711-
} else {
712-
VmComputeConfig::default()
713-
};
714-
715-
if cfg.state_dir.as_os_str().is_empty() {
716-
cfg.state_dir = VmComputeConfig::default_state_dir();
717-
}
718-
if cfg.grpc_endpoint.trim().is_empty() && (disable_tls || local_tls.is_some()) {
719-
let scheme = if disable_tls { "http" } else { "https" };
720-
cfg.grpc_endpoint = format!("{scheme}://127.0.0.1:{gateway_port}");
721-
}
722-
apply_guest_tls_defaults(
723-
&mut cfg.guest_tls_ca,
724-
&mut cfg.guest_tls_cert,
725-
&mut cfg.guest_tls_key,
726-
local_tls,
727-
);
728-
Ok(cfg)
729-
}
730-
731-
/// Build [`DockerComputeConfig`] using the same inheritance pattern as
732-
/// [`build_vm_config`].
733-
fn build_docker_config(
734-
file: Option<&ConfigFile>,
735-
local_tls: Option<&LocalTlsPaths>,
736-
) -> Result<DockerComputeConfig> {
737-
let mut cfg = if let Some(file) = file {
738-
let merged = config_file::driver_table(
739-
ComputeDriverKind::Docker,
740-
&file.openshell.gateway,
741-
file.openshell.drivers.get("docker"),
742-
);
743-
merged
744-
.try_into::<DockerComputeConfig>()
745-
.map_err(|e| miette::miette!("invalid [openshell.drivers.docker] table: {e}"))?
746-
} else {
747-
DockerComputeConfig::default()
748-
};
749-
apply_guest_tls_defaults(
750-
&mut cfg.guest_tls_ca,
751-
&mut cfg.guest_tls_cert,
752-
&mut cfg.guest_tls_key,
753-
local_tls,
754-
);
755-
Ok(cfg)
756-
}
757-
758-
fn apply_guest_tls_defaults(
759-
ca: &mut Option<PathBuf>,
760-
cert: &mut Option<PathBuf>,
761-
key: &mut Option<PathBuf>,
762-
local_tls: Option<&LocalTlsPaths>,
763-
) {
764-
if ca.is_none()
765-
&& cert.is_none()
766-
&& key.is_none()
767-
&& let Some(paths) = local_tls
768-
{
769-
*ca = Some(paths.ca.clone());
770-
*cert = Some(paths.client_cert.clone());
771-
*key = Some(paths.client_key.clone());
772-
}
773-
}
774-
775700
#[cfg(test)]
776701
mod tests {
777702
use super::{Cli, command};
@@ -1613,6 +1538,54 @@ enable_loopback_service_http = false
16131538
);
16141539
}
16151540

1541+
#[test]
1542+
fn server_config_preparation_ignores_unselected_driver_tables() {
1543+
let _lock = ENV_LOCK
1544+
.lock()
1545+
.unwrap_or_else(std::sync::PoisonError::into_inner);
1546+
let state = tempfile::tempdir().unwrap();
1547+
let local_tls = tempfile::tempdir().unwrap();
1548+
let _g1 = EnvVarGuard::set("XDG_STATE_HOME", state.path().to_str().unwrap());
1549+
let _g2 = EnvVarGuard::set(
1550+
"OPENSHELL_LOCAL_TLS_DIR",
1551+
local_tls.path().to_str().unwrap(),
1552+
);
1553+
let config_path = state.path().join("gateway.toml");
1554+
std::fs::write(
1555+
&config_path,
1556+
r#"
1557+
[openshell.drivers.docker]
1558+
unknown_docker_key = true
1559+
1560+
[openshell.drivers.vm]
1561+
mem_mib = "not-a-number"
1562+
"#,
1563+
)
1564+
.unwrap();
1565+
1566+
let (mut args, matches) = parse_with_args(&[
1567+
"openshell-gateway",
1568+
"--config",
1569+
config_path.to_str().unwrap(),
1570+
"--db-url",
1571+
"sqlite::memory:",
1572+
"--drivers",
1573+
"podman",
1574+
"--disable-tls",
1575+
]);
1576+
1577+
let prepared =
1578+
super::prepare_server_config(&mut args, &matches).expect("server config is prepared");
1579+
1580+
assert_eq!(
1581+
prepared.config.compute_drivers,
1582+
vec![super::ComputeDriverKind::Podman]
1583+
);
1584+
let file = prepared.config_file.expect("config file is preserved");
1585+
assert!(file.openshell.drivers.contains_key("docker"));
1586+
assert!(file.openshell.drivers.contains_key("vm"));
1587+
}
1588+
16161589
#[test]
16171590
fn driver_inherits_shared_image_from_gateway_section() {
16181591
// [openshell.gateway].default_image inherits into the K8s driver
@@ -1659,18 +1632,4 @@ default_image = "k8s-specific:1.0"
16591632
.expect("deserializes");
16601633
assert_eq!(parsed.default_image, "k8s-specific:1.0");
16611634
}
1662-
1663-
#[test]
1664-
fn docker_config_reads_bind_mount_opt_in_from_driver_table() {
1665-
let file = config_file_from_toml(
1666-
r"
1667-
[openshell.drivers.docker]
1668-
enable_bind_mounts = true
1669-
",
1670-
);
1671-
1672-
let cfg = super::build_docker_config(Some(&file), None).expect("docker config");
1673-
1674-
assert!(cfg.enable_bind_mounts);
1675-
}
16761635
}

0 commit comments

Comments
 (0)