Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{Column, col};
use datafusion_physical_expr_common::metrics::MetricsSet;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion_physical_plan::metrics::MetricValue;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::StreamExt;

Expand All @@ -69,16 +71,18 @@ async fn test_sort_with_limited_memory() -> Result<()> {

// Basic test with a lot of groups that cannot all fit in memory and 1 record batch
// from each spill file is too much memory
let spill_count = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs {
let metrics = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs {
pool_size,
task_ctx: Arc::new(task_ctx),
number_of_record_batches: 100,
get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size),
memory_behavior: Default::default(),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

let total_spill_files_size = spill_count * record_batch_size;
let total_spill_files_size =
metrics.spill_count().unwrap_or_default() * record_batch_size;
assert!(
total_spill_files_size > pool_size,
"Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}",
Expand Down Expand Up @@ -119,6 +123,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() ->
}
}),
memory_behavior: Default::default(),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand Down Expand Up @@ -157,6 +162,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_c
}
}),
memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand Down Expand Up @@ -195,6 +201,7 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_t
}
}),
memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning,
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand Down Expand Up @@ -227,6 +234,7 @@ async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> {
number_of_record_batches: 100,
get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6),
memory_behavior: Default::default(),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand All @@ -252,21 +260,33 @@ async fn test_sort_with_limited_memory_and_oversized_record_batch() -> Result<()
))
};

let number_of_record_batches = 100;

// Each spilled run's largest batch is so big that two merge streams cannot be
// reserved at once even at the smallest read-buffer size (`2 * (2 * batch) >
// pool`), yet a single stream still fits (`2 * batch < pool`). Reducing the
// buffer size therefore cannot help, the multi-level merge has to re-spill a
// run with a smaller batch size to make progress instead of failing with
// `ResourcesExhausted`.
run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs {
let metrics = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs {
pool_size,
task_ctx: Arc::new(task_ctx),
number_of_record_batches: 100,
number_of_record_batches,
get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 3),
memory_behavior: Default::default(),

assert_all_output_batches_roughly_match_batch_size_conf: false,
})
.await?;

let output_batches = get_output_batches_from_metrics(&metrics);

// minimum 2 batches more
assert!(
output_batches >= number_of_record_batches + 2,
"output_batches {output_batches} should be greater than number_of_record_batches ({number_of_record_batches}) + 2"
);

Ok(())
}

Expand All @@ -277,6 +297,9 @@ struct RunTestWithLimitedMemoryArgs {
get_size_of_record_batch_to_generate:
Pin<Box<dyn Fn(usize) -> usize + Send + 'static>>,
memory_behavior: MemoryBehavior,

/// When true we would `assert_eq(the number of output_rows metric / output_batches metric == task_ctx.batch_size)`
assert_all_output_batches_roughly_match_batch_size_conf: bool,
}

#[derive(Default)]
Expand All @@ -289,7 +312,7 @@ enum MemoryBehavior {

async fn run_sort_test_with_limited_memory(
mut args: RunTestWithLimitedMemoryArgs,
) -> Result<usize> {
) -> Result<MetricsSet> {
let get_size_of_record_batch_to_generate = std::mem::replace(
&mut args.get_size_of_record_batch_to_generate,
Box::pin(move |_| unreachable!("should not be called after take")),
Expand Down Expand Up @@ -349,7 +372,23 @@ async fn run_sort_test_with_limited_memory(

let result = sort_exec.execute(0, Arc::clone(&args.task_ctx))?;

run_test(args, sort_exec, result).await
let number_of_record_batches = args.number_of_record_batches;
let assert_output_batch_size =
args.assert_all_output_batches_roughly_match_batch_size_conf;

let metrics = run_test(args, sort_exec, result).await?;

assert_baseline_metrics_for_non_empty_output(
&metrics,
number_of_record_batches * record_batch_size as usize,
if assert_output_batch_size {
Some(record_batch_size as usize)
} else {
None
},
);

Ok(metrics)
}

fn grow_memory_as_much_as_possible(
Expand Down Expand Up @@ -383,17 +422,19 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<()

// Basic test with a lot of groups that cannot all fit in memory and 1 record batch
// from each spill file is too much memory
let spill_count =
let metrics =
run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs {
pool_size,
task_ctx: Arc::new(task_ctx),
number_of_record_batches: 100,
get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size),
memory_behavior: Default::default(),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

let total_spill_files_size = spill_count * record_batch_size;
let total_spill_files_size =
metrics.spill_count().unwrap_or_default() * record_batch_size;
assert!(
total_spill_files_size > pool_size,
"Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}",
Expand Down Expand Up @@ -430,6 +471,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_
}
}),
memory_behavior: Default::default(),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand Down Expand Up @@ -464,6 +506,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_
}
}),
memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand Down Expand Up @@ -498,6 +541,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_
}
}),
memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning,
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand Down Expand Up @@ -527,6 +571,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_reco
number_of_record_batches: 100,
get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6),
memory_behavior: Default::default(),
assert_all_output_batches_roughly_match_batch_size_conf: true,
})
.await?;

Expand All @@ -535,7 +580,7 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_reco

async fn run_test_aggregate_with_high_cardinality(
mut args: RunTestWithLimitedMemoryArgs,
) -> Result<usize> {
) -> Result<MetricsSet> {
let get_size_of_record_batch_to_generate = std::mem::replace(
&mut args.get_size_of_record_batch_to_generate,
Box::pin(move |_| unreachable!("should not be called after take")),
Expand Down Expand Up @@ -624,20 +669,21 @@ async fn run_test(
args: RunTestWithLimitedMemoryArgs,
plan: Arc<dyn ExecutionPlan>,
result_stream: SendableRecordBatchStream,
) -> Result<usize> {
) -> Result<MetricsSet> {
let number_of_record_batches = args.number_of_record_batches;

consume_stream_and_simulate_other_running_memory_consumers(args, result_stream)
.await?;

let metrics = plan.metrics().expect("must have metrics");
let spill_count = assert_spill_count_metric(true, plan);

assert!(
spill_count > 0,
"Expected spill, but did not, number of record batches: {number_of_record_batches}",
);

Ok(spill_count)
Ok(metrics)
}

/// Consume the stream and change the amount of memory used while consuming it based on the [`MemoryBehavior`] provided
Expand Down Expand Up @@ -693,3 +739,56 @@ async fn consume_stream_and_simulate_other_running_memory_consumers(

Ok(())
}

/// Assert baseline metrics are as expected or around that
///
/// `output_batch_size` should be `None` when you expect to not get batched at the same size
/// `Some(session conf batch size)` for the rest
fn assert_baseline_metrics_for_non_empty_output(
metrics: &MetricsSet,
expected_output_rows: usize,
output_batch_size: Option<usize>,
) {
let end_time = metrics
.iter()
.find_map(|item| match item.value() {
MetricValue::EndTimestamp(end) => Some(end),
_ => None,
})
.expect("Must have end time metric since it exists in the baseline");

assert_ne!(end_time.value(), None);

assert_eq!(metrics.output_rows(), Some(expected_output_rows));

let output_bytes = metrics
.iter()
.find_map(|item| match item.value() {
MetricValue::OutputBytes(total) => Some(total),
_ => None,
})
.expect("Must have output_bytes metric since it exists in the baseline");

assert_ne!(output_bytes.value(), 0_usize);

let output_batches = get_output_batches_from_metrics(metrics);

if let Some(output_batch_size) = output_batch_size {
assert_eq!(
output_batches,
expected_output_rows.div_ceil(output_batch_size)
);
} else {
assert_ne!(output_batches, 0,);
}
}

fn get_output_batches_from_metrics(metrics: &MetricsSet) -> usize {
metrics
.iter()
.find_map(|item| match item.value() {
MetricValue::OutputBatches(total) => Some(total.value()),
_ => None,
})
.expect("Must have output_batches metric since it exists in the baseline")
}
32 changes: 23 additions & 9 deletions datafusion/physical-plan/src/sorts/multi_level_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_execution::memory_pool::MemoryReservation;
use crate::sorts::builder::try_grow_reservation_to_at_least;
use crate::sorts::sort::get_reserved_bytes_for_record_batch_size;
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::stream::RecordBatchStreamAdapter;
use crate::stream::{ObservedStream, RecordBatchStreamAdapter};
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use futures::TryStreamExt;
Expand Down Expand Up @@ -242,27 +242,34 @@ impl MultiLevelMergeBuilder {
fn merge_sorted_runs_within_mem_limit(&mut self) -> Result<MergeStep> {
match (self.sorted_spill_files.len(), self.sorted_streams.len()) {
// No data so empty batch
(0, 0) => Ok(MergeStep::Stream(Box::pin(EmptyRecordBatchStream::new(
Arc::clone(&self.schema),
)))),
(0, 0) => {
let empty_stream =
Box::pin(EmptyRecordBatchStream::new(Arc::clone(&self.schema)));
Ok(MergeStep::Stream(self.observe_output(empty_stream)))
}

// Only in-memory stream, return that
(0, 1) => Ok(MergeStep::Stream(self.sorted_streams.remove(0))),
(0, 1) => {
let output_stream = self.sorted_streams.remove(0);
Ok(MergeStep::Stream(self.observe_output(output_stream)))
}

// Only single sorted spill file so return it
(1, 0) => {
let spill_file = self.sorted_spill_files.remove(0);

// Not reserving any memory for this disk as we are not holding it in memory
Ok(MergeStep::Stream(
self.spill_manager
.read_spill_as_stream(spill_file.file, None)?,
))
let output_stream = self
.spill_manager
.read_spill_as_stream(spill_file.file, None)?;

Ok(MergeStep::Stream(self.observe_output(output_stream)))
}

// Only in memory streams, so merge them all in a single pass
(0, _) => {
let sorted_stream = mem::take(&mut self.sorted_streams);
// No need to wrap with observed stream since merge sort will update the observed metrics
Ok(MergeStep::Stream(self.create_new_merge_sort(
sorted_stream,
// If we have no sorted spill files left, this is the last run
Expand Down Expand Up @@ -574,6 +581,13 @@ impl MultiLevelMergeBuilder {

Ok(())
}

fn observe_output(
&self,
stream: SendableRecordBatchStream,
) -> SendableRecordBatchStream {
Box::pin(ObservedStream::new(stream, self.metrics.clone(), None))
}
}

/// Outcome of trying to reserve memory for one multi-level merge pass.
Expand Down
Loading
Loading