diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 6c17669956b7..8ff45d6c9a4d 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -20,7 +20,7 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result}; use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray}; -use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; +use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteContext, IpcWriteOptions}; use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode}; use bytes::Bytes; @@ -372,15 +372,15 @@ impl FlightDataEncoder { DictionaryHandling::Resend => batch, DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, }; - for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { - let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; + let (flight_dictionaries, flight_batch) = self + .encoder + .encode_batch(&batch, self.max_flight_data_size)?; for dict in flight_dictionaries { self.queue_message(dict); } self.queue_message(flight_batch); } - Ok(()) } } @@ -701,7 +701,7 @@ struct FlightIpcEncoder { options: IpcWriteOptions, data_gen: IpcDataGenerator, dictionary_tracker: DictionaryTracker, - compression_context: CompressionContext, + compression_context: IpcWriteContext, } impl FlightIpcEncoder { @@ -710,7 +710,7 @@ impl FlightIpcEncoder { options, data_gen: IpcDataGenerator::default(), dictionary_tracker: DictionaryTracker::new(error_on_replacement), - compression_context: CompressionContext::default(), + compression_context: IpcWriteContext::default(), } } @@ -724,7 +724,11 @@ impl FlightIpcEncoder { fn encode_batch( &mut self, batch: &RecordBatch, - ) -> Result<(impl Iterator + use<>, FlightData)> { + max_flight_data_size: usize, + ) -> Result<(impl Iterator + 'static, FlightData)> { + self.compression_context + .scratch + .reserve(max_flight_data_size); let (encoded_dictionaries, encoded_batch) = self.data_gen.encode( batch, &mut self.dictionary_tracker, @@ -733,7 +737,7 @@ impl FlightIpcEncoder { )?; let flight_dictionaries = encoded_dictionaries.into_iter().map(|e| e.into()); - let flight_batch = encoded_batch.into(); + let flight_batch: FlightData = encoded_batch.into(); Ok((flight_dictionaries, flight_batch)) } @@ -1833,7 +1837,7 @@ mod tests { ) -> (Vec, FlightData) { let data_gen = IpcDataGenerator::default(); let mut dictionary_tracker = DictionaryTracker::new(false); - let mut compression_context = CompressionContext::default(); + let mut compression_context = IpcWriteContext::default(); let (encoded_dictionaries, encoded_batch) = data_gen .encode( diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 6effb5f86aaf..796688100797 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use arrow_buffer::Buffer; use arrow_ipc::convert::fb_to_schema; -use arrow_ipc::writer::CompressionContext; +use arrow_ipc::writer::IpcWriteContext; use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions}; use arrow_schema::{ArrowError, Schema, SchemaRef}; @@ -92,7 +92,7 @@ pub fn batches_to_flight_data( let data_gen = writer::IpcDataGenerator::default(); let mut dictionary_tracker = writer::DictionaryTracker::new(false); - let mut compression_context = CompressionContext::default(); + let mut compression_context = IpcWriteContext::default(); for batch in batches.iter() { let (encoded_dictionaries, encoded_batch) = data_gen.encode( diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index 05ca5627ecd8..3c01410e3873 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -26,7 +26,7 @@ use arrow::{ datatypes::SchemaRef, ipc::{ self, reader, - writer::{self, CompressionContext}, + writer::{self, IpcWriteContext}, }, record_batch::RecordBatch, }; @@ -95,7 +95,7 @@ async fn upload_data( let mut original_data_iter = original_data.iter().enumerate(); - let mut compression_context = CompressionContext::default(); + let mut compression_context = IpcWriteContext::default(); if let Some((counter, first_batch)) = original_data_iter.next() { let metadata = counter.to_string().into_bytes(); @@ -159,7 +159,7 @@ async fn send_batch( batch: &RecordBatch, options: &writer::IpcWriteOptions, dictionary_tracker: &mut writer::DictionaryTracker, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, ) -> Result { let data_gen = writer::IpcDataGenerator::default(); diff --git a/arrow-ipc/src/compression.rs b/arrow-ipc/src/compression.rs index ff6e83dfdd0b..2b15d00a0278 100644 --- a/arrow-ipc/src/compression.rs +++ b/arrow-ipc/src/compression.rs @@ -18,22 +18,24 @@ use crate::CompressionType; use arrow_buffer::Buffer; use arrow_schema::ArrowError; +use flatbuffers::FlatBufferBuilder; const LENGTH_NO_COMPRESSED_DATA: i64 = -1; const LENGTH_OF_PREFIX_DATA: i64 = 8; -/// Additional context that may be needed for compression. -/// -/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent -/// compression calls to avoid the performance overhead of initialising a new context for every -/// compression. +/// - The flatbuffer builder (`fbb`) is reset and reused across calls. +/// - The zstd compressor (when enabled) is kept alive to avoid re-initialisation overhead. #[derive(Default)] -pub struct CompressionContext { +pub struct IpcWriteContext { #[cfg(feature = "zstd")] compressor: Option>, + pub(crate) fbb: FlatBufferBuilder<'static>, + /// Scratch buffer for the IPC arrow data body. When set by the caller before + /// encode(), the existing allocation is reused instead of creating a fresh Vec. + pub scratch: Vec, } -impl CompressionContext { +impl IpcWriteContext { #[cfg(feature = "zstd")] fn zstd_compressor(&mut self) -> &mut zstd::bulk::Compressor<'static> { self.compressor.get_or_insert_with(|| { @@ -43,9 +45,9 @@ impl CompressionContext { } } -impl std::fmt::Debug for CompressionContext { +impl std::fmt::Debug for IpcWriteContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut ds = f.debug_struct("CompressionContext"); + let mut ds = f.debug_struct("IpcWriteContext"); #[cfg(feature = "zstd")] ds.field( @@ -143,7 +145,7 @@ impl CompressionCodec { &self, input: &[u8], output: &mut Vec, - context: &mut CompressionContext, + context: &mut IpcWriteContext, ) -> Result { let uncompressed_data_len = input.len(); let original_output_len = output.len(); @@ -209,7 +211,7 @@ impl CompressionCodec { &self, input: &[u8], output: &mut Vec, - context: &mut CompressionContext, + context: &mut IpcWriteContext, ) -> Result<(), ArrowError> { match self { CompressionCodec::Lz4Frame => compress_lz4(input, output), @@ -278,7 +280,7 @@ fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result, A fn compress_zstd( input: &[u8], output: &mut Vec, - context: &mut CompressionContext, + context: &mut IpcWriteContext, ) -> Result<(), ArrowError> { let result = context.zstd_compressor().compress(input)?; output.extend_from_slice(&result); @@ -290,7 +292,7 @@ fn compress_zstd( fn compress_zstd( _input: &[u8], _output: &mut Vec, - _context: &mut CompressionContext, + _context: &mut IpcWriteContext, ) -> Result<(), ArrowError> { Err(ArrowError::InvalidArgumentError( "zstd IPC compression requires the zstd feature".to_string(), diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 4142858ce80c..e0f109ff8072 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -43,7 +43,7 @@ use arrow_schema::*; use crate::CONTINUATION_MARKER; use crate::compression::CompressionCodec; -pub use crate::compression::CompressionContext; +pub use crate::compression::IpcWriteContext; use crate::convert::IpcSchemaEncoder; /// IPC write options used to control the behaviour of the [`IpcDataGenerator`] @@ -244,7 +244,7 @@ impl Default for IpcWriteOptions { /// # use std::sync::Arc; /// # use arrow_array::UInt64Array; /// # use arrow_array::RecordBatch; -/// # use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; +/// # use arrow_ipc::writer::{IpcWriteContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; /// /// // Create a record batch /// let batch = RecordBatch::try_from_iter(vec![ @@ -256,7 +256,7 @@ impl Default for IpcWriteOptions { /// let options = IpcWriteOptions::default(); /// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement); /// -/// let mut compression_context = CompressionContext::default(); +/// let mut compression_context = IpcWriteContext::default(); /// /// // encode the batch into zero or more encoded dictionaries /// // and the data for the actual array. @@ -310,7 +310,7 @@ impl IpcDataGenerator { dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, dict_id: &mut I, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, ) -> Result<(), ArrowError> { match column.data_type() { DataType::Struct(fields) => { @@ -471,7 +471,7 @@ impl IpcDataGenerator { dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, dict_id_seq: &mut I, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, ) -> Result<(), ArrowError> { match column.data_type() { DataType::Dictionary(_key_type, _value_type) => { @@ -548,7 +548,7 @@ impl IpcDataGenerator { batch: &RecordBatch, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, ) -> Result<(Vec, EncodedData), ArrowError> { let encoded_dictionaries = self.encode_all_dicts( batch, @@ -556,13 +556,26 @@ impl IpcDataGenerator { write_options, compression_context, )?; - let mut arrow_data = Vec::new(); - let (ipc_message, _, tail_pad) = self.record_batch_to_bytes( + let capacity = batch + .columns() + .iter() + .map(|a| estimate_encoded_buffer_count(a.data_type())) + .sum(); + let mut encoded_buffers: Vec = Vec::with_capacity(capacity); + let (ipc_message, body_len, tail_pad) = self.record_batch_to_bytes( batch, write_options, compression_context, - &mut IpcBodySink::Write(&mut arrow_data), + &mut IpcBodySink::Collect(&mut encoded_buffers), )?; + let alignment = write_options.alignment; + let mut arrow_data = std::mem::take(&mut compression_context.scratch); + arrow_data.clear(); + arrow_data.reserve(body_len); // safe guards against string data that are large + for enc in &encoded_buffers { + arrow_data.extend_from_slice(enc.as_slice()); + arrow_data.extend_from_slice(&PADDING[..pad_to_alignment(alignment, enc.len())]); + } arrow_data.extend_from_slice(&PADDING[..tail_pad]); Ok(( encoded_dictionaries, @@ -579,7 +592,7 @@ impl IpcDataGenerator { batch: &RecordBatch, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, ) -> Result, ArrowError> { let schema = batch.schema(); let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len()); @@ -606,7 +619,7 @@ impl IpcDataGenerator { batch: &RecordBatch, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, writer: &mut W, ) -> Result { let encoded_dictionaries = self.encode_all_dicts( @@ -689,20 +702,11 @@ impl IpcDataGenerator { &self, batch: &RecordBatch, write_options: &IpcWriteOptions, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, sink: &mut IpcBodySink<'_>, ) -> Result<(Vec, usize, usize), ArrowError> { - let mut fbb = FlatBufferBuilder::new(); - let batch_compression_type = write_options.batch_compression_type; - let compression = batch_compression_type.map(|batch_compression_type| { - let mut c = crate::BodyCompressionBuilder::new(&mut fbb); - c.add_method(crate::BodyCompressionMethod::BUFFER); - c.add_codec(batch_compression_type); - c.finish() - }); - let compression_codec: Option = batch_compression_type.map(TryInto::try_into).transpose()?; @@ -728,6 +732,15 @@ impl IpcDataGenerator { let tail_pad = pad_to_alignment(alignment, offset as usize); let body_len = offset as usize + tail_pad; + let fbb = &mut compression_context.fbb; + + let compression = batch_compression_type.map(|batch_compression_type| { + let mut c = crate::BodyCompressionBuilder::new(fbb); + c.add_method(crate::BodyCompressionMethod::BUFFER); + c.add_codec(batch_compression_type); + c.finish() + }); + let buffers = fbb.create_vector(&meta.buffers); let nodes = fbb.create_vector(&meta.nodes); let variadic_buffer = if variadic_buffer_counts.is_empty() { @@ -737,7 +750,7 @@ impl IpcDataGenerator { }; let root = { - let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); + let mut batch_builder = crate::RecordBatchBuilder::new(fbb); batch_builder.add_length(batch.num_rows() as i64); batch_builder.add_nodes(nodes); batch_builder.add_buffers(buffers); @@ -750,7 +763,7 @@ impl IpcDataGenerator { batch_builder.finish().as_union_value() }; // create an crate::Message - let mut message = crate::MessageBuilder::new(&mut fbb); + let mut message = crate::MessageBuilder::new(fbb); message.add_version(write_options.metadata_version); message.add_header_type(crate::MessageHeader::RecordBatch); message.add_bodyLength(body_len as i64); @@ -758,7 +771,9 @@ impl IpcDataGenerator { let root = message.finish(); fbb.finish(root, None); - Ok((fbb.finished_data().to_vec(), body_len, tail_pad)) + let ipc_message = fbb.finished_data().to_vec(); + fbb.reset(); + Ok((ipc_message, body_len, tail_pad)) } /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the @@ -769,22 +784,13 @@ impl IpcDataGenerator { array_data: &ArrayData, write_options: &IpcWriteOptions, is_delta: bool, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, ) -> Result { - let mut fbb = FlatBufferBuilder::new(); - let mut arrow_data: Vec = vec![]; // get the type of compression let batch_compression_type = write_options.batch_compression_type; - let compression = batch_compression_type.map(|batch_compression_type| { - let mut c = crate::BodyCompressionBuilder::new(&mut fbb); - c.add_method(crate::BodyCompressionMethod::BUFFER); - c.add_codec(batch_compression_type); - c.finish() - }); - let compression_codec: Option = batch_compression_type .map(|batch_compression_type| batch_compression_type.try_into()) .transpose()?; @@ -810,7 +816,15 @@ impl IpcDataGenerator { let body_len = offset as usize + tail_pad; arrow_data.extend_from_slice(&PADDING[..tail_pad]); - // write data + let fbb = &mut compression_context.fbb; + + let compression = batch_compression_type.map(|batch_compression_type| { + let mut c = crate::BodyCompressionBuilder::new(fbb); + c.add_method(crate::BodyCompressionMethod::BUFFER); + c.add_codec(batch_compression_type); + c.finish() + }); + let buffers = fbb.create_vector(&meta.buffers); let nodes = fbb.create_vector(&meta.nodes); let variadic_buffer = if variadic_buffer_counts.is_empty() { @@ -820,7 +834,7 @@ impl IpcDataGenerator { }; let root = { - let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); + let mut batch_builder = crate::RecordBatchBuilder::new(fbb); batch_builder.add_length(array_data.len() as i64); batch_builder.add_nodes(nodes); batch_builder.add_buffers(buffers); @@ -834,7 +848,7 @@ impl IpcDataGenerator { }; let root = { - let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb); + let mut batch_builder = crate::DictionaryBatchBuilder::new(fbb); batch_builder.add_id(dict_id); batch_builder.add_data(root); batch_builder.add_isDelta(is_delta); @@ -842,7 +856,7 @@ impl IpcDataGenerator { }; let root = { - let mut message_builder = crate::MessageBuilder::new(&mut fbb); + let mut message_builder = crate::MessageBuilder::new(fbb); message_builder.add_version(write_options.metadata_version); message_builder.add_header_type(crate::MessageHeader::DictionaryBatch); message_builder.add_bodyLength(body_len as i64); @@ -851,10 +865,11 @@ impl IpcDataGenerator { }; fbb.finish(root, None); - let finished_data = fbb.finished_data(); + let ipc_message = fbb.finished_data().to_vec(); + fbb.reset(); Ok(EncodedData { - ipc_message: finished_data.to_vec(), + ipc_message, arrow_data, }) } @@ -1238,7 +1253,7 @@ pub struct FileWriter { data_gen: IpcDataGenerator, - compression_context: CompressionContext, + compression_context: IpcWriteContext, } impl FileWriter> { @@ -1300,7 +1315,7 @@ impl FileWriter { dictionary_tracker, custom_metadata: HashMap::new(), data_gen, - compression_context: CompressionContext::default(), + compression_context: IpcWriteContext::default(), }) } @@ -1529,7 +1544,7 @@ pub struct StreamWriter { data_gen: IpcDataGenerator, - compression_context: CompressionContext, + compression_context: IpcWriteContext, } impl StreamWriter> { @@ -1580,7 +1595,7 @@ impl StreamWriter { finished: false, dictionary_tracker, data_gen, - compression_context: CompressionContext::default(), + compression_context: IpcWriteContext::default(), }) } @@ -1957,7 +1972,7 @@ fn write_array_data( sink: &mut IpcBodySink<'_>, offset: i64, compression_codec: Option, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, write_options: &IpcWriteOptions, ) -> Result { let mut offset = offset; @@ -2252,7 +2267,7 @@ fn encode_sink_buffer( sink: &mut IpcBodySink<'_>, offset: i64, compression_codec: Option, - compression_context: &mut CompressionContext, + compression_context: &mut IpcWriteContext, alignment: u8, ) -> Result { let (encoded, len) = match compression_codec { @@ -4446,7 +4461,7 @@ mod tests { let data_gen = IpcDataGenerator::default(); let mut dictionary_tracker = DictionaryTracker::new(false); let writer_options = IpcWriteOptions::default(); - let mut compression_ctx = CompressionContext::default(); + let mut compression_ctx = IpcWriteContext::default(); let schema = Arc::new(Schema::new(vec![Field::new( "a",