Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,12 @@ impl ExecutionPlan for RepartitionExec {
if preserve_order {
// Store streams from all the input partitions:
// Each input partition gets its own spill reader to maintain proper FIFO ordering
//
// Pass None for metrics here — these intermediate streams feed into
// StreamingMerge which is the actual output. Only the merge's
// BaselineMetrics should contribute to the operator's reported
// output_rows. Without this, every row would be counted twice
// (once by PerPartitionStream, once by StreamingMerge).
let input_streams = rx
.into_iter()
.zip(spill_readers)
Expand All @@ -1311,7 +1317,7 @@ impl ExecutionPlan for RepartitionExec {
Arc::clone(&reservation),
spill_stream,
1, // Each receiver handles one input partition
BaselineMetrics::new(&metrics, partition),
None,
)) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -1349,7 +1355,7 @@ impl ExecutionPlan for RepartitionExec {
reservation,
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
Some(BaselineMetrics::new(&metrics, partition)),
)) as SendableRecordBatchStream)
}
})
Expand Down Expand Up @@ -1862,8 +1868,8 @@ struct PerPartitionStream {
/// each sending None when complete. We must wait for all of them.
remaining_partitions: usize,

/// Execution metrics
baseline_metrics: BaselineMetrics,
/// Execution metrics (None in preserve-order mode where StreamingMerge owns the metrics)
baseline_metrics: Option<BaselineMetrics>,
}

impl PerPartitionStream {
Expand All @@ -1874,7 +1880,7 @@ impl PerPartitionStream {
reservation: SharedMemoryReservation,
spill_stream: SendableRecordBatchStream,
num_input_partitions: usize,
baseline_metrics: BaselineMetrics,
baseline_metrics: Option<BaselineMetrics>,
) -> Self {
Self {
schema,
Expand All @@ -1893,8 +1899,11 @@ impl PerPartitionStream {
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
use futures::StreamExt;
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let _timer = cloned_time.timer();
let elapsed = self
.baseline_metrics
.as_ref()
.map(|m| m.elapsed_compute().clone());
let _timer = elapsed.as_ref().map(|t| t.timer());

loop {
match self.state {
Expand Down Expand Up @@ -1980,7 +1989,11 @@ impl Stream for PerPartitionStream {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.poll_next_inner(cx);
self.baseline_metrics.record_poll(poll)
if let Some(metrics) = &self.baseline_metrics {
metrics.record_poll(poll)
} else {
poll
}
}
}

Expand Down Expand Up @@ -3294,4 +3307,47 @@ mod test {
let exec = Arc::new(exec);
Arc::new(TestMemoryExec::update_cache(&exec))
}

/// preserve_order repartition should not double-count
/// output rows.
#[tokio::test]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified this test covers the new code


  cargo test -p datafusion-physical-plan -- --nocapture

This fails without the code change

  thread 'repartition::test::test_preserve_order_output_rows_not_double_counted' panicked at datafusion/physical-plan/src/repartition/mod.rs:2991:9:
  assertion `left == right` failed: metrics output_rows (8) should match actual rows collected (4), not double-count
    left: 8
   right: 4

  Verification

async fn test_preserve_order_output_rows_not_double_counted() -> Result<()> {
use datafusion_execution::TaskContext;

// Two sorted input partitions, 2 rows each (4 total)
let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
let schema = batch1.schema();
let sort_exprs = sort_exprs(&schema);

let input_partitions = vec![vec![batch1], vec![batch2]];
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
.try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
let exec = Arc::new(exec);
let exec = Arc::new(TestMemoryExec::update_cache(&exec));

let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
.with_preserve_order();

let task_ctx = Arc::new(TaskContext::default());
let mut total_rows = 0;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
total_rows += result?.num_rows();
}
}

assert_eq!(total_rows, 4, "actual rows collected should be 4");

let metrics = exec.metrics().unwrap();
let reported_output_rows = metrics.output_rows().unwrap();
assert_eq!(
reported_output_rows, total_rows,
"metrics output_rows ({reported_output_rows}) should match \
actual rows collected ({total_rows}), not double-count"
);

Ok(())
}
}
Loading