diff --git a/bottlecap/Cargo.toml b/bottlecap/Cargo.toml index b77902083..68a9af7bc 100644 --- a/bottlecap/Cargo.toml +++ b/bottlecap/Cargo.toml @@ -132,6 +132,11 @@ fips = [ "rustls/fips", "rustls-native-certs", ] +# Exposes test-only constructors (for example, +# `InvocationProcessorHandle::noop()`) to the upcoming testmode binary. +# Not enabled in `default` or `fips`, so the items it gates do not appear +# in production builds. +test-mode = [] [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage,coverage_nightly)'] } diff --git a/bottlecap/src/bin/bottlecap/main.rs b/bottlecap/src/bin/bottlecap/main.rs index 8ee88bb26..59de10197 100644 --- a/bottlecap/src/bin/bottlecap/main.rs +++ b/bottlecap/src/bin/bottlecap/main.rs @@ -61,21 +61,10 @@ use bottlecap::{ provider::Provider as TagProvider, }, traces::{ - http_client as trace_http_client, propagation::DatadogCompositePropagator, - proxy_aggregator, - proxy_flusher::Flusher as ProxyFlusher, - span_dedup_service, - stats_aggregator::StatsAggregator, - stats_concentrator_service::{StatsConcentratorHandle, StatsConcentratorService}, - stats_flusher, + stats_concentrator_service::StatsConcentratorHandle, stats_generator::StatsGenerator, - stats_processor, trace_agent, trace_aggregator::SendDataBuilderInfo, - trace_aggregator_service::{ - AggregatorHandle as TraceAggregatorHandle, AggregatorService as TraceAggregatorService, - }, - trace_flusher, trace_processor::{self, SendingTraceProcessor}, }, }; @@ -95,7 +84,6 @@ use dogstatsd::{ flusher::{Flusher as MetricsFlusher, FlusherConfig as MetricsFlusherConfig}, metric::{EMPTY_TAGS, SortedTags}, }; -use libdd_trace_obfuscation::obfuscation_config; use reqwest::Client; use std::{collections::hash_map, env, path::Path, str::FromStr, sync::Arc}; use tokio::time::Instant; @@ -356,16 +344,16 @@ async fn extension_loop_active( } }; - let ( - trace_agent_channel, + let bottlecap::startup::TraceAgentPipeline { + trace_tx: trace_agent_channel, trace_flusher, trace_processor, stats_flusher, proxy_flusher, - trace_agent_shutdown_token, + shutdown_token: trace_agent_shutdown_token, stats_concentrator, trace_aggregator_handle, - ) = start_trace_agent( + } = bottlecap::startup::start_trace_agent( config, &api_key_factory, &tags_provider, @@ -1092,120 +1080,6 @@ fn start_logs_agent( ) } -#[allow(clippy::type_complexity)] -fn start_trace_agent( - config: &Arc, - api_key_factory: &Arc, - tags_provider: &Arc, - invocation_processor_handle: InvocationProcessorHandle, - appsec_processor: Option>>, - client: &Client, -) -> ( - Sender, - Arc, - Arc, - Arc, - Arc, - tokio_util::sync::CancellationToken, - StatsConcentratorHandle, - TraceAggregatorHandle, -) { - // Build one shared hyper-based HTTP client for trace and stats flushing. - // This client type is required by libdd_trace_utils for SendData::send(). - let trace_http_client = trace_http_client::create_client( - config.proxy_https.as_ref(), - config.tls_cert_file.as_ref(), - config.skip_ssl_validation, - ) - .expect("Failed to create trace HTTP client"); - - // Stats - let (stats_concentrator_service, stats_concentrator_handle) = - StatsConcentratorService::new(Arc::clone(config)); - tokio::spawn(stats_concentrator_service.run()); - let stats_aggregator: Arc> = Arc::new(TokioMutex::new( - StatsAggregator::new_with_concentrator(stats_concentrator_handle.clone()), - )); - let stats_flusher = Arc::new(stats_flusher::StatsFlusher::new( - api_key_factory.clone(), - stats_aggregator.clone(), - Arc::clone(config), - trace_http_client.clone(), - )); - - let stats_processor = Arc::new(stats_processor::ServerlessStatsProcessor {}); - - // Traces - let (trace_aggregator_service, trace_aggregator_handle) = TraceAggregatorService::default(); - tokio::spawn(trace_aggregator_service.run()); - - let trace_flusher = Arc::new(trace_flusher::TraceFlusher::new( - trace_aggregator_handle.clone(), - config.clone(), - api_key_factory.clone(), - trace_http_client, - )); - - let obfuscation_config = obfuscation_config::ObfuscationConfig { - tag_replace_rules: config.apm_replace_tags.clone(), - http_remove_path_digits: config.apm_config_obfuscation_http_remove_paths_with_digits, - http_remove_query_string: config.apm_config_obfuscation_http_remove_query_string, - obfuscate_memcached: false, - obfuscation_redis_enabled: false, - obfuscation_redis_remove_all_args: false, - }; - - let trace_processor = Arc::new(trace_processor::ServerlessTraceProcessor { - obfuscation_config: Arc::new(obfuscation_config), - }); - - let (span_dedup_service, span_dedup_handle) = span_dedup_service::DedupService::new(); - tokio::spawn(span_dedup_service.run()); - - // Proxy - let proxy_aggregator = Arc::new(TokioMutex::new(proxy_aggregator::Aggregator::default())); - let proxy_flusher = Arc::new(ProxyFlusher::new( - api_key_factory.clone(), - Arc::clone(&proxy_aggregator), - Arc::clone(tags_provider), - Arc::clone(config), - client.clone(), - )); - - let trace_agent = trace_agent::TraceAgent::new( - Arc::clone(config), - trace_aggregator_handle.clone(), - trace_processor.clone(), - stats_aggregator, - stats_processor, - proxy_aggregator, - invocation_processor_handle, - appsec_processor, - Arc::clone(tags_provider), - stats_concentrator_handle.clone(), - span_dedup_handle, - ); - let trace_agent_channel = trace_agent.get_sender_copy(); - let shutdown_token = trace_agent.shutdown_token(); - - tokio::spawn(async move { - if let Err(e) = trace_agent.start().await { - error!("Error starting trace agent: {e:?}"); - } - }); - - ( - trace_agent_channel, - trace_flusher, - trace_processor, - stats_flusher, - proxy_flusher, - shutdown_token, - stats_concentrator_handle, - trace_aggregator_handle, - ) -} - async fn start_dogstatsd( tags_provider: Arc, api_key_factory: Arc, diff --git a/bottlecap/src/lib.rs b/bottlecap/src/lib.rs index df94fd246..97784956c 100644 --- a/bottlecap/src/lib.rs +++ b/bottlecap/src/lib.rs @@ -35,6 +35,7 @@ pub mod otlp; pub mod proc; pub mod proxy; pub mod secrets; +pub mod startup; pub mod tags; pub mod traces; diff --git a/bottlecap/src/lifecycle/invocation/processor_service.rs b/bottlecap/src/lifecycle/invocation/processor_service.rs index a41a95b26..1f48f1960 100644 --- a/bottlecap/src/lifecycle/invocation/processor_service.rs +++ b/bottlecap/src/lifecycle/invocation/processor_service.rs @@ -135,6 +135,68 @@ pub struct InvocationProcessorHandle { } impl InvocationProcessorHandle { + /// Returns a handle backed by a background task that acknowledges every + /// command with a sensible default. Use this for callers that need an + /// `InvocationProcessorHandle` (for example, to reuse `TraceAgent` or + /// `handle_traces`) but have no Lambda lifecycle state to drive. + /// + /// The match below is exhaustive: adding a new `ProcessorCommand` variant + /// will fail to compile here, so the noop behavior for it must be + /// decided explicitly. A response-carrying variant placed in the + /// fire-and-forget arm would silently drop its sender, causing the + /// caller to receive `ProcessorError::ChannelReceive` instead of the + /// intended default. + /// + /// Compiled only under unit tests and when the `test-mode` feature is + /// enabled, so this constructor does not exist in production builds. + /// + /// Must be called from within an active Tokio runtime: the constructor + /// uses `tokio::spawn` to start the draining task, which panics if no + /// runtime is registered for the current thread. + #[cfg(any(test, feature = "test-mode"))] + #[must_use] + pub fn noop() -> Self { + let (sender, mut receiver) = mpsc::channel::(1000); + tokio::spawn(async move { + while let Some(command) = receiver.recv().await { + match command { + // Request-response commands: reply with a default so the + // caller doesn't block on the oneshot forever. + ProcessorCommand::GetReparentingInfo { response } => { + let _ = response.send(Ok(std::collections::VecDeque::new())); + } + ProcessorCommand::UpdateReparenting { response, .. } => { + let _ = response.send(Ok(Vec::new())); + } + ProcessorCommand::SetColdStartSpanTraceId { response, .. } => { + let _ = response.send(Ok(None)); + } + ProcessorCommand::PlatformRuntimeDone { response, .. } + | ProcessorCommand::PlatformReport { response, .. } => { + let _ = response.send(()); + } + // Fire-and-forget commands: drop silently. + ProcessorCommand::InvokeEvent { .. } + | ProcessorCommand::PlatformInitStart { .. } + | ProcessorCommand::PlatformInitReport { .. } + | ProcessorCommand::PlatformRestoreStart { .. } + | ProcessorCommand::PlatformRestoreReport { .. } + | ProcessorCommand::PlatformStart { .. } + | ProcessorCommand::UniversalInstrumentationStart { .. } + | ProcessorCommand::UniversalInstrumentationEnd { .. } + | ProcessorCommand::AddReparenting { .. } + | ProcessorCommand::AddTracerSpan { .. } + | ProcessorCommand::ForwardDurableContext { .. } + | ProcessorCommand::OnOutOfMemoryError { .. } + | ProcessorCommand::OnShutdownEvent + | ProcessorCommand::SendCtxSpans { .. } + | ProcessorCommand::Shutdown => {} + } + } + }); + InvocationProcessorHandle { sender } + } + pub async fn on_invoke_event( &self, request_id: String, @@ -657,3 +719,221 @@ impl InvocationProcessorService { debug!("InvocationProcessorService stopped"); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn noop_request_response_methods_return_defaults() { + let handle = InvocationProcessorHandle::noop(); + + let info = handle + .get_reparenting_info() + .await + .expect("noop get_reparenting_info"); + assert!(info.is_empty()); + + let contexts = handle + .update_reparenting(std::collections::VecDeque::new()) + .await + .expect("noop update_reparenting"); + assert!(contexts.is_empty()); + + let cold_start = handle + .set_cold_start_span_trace_id(42) + .await + .expect("noop set_cold_start_span_trace_id"); + assert!(cold_start.is_none()); + } + + #[tokio::test] + async fn noop_fire_and_forget_commands_do_not_panic() { + let handle = InvocationProcessorHandle::noop(); + + handle + .on_invoke_event("rid".to_string()) + .await + .expect("noop on_invoke_event"); + handle + .on_shutdown_event() + .await + .expect("noop on_shutdown_event"); + handle + .on_out_of_memory_error(0) + .await + .expect("noop on_out_of_memory_error"); + } + + #[tokio::test] + async fn noop_platform_runtime_done_and_report_respond_without_blocking() { + use crate::{ + LAMBDA_RUNTIME_SLUG, config, + extension::telemetry::events::OnDemandReportMetrics, + traces::{ + stats_concentrator_service::StatsConcentratorService, + stats_generator::StatsGenerator, trace_processor::ServerlessTraceProcessor, + }, + }; + use libdd_trace_obfuscation::obfuscation_config::ObfuscationConfig; + use std::collections::HashMap; + + let config = Arc::new(config::Config::default()); + let (svc, concentrator) = StatsConcentratorService::new(Arc::clone(&config)); + tokio::spawn(svc.run()); + let trace_sender = Arc::new(SendingTraceProcessor { + appsec: None, + processor: Arc::new(ServerlessTraceProcessor { + obfuscation_config: Arc::new(ObfuscationConfig::new().expect("ObfuscationConfig")), + }), + trace_tx: tokio::sync::mpsc::channel(1).0, + stats_generator: Arc::new(StatsGenerator::new(concentrator)), + }); + let tags_provider = Arc::new(provider::Provider::new( + Arc::clone(&config), + LAMBDA_RUNTIME_SLUG.to_string(), + &HashMap::from([("function_arn".to_string(), "test-arn".to_string())]), + )); + + let handle = InvocationProcessorHandle::noop(); + + handle + .on_platform_runtime_done( + "rid".to_string(), + RuntimeDoneMetrics { + duration_ms: 0.0, + produced_bytes: None, + }, + Status::Success, + None, + Arc::clone(&tags_provider), + Arc::clone(&trace_sender), + 0, + ) + .await + .expect("noop on_platform_runtime_done"); + + handle + .on_platform_report( + "rid", + ReportMetrics::OnDemand(OnDemandReportMetrics { + duration_ms: 0.0, + billed_duration_ms: 0, + memory_size_mb: 0, + max_memory_used_mb: 0, + init_duration_ms: None, + restore_duration_ms: None, + }), + 0, + Status::Success, + &None, + &None, + tags_provider, + trace_sender, + ) + .await + .expect("noop on_platform_report"); + } + + /// Guards against a future `ProcessorCommand` variant with a `response` + /// field being accidentally placed in the fire-and-forget arm: that would + /// silently drop the sender, causing `rx.await` to return + /// `ProcessorError::ChannelReceive`. With an explicit timeout, any such + /// regression fails fast instead of hanging the test suite. + #[tokio::test] + async fn noop_request_response_variants_complete_within_timeout() { + use crate::{ + LAMBDA_RUNTIME_SLUG, config, + extension::telemetry::events::OnDemandReportMetrics, + traces::{ + stats_concentrator_service::StatsConcentratorService, + stats_generator::StatsGenerator, trace_processor::ServerlessTraceProcessor, + }, + }; + use libdd_trace_obfuscation::obfuscation_config::ObfuscationConfig; + use std::collections::HashMap; + use tokio::time::{Duration, timeout}; + + let timeout_dur = Duration::from_millis(500); + + let config = Arc::new(config::Config::default()); + let (svc, concentrator) = StatsConcentratorService::new(Arc::clone(&config)); + tokio::spawn(svc.run()); + let trace_sender = Arc::new(SendingTraceProcessor { + appsec: None, + processor: Arc::new(ServerlessTraceProcessor { + obfuscation_config: Arc::new(ObfuscationConfig::new().expect("ObfuscationConfig")), + }), + trace_tx: tokio::sync::mpsc::channel(1).0, + stats_generator: Arc::new(StatsGenerator::new(concentrator)), + }); + let tags_provider = Arc::new(provider::Provider::new( + Arc::clone(&config), + LAMBDA_RUNTIME_SLUG.to_string(), + &HashMap::from([("function_arn".to_string(), "test-arn".to_string())]), + )); + + let handle = InvocationProcessorHandle::noop(); + + timeout(timeout_dur, handle.get_reparenting_info()) + .await + .expect("get_reparenting_info timed out") + .expect("get_reparenting_info"); + + timeout( + timeout_dur, + handle.update_reparenting(std::collections::VecDeque::new()), + ) + .await + .expect("update_reparenting timed out") + .expect("update_reparenting"); + + timeout(timeout_dur, handle.set_cold_start_span_trace_id(42)) + .await + .expect("set_cold_start_span_trace_id timed out") + .expect("set_cold_start_span_trace_id"); + + timeout( + timeout_dur, + handle.on_platform_runtime_done( + "rid".to_string(), + RuntimeDoneMetrics { + duration_ms: 0.0, + produced_bytes: None, + }, + Status::Success, + None, + Arc::clone(&tags_provider), + Arc::clone(&trace_sender), + 0, + ), + ) + .await + .expect("on_platform_runtime_done timed out") + .expect("on_platform_runtime_done"); + + timeout( + timeout_dur, + handle.on_platform_report( + "rid", + ReportMetrics::OnDemand(OnDemandReportMetrics { + duration_ms: 0.0, + billed_duration_ms: 0, + memory_size_mb: 0, + max_memory_used_mb: 0, + init_duration_ms: None, + restore_duration_ms: None, + }), + 0, + Status::Success, + &None, + &None, + tags_provider, + trace_sender, + ), + ) + .await + .expect("on_platform_report timed out") + .expect("on_platform_report"); + } +} diff --git a/bottlecap/src/startup.rs b/bottlecap/src/startup.rs new file mode 100644 index 000000000..88365e941 --- /dev/null +++ b/bottlecap/src/startup.rs @@ -0,0 +1,201 @@ +// Copyright 2023-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::Arc; + +use dogstatsd::api_key::ApiKeyFactory; +use libdd_trace_obfuscation::obfuscation_config; +use tokio::sync::{Mutex as TokioMutex, mpsc::Sender}; +use tokio_util::sync::CancellationToken; +use tracing::error; + +use crate::appsec::processor::Processor as AppSecProcessor; +use crate::config::Config; +use crate::lifecycle::invocation::processor_service::InvocationProcessorHandle; +use crate::tags::provider::Provider as TagProvider; +use crate::traces::{ + http_client as trace_http_client, proxy_aggregator, + proxy_flusher::Flusher as ProxyFlusher, + span_dedup_service, + stats_aggregator::StatsAggregator, + stats_concentrator_service::{self, StatsConcentratorHandle}, + stats_flusher, stats_processor, trace_agent, + trace_aggregator::SendDataBuilderInfo, + trace_aggregator_service::{self, AggregatorHandle as TraceAggregatorHandle}, + trace_flusher, trace_processor, +}; + +/// Handles produced by [`build_trace_agent`] / [`start_trace_agent`]. Holds +/// the trace-channel sender, the per-domain flushers, the shutdown token, and +/// the aggregator/concentrator handles the caller needs to drive flushes and +/// shut the pipeline down. +pub struct TraceAgentPipeline { + pub trace_tx: Sender, + pub trace_flusher: Arc, + pub trace_processor: Arc, + pub stats_flusher: Arc, + pub proxy_flusher: Arc, + pub shutdown_token: CancellationToken, + pub stats_concentrator: StatsConcentratorHandle, + pub trace_aggregator_handle: TraceAggregatorHandle, +} + +/// Builds the full trace-processing pipeline (trace + stats + proxy +/// aggregators, services, flushers) and the [`trace_agent::TraceAgent`] that +/// owns the HTTP listener. Spawns the aggregator/concentrator/dedup services +/// onto the current tokio runtime; `TraceAgent::new` additionally spawns a +/// trace-payload drain task. Does **not** spawn the `TraceAgent` itself. +/// The caller owns `trace_agent` and is responsible for spawning +/// `trace_agent.start()`, optionally after further configuring it (for +/// example, via [`trace_agent::TraceAgent::with_router_extension`]). +/// +/// Note: the four background tasks started during this call (aggregator, +/// concentrator, dedup, and the trace-payload drain task inside +/// `TraceAgent::new`) have no external shutdown signal; they run until +/// their command channels are dropped. Callers that abandon the returned +/// `TraceAgent` without either spawning it or dropping the pipeline handles +/// will leak those background tasks for the lifetime of the process. +/// +/// Most callers want [`start_trace_agent`] instead, which handles the spawn. +pub fn build_trace_agent( + config: &Arc, + api_key_factory: &Arc, + tags_provider: &Arc, + invocation_processor_handle: InvocationProcessorHandle, + appsec_processor: Option>>, + client: &reqwest::Client, +) -> (trace_agent::TraceAgent, TraceAgentPipeline) { + // Build one shared hyper-based HTTP client for trace and stats flushing. + // This client type is required by libdd_trace_utils for SendData::send(). + let trace_http_client = trace_http_client::create_client( + config.proxy_https.as_ref(), + config.tls_cert_file.as_ref(), + config.skip_ssl_validation, + ) + .expect("Failed to create trace HTTP client"); + + // Stats + let (stats_concentrator_service, stats_concentrator_handle) = + stats_concentrator_service::StatsConcentratorService::new(Arc::clone(config)); + tokio::spawn(stats_concentrator_service.run()); + let stats_aggregator: Arc> = Arc::new(TokioMutex::new( + StatsAggregator::new_with_concentrator(stats_concentrator_handle.clone()), + )); + let stats_flusher = Arc::new(stats_flusher::StatsFlusher::new( + api_key_factory.clone(), + stats_aggregator.clone(), + Arc::clone(config), + trace_http_client.clone(), + )); + + let stats_processor = Arc::new(stats_processor::ServerlessStatsProcessor {}); + + // Traces + let (trace_aggregator_service, trace_aggregator_handle) = + trace_aggregator_service::AggregatorService::default(); + tokio::spawn(trace_aggregator_service.run()); + + let trace_flusher = Arc::new(trace_flusher::TraceFlusher::new( + trace_aggregator_handle.clone(), + config.clone(), + api_key_factory.clone(), + trace_http_client, + )); + + let obfuscation_config = obfuscation_config::ObfuscationConfig { + tag_replace_rules: config.apm_replace_tags.clone(), + http_remove_path_digits: config.apm_config_obfuscation_http_remove_paths_with_digits, + http_remove_query_string: config.apm_config_obfuscation_http_remove_query_string, + obfuscate_memcached: false, + obfuscation_redis_enabled: false, + obfuscation_redis_remove_all_args: false, + }; + + let trace_processor = Arc::new(trace_processor::ServerlessTraceProcessor { + obfuscation_config: Arc::new(obfuscation_config), + }); + + let (span_dedup_service, span_dedup_handle) = span_dedup_service::DedupService::new(); + tokio::spawn(span_dedup_service.run()); + + // Proxy + let proxy_aggregator = Arc::new(TokioMutex::new(proxy_aggregator::Aggregator::default())); + let proxy_flusher = Arc::new(ProxyFlusher::new( + api_key_factory.clone(), + Arc::clone(&proxy_aggregator), + Arc::clone(tags_provider), + Arc::clone(config), + client.clone(), + )); + + let trace_agent = trace_agent::TraceAgent::new( + Arc::clone(config), + trace_aggregator_handle.clone(), + trace_processor.clone(), + stats_aggregator, + stats_processor, + proxy_aggregator, + invocation_processor_handle, + appsec_processor, + Arc::clone(tags_provider), + stats_concentrator_handle.clone(), + span_dedup_handle, + ); + let pipeline = TraceAgentPipeline { + trace_tx: trace_agent.get_sender_copy(), + trace_flusher, + trace_processor, + stats_flusher, + proxy_flusher, + shutdown_token: trace_agent.shutdown_token(), + stats_concentrator: stats_concentrator_handle, + trace_aggregator_handle, + }; + + (trace_agent, pipeline) +} + +/// Builds the trace-processing pipeline with [`build_trace_agent`] and spawns +/// the [`trace_agent::TraceAgent`] HTTP listener onto the current tokio +/// runtime. Convenience entry point for callers that do not need to +/// further configure the `TraceAgent` before spawning it. +/// +/// Errors from `TraceAgent::start` (TCP bind failures, router-extension +/// failures, axum serve errors) are logged and discarded; this preserves +/// the pre-extraction behavior from `main.rs` and means the surrounding +/// pipeline keeps running with a dead trace channel. Callers that need to +/// react to startup errors should use [`build_trace_agent`] and spawn the +/// agent themselves. +/// +/// Callers that need to customize the `TraceAgent` (for example via +/// [`trace_agent::TraceAgent::with_router_extension`]) should use +/// [`build_trace_agent`] directly and spawn the returned `TraceAgent` +/// themselves after applying the extra configuration. +pub fn start_trace_agent( + config: &Arc, + api_key_factory: &Arc, + tags_provider: &Arc, + invocation_processor_handle: InvocationProcessorHandle, + appsec_processor: Option>>, + client: &reqwest::Client, +) -> TraceAgentPipeline { + let (trace_agent, pipeline) = build_trace_agent( + config, + api_key_factory, + tags_provider, + invocation_processor_handle, + appsec_processor, + client, + ); + + // Log-only error handling preserved from the pre-extraction code in + // main.rs. See the doc comment above for callers that need reactive + // error handling. + tokio::spawn(async move { + if let Err(e) = trace_agent.start().await { + error!("Error starting trace agent: {e:?}"); + } + }); + + pipeline +} diff --git a/bottlecap/src/traces/trace_agent.rs b/bottlecap/src/traces/trace_agent.rs index fb7e8343a..c2205b027 100644 --- a/bottlecap/src/traces/trace_agent.rs +++ b/bottlecap/src/traces/trace_agent.rs @@ -105,6 +105,26 @@ pub struct ProxyState { pub proxy_aggregator: Arc>, } +/// Extension seam for the [`TraceAgent`] HTTP router. Implementors receive +/// the fully-assembled production router and return it with any additional +/// routes merged in. Used to attach optional routes (for example, a +/// deterministic drain endpoint) without adding a dedicated field per +/// route to [`TraceAgent`]. +/// +/// Returning `Err` propagates out of [`TraceAgent::start`], aborting the +/// HTTP listener task. Note that the production convenience entry point +/// [`crate::startup::start_trace_agent`] spawns `start` and only logs its +/// error; the surrounding pipeline does not observe the failure. Callers +/// that need to react to startup errors must use +/// [`crate::startup::build_trace_agent`] and spawn the agent themselves. +/// +/// Implementors that carry state must call `.with_state(...)` on their +/// sub-router before merging, because `Router::merge` requires both +/// routers to share the same state type. +pub trait RouterExtension: Send + Sync { + fn extend(&self, router: Router) -> Result>; +} + pub struct TraceAgent { pub config: Arc, pub trace_processor: Arc, @@ -118,6 +138,9 @@ pub struct TraceAgent { tx: Sender, stats_concentrator: StatsConcentratorHandle, span_deduper: DedupHandle, + /// `None` when the caller wants no extra routes. See + /// [`TraceAgent::with_router_extension`]. + router_extension: Option>, } #[derive(Clone, Copy)] @@ -170,9 +193,20 @@ impl TraceAgent { shutdown_token: CancellationToken::new(), stats_concentrator, span_deduper, + router_extension: None, } } + /// Attaches a [`RouterExtension`] that will be applied to the + /// fully-assembled production router before the outer fallback and + /// body-limit layers. Without this call, the agent exposes only its + /// production routes. + #[must_use] + pub fn with_router_extension(mut self, extension: Arc) -> Self { + self.router_extension = Some(extension); + self + } + #[allow(clippy::cast_possible_truncation)] pub async fn start(&self) -> Result<(), Box> { let now = Instant::now(); @@ -192,7 +226,7 @@ impl TraceAgent { } }); - let router = self.make_router(stats_tx); + let router = self.make_router(stats_tx)?; let port = u16::try_from(TRACE_AGENT_PORT).expect("TRACE_AGENT_PORT is too large"); let socket = SocketAddr::from(([127, 0, 0, 1], port)); @@ -212,7 +246,10 @@ impl TraceAgent { Ok(()) } - fn make_router(&self, stats_tx: Sender) -> Router { + fn make_router( + &self, + stats_tx: Sender, + ) -> Result> { let stats_generator = Arc::new(StatsGenerator::new(self.stats_concentrator.clone())); let trace_state = TraceState { config: Arc::clone(&self.config), @@ -281,14 +318,20 @@ impl TraceAgent { let info_router = Router::new().route(INFO_ENDPOINT_PATH, any(Self::info)); - Router::new() + let mut router = Router::new() .merge(trace_router) .merge(stats_router) .merge(proxy_router) - .merge(info_router) + .merge(info_router); + + if let Some(extension) = &self.router_extension { + router = extension.extend(router)?; + } + + Ok(router .fallback(handler_not_found) // Disable the default body limit so we can use our own limit - .layer(DefaultBodyLimit::disable()) + .layer(DefaultBodyLimit::disable())) } async fn graceful_shutdown(shutdown_token: CancellationToken) { @@ -765,3 +808,149 @@ fn success_response(message: &str) -> Response { debug!("{}", message); (StatusCode::OK, json!({"rate_by_service": {}}).to_string()).into_response() } + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use crate::{ + LAMBDA_RUNTIME_SLUG, config, + traces::{ + span_dedup_service::DedupService, stats_concentrator_service::StatsConcentratorService, + trace_aggregator_service::AggregatorService, + }, + }; + use axum::body::Body; + use axum::http::Request; + use libdd_trace_obfuscation::obfuscation_config::ObfuscationConfig; + use std::collections::HashMap; + use std::sync::atomic::{AtomicUsize, Ordering}; + use tower::ServiceExt; + + /// Test extension that records how many times its route was hit. + struct SpyExtension { + hits: Arc, + } + + impl RouterExtension for SpyExtension { + fn extend(&self, router: Router) -> Result> { + Ok(router.merge( + Router::new() + .route( + "/spy", + post(|State(hits): State>| async move { + hits.fetch_add(1, Ordering::SeqCst); + StatusCode::NO_CONTENT + }), + ) + .with_state(Arc::clone(&self.hits)), + )) + } + } + + /// Test extension that always fails, used to assert that `make_router` + /// surfaces extension errors instead of swallowing them. + struct FailingExtension; + + impl RouterExtension for FailingExtension { + fn extend(&self, _router: Router) -> Result> { + Err("extension failed during make_router".into()) + } + } + + fn build_test_agent() -> TraceAgent { + let config = Arc::new(config::Config::default()); + let (concentrator_svc, concentrator) = StatsConcentratorService::new(Arc::clone(&config)); + tokio::spawn(concentrator_svc.run()); + let (aggregator_svc, aggregator_handle) = AggregatorService::default(); + tokio::spawn(aggregator_svc.run()); + let (dedup_svc, dedup_handle) = DedupService::new(); + tokio::spawn(dedup_svc.run()); + let trace_processor = Arc::new(trace_processor::ServerlessTraceProcessor { + obfuscation_config: Arc::new(ObfuscationConfig::new().expect("ObfuscationConfig")), + }); + let stats_aggregator = Arc::new(Mutex::new( + stats_aggregator::StatsAggregator::new_with_concentrator(concentrator.clone()), + )); + let proxy_aggregator = Arc::new(Mutex::new(proxy_aggregator::Aggregator::default())); + let tags_provider = Arc::new(provider::Provider::new( + Arc::clone(&config), + LAMBDA_RUNTIME_SLUG.to_string(), + &HashMap::from([("function_arn".to_string(), "test-arn".to_string())]), + )); + + TraceAgent::new( + config, + aggregator_handle, + trace_processor, + stats_aggregator, + Arc::new(stats_processor::ServerlessStatsProcessor {}), + proxy_aggregator, + InvocationProcessorHandle::noop(), + None, + tags_provider, + concentrator, + dedup_handle, + ) + } + + #[tokio::test] + async fn with_router_extension_adds_reachable_route_to_make_router() { + let hits = Arc::new(AtomicUsize::new(0)); + let agent = build_test_agent().with_router_extension(Arc::new(SpyExtension { + hits: Arc::clone(&hits), + })); + let (stats_tx, _stats_rx) = mpsc::channel::(1); + let router = agent.make_router(stats_tx).expect("make_router"); + + let response = router + .oneshot( + Request::builder() + .method("POST") + .uri("/spy") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("route response"); + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + assert_eq!(hits.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn make_router_propagates_extension_error() { + let agent = build_test_agent().with_router_extension(Arc::new(FailingExtension)); + let (stats_tx, _stats_rx) = mpsc::channel::(1); + + let err = agent + .make_router(stats_tx) + .expect_err("make_router should surface extension error"); + + assert!( + err.to_string() + .contains("extension failed during make_router"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn make_router_returns_404_for_extension_route_when_none_attached() { + let agent = build_test_agent(); + let (stats_tx, _stats_rx) = mpsc::channel::(1); + let router = agent.make_router(stats_tx).expect("make_router"); + + let response = router + .oneshot( + Request::builder() + .method("POST") + .uri("/spy") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("route response"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } +}