Skip to main content

parquet/arrow/arrow_writer/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains writer which writes arrow data into parquet data.
19
20use bytes::Bytes;
21use std::io::{Read, Write};
22use std::iter::Peekable;
23use std::slice::Iter;
24use std::sync::{Arc, Mutex};
25use std::vec::IntoIter;
26
27use arrow_array::cast::AsArray;
28use arrow_array::types::*;
29use arrow_array::{ArrayRef, Int32Array, RecordBatch, RecordBatchWriter};
30use arrow_schema::{
31    ArrowError, DataType as ArrowDataType, Field, IntervalUnit, SchemaRef, TimeUnit,
32};
33
34use super::schema::{add_encoded_arrow_schema_to_metadata, decimal_length_from_precision};
35
36use crate::arrow::ArrowSchemaConverter;
37use crate::arrow::arrow_writer::byte_array::ByteArrayEncoder;
38use crate::column::page::{CompressedPage, PageWriteSpec, PageWriter};
39use crate::column::page_encryption::PageEncryptor;
40use crate::column::writer::encoder::ColumnValueEncoder;
41use crate::column::writer::{
42    ColumnCloseResult, ColumnWriter, GenericColumnWriter, get_column_writer,
43};
44use crate::data_type::{ByteArray, FixedLenByteArray};
45#[cfg(feature = "encryption")]
46use crate::encryption::encrypt::FileEncryptor;
47use crate::errors::{ParquetError, Result};
48use crate::file::metadata::{KeyValue, ParquetMetaData, RowGroupMetaData};
49use crate::file::properties::{WriterProperties, WriterPropertiesPtr};
50use crate::file::reader::{ChunkReader, Length};
51use crate::file::writer::{SerializedFileWriter, SerializedRowGroupWriter};
52use crate::parquet_thrift::{ThriftCompactOutputProtocol, WriteThrift};
53use crate::schema::types::{ColumnDescPtr, SchemaDescPtr, SchemaDescriptor};
54use levels::{ArrayLevels, calculate_array_levels};
55
56mod byte_array;
57mod levels;
58
59/// Encodes [`RecordBatch`] to parquet
60///
61/// Writes Arrow `RecordBatch`es to a Parquet writer. Multiple [`RecordBatch`] will be encoded
62/// to the same row group, up to `max_row_group_size` rows. Any remaining rows will be
63/// flushed on close, leading the final row group in the output file to potentially
64/// contain fewer than `max_row_group_size` rows
65///
66/// # Example: Writing `RecordBatch`es
67/// ```
68/// # use std::sync::Arc;
69/// # use bytes::Bytes;
70/// # use arrow_array::{ArrayRef, Int64Array};
71/// # use arrow_array::RecordBatch;
72/// # use parquet::arrow::arrow_writer::ArrowWriter;
73/// # use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
74/// let col = Arc::new(Int64Array::from_iter_values([1, 2, 3])) as ArrayRef;
75/// let to_write = RecordBatch::try_from_iter([("col", col)]).unwrap();
76///
77/// let mut buffer = Vec::new();
78/// let mut writer = ArrowWriter::try_new(&mut buffer, to_write.schema(), None).unwrap();
79/// writer.write(&to_write).unwrap();
80/// writer.close().unwrap();
81///
82/// let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buffer), 1024).unwrap();
83/// let read = reader.next().unwrap().unwrap();
84///
85/// assert_eq!(to_write, read);
86/// ```
87///
88/// # Memory Usage and Limiting
89///
90/// The nature of Parquet requires buffering of an entire row group before it can
91/// be flushed to the underlying writer. Data is mostly buffered in its encoded
92/// form, reducing memory usage. However, some data such as dictionary keys,
93/// large strings or very nested data may still result in non-trivial memory
94/// usage.
95///
96/// See Also:
97/// * [`ArrowWriter::memory_size`]: the current memory usage of the writer.
98/// * [`ArrowWriter::in_progress_size`]: Estimated size of the buffered row group,
99///
100/// Call [`Self::flush`] to trigger an early flush of a row group based on a
101/// memory threshold and/or global memory pressure. However,  smaller row groups
102/// result in higher metadata overheads, and thus may worsen compression ratios
103/// and query performance.
104///
105/// ```no_run
106/// # use std::io::Write;
107/// # use arrow_array::RecordBatch;
108/// # use parquet::arrow::ArrowWriter;
109/// # let mut writer: ArrowWriter<Vec<u8>> = todo!();
110/// # let batch: RecordBatch = todo!();
111/// writer.write(&batch).unwrap();
112/// // Trigger an early flush if anticipated size exceeds 1_000_000
113/// if writer.in_progress_size() > 1_000_000 {
114///     writer.flush().unwrap();
115/// }
116/// ```
117///
118/// ## Type Support
119///
120/// The writer supports writing all Arrow [`DataType`]s that have a direct mapping to
121/// Parquet types including  [`StructArray`] and [`ListArray`].
122///
123/// The following are not supported:
124///
125/// * [`IntervalMonthDayNanoArray`]: Parquet does not [support nanosecond intervals].
126///
127/// [`DataType`]: https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html
128/// [`StructArray`]: https://docs.rs/arrow/latest/arrow/array/struct.StructArray.html
129/// [`ListArray`]: https://docs.rs/arrow/latest/arrow/array/type.ListArray.html
130/// [`IntervalMonthDayNanoArray`]: https://docs.rs/arrow/latest/arrow/array/type.IntervalMonthDayNanoArray.html
131/// [support nanosecond intervals]: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#interval
132///
133/// ## Type Compatibility
134/// The writer can write Arrow [`RecordBatch`]s that are logically equivalent. This means that for
135/// a  given column, the writer can accept multiple Arrow [`DataType`]s that contain the same
136/// value type.
137///
138/// For example, the following [`DataType`]s are all logically equivalent and can be written
139/// to the same column:
140/// * String, LargeString, StringView
141/// * Binary, LargeBinary, BinaryView
142///
143/// The writer can will also accept both native and dictionary encoded arrays if the dictionaries
144/// contain compatible values.
145/// ```
146/// # use std::sync::Arc;
147/// # use arrow_array::{DictionaryArray, LargeStringArray, RecordBatch, StringArray, UInt8Array};
148/// # use arrow_schema::{DataType, Field, Schema};
149/// # use parquet::arrow::arrow_writer::ArrowWriter;
150/// let record_batch1 = RecordBatch::try_new(
151///    Arc::new(Schema::new(vec![Field::new("col", DataType::LargeUtf8, false)])),
152///    vec![Arc::new(LargeStringArray::from_iter_values(vec!["a", "b"]))]
153///  )
154/// .unwrap();
155///
156/// let mut buffer = Vec::new();
157/// let mut writer = ArrowWriter::try_new(&mut buffer, record_batch1.schema(), None).unwrap();
158/// writer.write(&record_batch1).unwrap();
159///
160/// let record_batch2 = RecordBatch::try_new(
161///     Arc::new(Schema::new(vec![Field::new(
162///         "col",
163///         DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
164///          false,
165///     )])),
166///     vec![Arc::new(DictionaryArray::new(
167///          UInt8Array::from_iter_values(vec![0, 1]),
168///          Arc::new(StringArray::from_iter_values(vec!["b", "c"])),
169///      ))],
170///  )
171///  .unwrap();
172///  writer.write(&record_batch2).unwrap();
173///  writer.close();
174/// ```
175pub struct ArrowWriter<W: Write> {
176    /// Underlying Parquet writer
177    writer: SerializedFileWriter<W>,
178
179    /// The in-progress row group if any
180    in_progress: Option<ArrowRowGroupWriter>,
181
182    /// A copy of the Arrow schema.
183    ///
184    /// The schema is used to verify that each record batch written has the correct schema
185    arrow_schema: SchemaRef,
186
187    /// Creates new [`ArrowRowGroupWriter`] instances as required
188    row_group_writer_factory: ArrowRowGroupWriterFactory,
189
190    /// The maximum number of rows to write to each row group, or None for unlimited
191    max_row_group_row_count: Option<usize>,
192
193    /// The maximum size in bytes for a row group, or None for unlimited
194    max_row_group_bytes: Option<usize>,
195}
196
197impl<W: Write + Send> std::fmt::Debug for ArrowWriter<W> {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        let buffered_memory = self.in_progress_size();
200        f.debug_struct("ArrowWriter")
201            .field("writer", &self.writer)
202            .field("in_progress_size", &format_args!("{buffered_memory} bytes"))
203            .field("in_progress_rows", &self.in_progress_rows())
204            .field("arrow_schema", &self.arrow_schema)
205            .field("max_row_group_row_count", &self.max_row_group_row_count)
206            .field("max_row_group_bytes", &self.max_row_group_bytes)
207            .finish()
208    }
209}
210
211impl<W: Write + Send> ArrowWriter<W> {
212    /// Try to create a new Arrow writer
213    ///
214    /// The writer will fail if:
215    ///  * a `SerializedFileWriter` cannot be created from the ParquetWriter
216    ///  * the Arrow schema contains unsupported datatypes such as Unions
217    pub fn try_new(
218        writer: W,
219        arrow_schema: SchemaRef,
220        props: Option<WriterProperties>,
221    ) -> Result<Self> {
222        let options = ArrowWriterOptions::new().with_properties(props.unwrap_or_default());
223        Self::try_new_with_options(writer, arrow_schema, options)
224    }
225
226    /// Try to create a new Arrow writer with [`ArrowWriterOptions`].
227    ///
228    /// The writer will fail if:
229    ///  * a `SerializedFileWriter` cannot be created from the ParquetWriter
230    ///  * the Arrow schema contains unsupported datatypes such as Unions
231    pub fn try_new_with_options(
232        writer: W,
233        arrow_schema: SchemaRef,
234        options: ArrowWriterOptions,
235    ) -> Result<Self> {
236        let mut props = options.properties;
237
238        let schema = if let Some(parquet_schema) = options.schema_descr {
239            parquet_schema.clone()
240        } else {
241            let mut converter = ArrowSchemaConverter::new().with_coerce_types(props.coerce_types());
242            if let Some(schema_root) = &options.schema_root {
243                converter = converter.schema_root(schema_root);
244            }
245
246            converter.convert(&arrow_schema)?
247        };
248
249        if !options.skip_arrow_metadata {
250            // add serialized arrow schema
251            add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut props);
252        }
253
254        let max_row_group_row_count = props.max_row_group_row_count();
255        let max_row_group_bytes = props.max_row_group_bytes();
256
257        let props_ptr = Arc::new(props);
258        let file_writer =
259            SerializedFileWriter::new(writer, schema.root_schema_ptr(), Arc::clone(&props_ptr))?;
260
261        let row_group_writer_factory =
262            ArrowRowGroupWriterFactory::new(&file_writer, arrow_schema.clone());
263
264        Ok(Self {
265            writer: file_writer,
266            in_progress: None,
267            arrow_schema,
268            row_group_writer_factory,
269            max_row_group_row_count,
270            max_row_group_bytes,
271        })
272    }
273
274    /// Returns metadata for any flushed row groups
275    pub fn flushed_row_groups(&self) -> &[RowGroupMetaData] {
276        self.writer.flushed_row_groups()
277    }
278
279    /// Estimated memory usage, in bytes, of this `ArrowWriter`
280    ///
281    /// This estimate is formed bu summing the values of
282    /// [`ArrowColumnWriter::memory_size`] all in progress columns.
283    pub fn memory_size(&self) -> usize {
284        match &self.in_progress {
285            Some(in_progress) => in_progress.writers.iter().map(|x| x.memory_size()).sum(),
286            None => 0,
287        }
288    }
289
290    /// Anticipated encoded size of the in progress row group.
291    ///
292    /// This estimate the row group size after being completely encoded is,
293    /// formed by summing the values of
294    /// [`ArrowColumnWriter::get_estimated_total_bytes`] for all in progress
295    /// columns.
296    pub fn in_progress_size(&self) -> usize {
297        match &self.in_progress {
298            Some(in_progress) => in_progress
299                .writers
300                .iter()
301                .map(|x| x.get_estimated_total_bytes())
302                .sum(),
303            None => 0,
304        }
305    }
306
307    /// Returns the number of rows buffered in the in progress row group
308    pub fn in_progress_rows(&self) -> usize {
309        self.in_progress
310            .as_ref()
311            .map(|x| x.buffered_rows)
312            .unwrap_or_default()
313    }
314
315    /// Returns the number of bytes written by this instance
316    pub fn bytes_written(&self) -> usize {
317        self.writer.bytes_written()
318    }
319
320    /// Encodes the provided [`RecordBatch`]
321    ///
322    /// If this would cause the current row group to exceed [`WriterProperties::max_row_group_row_count`]
323    /// rows or [`WriterProperties::max_row_group_bytes`] bytes, the contents of `batch` will be
324    /// written to one or more row groups such that limits are respected.
325    ///
326    /// If both limits are `None`, all data is written to a single row group.
327    /// If one limit is set, that limit is respected.
328    /// If both limits are set, the lower bound (whichever triggers first) is respected.
329    ///
330    /// This will fail if the `batch`'s schema does not match the writer's schema.
331    pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
332        if batch.num_rows() == 0 {
333            return Ok(());
334        }
335
336        let in_progress = match &mut self.in_progress {
337            Some(in_progress) => in_progress,
338            x => x.insert(
339                self.row_group_writer_factory
340                    .create_row_group_writer(self.writer.flushed_row_groups().len())?,
341            ),
342        };
343
344        if let Some(max_rows) = self.max_row_group_row_count {
345            if in_progress.buffered_rows + batch.num_rows() > max_rows {
346                let to_write = max_rows - in_progress.buffered_rows;
347                let a = batch.slice(0, to_write);
348                let b = batch.slice(to_write, batch.num_rows() - to_write);
349                self.write(&a)?;
350                return self.write(&b);
351            }
352        }
353
354        // Check byte limit: if we have buffered data, use measured average row size
355        // to split batch proactively before exceeding byte limit
356        if let Some(max_bytes) = self.max_row_group_bytes {
357            if in_progress.buffered_rows > 0 {
358                let current_bytes = in_progress.get_estimated_total_bytes();
359
360                if current_bytes >= max_bytes {
361                    self.flush()?;
362                    return self.write(batch);
363                }
364
365                let avg_row_bytes = current_bytes / in_progress.buffered_rows;
366                if avg_row_bytes > 0 {
367                    // At this point, `current_bytes < max_bytes` (checked above)
368                    let remaining_bytes = max_bytes - current_bytes;
369                    let rows_that_fit = remaining_bytes / avg_row_bytes;
370
371                    if batch.num_rows() > rows_that_fit {
372                        if rows_that_fit > 0 {
373                            let a = batch.slice(0, rows_that_fit);
374                            let b = batch.slice(rows_that_fit, batch.num_rows() - rows_that_fit);
375                            self.write(&a)?;
376                            return self.write(&b);
377                        } else {
378                            self.flush()?;
379                            return self.write(batch);
380                        }
381                    }
382                }
383            }
384        }
385
386        in_progress.write(batch)?;
387
388        let should_flush = self
389            .max_row_group_row_count
390            .is_some_and(|max| in_progress.buffered_rows >= max)
391            || self
392                .max_row_group_bytes
393                .is_some_and(|max| in_progress.get_estimated_total_bytes() >= max);
394
395        if should_flush {
396            self.flush()?
397        }
398        Ok(())
399    }
400
401    /// Writes the given buf bytes to the internal buffer.
402    ///
403    /// It's safe to use this method to write data to the underlying writer,
404    /// because it will ensure that the buffering and byte‐counting layers are used.
405    pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
406        self.writer.write_all(buf)
407    }
408
409    /// Flushes underlying writer
410    pub fn sync(&mut self) -> std::io::Result<()> {
411        self.writer.flush()
412    }
413
414    /// Flushes all buffered rows into a new row group
415    ///
416    /// Note the underlying writer is not flushed with this call.
417    /// If this is a desired behavior, please call [`ArrowWriter::sync`].
418    pub fn flush(&mut self) -> Result<()> {
419        let in_progress = match self.in_progress.take() {
420            Some(in_progress) => in_progress,
421            None => return Ok(()),
422        };
423
424        let mut row_group_writer = self.writer.next_row_group()?;
425        for chunk in in_progress.close()? {
426            chunk.append_to_row_group(&mut row_group_writer)?;
427        }
428        row_group_writer.close()?;
429        Ok(())
430    }
431
432    /// Additional [`KeyValue`] metadata to be written in addition to those from [`WriterProperties`]
433    ///
434    /// This method provide a way to append kv_metadata after write RecordBatch
435    pub fn append_key_value_metadata(&mut self, kv_metadata: KeyValue) {
436        self.writer.append_key_value_metadata(kv_metadata)
437    }
438
439    /// Returns a reference to the underlying writer.
440    pub fn inner(&self) -> &W {
441        self.writer.inner()
442    }
443
444    /// Returns a mutable reference to the underlying writer.
445    ///
446    /// **Warning**: if you write directly to this writer, you will skip
447    /// the `TrackedWrite` buffering and byte‐counting layers. That’ll cause
448    /// the file footer’s recorded offsets and sizes to diverge from reality,
449    /// resulting in an unreadable or corrupted Parquet file.
450    ///
451    /// If you want to write safely to the underlying writer, use [`Self::write_all`].
452    pub fn inner_mut(&mut self) -> &mut W {
453        self.writer.inner_mut()
454    }
455
456    /// Flushes any outstanding data and returns the underlying writer.
457    pub fn into_inner(mut self) -> Result<W> {
458        self.flush()?;
459        self.writer.into_inner()
460    }
461
462    /// Close and finalize the underlying Parquet writer
463    ///
464    /// Unlike [`Self::close`] this does not consume self
465    ///
466    /// Attempting to write after calling finish will result in an error
467    pub fn finish(&mut self) -> Result<ParquetMetaData> {
468        self.flush()?;
469        self.writer.finish()
470    }
471
472    /// Close and finalize the underlying Parquet writer
473    pub fn close(mut self) -> Result<ParquetMetaData> {
474        self.finish()
475    }
476
477    /// Create a new row group writer and return its column writers.
478    #[deprecated(
479        since = "56.2.0",
480        note = "Use `ArrowRowGroupWriterFactory` instead, see `ArrowColumnWriter` for an example"
481    )]
482    pub fn get_column_writers(&mut self) -> Result<Vec<ArrowColumnWriter>> {
483        self.flush()?;
484        let in_progress = self
485            .row_group_writer_factory
486            .create_row_group_writer(self.writer.flushed_row_groups().len())?;
487        Ok(in_progress.writers)
488    }
489
490    /// Append the given column chunks to the file as a new row group.
491    #[deprecated(
492        since = "56.2.0",
493        note = "Use `SerializedFileWriter` directly instead, see `ArrowColumnWriter` for an example"
494    )]
495    pub fn append_row_group(&mut self, chunks: Vec<ArrowColumnChunk>) -> Result<()> {
496        let mut row_group_writer = self.writer.next_row_group()?;
497        for chunk in chunks {
498            chunk.append_to_row_group(&mut row_group_writer)?;
499        }
500        row_group_writer.close()?;
501        Ok(())
502    }
503
504    /// Converts this writer into a lower-level [`SerializedFileWriter`] and [`ArrowRowGroupWriterFactory`].
505    ///
506    /// Flushes any outstanding data before returning.
507    ///
508    /// This can be useful to provide more control over how files are written, for example
509    /// to write columns in parallel. See the example on [`ArrowColumnWriter`].
510    pub fn into_serialized_writer(
511        mut self,
512    ) -> Result<(SerializedFileWriter<W>, ArrowRowGroupWriterFactory)> {
513        self.flush()?;
514        Ok((self.writer, self.row_group_writer_factory))
515    }
516}
517
518impl<W: Write + Send> RecordBatchWriter for ArrowWriter<W> {
519    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
520        self.write(batch).map_err(|e| e.into())
521    }
522
523    fn close(self) -> std::result::Result<(), ArrowError> {
524        self.close()?;
525        Ok(())
526    }
527}
528
529/// Arrow-specific configuration settings for writing parquet files.
530///
531/// See [`ArrowWriter`] for how to configure the writer.
532#[derive(Debug, Clone, Default)]
533pub struct ArrowWriterOptions {
534    properties: WriterProperties,
535    skip_arrow_metadata: bool,
536    schema_root: Option<String>,
537    schema_descr: Option<SchemaDescriptor>,
538}
539
540impl ArrowWriterOptions {
541    /// Creates a new [`ArrowWriterOptions`] with the default settings.
542    pub fn new() -> Self {
543        Self::default()
544    }
545
546    /// Sets the [`WriterProperties`] for writing parquet files.
547    pub fn with_properties(self, properties: WriterProperties) -> Self {
548        Self { properties, ..self }
549    }
550
551    /// Skip encoding the embedded arrow metadata (defaults to `false`)
552    ///
553    /// Parquet files generated by the [`ArrowWriter`] contain embedded arrow schema
554    /// by default.
555    ///
556    /// Set `skip_arrow_metadata` to true, to skip encoding the embedded metadata.
557    pub fn with_skip_arrow_metadata(self, skip_arrow_metadata: bool) -> Self {
558        Self {
559            skip_arrow_metadata,
560            ..self
561        }
562    }
563
564    /// Set the name of the root parquet schema element (defaults to `"arrow_schema"`)
565    pub fn with_schema_root(self, schema_root: String) -> Self {
566        Self {
567            schema_root: Some(schema_root),
568            ..self
569        }
570    }
571
572    /// Explicitly specify the Parquet schema to be used
573    ///
574    /// If omitted (the default), the [`ArrowSchemaConverter`] is used to compute the
575    /// Parquet [`SchemaDescriptor`]. This may be used When the [`SchemaDescriptor`] is
576    /// already known or must be calculated using custom logic.
577    pub fn with_parquet_schema(self, schema_descr: SchemaDescriptor) -> Self {
578        Self {
579            schema_descr: Some(schema_descr),
580            ..self
581        }
582    }
583}
584
585/// A single column chunk produced by [`ArrowColumnWriter`]
586#[derive(Default)]
587struct ArrowColumnChunkData {
588    length: usize,
589    data: Vec<Bytes>,
590}
591
592impl Length for ArrowColumnChunkData {
593    fn len(&self) -> u64 {
594        self.length as _
595    }
596}
597
598impl ChunkReader for ArrowColumnChunkData {
599    type T = ArrowColumnChunkReader;
600
601    fn get_read(&self, start: u64) -> Result<Self::T> {
602        assert_eq!(start, 0); // Assume append_column writes all data in one-shot
603        Ok(ArrowColumnChunkReader(
604            self.data.clone().into_iter().peekable(),
605        ))
606    }
607
608    fn get_bytes(&self, _start: u64, _length: usize) -> Result<Bytes> {
609        unimplemented!()
610    }
611}
612
613/// A [`Read`] for [`ArrowColumnChunkData`]
614struct ArrowColumnChunkReader(Peekable<IntoIter<Bytes>>);
615
616impl Read for ArrowColumnChunkReader {
617    fn read(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
618        let buffer = loop {
619            match self.0.peek_mut() {
620                Some(b) if b.is_empty() => {
621                    self.0.next();
622                    continue;
623                }
624                Some(b) => break b,
625                None => return Ok(0),
626            }
627        };
628
629        let len = buffer.len().min(out.len());
630        let b = buffer.split_to(len);
631        out[..len].copy_from_slice(&b);
632        Ok(len)
633    }
634}
635
636/// A shared [`ArrowColumnChunkData`]
637///
638/// This allows it to be owned by [`ArrowPageWriter`] whilst allowing access via
639/// [`ArrowRowGroupWriter`] on flush, without requiring self-referential borrows
640type SharedColumnChunk = Arc<Mutex<ArrowColumnChunkData>>;
641
642#[derive(Default)]
643struct ArrowPageWriter {
644    buffer: SharedColumnChunk,
645    #[cfg(feature = "encryption")]
646    page_encryptor: Option<PageEncryptor>,
647}
648
649impl ArrowPageWriter {
650    #[cfg(feature = "encryption")]
651    pub fn with_encryptor(mut self, page_encryptor: Option<PageEncryptor>) -> Self {
652        self.page_encryptor = page_encryptor;
653        self
654    }
655
656    #[cfg(feature = "encryption")]
657    fn page_encryptor_mut(&mut self) -> Option<&mut PageEncryptor> {
658        self.page_encryptor.as_mut()
659    }
660
661    #[cfg(not(feature = "encryption"))]
662    fn page_encryptor_mut(&mut self) -> Option<&mut PageEncryptor> {
663        None
664    }
665}
666
667impl PageWriter for ArrowPageWriter {
668    fn write_page(&mut self, page: CompressedPage) -> Result<PageWriteSpec> {
669        let page = match self.page_encryptor_mut() {
670            Some(page_encryptor) => page_encryptor.encrypt_compressed_page(page)?,
671            None => page,
672        };
673
674        let page_header = page.to_thrift_header()?;
675        let header = {
676            let mut header = Vec::with_capacity(1024);
677
678            match self.page_encryptor_mut() {
679                Some(page_encryptor) => {
680                    page_encryptor.encrypt_page_header(&page_header, &mut header)?;
681                    if page.compressed_page().is_data_page() {
682                        page_encryptor.increment_page();
683                    }
684                }
685                None => {
686                    let mut protocol = ThriftCompactOutputProtocol::new(&mut header);
687                    page_header.write_thrift(&mut protocol)?;
688                }
689            };
690
691            Bytes::from(header)
692        };
693
694        let mut buf = self.buffer.try_lock().unwrap();
695
696        let data = page.compressed_page().buffer().clone();
697        let compressed_size = data.len() + header.len();
698
699        let mut spec = PageWriteSpec::new();
700        spec.page_type = page.page_type();
701        spec.num_values = page.num_values();
702        spec.uncompressed_size = page.uncompressed_size() + header.len();
703        spec.offset = buf.length as u64;
704        spec.compressed_size = compressed_size;
705        spec.bytes_written = compressed_size as u64;
706
707        buf.length += compressed_size;
708        buf.data.push(header);
709        buf.data.push(data);
710
711        Ok(spec)
712    }
713
714    fn close(&mut self) -> Result<()> {
715        Ok(())
716    }
717}
718
719/// A leaf column that can be encoded by [`ArrowColumnWriter`]
720#[derive(Debug)]
721pub struct ArrowLeafColumn(ArrayLevels);
722
723/// Computes the [`ArrowLeafColumn`] for a potentially nested [`ArrayRef`]
724///
725/// This function can be used along with [`get_column_writers`] to encode
726/// individual columns in parallel. See example on [`ArrowColumnWriter`]
727pub fn compute_leaves(field: &Field, array: &ArrayRef) -> Result<Vec<ArrowLeafColumn>> {
728    let levels = calculate_array_levels(array, field)?;
729    Ok(levels.into_iter().map(ArrowLeafColumn).collect())
730}
731
732/// The data for a single column chunk, see [`ArrowColumnWriter`]
733pub struct ArrowColumnChunk {
734    data: ArrowColumnChunkData,
735    close: ColumnCloseResult,
736}
737
738impl std::fmt::Debug for ArrowColumnChunk {
739    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
740        f.debug_struct("ArrowColumnChunk")
741            .field("length", &self.data.length)
742            .finish_non_exhaustive()
743    }
744}
745
746impl ArrowColumnChunk {
747    /// Calls [`SerializedRowGroupWriter::append_column`] with this column's data
748    pub fn append_to_row_group<W: Write + Send>(
749        self,
750        writer: &mut SerializedRowGroupWriter<'_, W>,
751    ) -> Result<()> {
752        writer.append_column(&self.data, self.close)
753    }
754}
755
756/// Encodes [`ArrowLeafColumn`] to [`ArrowColumnChunk`]
757///
758/// `ArrowColumnWriter` instances can be created using an [`ArrowRowGroupWriterFactory`];
759///
760/// Note: This is a low-level interface for applications that require
761/// fine-grained control of encoding (e.g. encoding using multiple threads),
762/// see [`ArrowWriter`] for a higher-level interface
763///
764/// # Example: Encoding two Arrow Array's in Parallel
765/// ```
766/// // The arrow schema
767/// # use std::sync::Arc;
768/// # use arrow_array::*;
769/// # use arrow_schema::*;
770/// # use parquet::arrow::ArrowSchemaConverter;
771/// # use parquet::arrow::arrow_writer::{compute_leaves, ArrowColumnChunk, ArrowLeafColumn, ArrowRowGroupWriterFactory};
772/// # use parquet::file::properties::WriterProperties;
773/// # use parquet::file::writer::{SerializedFileWriter, SerializedRowGroupWriter};
774/// #
775/// let schema = Arc::new(Schema::new(vec![
776///     Field::new("i32", DataType::Int32, false),
777///     Field::new("f32", DataType::Float32, false),
778/// ]));
779///
780/// // Compute the parquet schema
781/// let props = Arc::new(WriterProperties::default());
782/// let parquet_schema = ArrowSchemaConverter::new()
783///   .with_coerce_types(props.coerce_types())
784///   .convert(&schema)
785///   .unwrap();
786///
787/// // Create parquet writer
788/// let root_schema = parquet_schema.root_schema_ptr();
789/// // write to memory in the example, but this could be a File
790/// let mut out = Vec::with_capacity(1024);
791/// let mut writer = SerializedFileWriter::new(&mut out, root_schema, props.clone())
792///   .unwrap();
793///
794/// // Create a factory for building Arrow column writers
795/// let row_group_factory = ArrowRowGroupWriterFactory::new(&writer, Arc::clone(&schema));
796/// // Create column writers for the 0th row group
797/// let col_writers = row_group_factory.create_column_writers(0).unwrap();
798///
799/// // Spawn a worker thread for each column
800/// //
801/// // Note: This is for demonstration purposes, a thread-pool e.g. rayon or tokio, would be better.
802/// // The `map` produces an iterator of type `tuple of (thread handle, send channel)`.
803/// let mut workers: Vec<_> = col_writers
804///     .into_iter()
805///     .map(|mut col_writer| {
806///         let (send, recv) = std::sync::mpsc::channel::<ArrowLeafColumn>();
807///         let handle = std::thread::spawn(move || {
808///             // receive Arrays to encode via the channel
809///             for col in recv {
810///                 col_writer.write(&col)?;
811///             }
812///             // once the input is complete, close the writer
813///             // to return the newly created ArrowColumnChunk
814///             col_writer.close()
815///         });
816///         (handle, send)
817///     })
818///     .collect();
819///
820/// // Start row group
821/// let mut row_group_writer: SerializedRowGroupWriter<'_, _> = writer
822///   .next_row_group()
823///   .unwrap();
824///
825/// // Create some example input columns to encode
826/// let to_write = vec![
827///     Arc::new(Int32Array::from_iter_values([1, 2, 3])) as _,
828///     Arc::new(Float32Array::from_iter_values([1., 45., -1.])) as _,
829/// ];
830///
831/// // Send the input columns to the workers
832/// let mut worker_iter = workers.iter_mut();
833/// for (arr, field) in to_write.iter().zip(&schema.fields) {
834///     for leaves in compute_leaves(field, arr).unwrap() {
835///         worker_iter.next().unwrap().1.send(leaves).unwrap();
836///     }
837/// }
838///
839/// // Wait for the workers to complete encoding, and append
840/// // the resulting column chunks to the row group (and the file)
841/// for (handle, send) in workers {
842///     drop(send); // Drop send side to signal termination
843///     // wait for the worker to send the completed chunk
844///     let chunk: ArrowColumnChunk = handle.join().unwrap().unwrap();
845///     chunk.append_to_row_group(&mut row_group_writer).unwrap();
846/// }
847/// // Close the row group which writes to the underlying file
848/// row_group_writer.close().unwrap();
849///
850/// let metadata = writer.close().unwrap();
851/// assert_eq!(metadata.file_metadata().num_rows(), 3);
852/// ```
853pub struct ArrowColumnWriter {
854    writer: ArrowColumnWriterImpl,
855    chunk: SharedColumnChunk,
856}
857
858impl std::fmt::Debug for ArrowColumnWriter {
859    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860        f.debug_struct("ArrowColumnWriter").finish_non_exhaustive()
861    }
862}
863
864enum ArrowColumnWriterImpl {
865    ByteArray(GenericColumnWriter<'static, ByteArrayEncoder>),
866    Column(ColumnWriter<'static>),
867}
868
869impl ArrowColumnWriter {
870    /// Write an [`ArrowLeafColumn`]
871    pub fn write(&mut self, col: &ArrowLeafColumn) -> Result<()> {
872        match &mut self.writer {
873            ArrowColumnWriterImpl::Column(c) => {
874                let leaf = col.0.array();
875                match leaf.as_any_dictionary_opt() {
876                    Some(dictionary) => {
877                        let materialized =
878                            arrow_select::take::take(dictionary.values(), dictionary.keys(), None)?;
879                        write_leaf(c, &materialized, &col.0)?
880                    }
881                    None => write_leaf(c, leaf, &col.0)?,
882                };
883            }
884            ArrowColumnWriterImpl::ByteArray(c) => {
885                write_primitive(c, col.0.array().as_ref(), &col.0)?;
886            }
887        }
888        Ok(())
889    }
890
891    /// Close this column returning the written [`ArrowColumnChunk`]
892    pub fn close(self) -> Result<ArrowColumnChunk> {
893        let close = match self.writer {
894            ArrowColumnWriterImpl::ByteArray(c) => c.close()?,
895            ArrowColumnWriterImpl::Column(c) => c.close()?,
896        };
897        let chunk = Arc::try_unwrap(self.chunk).ok().unwrap();
898        let data = chunk.into_inner().unwrap();
899        Ok(ArrowColumnChunk { data, close })
900    }
901
902    /// Returns the estimated total memory usage by the writer.
903    ///
904    /// This  [`Self::get_estimated_total_bytes`] this is an estimate
905    /// of the current memory usage and not it's anticipated encoded size.
906    ///
907    /// This includes:
908    /// 1. Data buffered in encoded form
909    /// 2. Data buffered in un-encoded form (e.g. `usize` dictionary keys)
910    ///
911    /// This value should be greater than or equal to [`Self::get_estimated_total_bytes`]
912    pub fn memory_size(&self) -> usize {
913        match &self.writer {
914            ArrowColumnWriterImpl::ByteArray(c) => c.memory_size(),
915            ArrowColumnWriterImpl::Column(c) => c.memory_size(),
916        }
917    }
918
919    /// Returns the estimated total encoded bytes for this column writer.
920    ///
921    /// This includes:
922    /// 1. Data buffered in encoded form
923    /// 2. An estimate of how large the data buffered in un-encoded form would be once encoded
924    ///
925    /// This value should be less than or equal to [`Self::memory_size`]
926    pub fn get_estimated_total_bytes(&self) -> usize {
927        match &self.writer {
928            ArrowColumnWriterImpl::ByteArray(c) => c.get_estimated_total_bytes() as _,
929            ArrowColumnWriterImpl::Column(c) => c.get_estimated_total_bytes() as _,
930        }
931    }
932}
933
934/// Encodes [`RecordBatch`] to a parquet row group
935///
936/// Note: this structure is created by [`ArrowRowGroupWriterFactory`] internally used to
937/// create [`ArrowRowGroupWriter`]s, but it is not exposed publicly.
938///
939/// See the example on [`ArrowColumnWriter`] for how to encode columns in parallel
940#[derive(Debug)]
941struct ArrowRowGroupWriter {
942    writers: Vec<ArrowColumnWriter>,
943    schema: SchemaRef,
944    buffered_rows: usize,
945}
946
947impl ArrowRowGroupWriter {
948    fn new(writers: Vec<ArrowColumnWriter>, arrow: &SchemaRef) -> Self {
949        Self {
950            writers,
951            schema: arrow.clone(),
952            buffered_rows: 0,
953        }
954    }
955
956    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
957        self.buffered_rows += batch.num_rows();
958        let mut writers = self.writers.iter_mut();
959        for (field, column) in self.schema.fields().iter().zip(batch.columns()) {
960            for leaf in compute_leaves(field.as_ref(), column)? {
961                writers.next().unwrap().write(&leaf)?
962            }
963        }
964        Ok(())
965    }
966
967    /// Returns the estimated total encoded bytes for this row group
968    fn get_estimated_total_bytes(&self) -> usize {
969        self.writers
970            .iter()
971            .map(|x| x.get_estimated_total_bytes())
972            .sum()
973    }
974
975    fn close(self) -> Result<Vec<ArrowColumnChunk>> {
976        self.writers
977            .into_iter()
978            .map(|writer| writer.close())
979            .collect()
980    }
981}
982
983/// Factory that creates new column writers for each row group in the Parquet file.
984///
985/// You can create this structure via an [`ArrowWriter::into_serialized_writer`].
986/// See the example on [`ArrowColumnWriter`] for how to encode columns in parallel
987#[derive(Debug)]
988pub struct ArrowRowGroupWriterFactory {
989    schema: SchemaDescPtr,
990    arrow_schema: SchemaRef,
991    props: WriterPropertiesPtr,
992    #[cfg(feature = "encryption")]
993    file_encryptor: Option<Arc<FileEncryptor>>,
994}
995
996impl ArrowRowGroupWriterFactory {
997    /// Create a new [`ArrowRowGroupWriterFactory`] for the provided file writer and Arrow schema
998    pub fn new<W: Write + Send>(
999        file_writer: &SerializedFileWriter<W>,
1000        arrow_schema: SchemaRef,
1001    ) -> Self {
1002        let schema = Arc::clone(file_writer.schema_descr_ptr());
1003        let props = Arc::clone(file_writer.properties());
1004        Self {
1005            schema,
1006            arrow_schema,
1007            props,
1008            #[cfg(feature = "encryption")]
1009            file_encryptor: file_writer.file_encryptor(),
1010        }
1011    }
1012
1013    fn create_row_group_writer(&self, row_group_index: usize) -> Result<ArrowRowGroupWriter> {
1014        let writers = self.create_column_writers(row_group_index)?;
1015        Ok(ArrowRowGroupWriter::new(writers, &self.arrow_schema))
1016    }
1017
1018    /// Create column writers for a new row group, with the given row group index
1019    pub fn create_column_writers(&self, row_group_index: usize) -> Result<Vec<ArrowColumnWriter>> {
1020        let mut writers = Vec::with_capacity(self.arrow_schema.fields.len());
1021        let mut leaves = self.schema.columns().iter();
1022        let column_factory = self.column_writer_factory(row_group_index);
1023        for field in &self.arrow_schema.fields {
1024            column_factory.get_arrow_column_writer(
1025                field.data_type(),
1026                &self.props,
1027                &mut leaves,
1028                &mut writers,
1029            )?;
1030        }
1031        Ok(writers)
1032    }
1033
1034    #[cfg(feature = "encryption")]
1035    fn column_writer_factory(&self, row_group_idx: usize) -> ArrowColumnWriterFactory {
1036        ArrowColumnWriterFactory::new()
1037            .with_file_encryptor(row_group_idx, self.file_encryptor.clone())
1038    }
1039
1040    #[cfg(not(feature = "encryption"))]
1041    fn column_writer_factory(&self, _row_group_idx: usize) -> ArrowColumnWriterFactory {
1042        ArrowColumnWriterFactory::new()
1043    }
1044}
1045
1046/// Returns [`ArrowColumnWriter`]s for each column in a given schema
1047#[deprecated(since = "57.0.0", note = "Use `ArrowRowGroupWriterFactory` instead")]
1048pub fn get_column_writers(
1049    parquet: &SchemaDescriptor,
1050    props: &WriterPropertiesPtr,
1051    arrow: &SchemaRef,
1052) -> Result<Vec<ArrowColumnWriter>> {
1053    let mut writers = Vec::with_capacity(arrow.fields.len());
1054    let mut leaves = parquet.columns().iter();
1055    let column_factory = ArrowColumnWriterFactory::new();
1056    for field in &arrow.fields {
1057        column_factory.get_arrow_column_writer(
1058            field.data_type(),
1059            props,
1060            &mut leaves,
1061            &mut writers,
1062        )?;
1063    }
1064    Ok(writers)
1065}
1066
1067/// Creates [`ArrowColumnWriter`] instances
1068struct ArrowColumnWriterFactory {
1069    #[cfg(feature = "encryption")]
1070    row_group_index: usize,
1071    #[cfg(feature = "encryption")]
1072    file_encryptor: Option<Arc<FileEncryptor>>,
1073}
1074
1075impl ArrowColumnWriterFactory {
1076    pub fn new() -> Self {
1077        Self {
1078            #[cfg(feature = "encryption")]
1079            row_group_index: 0,
1080            #[cfg(feature = "encryption")]
1081            file_encryptor: None,
1082        }
1083    }
1084
1085    #[cfg(feature = "encryption")]
1086    pub fn with_file_encryptor(
1087        mut self,
1088        row_group_index: usize,
1089        file_encryptor: Option<Arc<FileEncryptor>>,
1090    ) -> Self {
1091        self.row_group_index = row_group_index;
1092        self.file_encryptor = file_encryptor;
1093        self
1094    }
1095
1096    #[cfg(feature = "encryption")]
1097    fn create_page_writer(
1098        &self,
1099        column_descriptor: &ColumnDescPtr,
1100        column_index: usize,
1101    ) -> Result<Box<ArrowPageWriter>> {
1102        let column_path = column_descriptor.path().string();
1103        let page_encryptor = PageEncryptor::create_if_column_encrypted(
1104            &self.file_encryptor,
1105            self.row_group_index,
1106            column_index,
1107            &column_path,
1108        )?;
1109        Ok(Box::new(
1110            ArrowPageWriter::default().with_encryptor(page_encryptor),
1111        ))
1112    }
1113
1114    #[cfg(not(feature = "encryption"))]
1115    fn create_page_writer(
1116        &self,
1117        _column_descriptor: &ColumnDescPtr,
1118        _column_index: usize,
1119    ) -> Result<Box<ArrowPageWriter>> {
1120        Ok(Box::<ArrowPageWriter>::default())
1121    }
1122
1123    /// Gets an [`ArrowColumnWriter`] for the given `data_type`, appending the
1124    /// output ColumnDesc to `leaves` and the column writers to `out`
1125    fn get_arrow_column_writer(
1126        &self,
1127        data_type: &ArrowDataType,
1128        props: &WriterPropertiesPtr,
1129        leaves: &mut Iter<'_, ColumnDescPtr>,
1130        out: &mut Vec<ArrowColumnWriter>,
1131    ) -> Result<()> {
1132        // Instantiate writers for normal columns
1133        let col = |desc: &ColumnDescPtr| -> Result<ArrowColumnWriter> {
1134            let page_writer = self.create_page_writer(desc, out.len())?;
1135            let chunk = page_writer.buffer.clone();
1136            let writer = get_column_writer(desc.clone(), props.clone(), page_writer);
1137            Ok(ArrowColumnWriter {
1138                chunk,
1139                writer: ArrowColumnWriterImpl::Column(writer),
1140            })
1141        };
1142
1143        // Instantiate writers for byte arrays (e.g. Utf8,  Binary, etc)
1144        let bytes = |desc: &ColumnDescPtr| -> Result<ArrowColumnWriter> {
1145            let page_writer = self.create_page_writer(desc, out.len())?;
1146            let chunk = page_writer.buffer.clone();
1147            let writer = GenericColumnWriter::new(desc.clone(), props.clone(), page_writer);
1148            Ok(ArrowColumnWriter {
1149                chunk,
1150                writer: ArrowColumnWriterImpl::ByteArray(writer),
1151            })
1152        };
1153
1154        match data_type {
1155            _ if data_type.is_primitive() => out.push(col(leaves.next().unwrap())?),
1156            ArrowDataType::FixedSizeBinary(_) | ArrowDataType::Boolean | ArrowDataType::Null => {
1157                out.push(col(leaves.next().unwrap())?)
1158            }
1159            ArrowDataType::LargeBinary
1160            | ArrowDataType::Binary
1161            | ArrowDataType::Utf8
1162            | ArrowDataType::LargeUtf8
1163            | ArrowDataType::BinaryView
1164            | ArrowDataType::Utf8View => out.push(bytes(leaves.next().unwrap())?),
1165            ArrowDataType::List(f)
1166            | ArrowDataType::LargeList(f)
1167            | ArrowDataType::FixedSizeList(f, _)
1168            | ArrowDataType::ListView(f)
1169            | ArrowDataType::LargeListView(f) => {
1170                self.get_arrow_column_writer(f.data_type(), props, leaves, out)?
1171            }
1172            ArrowDataType::Struct(fields) => {
1173                for field in fields {
1174                    self.get_arrow_column_writer(field.data_type(), props, leaves, out)?
1175                }
1176            }
1177            ArrowDataType::Map(f, _) => match f.data_type() {
1178                ArrowDataType::Struct(f) => {
1179                    self.get_arrow_column_writer(f[0].data_type(), props, leaves, out)?;
1180                    self.get_arrow_column_writer(f[1].data_type(), props, leaves, out)?
1181                }
1182                _ => unreachable!("invalid map type"),
1183            },
1184            ArrowDataType::Dictionary(_, value_type) => match value_type.as_ref() {
1185                ArrowDataType::Utf8
1186                | ArrowDataType::LargeUtf8
1187                | ArrowDataType::Binary
1188                | ArrowDataType::LargeBinary => out.push(bytes(leaves.next().unwrap())?),
1189                ArrowDataType::Utf8View | ArrowDataType::BinaryView => {
1190                    out.push(bytes(leaves.next().unwrap())?)
1191                }
1192                ArrowDataType::FixedSizeBinary(_) => out.push(bytes(leaves.next().unwrap())?),
1193                _ => out.push(col(leaves.next().unwrap())?),
1194            },
1195            _ => {
1196                return Err(ParquetError::NYI(format!(
1197                    "Attempting to write an Arrow type {data_type} to parquet that is not yet implemented"
1198                )));
1199            }
1200        }
1201        Ok(())
1202    }
1203}
1204
1205fn write_leaf(
1206    writer: &mut ColumnWriter<'_>,
1207    column: &dyn arrow_array::Array,
1208    levels: &ArrayLevels,
1209) -> Result<usize> {
1210    let indices = levels.non_null_indices();
1211
1212    match writer {
1213        // Note: this should match the contents of arrow_to_parquet_type
1214        ColumnWriter::Int32ColumnWriter(typed) => {
1215            match column.data_type() {
1216                ArrowDataType::Null => {
1217                    let array = Int32Array::new_null(column.len());
1218                    write_primitive(typed, array.values(), levels)
1219                }
1220                ArrowDataType::Int8 => {
1221                    let array: Int32Array = column.as_primitive::<Int8Type>().unary(|x| x as i32);
1222                    write_primitive(typed, array.values(), levels)
1223                }
1224                ArrowDataType::Int16 => {
1225                    let array: Int32Array = column.as_primitive::<Int16Type>().unary(|x| x as i32);
1226                    write_primitive(typed, array.values(), levels)
1227                }
1228                ArrowDataType::Int32 => {
1229                    write_primitive(typed, column.as_primitive::<Int32Type>().values(), levels)
1230                }
1231                ArrowDataType::UInt8 => {
1232                    let array: Int32Array = column.as_primitive::<UInt8Type>().unary(|x| x as i32);
1233                    write_primitive(typed, array.values(), levels)
1234                }
1235                ArrowDataType::UInt16 => {
1236                    let array: Int32Array = column.as_primitive::<UInt16Type>().unary(|x| x as i32);
1237                    write_primitive(typed, array.values(), levels)
1238                }
1239                ArrowDataType::UInt32 => {
1240                    // follow C++ implementation and use overflow/reinterpret cast from  u32 to i32 which will map
1241                    // `(i32::MAX as u32)..u32::MAX` to `i32::MIN..0`
1242                    let array = column.as_primitive::<UInt32Type>();
1243                    write_primitive(typed, array.values().inner().typed_data(), levels)
1244                }
1245                ArrowDataType::Date32 => {
1246                    let array = column.as_primitive::<Date32Type>();
1247                    write_primitive(typed, array.values(), levels)
1248                }
1249                ArrowDataType::Time32(TimeUnit::Second) => {
1250                    let array = column.as_primitive::<Time32SecondType>();
1251                    write_primitive(typed, array.values(), levels)
1252                }
1253                ArrowDataType::Time32(TimeUnit::Millisecond) => {
1254                    let array = column.as_primitive::<Time32MillisecondType>();
1255                    write_primitive(typed, array.values(), levels)
1256                }
1257                ArrowDataType::Date64 => {
1258                    // If the column is a Date64, we truncate it
1259                    let array: Int32Array = column
1260                        .as_primitive::<Date64Type>()
1261                        .unary(|x| (x / 86_400_000) as _);
1262
1263                    write_primitive(typed, array.values(), levels)
1264                }
1265                ArrowDataType::Decimal32(_, _) => {
1266                    let array = column
1267                        .as_primitive::<Decimal32Type>()
1268                        .unary::<_, Int32Type>(|v| v);
1269                    write_primitive(typed, array.values(), levels)
1270                }
1271                ArrowDataType::Decimal64(_, _) => {
1272                    // use the int32 to represent the decimal with low precision
1273                    let array = column
1274                        .as_primitive::<Decimal64Type>()
1275                        .unary::<_, Int32Type>(|v| v as i32);
1276                    write_primitive(typed, array.values(), levels)
1277                }
1278                ArrowDataType::Decimal128(_, _) => {
1279                    // use the int32 to represent the decimal with low precision
1280                    let array = column
1281                        .as_primitive::<Decimal128Type>()
1282                        .unary::<_, Int32Type>(|v| v as i32);
1283                    write_primitive(typed, array.values(), levels)
1284                }
1285                ArrowDataType::Decimal256(_, _) => {
1286                    // use the int32 to represent the decimal with low precision
1287                    let array = column
1288                        .as_primitive::<Decimal256Type>()
1289                        .unary::<_, Int32Type>(|v| v.as_i128() as i32);
1290                    write_primitive(typed, array.values(), levels)
1291                }
1292                d => Err(ParquetError::General(format!("Cannot coerce {d} to I32"))),
1293            }
1294        }
1295        ColumnWriter::BoolColumnWriter(typed) => {
1296            let array = column.as_boolean();
1297            typed.write_batch(
1298                get_bool_array_slice(array, indices).as_slice(),
1299                levels.def_levels(),
1300                levels.rep_levels(),
1301            )
1302        }
1303        ColumnWriter::Int64ColumnWriter(typed) => {
1304            match column.data_type() {
1305                ArrowDataType::Date64 => {
1306                    let array = column
1307                        .as_primitive::<Date64Type>()
1308                        .reinterpret_cast::<Int64Type>();
1309
1310                    write_primitive(typed, array.values(), levels)
1311                }
1312                ArrowDataType::Int64 => {
1313                    let array = column.as_primitive::<Int64Type>();
1314                    write_primitive(typed, array.values(), levels)
1315                }
1316                ArrowDataType::UInt64 => {
1317                    let values = column.as_primitive::<UInt64Type>().values();
1318                    // follow C++ implementation and use overflow/reinterpret cast from  u64 to i64 which will map
1319                    // `(i64::MAX as u64)..u64::MAX` to `i64::MIN..0`
1320                    let array = values.inner().typed_data::<i64>();
1321                    write_primitive(typed, array, levels)
1322                }
1323                ArrowDataType::Time64(TimeUnit::Microsecond) => {
1324                    let array = column.as_primitive::<Time64MicrosecondType>();
1325                    write_primitive(typed, array.values(), levels)
1326                }
1327                ArrowDataType::Time64(TimeUnit::Nanosecond) => {
1328                    let array = column.as_primitive::<Time64NanosecondType>();
1329                    write_primitive(typed, array.values(), levels)
1330                }
1331                ArrowDataType::Timestamp(unit, _) => match unit {
1332                    TimeUnit::Second => {
1333                        let array = column.as_primitive::<TimestampSecondType>();
1334                        write_primitive(typed, array.values(), levels)
1335                    }
1336                    TimeUnit::Millisecond => {
1337                        let array = column.as_primitive::<TimestampMillisecondType>();
1338                        write_primitive(typed, array.values(), levels)
1339                    }
1340                    TimeUnit::Microsecond => {
1341                        let array = column.as_primitive::<TimestampMicrosecondType>();
1342                        write_primitive(typed, array.values(), levels)
1343                    }
1344                    TimeUnit::Nanosecond => {
1345                        let array = column.as_primitive::<TimestampNanosecondType>();
1346                        write_primitive(typed, array.values(), levels)
1347                    }
1348                },
1349                ArrowDataType::Duration(unit) => match unit {
1350                    TimeUnit::Second => {
1351                        let array = column.as_primitive::<DurationSecondType>();
1352                        write_primitive(typed, array.values(), levels)
1353                    }
1354                    TimeUnit::Millisecond => {
1355                        let array = column.as_primitive::<DurationMillisecondType>();
1356                        write_primitive(typed, array.values(), levels)
1357                    }
1358                    TimeUnit::Microsecond => {
1359                        let array = column.as_primitive::<DurationMicrosecondType>();
1360                        write_primitive(typed, array.values(), levels)
1361                    }
1362                    TimeUnit::Nanosecond => {
1363                        let array = column.as_primitive::<DurationNanosecondType>();
1364                        write_primitive(typed, array.values(), levels)
1365                    }
1366                },
1367                ArrowDataType::Decimal64(_, _) => {
1368                    let array = column
1369                        .as_primitive::<Decimal64Type>()
1370                        .reinterpret_cast::<Int64Type>();
1371                    write_primitive(typed, array.values(), levels)
1372                }
1373                ArrowDataType::Decimal128(_, _) => {
1374                    // use the int64 to represent the decimal with low precision
1375                    let array = column
1376                        .as_primitive::<Decimal128Type>()
1377                        .unary::<_, Int64Type>(|v| v as i64);
1378                    write_primitive(typed, array.values(), levels)
1379                }
1380                ArrowDataType::Decimal256(_, _) => {
1381                    // use the int64 to represent the decimal with low precision
1382                    let array = column
1383                        .as_primitive::<Decimal256Type>()
1384                        .unary::<_, Int64Type>(|v| v.as_i128() as i64);
1385                    write_primitive(typed, array.values(), levels)
1386                }
1387                d => Err(ParquetError::General(format!("Cannot coerce {d} to I64"))),
1388            }
1389        }
1390        ColumnWriter::Int96ColumnWriter(_typed) => {
1391            unreachable!("Currently unreachable because data type not supported")
1392        }
1393        ColumnWriter::FloatColumnWriter(typed) => {
1394            let array = column.as_primitive::<Float32Type>();
1395            write_primitive(typed, array.values(), levels)
1396        }
1397        ColumnWriter::DoubleColumnWriter(typed) => {
1398            let array = column.as_primitive::<Float64Type>();
1399            write_primitive(typed, array.values(), levels)
1400        }
1401        ColumnWriter::ByteArrayColumnWriter(_) => {
1402            unreachable!("should use ByteArrayWriter")
1403        }
1404        ColumnWriter::FixedLenByteArrayColumnWriter(typed) => {
1405            let bytes = match column.data_type() {
1406                ArrowDataType::Interval(interval_unit) => match interval_unit {
1407                    IntervalUnit::YearMonth => {
1408                        let array = column.as_primitive::<IntervalYearMonthType>();
1409                        get_interval_ym_array_slice(array, indices)
1410                    }
1411                    IntervalUnit::DayTime => {
1412                        let array = column.as_primitive::<IntervalDayTimeType>();
1413                        get_interval_dt_array_slice(array, indices)
1414                    }
1415                    _ => {
1416                        return Err(ParquetError::NYI(format!(
1417                            "Attempting to write an Arrow interval type {interval_unit:?} to parquet that is not yet implemented"
1418                        )));
1419                    }
1420                },
1421                ArrowDataType::FixedSizeBinary(_) => {
1422                    let array = column.as_fixed_size_binary();
1423                    get_fsb_array_slice(array, indices)
1424                }
1425                ArrowDataType::Decimal32(_, _) => {
1426                    let array = column.as_primitive::<Decimal32Type>();
1427                    get_decimal_32_array_slice(array, indices)
1428                }
1429                ArrowDataType::Decimal64(_, _) => {
1430                    let array = column.as_primitive::<Decimal64Type>();
1431                    get_decimal_64_array_slice(array, indices)
1432                }
1433                ArrowDataType::Decimal128(_, _) => {
1434                    let array = column.as_primitive::<Decimal128Type>();
1435                    get_decimal_128_array_slice(array, indices)
1436                }
1437                ArrowDataType::Decimal256(_, _) => {
1438                    let array = column.as_primitive::<Decimal256Type>();
1439                    get_decimal_256_array_slice(array, indices)
1440                }
1441                ArrowDataType::Float16 => {
1442                    let array = column.as_primitive::<Float16Type>();
1443                    get_float_16_array_slice(array, indices)
1444                }
1445                _ => {
1446                    return Err(ParquetError::NYI(
1447                        "Attempting to write an Arrow type that is not yet implemented".to_string(),
1448                    ));
1449                }
1450            };
1451            typed.write_batch(bytes.as_slice(), levels.def_levels(), levels.rep_levels())
1452        }
1453    }
1454}
1455
1456fn write_primitive<E: ColumnValueEncoder>(
1457    writer: &mut GenericColumnWriter<E>,
1458    values: &E::Values,
1459    levels: &ArrayLevels,
1460) -> Result<usize> {
1461    writer.write_batch_internal(
1462        values,
1463        Some(levels.non_null_indices()),
1464        levels.def_levels(),
1465        levels.rep_levels(),
1466        None,
1467        None,
1468        None,
1469    )
1470}
1471
1472fn get_bool_array_slice(array: &arrow_array::BooleanArray, indices: &[usize]) -> Vec<bool> {
1473    let mut values = Vec::with_capacity(indices.len());
1474    for i in indices {
1475        values.push(array.value(*i))
1476    }
1477    values
1478}
1479
1480/// Returns 12-byte values representing 3 values of months, days and milliseconds (4-bytes each).
1481/// An Arrow YearMonth interval only stores months, thus only the first 4 bytes are populated.
1482fn get_interval_ym_array_slice(
1483    array: &arrow_array::IntervalYearMonthArray,
1484    indices: &[usize],
1485) -> Vec<FixedLenByteArray> {
1486    let mut values = Vec::with_capacity(indices.len());
1487    for i in indices {
1488        let mut value = array.value(*i).to_le_bytes().to_vec();
1489        let mut suffix = vec![0; 8];
1490        value.append(&mut suffix);
1491        values.push(FixedLenByteArray::from(ByteArray::from(value)))
1492    }
1493    values
1494}
1495
1496/// Returns 12-byte values representing 3 values of months, days and milliseconds (4-bytes each).
1497/// An Arrow DayTime interval only stores days and millis, thus the first 4 bytes are not populated.
1498fn get_interval_dt_array_slice(
1499    array: &arrow_array::IntervalDayTimeArray,
1500    indices: &[usize],
1501) -> Vec<FixedLenByteArray> {
1502    let mut values = Vec::with_capacity(indices.len());
1503    for i in indices {
1504        let mut out = [0; 12];
1505        let value = array.value(*i);
1506        out[4..8].copy_from_slice(&value.days.to_le_bytes());
1507        out[8..12].copy_from_slice(&value.milliseconds.to_le_bytes());
1508        values.push(FixedLenByteArray::from(ByteArray::from(out.to_vec())));
1509    }
1510    values
1511}
1512
1513fn get_decimal_32_array_slice(
1514    array: &arrow_array::Decimal32Array,
1515    indices: &[usize],
1516) -> Vec<FixedLenByteArray> {
1517    let mut values = Vec::with_capacity(indices.len());
1518    let size = decimal_length_from_precision(array.precision());
1519    for i in indices {
1520        let as_be_bytes = array.value(*i).to_be_bytes();
1521        let resized_value = as_be_bytes[(4 - size)..].to_vec();
1522        values.push(FixedLenByteArray::from(ByteArray::from(resized_value)));
1523    }
1524    values
1525}
1526
1527fn get_decimal_64_array_slice(
1528    array: &arrow_array::Decimal64Array,
1529    indices: &[usize],
1530) -> Vec<FixedLenByteArray> {
1531    let mut values = Vec::with_capacity(indices.len());
1532    let size = decimal_length_from_precision(array.precision());
1533    for i in indices {
1534        let as_be_bytes = array.value(*i).to_be_bytes();
1535        let resized_value = as_be_bytes[(8 - size)..].to_vec();
1536        values.push(FixedLenByteArray::from(ByteArray::from(resized_value)));
1537    }
1538    values
1539}
1540
1541fn get_decimal_128_array_slice(
1542    array: &arrow_array::Decimal128Array,
1543    indices: &[usize],
1544) -> Vec<FixedLenByteArray> {
1545    let mut values = Vec::with_capacity(indices.len());
1546    let size = decimal_length_from_precision(array.precision());
1547    for i in indices {
1548        let as_be_bytes = array.value(*i).to_be_bytes();
1549        let resized_value = as_be_bytes[(16 - size)..].to_vec();
1550        values.push(FixedLenByteArray::from(ByteArray::from(resized_value)));
1551    }
1552    values
1553}
1554
1555fn get_decimal_256_array_slice(
1556    array: &arrow_array::Decimal256Array,
1557    indices: &[usize],
1558) -> Vec<FixedLenByteArray> {
1559    let mut values = Vec::with_capacity(indices.len());
1560    let size = decimal_length_from_precision(array.precision());
1561    for i in indices {
1562        let as_be_bytes = array.value(*i).to_be_bytes();
1563        let resized_value = as_be_bytes[(32 - size)..].to_vec();
1564        values.push(FixedLenByteArray::from(ByteArray::from(resized_value)));
1565    }
1566    values
1567}
1568
1569fn get_float_16_array_slice(
1570    array: &arrow_array::Float16Array,
1571    indices: &[usize],
1572) -> Vec<FixedLenByteArray> {
1573    let mut values = Vec::with_capacity(indices.len());
1574    for i in indices {
1575        let value = array.value(*i).to_le_bytes().to_vec();
1576        values.push(FixedLenByteArray::from(ByteArray::from(value)));
1577    }
1578    values
1579}
1580
1581fn get_fsb_array_slice(
1582    array: &arrow_array::FixedSizeBinaryArray,
1583    indices: &[usize],
1584) -> Vec<FixedLenByteArray> {
1585    let mut values = Vec::with_capacity(indices.len());
1586    for i in indices {
1587        let value = array.value(*i).to_vec();
1588        values.push(FixedLenByteArray::from(ByteArray::from(value)))
1589    }
1590    values
1591}
1592
1593#[cfg(test)]
1594mod tests {
1595    use super::*;
1596    use std::collections::HashMap;
1597
1598    use std::fs::File;
1599
1600    use crate::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
1601    use crate::arrow::{ARROW_SCHEMA_META_KEY, PARQUET_FIELD_ID_META_KEY};
1602    use crate::column::page::{Page, PageReader};
1603    use crate::file::metadata::thrift::PageHeader;
1604    use crate::file::page_index::column_index::ColumnIndexMetaData;
1605    use crate::file::reader::SerializedPageReader;
1606    use crate::parquet_thrift::{ReadThrift, ThriftSliceInputProtocol};
1607    use crate::schema::types::ColumnPath;
1608    use arrow::datatypes::ToByteSlice;
1609    use arrow::datatypes::{DataType, Schema};
1610    use arrow::error::Result as ArrowResult;
1611    use arrow::util::data_gen::create_random_array;
1612    use arrow::util::pretty::pretty_format_batches;
1613    use arrow::{array::*, buffer::Buffer};
1614    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, NullBuffer, OffsetBuffer, i256};
1615    use arrow_schema::Fields;
1616    use half::f16;
1617    use num_traits::{FromPrimitive, ToPrimitive};
1618    use tempfile::tempfile;
1619
1620    use crate::basic::Encoding;
1621    use crate::data_type::AsBytes;
1622    use crate::file::metadata::{ColumnChunkMetaData, ParquetMetaData, ParquetMetaDataReader};
1623    use crate::file::properties::{
1624        BloomFilterPosition, EnabledStatistics, ReaderProperties, WriterVersion,
1625    };
1626    use crate::file::serialized_reader::ReadOptionsBuilder;
1627    use crate::file::{
1628        reader::{FileReader, SerializedFileReader},
1629        statistics::Statistics,
1630    };
1631
1632    #[test]
1633    fn arrow_writer() {
1634        // define schema
1635        let schema = Schema::new(vec![
1636            Field::new("a", DataType::Int32, false),
1637            Field::new("b", DataType::Int32, true),
1638        ]);
1639
1640        // create some data
1641        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1642        let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]);
1643
1644        // build a record batch
1645        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1646
1647        roundtrip(batch, Some(SMALL_SIZE / 2));
1648    }
1649
1650    fn get_bytes_after_close(schema: SchemaRef, expected_batch: &RecordBatch) -> Vec<u8> {
1651        let mut buffer = vec![];
1652
1653        let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
1654        writer.write(expected_batch).unwrap();
1655        writer.close().unwrap();
1656
1657        buffer
1658    }
1659
1660    fn get_bytes_by_into_inner(schema: SchemaRef, expected_batch: &RecordBatch) -> Vec<u8> {
1661        let mut writer = ArrowWriter::try_new(Vec::new(), schema, None).unwrap();
1662        writer.write(expected_batch).unwrap();
1663        writer.into_inner().unwrap()
1664    }
1665
1666    #[test]
1667    fn roundtrip_bytes() {
1668        // define schema
1669        let schema = Arc::new(Schema::new(vec![
1670            Field::new("a", DataType::Int32, false),
1671            Field::new("b", DataType::Int32, true),
1672        ]));
1673
1674        // create some data
1675        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1676        let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]);
1677
1678        // build a record batch
1679        let expected_batch =
1680            RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap();
1681
1682        for buffer in [
1683            get_bytes_after_close(schema.clone(), &expected_batch),
1684            get_bytes_by_into_inner(schema, &expected_batch),
1685        ] {
1686            let cursor = Bytes::from(buffer);
1687            let mut record_batch_reader = ParquetRecordBatchReader::try_new(cursor, 1024).unwrap();
1688
1689            let actual_batch = record_batch_reader
1690                .next()
1691                .expect("No batch found")
1692                .expect("Unable to get batch");
1693
1694            assert_eq!(expected_batch.schema(), actual_batch.schema());
1695            assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
1696            assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
1697            for i in 0..expected_batch.num_columns() {
1698                let expected_data = expected_batch.column(i).to_data();
1699                let actual_data = actual_batch.column(i).to_data();
1700
1701                assert_eq!(expected_data, actual_data);
1702            }
1703        }
1704    }
1705
1706    #[test]
1707    fn arrow_writer_non_null() {
1708        // define schema
1709        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1710
1711        // create some data
1712        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1713
1714        // build a record batch
1715        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1716
1717        roundtrip(batch, Some(SMALL_SIZE / 2));
1718    }
1719
1720    #[test]
1721    fn arrow_writer_list() {
1722        // define schema
1723        let schema = Schema::new(vec![Field::new(
1724            "a",
1725            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))),
1726            true,
1727        )]);
1728
1729        // create some data
1730        let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
1731
1732        // Construct a buffer for value offsets, for the nested array:
1733        //  [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]]
1734        let a_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice());
1735
1736        // Construct a list array from the above two
1737        let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field(
1738            DataType::Int32,
1739            false,
1740        ))))
1741        .len(5)
1742        .add_buffer(a_value_offsets)
1743        .add_child_data(a_values.into_data())
1744        .null_bit_buffer(Some(Buffer::from([0b00011011])))
1745        .build()
1746        .unwrap();
1747        let a = ListArray::from(a_list_data);
1748
1749        // build a record batch
1750        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1751
1752        assert_eq!(batch.column(0).null_count(), 1);
1753
1754        // This test fails if the max row group size is less than the batch's length
1755        // see https://github.com/apache/arrow-rs/issues/518
1756        roundtrip(batch, None);
1757    }
1758
1759    #[test]
1760    fn arrow_writer_list_non_null() {
1761        // define schema
1762        let schema = Schema::new(vec![Field::new(
1763            "a",
1764            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))),
1765            false,
1766        )]);
1767
1768        // create some data
1769        let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
1770
1771        // Construct a buffer for value offsets, for the nested array:
1772        //  [[1], [2, 3], [], [4, 5, 6], [7, 8, 9, 10]]
1773        let a_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice());
1774
1775        // Construct a list array from the above two
1776        let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field(
1777            DataType::Int32,
1778            false,
1779        ))))
1780        .len(5)
1781        .add_buffer(a_value_offsets)
1782        .add_child_data(a_values.into_data())
1783        .build()
1784        .unwrap();
1785        let a = ListArray::from(a_list_data);
1786
1787        // build a record batch
1788        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1789
1790        // This test fails if the max row group size is less than the batch's length
1791        // see https://github.com/apache/arrow-rs/issues/518
1792        assert_eq!(batch.column(0).null_count(), 0);
1793
1794        roundtrip(batch, None);
1795    }
1796
1797    #[test]
1798    fn arrow_writer_list_view() {
1799        let list_field = Arc::new(Field::new_list_field(DataType::Int32, false));
1800        let schema = Schema::new(vec![Field::new(
1801            "a",
1802            DataType::ListView(list_field.clone()),
1803            true,
1804        )]);
1805
1806        //  [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]]
1807        let a = ListViewArray::new(
1808            list_field,
1809            vec![0, 1, 0, 3, 6].into(),
1810            vec![1, 2, 0, 3, 4].into(),
1811            Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
1812            Some(vec![true, true, false, true, true].into()),
1813        );
1814
1815        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1816
1817        assert_eq!(batch.column(0).null_count(), 1);
1818
1819        roundtrip(batch, None);
1820    }
1821
1822    #[test]
1823    fn arrow_writer_list_view_non_null() {
1824        let list_field = Arc::new(Field::new_list_field(DataType::Int32, false));
1825        let schema = Schema::new(vec![Field::new(
1826            "a",
1827            DataType::ListView(list_field.clone()),
1828            false,
1829        )]);
1830
1831        //  [[1], [2, 3], [], [4, 5, 6], [7, 8, 9, 10]]
1832        let a = ListViewArray::new(
1833            list_field,
1834            vec![0, 1, 0, 3, 6].into(),
1835            vec![1, 2, 0, 3, 4].into(),
1836            Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
1837            None,
1838        );
1839
1840        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1841
1842        assert_eq!(batch.column(0).null_count(), 0);
1843
1844        roundtrip(batch, None);
1845    }
1846
1847    #[test]
1848    fn arrow_writer_list_view_out_of_order() {
1849        let list_field = Arc::new(Field::new_list_field(DataType::Int32, false));
1850        let schema = Schema::new(vec![Field::new(
1851            "a",
1852            DataType::ListView(list_field.clone()),
1853            false,
1854        )]);
1855
1856        // [[1], [2, 3], [], [7, 8, 9, 10], [4, 5, 6]] - out of order offsets
1857        let a = ListViewArray::new(
1858            list_field,
1859            vec![0, 1, 0, 6, 3].into(),
1860            vec![1, 2, 0, 4, 3].into(),
1861            Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
1862            None,
1863        );
1864
1865        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1866
1867        roundtrip(batch, None);
1868    }
1869
1870    #[test]
1871    fn arrow_writer_large_list_view() {
1872        let list_field = Arc::new(Field::new_list_field(DataType::Int32, false));
1873        let schema = Schema::new(vec![Field::new(
1874            "a",
1875            DataType::LargeListView(list_field.clone()),
1876            true,
1877        )]);
1878
1879        //  [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]]
1880        let a = LargeListViewArray::new(
1881            list_field,
1882            vec![0i64, 1, 0, 3, 6].into(),
1883            vec![1i64, 2, 0, 3, 4].into(),
1884            Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
1885            Some(vec![true, true, false, true, true].into()),
1886        );
1887
1888        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
1889
1890        assert_eq!(batch.column(0).null_count(), 1);
1891
1892        roundtrip(batch, None);
1893    }
1894
1895    #[test]
1896    fn arrow_writer_list_view_with_struct() {
1897        // Test ListView containing Struct: ListView<Struct<Int32, Utf8>>
1898        let struct_fields = Fields::from(vec![
1899            Field::new("id", DataType::Int32, false),
1900            Field::new("name", DataType::Utf8, false),
1901        ]);
1902        let struct_type = DataType::Struct(struct_fields.clone());
1903        let list_field = Arc::new(Field::new("item", struct_type.clone(), false));
1904
1905        let schema = Schema::new(vec![Field::new(
1906            "a",
1907            DataType::ListView(list_field.clone()),
1908            true,
1909        )]);
1910
1911        // Create struct values
1912        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1913        let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1914        let struct_array = StructArray::new(
1915            struct_fields,
1916            vec![Arc::new(id_array), Arc::new(name_array)],
1917            None,
1918        );
1919
1920        // Create ListView: [{1, "a"}, {2, "b"}], null, [{3, "c"}, {4, "d"}, {5, "e"}]
1921        let list_view = ListViewArray::new(
1922            list_field,
1923            vec![0, 2, 2].into(), // offsets
1924            vec![2, 0, 3].into(), // sizes
1925            Arc::new(struct_array),
1926            Some(vec![true, false, true].into()),
1927        );
1928
1929        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_view)]).unwrap();
1930
1931        roundtrip(batch, None);
1932    }
1933
1934    #[test]
1935    fn arrow_writer_binary() {
1936        let string_field = Field::new("a", DataType::Utf8, false);
1937        let binary_field = Field::new("b", DataType::Binary, false);
1938        let schema = Schema::new(vec![string_field, binary_field]);
1939
1940        let raw_string_values = vec!["foo", "bar", "baz", "quux"];
1941        let raw_binary_values = [
1942            b"foo".to_vec(),
1943            b"bar".to_vec(),
1944            b"baz".to_vec(),
1945            b"quux".to_vec(),
1946        ];
1947        let raw_binary_value_refs = raw_binary_values
1948            .iter()
1949            .map(|x| x.as_slice())
1950            .collect::<Vec<_>>();
1951
1952        let string_values = StringArray::from(raw_string_values.clone());
1953        let binary_values = BinaryArray::from(raw_binary_value_refs);
1954        let batch = RecordBatch::try_new(
1955            Arc::new(schema),
1956            vec![Arc::new(string_values), Arc::new(binary_values)],
1957        )
1958        .unwrap();
1959
1960        roundtrip(batch, Some(SMALL_SIZE / 2));
1961    }
1962
1963    #[test]
1964    fn arrow_writer_binary_view() {
1965        let string_field = Field::new("a", DataType::Utf8View, false);
1966        let binary_field = Field::new("b", DataType::BinaryView, false);
1967        let nullable_string_field = Field::new("a", DataType::Utf8View, true);
1968        let schema = Schema::new(vec![string_field, binary_field, nullable_string_field]);
1969
1970        let raw_string_values = vec!["foo", "bar", "large payload over 12 bytes", "lulu"];
1971        let raw_binary_values = vec![
1972            b"foo".to_vec(),
1973            b"bar".to_vec(),
1974            b"large payload over 12 bytes".to_vec(),
1975            b"lulu".to_vec(),
1976        ];
1977        let nullable_string_values =
1978            vec![Some("foo"), None, Some("large payload over 12 bytes"), None];
1979
1980        let string_view_values = StringViewArray::from(raw_string_values);
1981        let binary_view_values = BinaryViewArray::from_iter_values(raw_binary_values);
1982        let nullable_string_view_values = StringViewArray::from(nullable_string_values);
1983        let batch = RecordBatch::try_new(
1984            Arc::new(schema),
1985            vec![
1986                Arc::new(string_view_values),
1987                Arc::new(binary_view_values),
1988                Arc::new(nullable_string_view_values),
1989            ],
1990        )
1991        .unwrap();
1992
1993        roundtrip(batch.clone(), Some(SMALL_SIZE / 2));
1994        roundtrip(batch, None);
1995    }
1996
1997    #[test]
1998    fn arrow_writer_binary_view_long_value() {
1999        let string_field = Field::new("a", DataType::Utf8View, false);
2000        let binary_field = Field::new("b", DataType::BinaryView, false);
2001        let schema = Schema::new(vec![string_field, binary_field]);
2002
2003        // There is special case validation for long values (greater than 128)
2004        // 128 encodes as 0x80 0x00 0x00 0x00 in little endian, which should
2005        // trigger the long-string UTF-8 validation branch in the plain decoder.
2006        let long = "a".repeat(128);
2007        let raw_string_values = vec!["foo", long.as_str(), "bar"];
2008        let raw_binary_values = vec![b"foo".to_vec(), long.as_bytes().to_vec(), b"bar".to_vec()];
2009
2010        let string_view_values: ArrayRef = Arc::new(StringViewArray::from(raw_string_values));
2011        let binary_view_values: ArrayRef =
2012            Arc::new(BinaryViewArray::from_iter_values(raw_binary_values));
2013
2014        one_column_roundtrip(Arc::clone(&string_view_values), false);
2015        one_column_roundtrip(Arc::clone(&binary_view_values), false);
2016
2017        let batch = RecordBatch::try_new(
2018            Arc::new(schema),
2019            vec![string_view_values, binary_view_values],
2020        )
2021        .unwrap();
2022
2023        // Disable dictionary to exercise plain encoding paths in the reader.
2024        for version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] {
2025            let props = WriterProperties::builder()
2026                .set_writer_version(version)
2027                .set_dictionary_enabled(false)
2028                .build();
2029            roundtrip_opts(&batch, props);
2030        }
2031    }
2032
2033    fn get_decimal_batch(precision: u8, scale: i8) -> RecordBatch {
2034        let decimal_field = Field::new("a", DataType::Decimal128(precision, scale), false);
2035        let schema = Schema::new(vec![decimal_field]);
2036
2037        let decimal_values = vec![10_000, 50_000, 0, -100]
2038            .into_iter()
2039            .map(Some)
2040            .collect::<Decimal128Array>()
2041            .with_precision_and_scale(precision, scale)
2042            .unwrap();
2043
2044        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)]).unwrap()
2045    }
2046
2047    #[test]
2048    fn arrow_writer_decimal() {
2049        // int32 to store the decimal value
2050        let batch_int32_decimal = get_decimal_batch(5, 2);
2051        roundtrip(batch_int32_decimal, Some(SMALL_SIZE / 2));
2052        // int64 to store the decimal value
2053        let batch_int64_decimal = get_decimal_batch(12, 2);
2054        roundtrip(batch_int64_decimal, Some(SMALL_SIZE / 2));
2055        // fixed_length_byte_array to store the decimal value
2056        let batch_fixed_len_byte_array_decimal = get_decimal_batch(30, 2);
2057        roundtrip(batch_fixed_len_byte_array_decimal, Some(SMALL_SIZE / 2));
2058    }
2059
2060    #[test]
2061    fn arrow_writer_complex() {
2062        // define schema
2063        let struct_field_d = Arc::new(Field::new("d", DataType::Float64, true));
2064        let struct_field_f = Arc::new(Field::new("f", DataType::Float32, true));
2065        let struct_field_g = Arc::new(Field::new_list(
2066            "g",
2067            Field::new_list_field(DataType::Int16, true),
2068            false,
2069        ));
2070        let struct_field_h = Arc::new(Field::new_list(
2071            "h",
2072            Field::new_list_field(DataType::Int16, false),
2073            true,
2074        ));
2075        let struct_field_e = Arc::new(Field::new_struct(
2076            "e",
2077            vec![
2078                struct_field_f.clone(),
2079                struct_field_g.clone(),
2080                struct_field_h.clone(),
2081            ],
2082            false,
2083        ));
2084        let schema = Schema::new(vec![
2085            Field::new("a", DataType::Int32, false),
2086            Field::new("b", DataType::Int32, true),
2087            Field::new_struct(
2088                "c",
2089                vec![struct_field_d.clone(), struct_field_e.clone()],
2090                false,
2091            ),
2092        ]);
2093
2094        // create some data
2095        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2096        let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]);
2097        let d = Float64Array::from(vec![None, None, None, Some(1.0), None]);
2098        let f = Float32Array::from(vec![Some(0.0), None, Some(333.3), None, Some(5.25)]);
2099
2100        let g_value = Int16Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
2101
2102        // Construct a buffer for value offsets, for the nested array:
2103        //  [[1], [2, 3], [], [4, 5, 6], [7, 8, 9, 10]]
2104        let g_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice());
2105
2106        // Construct a list array from the above two
2107        let g_list_data = ArrayData::builder(struct_field_g.data_type().clone())
2108            .len(5)
2109            .add_buffer(g_value_offsets.clone())
2110            .add_child_data(g_value.to_data())
2111            .build()
2112            .unwrap();
2113        let g = ListArray::from(g_list_data);
2114        // The difference between g and h is that h has a null bitmap
2115        let h_list_data = ArrayData::builder(struct_field_h.data_type().clone())
2116            .len(5)
2117            .add_buffer(g_value_offsets)
2118            .add_child_data(g_value.to_data())
2119            .null_bit_buffer(Some(Buffer::from([0b00011011])))
2120            .build()
2121            .unwrap();
2122        let h = ListArray::from(h_list_data);
2123
2124        let e = StructArray::from(vec![
2125            (struct_field_f, Arc::new(f) as ArrayRef),
2126            (struct_field_g, Arc::new(g) as ArrayRef),
2127            (struct_field_h, Arc::new(h) as ArrayRef),
2128        ]);
2129
2130        let c = StructArray::from(vec![
2131            (struct_field_d, Arc::new(d) as ArrayRef),
2132            (struct_field_e, Arc::new(e) as ArrayRef),
2133        ]);
2134
2135        // build a record batch
2136        let batch = RecordBatch::try_new(
2137            Arc::new(schema),
2138            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
2139        )
2140        .unwrap();
2141
2142        roundtrip(batch.clone(), Some(SMALL_SIZE / 2));
2143        roundtrip(batch, Some(SMALL_SIZE / 3));
2144    }
2145
2146    #[test]
2147    fn arrow_writer_complex_mixed() {
2148        // This test was added while investigating https://github.com/apache/arrow-rs/issues/244.
2149        // It was subsequently fixed while investigating https://github.com/apache/arrow-rs/issues/245.
2150
2151        // define schema
2152        let offset_field = Arc::new(Field::new("offset", DataType::Int32, false));
2153        let partition_field = Arc::new(Field::new("partition", DataType::Int64, true));
2154        let topic_field = Arc::new(Field::new("topic", DataType::Utf8, true));
2155        let schema = Schema::new(vec![Field::new(
2156            "some_nested_object",
2157            DataType::Struct(Fields::from(vec![
2158                offset_field.clone(),
2159                partition_field.clone(),
2160                topic_field.clone(),
2161            ])),
2162            false,
2163        )]);
2164
2165        // create some data
2166        let offset = Int32Array::from(vec![1, 2, 3, 4, 5]);
2167        let partition = Int64Array::from(vec![Some(1), None, None, Some(4), Some(5)]);
2168        let topic = StringArray::from(vec![Some("A"), None, Some("A"), Some(""), None]);
2169
2170        let some_nested_object = StructArray::from(vec![
2171            (offset_field, Arc::new(offset) as ArrayRef),
2172            (partition_field, Arc::new(partition) as ArrayRef),
2173            (topic_field, Arc::new(topic) as ArrayRef),
2174        ]);
2175
2176        // build a record batch
2177        let batch =
2178            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(some_nested_object)]).unwrap();
2179
2180        roundtrip(batch, Some(SMALL_SIZE / 2));
2181    }
2182
2183    #[test]
2184    fn arrow_writer_map() {
2185        // Note: we are using the JSON Arrow reader for brevity
2186        let json_content = r#"
2187        {"stocks":{"long": "$AAA", "short": "$BBB"}}
2188        {"stocks":{"long": null, "long": "$CCC", "short": null}}
2189        {"stocks":{"hedged": "$YYY", "long": null, "short": "$D"}}
2190        "#;
2191        let entries_struct_type = DataType::Struct(Fields::from(vec![
2192            Field::new("key", DataType::Utf8, false),
2193            Field::new("value", DataType::Utf8, true),
2194        ]));
2195        let stocks_field = Field::new(
2196            "stocks",
2197            DataType::Map(
2198                Arc::new(Field::new("entries", entries_struct_type, false)),
2199                false,
2200            ),
2201            true,
2202        );
2203        let schema = Arc::new(Schema::new(vec![stocks_field]));
2204        let builder = arrow::json::ReaderBuilder::new(schema).with_batch_size(64);
2205        let mut reader = builder.build(std::io::Cursor::new(json_content)).unwrap();
2206
2207        let batch = reader.next().unwrap().unwrap();
2208        roundtrip(batch, None);
2209    }
2210
2211    #[test]
2212    fn arrow_writer_2_level_struct() {
2213        // tests writing <struct<struct<primitive>>
2214        let field_c = Field::new("c", DataType::Int32, true);
2215        let field_b = Field::new("b", DataType::Struct(vec![field_c].into()), true);
2216        let type_a = DataType::Struct(vec![field_b.clone()].into());
2217        let field_a = Field::new("a", type_a, true);
2218        let schema = Schema::new(vec![field_a.clone()]);
2219
2220        // create data
2221        let c = Int32Array::from(vec![Some(1), None, Some(3), None, None, Some(6)]);
2222        let b_data = ArrayDataBuilder::new(field_b.data_type().clone())
2223            .len(6)
2224            .null_bit_buffer(Some(Buffer::from([0b00100111])))
2225            .add_child_data(c.into_data())
2226            .build()
2227            .unwrap();
2228        let b = StructArray::from(b_data);
2229        let a_data = ArrayDataBuilder::new(field_a.data_type().clone())
2230            .len(6)
2231            .null_bit_buffer(Some(Buffer::from([0b00101111])))
2232            .add_child_data(b.into_data())
2233            .build()
2234            .unwrap();
2235        let a = StructArray::from(a_data);
2236
2237        assert_eq!(a.null_count(), 1);
2238        assert_eq!(a.column(0).null_count(), 2);
2239
2240        // build a racord batch
2241        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
2242
2243        roundtrip(batch, Some(SMALL_SIZE / 2));
2244    }
2245
2246    #[test]
2247    fn arrow_writer_2_level_struct_non_null() {
2248        // tests writing <struct<struct<primitive>>
2249        let field_c = Field::new("c", DataType::Int32, false);
2250        let type_b = DataType::Struct(vec![field_c].into());
2251        let field_b = Field::new("b", type_b.clone(), false);
2252        let type_a = DataType::Struct(vec![field_b].into());
2253        let field_a = Field::new("a", type_a.clone(), false);
2254        let schema = Schema::new(vec![field_a]);
2255
2256        // create data
2257        let c = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
2258        let b_data = ArrayDataBuilder::new(type_b)
2259            .len(6)
2260            .add_child_data(c.into_data())
2261            .build()
2262            .unwrap();
2263        let b = StructArray::from(b_data);
2264        let a_data = ArrayDataBuilder::new(type_a)
2265            .len(6)
2266            .add_child_data(b.into_data())
2267            .build()
2268            .unwrap();
2269        let a = StructArray::from(a_data);
2270
2271        assert_eq!(a.null_count(), 0);
2272        assert_eq!(a.column(0).null_count(), 0);
2273
2274        // build a racord batch
2275        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
2276
2277        roundtrip(batch, Some(SMALL_SIZE / 2));
2278    }
2279
2280    #[test]
2281    fn arrow_writer_2_level_struct_mixed_null() {
2282        // tests writing <struct<struct<primitive>>
2283        let field_c = Field::new("c", DataType::Int32, false);
2284        let type_b = DataType::Struct(vec![field_c].into());
2285        let field_b = Field::new("b", type_b.clone(), true);
2286        let type_a = DataType::Struct(vec![field_b].into());
2287        let field_a = Field::new("a", type_a.clone(), false);
2288        let schema = Schema::new(vec![field_a]);
2289
2290        // create data
2291        let c = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
2292        let b_data = ArrayDataBuilder::new(type_b)
2293            .len(6)
2294            .null_bit_buffer(Some(Buffer::from([0b00100111])))
2295            .add_child_data(c.into_data())
2296            .build()
2297            .unwrap();
2298        let b = StructArray::from(b_data);
2299        // a intentionally has no null buffer, to test that this is handled correctly
2300        let a_data = ArrayDataBuilder::new(type_a)
2301            .len(6)
2302            .add_child_data(b.into_data())
2303            .build()
2304            .unwrap();
2305        let a = StructArray::from(a_data);
2306
2307        assert_eq!(a.null_count(), 0);
2308        assert_eq!(a.column(0).null_count(), 2);
2309
2310        // build a racord batch
2311        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
2312
2313        roundtrip(batch, Some(SMALL_SIZE / 2));
2314    }
2315
2316    #[test]
2317    fn arrow_writer_2_level_struct_mixed_null_2() {
2318        // tests writing <struct<struct<primitive>>, where the primitive columns are non-null.
2319        let field_c = Field::new("c", DataType::Int32, false);
2320        let field_d = Field::new("d", DataType::FixedSizeBinary(4), false);
2321        let field_e = Field::new(
2322            "e",
2323            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
2324            false,
2325        );
2326
2327        let field_b = Field::new(
2328            "b",
2329            DataType::Struct(vec![field_c, field_d, field_e].into()),
2330            false,
2331        );
2332        let type_a = DataType::Struct(vec![field_b.clone()].into());
2333        let field_a = Field::new("a", type_a, true);
2334        let schema = Schema::new(vec![field_a.clone()]);
2335
2336        // create data
2337        let c = Int32Array::from_iter_values(0..6);
2338        let d = FixedSizeBinaryArray::try_from_iter(
2339            ["aaaa", "bbbb", "cccc", "dddd", "eeee", "ffff"].into_iter(),
2340        )
2341        .expect("four byte values");
2342        let e = Int32DictionaryArray::from_iter(["one", "two", "three", "four", "five", "one"]);
2343        let b_data = ArrayDataBuilder::new(field_b.data_type().clone())
2344            .len(6)
2345            .add_child_data(c.into_data())
2346            .add_child_data(d.into_data())
2347            .add_child_data(e.into_data())
2348            .build()
2349            .unwrap();
2350        let b = StructArray::from(b_data);
2351        let a_data = ArrayDataBuilder::new(field_a.data_type().clone())
2352            .len(6)
2353            .null_bit_buffer(Some(Buffer::from([0b00100101])))
2354            .add_child_data(b.into_data())
2355            .build()
2356            .unwrap();
2357        let a = StructArray::from(a_data);
2358
2359        assert_eq!(a.null_count(), 3);
2360        assert_eq!(a.column(0).null_count(), 0);
2361
2362        // build a record batch
2363        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
2364
2365        roundtrip(batch, Some(SMALL_SIZE / 2));
2366    }
2367
2368    #[test]
2369    fn test_fixed_size_binary_in_dict() {
2370        fn test_fixed_size_binary_in_dict_inner<K>()
2371        where
2372            K: ArrowDictionaryKeyType,
2373            K::Native: FromPrimitive + ToPrimitive + TryFrom<u8>,
2374            <<K as arrow_array::ArrowPrimitiveType>::Native as TryFrom<u8>>::Error: std::fmt::Debug,
2375        {
2376            let field = Field::new(
2377                "a",
2378                DataType::Dictionary(
2379                    Box::new(K::DATA_TYPE),
2380                    Box::new(DataType::FixedSizeBinary(4)),
2381                ),
2382                false,
2383            );
2384            let schema = Schema::new(vec![field]);
2385
2386            let keys: Vec<K::Native> = vec![
2387                K::Native::try_from(0u8).unwrap(),
2388                K::Native::try_from(0u8).unwrap(),
2389                K::Native::try_from(1u8).unwrap(),
2390            ];
2391            let keys = PrimitiveArray::<K>::from_iter_values(keys);
2392            let values = FixedSizeBinaryArray::try_from_iter(
2393                vec![vec![0, 0, 0, 0], vec![1, 1, 1, 1]].into_iter(),
2394            )
2395            .unwrap();
2396
2397            let data = DictionaryArray::<K>::new(keys, Arc::new(values));
2398            let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)]).unwrap();
2399            roundtrip(batch, None);
2400        }
2401
2402        test_fixed_size_binary_in_dict_inner::<UInt8Type>();
2403        test_fixed_size_binary_in_dict_inner::<UInt16Type>();
2404        test_fixed_size_binary_in_dict_inner::<UInt32Type>();
2405        test_fixed_size_binary_in_dict_inner::<UInt16Type>();
2406        test_fixed_size_binary_in_dict_inner::<Int8Type>();
2407        test_fixed_size_binary_in_dict_inner::<Int16Type>();
2408        test_fixed_size_binary_in_dict_inner::<Int32Type>();
2409        test_fixed_size_binary_in_dict_inner::<Int64Type>();
2410    }
2411
2412    #[test]
2413    fn test_empty_dict() {
2414        let struct_fields = Fields::from(vec![Field::new(
2415            "dict",
2416            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
2417            false,
2418        )]);
2419
2420        let schema = Schema::new(vec![Field::new_struct(
2421            "struct",
2422            struct_fields.clone(),
2423            true,
2424        )]);
2425        let dictionary = Arc::new(DictionaryArray::new(
2426            Int32Array::new_null(5),
2427            Arc::new(StringArray::new_null(0)),
2428        ));
2429
2430        let s = StructArray::new(
2431            struct_fields,
2432            vec![dictionary],
2433            Some(NullBuffer::new_null(5)),
2434        );
2435
2436        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(s)]).unwrap();
2437        roundtrip(batch, None);
2438    }
2439    #[test]
2440    fn arrow_writer_page_size() {
2441        let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)]));
2442
2443        let mut builder = StringBuilder::with_capacity(100, 329 * 10_000);
2444
2445        // Generate an array of 10 unique 10 character string
2446        for i in 0..10 {
2447            let value = i
2448                .to_string()
2449                .repeat(10)
2450                .chars()
2451                .take(10)
2452                .collect::<String>();
2453
2454            builder.append_value(value);
2455        }
2456
2457        let array = Arc::new(builder.finish());
2458
2459        let batch = RecordBatch::try_new(schema, vec![array]).unwrap();
2460
2461        let file = tempfile::tempfile().unwrap();
2462
2463        // Set everything very low so we fallback to PLAIN encoding after the first row
2464        let props = WriterProperties::builder()
2465            .set_data_page_size_limit(1)
2466            .set_dictionary_page_size_limit(1)
2467            .set_write_batch_size(1)
2468            .build();
2469
2470        let mut writer =
2471            ArrowWriter::try_new(file.try_clone().unwrap(), batch.schema(), Some(props))
2472                .expect("Unable to write file");
2473        writer.write(&batch).unwrap();
2474        writer.close().unwrap();
2475
2476        let options = ReadOptionsBuilder::new().with_page_index().build();
2477        let reader =
2478            SerializedFileReader::new_with_options(file.try_clone().unwrap(), options).unwrap();
2479
2480        let column = reader.metadata().row_group(0).columns();
2481
2482        assert_eq!(column.len(), 1);
2483
2484        // We should write one row before falling back to PLAIN encoding so there should still be a
2485        // dictionary page.
2486        assert!(
2487            column[0].dictionary_page_offset().is_some(),
2488            "Expected a dictionary page"
2489        );
2490
2491        assert!(reader.metadata().offset_index().is_some());
2492        let offset_indexes = &reader.metadata().offset_index().unwrap()[0];
2493
2494        let page_locations = offset_indexes[0].page_locations.clone();
2495
2496        // We should fallback to PLAIN encoding after the first row and our max page size is 1 bytes
2497        // so we expect one dictionary encoded page and then a page per row thereafter.
2498        assert_eq!(
2499            page_locations.len(),
2500            10,
2501            "Expected 10 pages but got {page_locations:#?}"
2502        );
2503    }
2504
2505    #[test]
2506    fn arrow_writer_float_nans() {
2507        let f16_field = Field::new("a", DataType::Float16, false);
2508        let f32_field = Field::new("b", DataType::Float32, false);
2509        let f64_field = Field::new("c", DataType::Float64, false);
2510        let schema = Schema::new(vec![f16_field, f32_field, f64_field]);
2511
2512        let f16_values = (0..MEDIUM_SIZE)
2513            .map(|i| {
2514                Some(if i % 2 == 0 {
2515                    f16::NAN
2516                } else {
2517                    f16::from_f32(i as f32)
2518                })
2519            })
2520            .collect::<Float16Array>();
2521
2522        let f32_values = (0..MEDIUM_SIZE)
2523            .map(|i| Some(if i % 2 == 0 { f32::NAN } else { i as f32 }))
2524            .collect::<Float32Array>();
2525
2526        let f64_values = (0..MEDIUM_SIZE)
2527            .map(|i| Some(if i % 2 == 0 { f64::NAN } else { i as f64 }))
2528            .collect::<Float64Array>();
2529
2530        let batch = RecordBatch::try_new(
2531            Arc::new(schema),
2532            vec![
2533                Arc::new(f16_values),
2534                Arc::new(f32_values),
2535                Arc::new(f64_values),
2536            ],
2537        )
2538        .unwrap();
2539
2540        roundtrip(batch, None);
2541    }
2542
2543    const SMALL_SIZE: usize = 7;
2544    const MEDIUM_SIZE: usize = 63;
2545
2546    // Write the batch to parquet and read it back out, ensuring
2547    // that what comes out is the same as what was written in
2548    fn roundtrip(expected_batch: RecordBatch, max_row_group_size: Option<usize>) -> Vec<Bytes> {
2549        let mut files = vec![];
2550        for version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] {
2551            let mut props = WriterProperties::builder().set_writer_version(version);
2552
2553            if let Some(size) = max_row_group_size {
2554                props = props.set_max_row_group_row_count(Some(size))
2555            }
2556
2557            let props = props.build();
2558            files.push(roundtrip_opts(&expected_batch, props))
2559        }
2560        files
2561    }
2562
2563    // Round trip the specified record batch with the specified writer properties,
2564    // to an in-memory file, and validate the arrays using the specified function.
2565    // Returns the in-memory file.
2566    fn roundtrip_opts_with_array_validation<F>(
2567        expected_batch: &RecordBatch,
2568        props: WriterProperties,
2569        validate: F,
2570    ) -> Bytes
2571    where
2572        F: Fn(&ArrayData, &ArrayData),
2573    {
2574        let mut file = vec![];
2575
2576        let mut writer = ArrowWriter::try_new(&mut file, expected_batch.schema(), Some(props))
2577            .expect("Unable to write file");
2578        writer.write(expected_batch).unwrap();
2579        writer.close().unwrap();
2580
2581        let file = Bytes::from(file);
2582        let mut record_batch_reader =
2583            ParquetRecordBatchReader::try_new(file.clone(), 1024).unwrap();
2584
2585        let actual_batch = record_batch_reader
2586            .next()
2587            .expect("No batch found")
2588            .expect("Unable to get batch");
2589
2590        assert_eq!(expected_batch.schema(), actual_batch.schema());
2591        assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
2592        assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
2593        for i in 0..expected_batch.num_columns() {
2594            let expected_data = expected_batch.column(i).to_data();
2595            let actual_data = actual_batch.column(i).to_data();
2596            validate(&expected_data, &actual_data);
2597        }
2598
2599        file
2600    }
2601
2602    fn roundtrip_opts(expected_batch: &RecordBatch, props: WriterProperties) -> Bytes {
2603        roundtrip_opts_with_array_validation(expected_batch, props, |a, b| {
2604            a.validate_full().expect("valid expected data");
2605            b.validate_full().expect("valid actual data");
2606            assert_eq!(a, b)
2607        })
2608    }
2609
2610    struct RoundTripOptions {
2611        values: ArrayRef,
2612        schema: SchemaRef,
2613        bloom_filter: bool,
2614        bloom_filter_position: BloomFilterPosition,
2615    }
2616
2617    impl RoundTripOptions {
2618        fn new(values: ArrayRef, nullable: bool) -> Self {
2619            let data_type = values.data_type().clone();
2620            let schema = Schema::new(vec![Field::new("col", data_type, nullable)]);
2621            Self {
2622                values,
2623                schema: Arc::new(schema),
2624                bloom_filter: false,
2625                bloom_filter_position: BloomFilterPosition::AfterRowGroup,
2626            }
2627        }
2628    }
2629
2630    fn one_column_roundtrip(values: ArrayRef, nullable: bool) -> Vec<Bytes> {
2631        one_column_roundtrip_with_options(RoundTripOptions::new(values, nullable))
2632    }
2633
2634    fn one_column_roundtrip_with_schema(values: ArrayRef, schema: SchemaRef) -> Vec<Bytes> {
2635        let mut options = RoundTripOptions::new(values, false);
2636        options.schema = schema;
2637        one_column_roundtrip_with_options(options)
2638    }
2639
2640    fn one_column_roundtrip_with_options(options: RoundTripOptions) -> Vec<Bytes> {
2641        let RoundTripOptions {
2642            values,
2643            schema,
2644            bloom_filter,
2645            bloom_filter_position,
2646        } = options;
2647
2648        let encodings = match values.data_type() {
2649            DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
2650                vec![
2651                    Encoding::PLAIN,
2652                    Encoding::DELTA_BYTE_ARRAY,
2653                    Encoding::DELTA_LENGTH_BYTE_ARRAY,
2654                ]
2655            }
2656            DataType::Int64
2657            | DataType::Int32
2658            | DataType::Int16
2659            | DataType::Int8
2660            | DataType::UInt64
2661            | DataType::UInt32
2662            | DataType::UInt16
2663            | DataType::UInt8 => vec![
2664                Encoding::PLAIN,
2665                Encoding::DELTA_BINARY_PACKED,
2666                Encoding::BYTE_STREAM_SPLIT,
2667            ],
2668            DataType::Float32 | DataType::Float64 => {
2669                vec![Encoding::PLAIN, Encoding::BYTE_STREAM_SPLIT]
2670            }
2671            _ => vec![Encoding::PLAIN],
2672        };
2673
2674        let expected_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2675
2676        let row_group_sizes = [1024, SMALL_SIZE, SMALL_SIZE / 2, SMALL_SIZE / 2 + 1, 10];
2677
2678        let mut files = vec![];
2679        for dictionary_size in [0, 1, 1024] {
2680            for encoding in &encodings {
2681                for version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] {
2682                    for row_group_size in row_group_sizes {
2683                        let props = WriterProperties::builder()
2684                            .set_writer_version(version)
2685                            .set_max_row_group_row_count(Some(row_group_size))
2686                            .set_dictionary_enabled(dictionary_size != 0)
2687                            .set_dictionary_page_size_limit(dictionary_size.max(1))
2688                            .set_encoding(*encoding)
2689                            .set_bloom_filter_enabled(bloom_filter)
2690                            .set_bloom_filter_position(bloom_filter_position)
2691                            .build();
2692
2693                        files.push(roundtrip_opts(&expected_batch, props))
2694                    }
2695                }
2696            }
2697        }
2698        files
2699    }
2700
2701    fn values_required<A, I>(iter: I) -> Vec<Bytes>
2702    where
2703        A: From<Vec<I::Item>> + Array + 'static,
2704        I: IntoIterator,
2705    {
2706        let raw_values: Vec<_> = iter.into_iter().collect();
2707        let values = Arc::new(A::from(raw_values));
2708        one_column_roundtrip(values, false)
2709    }
2710
2711    fn values_optional<A, I>(iter: I) -> Vec<Bytes>
2712    where
2713        A: From<Vec<Option<I::Item>>> + Array + 'static,
2714        I: IntoIterator,
2715    {
2716        let optional_raw_values: Vec<_> = iter
2717            .into_iter()
2718            .enumerate()
2719            .map(|(i, v)| if i % 2 == 0 { None } else { Some(v) })
2720            .collect();
2721        let optional_values = Arc::new(A::from(optional_raw_values));
2722        one_column_roundtrip(optional_values, true)
2723    }
2724
2725    fn required_and_optional<A, I>(iter: I)
2726    where
2727        A: From<Vec<I::Item>> + From<Vec<Option<I::Item>>> + Array + 'static,
2728        I: IntoIterator + Clone,
2729    {
2730        values_required::<A, I>(iter.clone());
2731        values_optional::<A, I>(iter);
2732    }
2733
2734    fn check_bloom_filter<T: AsBytes>(
2735        files: Vec<Bytes>,
2736        file_column: String,
2737        positive_values: Vec<T>,
2738        negative_values: Vec<T>,
2739    ) {
2740        files.into_iter().take(1).for_each(|file| {
2741            let file_reader = SerializedFileReader::new_with_options(
2742                file,
2743                ReadOptionsBuilder::new()
2744                    .with_reader_properties(
2745                        ReaderProperties::builder()
2746                            .set_read_bloom_filter(true)
2747                            .build(),
2748                    )
2749                    .build(),
2750            )
2751            .expect("Unable to open file as Parquet");
2752            let metadata = file_reader.metadata();
2753
2754            // Gets bloom filters from all row groups.
2755            let mut bloom_filters: Vec<_> = vec![];
2756            for (ri, row_group) in metadata.row_groups().iter().enumerate() {
2757                if let Some((column_index, _)) = row_group
2758                    .columns()
2759                    .iter()
2760                    .enumerate()
2761                    .find(|(_, column)| column.column_path().string() == file_column)
2762                {
2763                    let row_group_reader = file_reader
2764                        .get_row_group(ri)
2765                        .expect("Unable to read row group");
2766                    if let Some(sbbf) = row_group_reader.get_column_bloom_filter(column_index) {
2767                        bloom_filters.push(sbbf.clone());
2768                    } else {
2769                        panic!("No bloom filter for column named {file_column} found");
2770                    }
2771                } else {
2772                    panic!("No column named {file_column} found");
2773                }
2774            }
2775
2776            positive_values.iter().for_each(|value| {
2777                let found = bloom_filters.iter().find(|sbbf| sbbf.check(value));
2778                assert!(
2779                    found.is_some(),
2780                    "{}",
2781                    format!("Value {:?} should be in bloom filter", value.as_bytes())
2782                );
2783            });
2784
2785            negative_values.iter().for_each(|value| {
2786                let found = bloom_filters.iter().find(|sbbf| sbbf.check(value));
2787                assert!(
2788                    found.is_none(),
2789                    "{}",
2790                    format!("Value {:?} should not be in bloom filter", value.as_bytes())
2791                );
2792            });
2793        });
2794    }
2795
2796    #[test]
2797    fn all_null_primitive_single_column() {
2798        let values = Arc::new(Int32Array::from(vec![None; SMALL_SIZE]));
2799        one_column_roundtrip(values, true);
2800    }
2801    #[test]
2802    fn null_single_column() {
2803        let values = Arc::new(NullArray::new(SMALL_SIZE));
2804        one_column_roundtrip(values, true);
2805        // null arrays are always nullable, a test with non-nullable nulls fails
2806    }
2807
2808    #[test]
2809    fn bool_single_column() {
2810        required_and_optional::<BooleanArray, _>(
2811            [true, false].iter().cycle().copied().take(SMALL_SIZE),
2812        );
2813    }
2814
2815    #[test]
2816    fn bool_large_single_column() {
2817        let values = Arc::new(
2818            [None, Some(true), Some(false)]
2819                .iter()
2820                .cycle()
2821                .copied()
2822                .take(200_000)
2823                .collect::<BooleanArray>(),
2824        );
2825        let schema = Schema::new(vec![Field::new("col", values.data_type().clone(), true)]);
2826        let expected_batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap();
2827        let file = tempfile::tempfile().unwrap();
2828
2829        let mut writer =
2830            ArrowWriter::try_new(file.try_clone().unwrap(), expected_batch.schema(), None)
2831                .expect("Unable to write file");
2832        writer.write(&expected_batch).unwrap();
2833        writer.close().unwrap();
2834    }
2835
2836    #[test]
2837    fn check_page_offset_index_with_nan() {
2838        let values = Arc::new(Float64Array::from(vec![f64::NAN; 10]));
2839        let schema = Schema::new(vec![Field::new("col", DataType::Float64, true)]);
2840        let batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap();
2841
2842        let mut out = Vec::with_capacity(1024);
2843        let mut writer =
2844            ArrowWriter::try_new(&mut out, batch.schema(), None).expect("Unable to write file");
2845        writer.write(&batch).unwrap();
2846        let file_meta_data = writer.close().unwrap();
2847        for row_group in file_meta_data.row_groups() {
2848            for column in row_group.columns() {
2849                assert!(column.offset_index_offset().is_some());
2850                assert!(column.offset_index_length().is_some());
2851                assert!(column.column_index_offset().is_none());
2852                assert!(column.column_index_length().is_none());
2853            }
2854        }
2855    }
2856
2857    #[test]
2858    fn i8_single_column() {
2859        required_and_optional::<Int8Array, _>(0..SMALL_SIZE as i8);
2860    }
2861
2862    #[test]
2863    fn i16_single_column() {
2864        required_and_optional::<Int16Array, _>(0..SMALL_SIZE as i16);
2865    }
2866
2867    #[test]
2868    fn i32_single_column() {
2869        required_and_optional::<Int32Array, _>(0..SMALL_SIZE as i32);
2870    }
2871
2872    #[test]
2873    fn i64_single_column() {
2874        required_and_optional::<Int64Array, _>(0..SMALL_SIZE as i64);
2875    }
2876
2877    #[test]
2878    fn u8_single_column() {
2879        required_and_optional::<UInt8Array, _>(0..SMALL_SIZE as u8);
2880    }
2881
2882    #[test]
2883    fn u16_single_column() {
2884        required_and_optional::<UInt16Array, _>(0..SMALL_SIZE as u16);
2885    }
2886
2887    #[test]
2888    fn u32_single_column() {
2889        required_and_optional::<UInt32Array, _>(0..SMALL_SIZE as u32);
2890    }
2891
2892    #[test]
2893    fn u64_single_column() {
2894        required_and_optional::<UInt64Array, _>(0..SMALL_SIZE as u64);
2895    }
2896
2897    #[test]
2898    fn f32_single_column() {
2899        required_and_optional::<Float32Array, _>((0..SMALL_SIZE).map(|i| i as f32));
2900    }
2901
2902    #[test]
2903    fn f64_single_column() {
2904        required_and_optional::<Float64Array, _>((0..SMALL_SIZE).map(|i| i as f64));
2905    }
2906
2907    // The timestamp array types don't implement From<Vec<T>> because they need the timezone
2908    // argument, and they also doesn't support building from a Vec<Option<T>>, so call
2909    // one_column_roundtrip manually instead of calling required_and_optional for these tests.
2910
2911    #[test]
2912    fn timestamp_second_single_column() {
2913        let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect();
2914        let values = Arc::new(TimestampSecondArray::from(raw_values));
2915
2916        one_column_roundtrip(values, false);
2917    }
2918
2919    #[test]
2920    fn timestamp_millisecond_single_column() {
2921        let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect();
2922        let values = Arc::new(TimestampMillisecondArray::from(raw_values));
2923
2924        one_column_roundtrip(values, false);
2925    }
2926
2927    #[test]
2928    fn timestamp_microsecond_single_column() {
2929        let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect();
2930        let values = Arc::new(TimestampMicrosecondArray::from(raw_values));
2931
2932        one_column_roundtrip(values, false);
2933    }
2934
2935    #[test]
2936    fn timestamp_nanosecond_single_column() {
2937        let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect();
2938        let values = Arc::new(TimestampNanosecondArray::from(raw_values));
2939
2940        one_column_roundtrip(values, false);
2941    }
2942
2943    #[test]
2944    fn date32_single_column() {
2945        required_and_optional::<Date32Array, _>(0..SMALL_SIZE as i32);
2946    }
2947
2948    #[test]
2949    fn date64_single_column() {
2950        // Date64 must be a multiple of 86400000, see ARROW-10925
2951        required_and_optional::<Date64Array, _>(
2952            (0..(SMALL_SIZE as i64 * 86400000)).step_by(86400000),
2953        );
2954    }
2955
2956    #[test]
2957    fn time32_second_single_column() {
2958        required_and_optional::<Time32SecondArray, _>(0..SMALL_SIZE as i32);
2959    }
2960
2961    #[test]
2962    fn time32_millisecond_single_column() {
2963        required_and_optional::<Time32MillisecondArray, _>(0..SMALL_SIZE as i32);
2964    }
2965
2966    #[test]
2967    fn time64_microsecond_single_column() {
2968        required_and_optional::<Time64MicrosecondArray, _>(0..SMALL_SIZE as i64);
2969    }
2970
2971    #[test]
2972    fn time64_nanosecond_single_column() {
2973        required_and_optional::<Time64NanosecondArray, _>(0..SMALL_SIZE as i64);
2974    }
2975
2976    #[test]
2977    fn duration_second_single_column() {
2978        required_and_optional::<DurationSecondArray, _>(0..SMALL_SIZE as i64);
2979    }
2980
2981    #[test]
2982    fn duration_millisecond_single_column() {
2983        required_and_optional::<DurationMillisecondArray, _>(0..SMALL_SIZE as i64);
2984    }
2985
2986    #[test]
2987    fn duration_microsecond_single_column() {
2988        required_and_optional::<DurationMicrosecondArray, _>(0..SMALL_SIZE as i64);
2989    }
2990
2991    #[test]
2992    fn duration_nanosecond_single_column() {
2993        required_and_optional::<DurationNanosecondArray, _>(0..SMALL_SIZE as i64);
2994    }
2995
2996    #[test]
2997    fn interval_year_month_single_column() {
2998        required_and_optional::<IntervalYearMonthArray, _>(0..SMALL_SIZE as i32);
2999    }
3000
3001    #[test]
3002    fn interval_day_time_single_column() {
3003        required_and_optional::<IntervalDayTimeArray, _>(vec![
3004            IntervalDayTime::new(0, 1),
3005            IntervalDayTime::new(0, 3),
3006            IntervalDayTime::new(3, -2),
3007            IntervalDayTime::new(-200, 4),
3008        ]);
3009    }
3010
3011    #[test]
3012    #[should_panic(
3013        expected = "Attempting to write an Arrow interval type MonthDayNano to parquet that is not yet implemented"
3014    )]
3015    fn interval_month_day_nano_single_column() {
3016        required_and_optional::<IntervalMonthDayNanoArray, _>(vec![
3017            IntervalMonthDayNano::new(0, 1, 5),
3018            IntervalMonthDayNano::new(0, 3, 2),
3019            IntervalMonthDayNano::new(3, -2, -5),
3020            IntervalMonthDayNano::new(-200, 4, -1),
3021        ]);
3022    }
3023
3024    #[test]
3025    fn binary_single_column() {
3026        let one_vec: Vec<u8> = (0..SMALL_SIZE as u8).collect();
3027        let many_vecs: Vec<_> = std::iter::repeat_n(one_vec, SMALL_SIZE).collect();
3028        let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice());
3029
3030        // BinaryArrays can't be built from Vec<Option<&str>>, so only call `values_required`
3031        values_required::<BinaryArray, _>(many_vecs_iter);
3032    }
3033
3034    #[test]
3035    fn binary_view_single_column() {
3036        let one_vec: Vec<u8> = (0..SMALL_SIZE as u8).collect();
3037        let many_vecs: Vec<_> = std::iter::repeat_n(one_vec, SMALL_SIZE).collect();
3038        let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice());
3039
3040        // BinaryArrays can't be built from Vec<Option<&str>>, so only call `values_required`
3041        values_required::<BinaryViewArray, _>(many_vecs_iter);
3042    }
3043
3044    #[test]
3045    fn i32_column_bloom_filter_at_end() {
3046        let array = Arc::new(Int32Array::from_iter(0..SMALL_SIZE as i32));
3047        let mut options = RoundTripOptions::new(array, false);
3048        options.bloom_filter = true;
3049        options.bloom_filter_position = BloomFilterPosition::End;
3050
3051        let files = one_column_roundtrip_with_options(options);
3052        check_bloom_filter(
3053            files,
3054            "col".to_string(),
3055            (0..SMALL_SIZE as i32).collect(),
3056            (SMALL_SIZE as i32 + 1..SMALL_SIZE as i32 + 10).collect(),
3057        );
3058    }
3059
3060    #[test]
3061    fn i32_column_bloom_filter() {
3062        let array = Arc::new(Int32Array::from_iter(0..SMALL_SIZE as i32));
3063        let mut options = RoundTripOptions::new(array, false);
3064        options.bloom_filter = true;
3065
3066        let files = one_column_roundtrip_with_options(options);
3067        check_bloom_filter(
3068            files,
3069            "col".to_string(),
3070            (0..SMALL_SIZE as i32).collect(),
3071            (SMALL_SIZE as i32 + 1..SMALL_SIZE as i32 + 10).collect(),
3072        );
3073    }
3074
3075    #[test]
3076    fn binary_column_bloom_filter() {
3077        let one_vec: Vec<u8> = (0..SMALL_SIZE as u8).collect();
3078        let many_vecs: Vec<_> = std::iter::repeat_n(one_vec, SMALL_SIZE).collect();
3079        let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice());
3080
3081        let array = Arc::new(BinaryArray::from_iter_values(many_vecs_iter));
3082        let mut options = RoundTripOptions::new(array, false);
3083        options.bloom_filter = true;
3084
3085        let files = one_column_roundtrip_with_options(options);
3086        check_bloom_filter(
3087            files,
3088            "col".to_string(),
3089            many_vecs,
3090            vec![vec![(SMALL_SIZE + 1) as u8]],
3091        );
3092    }
3093
3094    #[test]
3095    fn empty_string_null_column_bloom_filter() {
3096        let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect();
3097        let raw_strs = raw_values.iter().map(|s| s.as_str());
3098
3099        let array = Arc::new(StringArray::from_iter_values(raw_strs));
3100        let mut options = RoundTripOptions::new(array, false);
3101        options.bloom_filter = true;
3102
3103        let files = one_column_roundtrip_with_options(options);
3104
3105        let optional_raw_values: Vec<_> = raw_values
3106            .iter()
3107            .enumerate()
3108            .filter_map(|(i, v)| if i % 2 == 0 { None } else { Some(v.as_str()) })
3109            .collect();
3110        // For null slots, empty string should not be in bloom filter.
3111        check_bloom_filter(files, "col".to_string(), optional_raw_values, vec![""]);
3112    }
3113
3114    #[test]
3115    fn large_binary_single_column() {
3116        let one_vec: Vec<u8> = (0..SMALL_SIZE as u8).collect();
3117        let many_vecs: Vec<_> = std::iter::repeat_n(one_vec, SMALL_SIZE).collect();
3118        let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice());
3119
3120        // LargeBinaryArrays can't be built from Vec<Option<&str>>, so only call `values_required`
3121        values_required::<LargeBinaryArray, _>(many_vecs_iter);
3122    }
3123
3124    #[test]
3125    fn fixed_size_binary_single_column() {
3126        let mut builder = FixedSizeBinaryBuilder::new(4);
3127        builder.append_value(b"0123").unwrap();
3128        builder.append_null();
3129        builder.append_value(b"8910").unwrap();
3130        builder.append_value(b"1112").unwrap();
3131        let array = Arc::new(builder.finish());
3132
3133        one_column_roundtrip(array, true);
3134    }
3135
3136    #[test]
3137    fn string_single_column() {
3138        let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect();
3139        let raw_strs = raw_values.iter().map(|s| s.as_str());
3140
3141        required_and_optional::<StringArray, _>(raw_strs);
3142    }
3143
3144    #[test]
3145    fn large_string_single_column() {
3146        let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect();
3147        let raw_strs = raw_values.iter().map(|s| s.as_str());
3148
3149        required_and_optional::<LargeStringArray, _>(raw_strs);
3150    }
3151
3152    #[test]
3153    fn string_view_single_column() {
3154        let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect();
3155        let raw_strs = raw_values.iter().map(|s| s.as_str());
3156
3157        required_and_optional::<StringViewArray, _>(raw_strs);
3158    }
3159
3160    #[test]
3161    fn null_list_single_column() {
3162        let null_field = Field::new_list_field(DataType::Null, true);
3163        let list_field = Field::new("emptylist", DataType::List(Arc::new(null_field)), true);
3164
3165        let schema = Schema::new(vec![list_field]);
3166
3167        // Build [[], null, [null, null]]
3168        let a_values = NullArray::new(2);
3169        let a_value_offsets = arrow::buffer::Buffer::from([0, 0, 0, 2].to_byte_slice());
3170        let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field(
3171            DataType::Null,
3172            true,
3173        ))))
3174        .len(3)
3175        .add_buffer(a_value_offsets)
3176        .null_bit_buffer(Some(Buffer::from([0b00000101])))
3177        .add_child_data(a_values.into_data())
3178        .build()
3179        .unwrap();
3180
3181        let a = ListArray::from(a_list_data);
3182
3183        assert!(a.is_valid(0));
3184        assert!(!a.is_valid(1));
3185        assert!(a.is_valid(2));
3186
3187        assert_eq!(a.value(0).len(), 0);
3188        assert_eq!(a.value(2).len(), 2);
3189        assert_eq!(a.value(2).logical_nulls().unwrap().null_count(), 2);
3190
3191        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
3192        roundtrip(batch, None);
3193    }
3194
3195    #[test]
3196    fn list_single_column() {
3197        let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
3198        let a_value_offsets = arrow::buffer::Buffer::from([0, 1, 3, 3, 6, 10].to_byte_slice());
3199        let a_list_data = ArrayData::builder(DataType::List(Arc::new(Field::new_list_field(
3200            DataType::Int32,
3201            false,
3202        ))))
3203        .len(5)
3204        .add_buffer(a_value_offsets)
3205        .null_bit_buffer(Some(Buffer::from([0b00011011])))
3206        .add_child_data(a_values.into_data())
3207        .build()
3208        .unwrap();
3209
3210        assert_eq!(a_list_data.null_count(), 1);
3211
3212        let a = ListArray::from(a_list_data);
3213        let values = Arc::new(a);
3214
3215        one_column_roundtrip(values, true);
3216    }
3217
3218    #[test]
3219    fn large_list_single_column() {
3220        let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
3221        let a_value_offsets = arrow::buffer::Buffer::from([0i64, 1, 3, 3, 6, 10].to_byte_slice());
3222        let a_list_data = ArrayData::builder(DataType::LargeList(Arc::new(Field::new(
3223            "large_item",
3224            DataType::Int32,
3225            true,
3226        ))))
3227        .len(5)
3228        .add_buffer(a_value_offsets)
3229        .add_child_data(a_values.into_data())
3230        .null_bit_buffer(Some(Buffer::from([0b00011011])))
3231        .build()
3232        .unwrap();
3233
3234        // I think this setup is incorrect because this should pass
3235        assert_eq!(a_list_data.null_count(), 1);
3236
3237        let a = LargeListArray::from(a_list_data);
3238        let values = Arc::new(a);
3239
3240        one_column_roundtrip(values, true);
3241    }
3242
3243    #[test]
3244    fn list_nested_nulls() {
3245        use arrow::datatypes::Int32Type;
3246        let data = vec![
3247            Some(vec![Some(1)]),
3248            Some(vec![Some(2), Some(3)]),
3249            None,
3250            Some(vec![Some(4), Some(5), None]),
3251            Some(vec![None]),
3252            Some(vec![Some(6), Some(7)]),
3253        ];
3254
3255        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(data.clone());
3256        one_column_roundtrip(Arc::new(list), true);
3257
3258        let list = LargeListArray::from_iter_primitive::<Int32Type, _, _>(data);
3259        one_column_roundtrip(Arc::new(list), true);
3260    }
3261
3262    #[test]
3263    fn struct_single_column() {
3264        let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
3265        let struct_field_a = Arc::new(Field::new("f", DataType::Int32, false));
3266        let s = StructArray::from(vec![(struct_field_a, Arc::new(a_values) as ArrayRef)]);
3267
3268        let values = Arc::new(s);
3269        one_column_roundtrip(values, false);
3270    }
3271
3272    #[test]
3273    fn list_and_map_coerced_names() {
3274        // Create map and list with non-Parquet naming
3275        let list_field =
3276            Field::new_list("my_list", Field::new("item", DataType::Int32, false), false);
3277        let map_field = Field::new_map(
3278            "my_map",
3279            "entries",
3280            Field::new("keys", DataType::Int32, false),
3281            Field::new("values", DataType::Int32, true),
3282            false,
3283            true,
3284        );
3285
3286        let list_array = create_random_array(&list_field, 100, 0.0, 0.0).unwrap();
3287        let map_array = create_random_array(&map_field, 100, 0.0, 0.0).unwrap();
3288
3289        let arrow_schema = Arc::new(Schema::new(vec![list_field, map_field]));
3290
3291        // Write data to Parquet but coerce names to match spec
3292        let props = Some(WriterProperties::builder().set_coerce_types(true).build());
3293        let file = tempfile::tempfile().unwrap();
3294        let mut writer =
3295            ArrowWriter::try_new(file.try_clone().unwrap(), arrow_schema.clone(), props).unwrap();
3296
3297        let batch = RecordBatch::try_new(arrow_schema, vec![list_array, map_array]).unwrap();
3298        writer.write(&batch).unwrap();
3299        let file_metadata = writer.close().unwrap();
3300
3301        let schema = file_metadata.file_metadata().schema();
3302        // Coerced name of "item" should be "element"
3303        let list_field = &schema.get_fields()[0].get_fields()[0];
3304        assert_eq!(list_field.get_fields()[0].name(), "element");
3305
3306        let map_field = &schema.get_fields()[1].get_fields()[0];
3307        // Coerced name of "entries" should be "key_value"
3308        assert_eq!(map_field.name(), "key_value");
3309        // Coerced name of "keys" should be "key"
3310        assert_eq!(map_field.get_fields()[0].name(), "key");
3311        // Coerced name of "values" should be "value"
3312        assert_eq!(map_field.get_fields()[1].name(), "value");
3313
3314        // Double check schema after reading from the file
3315        let reader = SerializedFileReader::new(file).unwrap();
3316        let file_schema = reader.metadata().file_metadata().schema();
3317        let fields = file_schema.get_fields();
3318        let list_field = &fields[0].get_fields()[0];
3319        assert_eq!(list_field.get_fields()[0].name(), "element");
3320        let map_field = &fields[1].get_fields()[0];
3321        assert_eq!(map_field.name(), "key_value");
3322        assert_eq!(map_field.get_fields()[0].name(), "key");
3323        assert_eq!(map_field.get_fields()[1].name(), "value");
3324    }
3325
3326    #[test]
3327    fn fallback_flush_data_page() {
3328        //tests if the Fallback::flush_data_page clears all buffers correctly
3329        let raw_values: Vec<_> = (0..MEDIUM_SIZE).map(|i| i.to_string()).collect();
3330        let values = Arc::new(StringArray::from(raw_values));
3331        let encodings = vec![
3332            Encoding::DELTA_BYTE_ARRAY,
3333            Encoding::DELTA_LENGTH_BYTE_ARRAY,
3334        ];
3335        let data_type = values.data_type().clone();
3336        let schema = Arc::new(Schema::new(vec![Field::new("col", data_type, false)]));
3337        let expected_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3338
3339        let row_group_sizes = [1024, SMALL_SIZE, SMALL_SIZE / 2, SMALL_SIZE / 2 + 1, 10];
3340        let data_page_size_limit: usize = 32;
3341        let write_batch_size: usize = 16;
3342
3343        for encoding in &encodings {
3344            for row_group_size in row_group_sizes {
3345                let props = WriterProperties::builder()
3346                    .set_writer_version(WriterVersion::PARQUET_2_0)
3347                    .set_max_row_group_row_count(Some(row_group_size))
3348                    .set_dictionary_enabled(false)
3349                    .set_encoding(*encoding)
3350                    .set_data_page_size_limit(data_page_size_limit)
3351                    .set_write_batch_size(write_batch_size)
3352                    .build();
3353
3354                roundtrip_opts_with_array_validation(&expected_batch, props, |a, b| {
3355                    let string_array_a = StringArray::from(a.clone());
3356                    let string_array_b = StringArray::from(b.clone());
3357                    let vec_a: Vec<&str> = string_array_a.iter().map(|v| v.unwrap()).collect();
3358                    let vec_b: Vec<&str> = string_array_b.iter().map(|v| v.unwrap()).collect();
3359                    assert_eq!(
3360                        vec_a, vec_b,
3361                        "failed for encoder: {encoding:?} and row_group_size: {row_group_size:?}"
3362                    );
3363                });
3364            }
3365        }
3366    }
3367
3368    #[test]
3369    fn arrow_writer_string_dictionary() {
3370        // define schema
3371        #[allow(deprecated)]
3372        let schema = Arc::new(Schema::new(vec![Field::new_dict(
3373            "dictionary",
3374            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3375            true,
3376            42,
3377            true,
3378        )]));
3379
3380        // create some data
3381        let d: Int32DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")]
3382            .iter()
3383            .copied()
3384            .collect();
3385
3386        // build a record batch
3387        one_column_roundtrip_with_schema(Arc::new(d), schema);
3388    }
3389
3390    #[test]
3391    fn arrow_writer_test_type_compatibility() {
3392        fn ensure_compatible_write<T1, T2>(array1: T1, array2: T2, expected_result: T1)
3393        where
3394            T1: Array + 'static,
3395            T2: Array + 'static,
3396        {
3397            let schema1 = Arc::new(Schema::new(vec![Field::new(
3398                "a",
3399                array1.data_type().clone(),
3400                false,
3401            )]));
3402
3403            let file = tempfile().unwrap();
3404            let mut writer =
3405                ArrowWriter::try_new(file.try_clone().unwrap(), schema1.clone(), None).unwrap();
3406
3407            let rb1 = RecordBatch::try_new(schema1.clone(), vec![Arc::new(array1)]).unwrap();
3408            writer.write(&rb1).unwrap();
3409
3410            let schema2 = Arc::new(Schema::new(vec![Field::new(
3411                "a",
3412                array2.data_type().clone(),
3413                false,
3414            )]));
3415            let rb2 = RecordBatch::try_new(schema2, vec![Arc::new(array2)]).unwrap();
3416            writer.write(&rb2).unwrap();
3417
3418            writer.close().unwrap();
3419
3420            let mut record_batch_reader =
3421                ParquetRecordBatchReader::try_new(file.try_clone().unwrap(), 1024).unwrap();
3422            let actual_batch = record_batch_reader.next().unwrap().unwrap();
3423
3424            let expected_batch =
3425                RecordBatch::try_new(schema1, vec![Arc::new(expected_result)]).unwrap();
3426            assert_eq!(actual_batch, expected_batch);
3427        }
3428
3429        // check compatibility between native and dictionaries
3430
3431        ensure_compatible_write(
3432            DictionaryArray::new(
3433                UInt8Array::from_iter_values(vec![0]),
3434                Arc::new(StringArray::from_iter_values(vec!["parquet"])),
3435            ),
3436            StringArray::from_iter_values(vec!["barquet"]),
3437            DictionaryArray::new(
3438                UInt8Array::from_iter_values(vec![0, 1]),
3439                Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])),
3440            ),
3441        );
3442
3443        ensure_compatible_write(
3444            StringArray::from_iter_values(vec!["parquet"]),
3445            DictionaryArray::new(
3446                UInt8Array::from_iter_values(vec![0]),
3447                Arc::new(StringArray::from_iter_values(vec!["barquet"])),
3448            ),
3449            StringArray::from_iter_values(vec!["parquet", "barquet"]),
3450        );
3451
3452        // check compatibility between dictionaries with different key types
3453
3454        ensure_compatible_write(
3455            DictionaryArray::new(
3456                UInt8Array::from_iter_values(vec![0]),
3457                Arc::new(StringArray::from_iter_values(vec!["parquet"])),
3458            ),
3459            DictionaryArray::new(
3460                UInt16Array::from_iter_values(vec![0]),
3461                Arc::new(StringArray::from_iter_values(vec!["barquet"])),
3462            ),
3463            DictionaryArray::new(
3464                UInt8Array::from_iter_values(vec![0, 1]),
3465                Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])),
3466            ),
3467        );
3468
3469        // check compatibility between dictionaries with different value types
3470        ensure_compatible_write(
3471            DictionaryArray::new(
3472                UInt8Array::from_iter_values(vec![0]),
3473                Arc::new(StringArray::from_iter_values(vec!["parquet"])),
3474            ),
3475            DictionaryArray::new(
3476                UInt8Array::from_iter_values(vec![0]),
3477                Arc::new(LargeStringArray::from_iter_values(vec!["barquet"])),
3478            ),
3479            DictionaryArray::new(
3480                UInt8Array::from_iter_values(vec![0, 1]),
3481                Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])),
3482            ),
3483        );
3484
3485        // check compatibility between a dictionary and a native array with a different type
3486        ensure_compatible_write(
3487            DictionaryArray::new(
3488                UInt8Array::from_iter_values(vec![0]),
3489                Arc::new(StringArray::from_iter_values(vec!["parquet"])),
3490            ),
3491            LargeStringArray::from_iter_values(vec!["barquet"]),
3492            DictionaryArray::new(
3493                UInt8Array::from_iter_values(vec![0, 1]),
3494                Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])),
3495            ),
3496        );
3497
3498        // check compatibility for string types
3499
3500        ensure_compatible_write(
3501            StringArray::from_iter_values(vec!["parquet"]),
3502            LargeStringArray::from_iter_values(vec!["barquet"]),
3503            StringArray::from_iter_values(vec!["parquet", "barquet"]),
3504        );
3505
3506        ensure_compatible_write(
3507            LargeStringArray::from_iter_values(vec!["parquet"]),
3508            StringArray::from_iter_values(vec!["barquet"]),
3509            LargeStringArray::from_iter_values(vec!["parquet", "barquet"]),
3510        );
3511
3512        ensure_compatible_write(
3513            StringArray::from_iter_values(vec!["parquet"]),
3514            StringViewArray::from_iter_values(vec!["barquet"]),
3515            StringArray::from_iter_values(vec!["parquet", "barquet"]),
3516        );
3517
3518        ensure_compatible_write(
3519            StringViewArray::from_iter_values(vec!["parquet"]),
3520            StringArray::from_iter_values(vec!["barquet"]),
3521            StringViewArray::from_iter_values(vec!["parquet", "barquet"]),
3522        );
3523
3524        ensure_compatible_write(
3525            LargeStringArray::from_iter_values(vec!["parquet"]),
3526            StringViewArray::from_iter_values(vec!["barquet"]),
3527            LargeStringArray::from_iter_values(vec!["parquet", "barquet"]),
3528        );
3529
3530        ensure_compatible_write(
3531            StringViewArray::from_iter_values(vec!["parquet"]),
3532            LargeStringArray::from_iter_values(vec!["barquet"]),
3533            StringViewArray::from_iter_values(vec!["parquet", "barquet"]),
3534        );
3535
3536        // check compatibility for binary types
3537
3538        ensure_compatible_write(
3539            BinaryArray::from_iter_values(vec![b"parquet"]),
3540            LargeBinaryArray::from_iter_values(vec![b"barquet"]),
3541            BinaryArray::from_iter_values(vec![b"parquet", b"barquet"]),
3542        );
3543
3544        ensure_compatible_write(
3545            LargeBinaryArray::from_iter_values(vec![b"parquet"]),
3546            BinaryArray::from_iter_values(vec![b"barquet"]),
3547            LargeBinaryArray::from_iter_values(vec![b"parquet", b"barquet"]),
3548        );
3549
3550        ensure_compatible_write(
3551            BinaryArray::from_iter_values(vec![b"parquet"]),
3552            BinaryViewArray::from_iter_values(vec![b"barquet"]),
3553            BinaryArray::from_iter_values(vec![b"parquet", b"barquet"]),
3554        );
3555
3556        ensure_compatible_write(
3557            BinaryViewArray::from_iter_values(vec![b"parquet"]),
3558            BinaryArray::from_iter_values(vec![b"barquet"]),
3559            BinaryViewArray::from_iter_values(vec![b"parquet", b"barquet"]),
3560        );
3561
3562        ensure_compatible_write(
3563            BinaryViewArray::from_iter_values(vec![b"parquet"]),
3564            LargeBinaryArray::from_iter_values(vec![b"barquet"]),
3565            BinaryViewArray::from_iter_values(vec![b"parquet", b"barquet"]),
3566        );
3567
3568        ensure_compatible_write(
3569            LargeBinaryArray::from_iter_values(vec![b"parquet"]),
3570            BinaryViewArray::from_iter_values(vec![b"barquet"]),
3571            LargeBinaryArray::from_iter_values(vec![b"parquet", b"barquet"]),
3572        );
3573
3574        // check compatibility for list types
3575
3576        let list_field_metadata = HashMap::from_iter(vec![(
3577            PARQUET_FIELD_ID_META_KEY.to_string(),
3578            "1".to_string(),
3579        )]);
3580        let list_field = Field::new_list_field(DataType::Int32, false);
3581
3582        let values1 = Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4]));
3583        let offsets1 = OffsetBuffer::new(vec![0, 2, 5].into());
3584
3585        let values2 = Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9]));
3586        let offsets2 = OffsetBuffer::new(vec![0, 3, 5].into());
3587
3588        let values_expected = Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]));
3589        let offsets_expected = OffsetBuffer::new(vec![0, 2, 5, 8, 10].into());
3590
3591        ensure_compatible_write(
3592            // when the initial schema has the metadata ...
3593            ListArray::try_new(
3594                Arc::new(
3595                    list_field
3596                        .clone()
3597                        .with_metadata(list_field_metadata.clone()),
3598                ),
3599                offsets1,
3600                values1,
3601                None,
3602            )
3603            .unwrap(),
3604            // ... and some intermediate schema doesn't have the metadata
3605            ListArray::try_new(Arc::new(list_field.clone()), offsets2, values2, None).unwrap(),
3606            // ... the write will still go through, and the resulting schema will inherit the initial metadata
3607            ListArray::try_new(
3608                Arc::new(
3609                    list_field
3610                        .clone()
3611                        .with_metadata(list_field_metadata.clone()),
3612                ),
3613                offsets_expected,
3614                values_expected,
3615                None,
3616            )
3617            .unwrap(),
3618        );
3619    }
3620
3621    #[test]
3622    fn arrow_writer_primitive_dictionary() {
3623        // define schema
3624        #[allow(deprecated)]
3625        let schema = Arc::new(Schema::new(vec![Field::new_dict(
3626            "dictionary",
3627            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)),
3628            true,
3629            42,
3630            true,
3631        )]));
3632
3633        // create some data
3634        let mut builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::new();
3635        builder.append(12345678).unwrap();
3636        builder.append_null();
3637        builder.append(22345678).unwrap();
3638        builder.append(12345678).unwrap();
3639        let d = builder.finish();
3640
3641        one_column_roundtrip_with_schema(Arc::new(d), schema);
3642    }
3643
3644    #[test]
3645    fn arrow_writer_decimal32_dictionary() {
3646        let integers = vec![12345, 56789, 34567];
3647
3648        let keys = UInt8Array::from(vec![Some(0), None, Some(1), Some(2), Some(1)]);
3649
3650        let values = Decimal32Array::from(integers.clone())
3651            .with_precision_and_scale(5, 2)
3652            .unwrap();
3653
3654        let array = DictionaryArray::new(keys, Arc::new(values));
3655        one_column_roundtrip(Arc::new(array.clone()), true);
3656
3657        let values = Decimal32Array::from(integers)
3658            .with_precision_and_scale(9, 2)
3659            .unwrap();
3660
3661        let array = array.with_values(Arc::new(values));
3662        one_column_roundtrip(Arc::new(array), true);
3663    }
3664
3665    #[test]
3666    fn arrow_writer_decimal64_dictionary() {
3667        let integers = vec![12345, 56789, 34567];
3668
3669        let keys = UInt8Array::from(vec![Some(0), None, Some(1), Some(2), Some(1)]);
3670
3671        let values = Decimal64Array::from(integers.clone())
3672            .with_precision_and_scale(5, 2)
3673            .unwrap();
3674
3675        let array = DictionaryArray::new(keys, Arc::new(values));
3676        one_column_roundtrip(Arc::new(array.clone()), true);
3677
3678        let values = Decimal64Array::from(integers)
3679            .with_precision_and_scale(12, 2)
3680            .unwrap();
3681
3682        let array = array.with_values(Arc::new(values));
3683        one_column_roundtrip(Arc::new(array), true);
3684    }
3685
3686    #[test]
3687    fn arrow_writer_decimal128_dictionary() {
3688        let integers = vec![12345, 56789, 34567];
3689
3690        let keys = UInt8Array::from(vec![Some(0), None, Some(1), Some(2), Some(1)]);
3691
3692        let values = Decimal128Array::from(integers.clone())
3693            .with_precision_and_scale(5, 2)
3694            .unwrap();
3695
3696        let array = DictionaryArray::new(keys, Arc::new(values));
3697        one_column_roundtrip(Arc::new(array.clone()), true);
3698
3699        let values = Decimal128Array::from(integers)
3700            .with_precision_and_scale(12, 2)
3701            .unwrap();
3702
3703        let array = array.with_values(Arc::new(values));
3704        one_column_roundtrip(Arc::new(array), true);
3705    }
3706
3707    #[test]
3708    fn arrow_writer_decimal256_dictionary() {
3709        let integers = vec![
3710            i256::from_i128(12345),
3711            i256::from_i128(56789),
3712            i256::from_i128(34567),
3713        ];
3714
3715        let keys = UInt8Array::from(vec![Some(0), None, Some(1), Some(2), Some(1)]);
3716
3717        let values = Decimal256Array::from(integers.clone())
3718            .with_precision_and_scale(5, 2)
3719            .unwrap();
3720
3721        let array = DictionaryArray::new(keys, Arc::new(values));
3722        one_column_roundtrip(Arc::new(array.clone()), true);
3723
3724        let values = Decimal256Array::from(integers)
3725            .with_precision_and_scale(12, 2)
3726            .unwrap();
3727
3728        let array = array.with_values(Arc::new(values));
3729        one_column_roundtrip(Arc::new(array), true);
3730    }
3731
3732    #[test]
3733    fn arrow_writer_string_dictionary_unsigned_index() {
3734        // define schema
3735        #[allow(deprecated)]
3736        let schema = Arc::new(Schema::new(vec![Field::new_dict(
3737            "dictionary",
3738            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
3739            true,
3740            42,
3741            true,
3742        )]));
3743
3744        // create some data
3745        let d: UInt8DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")]
3746            .iter()
3747            .copied()
3748            .collect();
3749
3750        one_column_roundtrip_with_schema(Arc::new(d), schema);
3751    }
3752
3753    #[test]
3754    fn u32_min_max() {
3755        // check values roundtrip through parquet
3756        let src = [
3757            u32::MIN,
3758            u32::MIN + 1,
3759            (i32::MAX as u32) - 1,
3760            i32::MAX as u32,
3761            (i32::MAX as u32) + 1,
3762            u32::MAX - 1,
3763            u32::MAX,
3764        ];
3765        let values = Arc::new(UInt32Array::from_iter_values(src.iter().cloned()));
3766        let files = one_column_roundtrip(values, false);
3767
3768        for file in files {
3769            // check statistics are valid
3770            let reader = SerializedFileReader::new(file).unwrap();
3771            let metadata = reader.metadata();
3772
3773            let mut row_offset = 0;
3774            for row_group in metadata.row_groups() {
3775                assert_eq!(row_group.num_columns(), 1);
3776                let column = row_group.column(0);
3777
3778                let num_values = column.num_values() as usize;
3779                let src_slice = &src[row_offset..row_offset + num_values];
3780                row_offset += column.num_values() as usize;
3781
3782                let stats = column.statistics().unwrap();
3783                if let Statistics::Int32(stats) = stats {
3784                    assert_eq!(
3785                        *stats.min_opt().unwrap() as u32,
3786                        *src_slice.iter().min().unwrap()
3787                    );
3788                    assert_eq!(
3789                        *stats.max_opt().unwrap() as u32,
3790                        *src_slice.iter().max().unwrap()
3791                    );
3792                } else {
3793                    panic!("Statistics::Int32 missing")
3794                }
3795            }
3796        }
3797    }
3798
3799    #[test]
3800    fn u64_min_max() {
3801        // check values roundtrip through parquet
3802        let src = [
3803            u64::MIN,
3804            u64::MIN + 1,
3805            (i64::MAX as u64) - 1,
3806            i64::MAX as u64,
3807            (i64::MAX as u64) + 1,
3808            u64::MAX - 1,
3809            u64::MAX,
3810        ];
3811        let values = Arc::new(UInt64Array::from_iter_values(src.iter().cloned()));
3812        let files = one_column_roundtrip(values, false);
3813
3814        for file in files {
3815            // check statistics are valid
3816            let reader = SerializedFileReader::new(file).unwrap();
3817            let metadata = reader.metadata();
3818
3819            let mut row_offset = 0;
3820            for row_group in metadata.row_groups() {
3821                assert_eq!(row_group.num_columns(), 1);
3822                let column = row_group.column(0);
3823
3824                let num_values = column.num_values() as usize;
3825                let src_slice = &src[row_offset..row_offset + num_values];
3826                row_offset += column.num_values() as usize;
3827
3828                let stats = column.statistics().unwrap();
3829                if let Statistics::Int64(stats) = stats {
3830                    assert_eq!(
3831                        *stats.min_opt().unwrap() as u64,
3832                        *src_slice.iter().min().unwrap()
3833                    );
3834                    assert_eq!(
3835                        *stats.max_opt().unwrap() as u64,
3836                        *src_slice.iter().max().unwrap()
3837                    );
3838                } else {
3839                    panic!("Statistics::Int64 missing")
3840                }
3841            }
3842        }
3843    }
3844
3845    #[test]
3846    fn statistics_null_counts_only_nulls() {
3847        // check that null-count statistics for "only NULL"-columns are correct
3848        let values = Arc::new(UInt64Array::from(vec![None, None]));
3849        let files = one_column_roundtrip(values, true);
3850
3851        for file in files {
3852            // check statistics are valid
3853            let reader = SerializedFileReader::new(file).unwrap();
3854            let metadata = reader.metadata();
3855            assert_eq!(metadata.num_row_groups(), 1);
3856            let row_group = metadata.row_group(0);
3857            assert_eq!(row_group.num_columns(), 1);
3858            let column = row_group.column(0);
3859            let stats = column.statistics().unwrap();
3860            assert_eq!(stats.null_count_opt(), Some(2));
3861        }
3862    }
3863
3864    #[test]
3865    fn test_list_of_struct_roundtrip() {
3866        // define schema
3867        let int_field = Field::new("a", DataType::Int32, true);
3868        let int_field2 = Field::new("b", DataType::Int32, true);
3869
3870        let int_builder = Int32Builder::with_capacity(10);
3871        let int_builder2 = Int32Builder::with_capacity(10);
3872
3873        let struct_builder = StructBuilder::new(
3874            vec![int_field, int_field2],
3875            vec![Box::new(int_builder), Box::new(int_builder2)],
3876        );
3877        let mut list_builder = ListBuilder::new(struct_builder);
3878
3879        // Construct the following array
3880        // [{a: 1, b: 2}], [], null, [null, null], [{a: null, b: 3}], [{a: 2, b: null}]
3881
3882        // [{a: 1, b: 2}]
3883        let values = list_builder.values();
3884        values
3885            .field_builder::<Int32Builder>(0)
3886            .unwrap()
3887            .append_value(1);
3888        values
3889            .field_builder::<Int32Builder>(1)
3890            .unwrap()
3891            .append_value(2);
3892        values.append(true);
3893        list_builder.append(true);
3894
3895        // []
3896        list_builder.append(true);
3897
3898        // null
3899        list_builder.append(false);
3900
3901        // [null, null]
3902        let values = list_builder.values();
3903        values
3904            .field_builder::<Int32Builder>(0)
3905            .unwrap()
3906            .append_null();
3907        values
3908            .field_builder::<Int32Builder>(1)
3909            .unwrap()
3910            .append_null();
3911        values.append(false);
3912        values
3913            .field_builder::<Int32Builder>(0)
3914            .unwrap()
3915            .append_null();
3916        values
3917            .field_builder::<Int32Builder>(1)
3918            .unwrap()
3919            .append_null();
3920        values.append(false);
3921        list_builder.append(true);
3922
3923        // [{a: null, b: 3}]
3924        let values = list_builder.values();
3925        values
3926            .field_builder::<Int32Builder>(0)
3927            .unwrap()
3928            .append_null();
3929        values
3930            .field_builder::<Int32Builder>(1)
3931            .unwrap()
3932            .append_value(3);
3933        values.append(true);
3934        list_builder.append(true);
3935
3936        // [{a: 2, b: null}]
3937        let values = list_builder.values();
3938        values
3939            .field_builder::<Int32Builder>(0)
3940            .unwrap()
3941            .append_value(2);
3942        values
3943            .field_builder::<Int32Builder>(1)
3944            .unwrap()
3945            .append_null();
3946        values.append(true);
3947        list_builder.append(true);
3948
3949        let array = Arc::new(list_builder.finish());
3950
3951        one_column_roundtrip(array, true);
3952    }
3953
3954    fn row_group_sizes(metadata: &ParquetMetaData) -> Vec<i64> {
3955        metadata.row_groups().iter().map(|x| x.num_rows()).collect()
3956    }
3957
3958    #[test]
3959    fn test_aggregates_records() {
3960        let arrays = [
3961            Int32Array::from((0..100).collect::<Vec<_>>()),
3962            Int32Array::from((0..50).collect::<Vec<_>>()),
3963            Int32Array::from((200..500).collect::<Vec<_>>()),
3964        ];
3965
3966        let schema = Arc::new(Schema::new(vec![Field::new(
3967            "int",
3968            ArrowDataType::Int32,
3969            false,
3970        )]));
3971
3972        let file = tempfile::tempfile().unwrap();
3973
3974        let props = WriterProperties::builder()
3975            .set_max_row_group_row_count(Some(200))
3976            .build();
3977
3978        let mut writer =
3979            ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), Some(props)).unwrap();
3980
3981        for array in arrays {
3982            let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
3983            writer.write(&batch).unwrap();
3984        }
3985
3986        writer.close().unwrap();
3987
3988        let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();
3989        assert_eq!(&row_group_sizes(builder.metadata()), &[200, 200, 50]);
3990
3991        let batches = builder
3992            .with_batch_size(100)
3993            .build()
3994            .unwrap()
3995            .collect::<ArrowResult<Vec<_>>>()
3996            .unwrap();
3997
3998        assert_eq!(batches.len(), 5);
3999        assert!(batches.iter().all(|x| x.num_columns() == 1));
4000
4001        let batch_sizes: Vec<_> = batches.iter().map(|x| x.num_rows()).collect();
4002
4003        assert_eq!(&batch_sizes, &[100, 100, 100, 100, 50]);
4004
4005        let values: Vec<_> = batches
4006            .iter()
4007            .flat_map(|x| {
4008                x.column(0)
4009                    .as_any()
4010                    .downcast_ref::<Int32Array>()
4011                    .unwrap()
4012                    .values()
4013                    .iter()
4014                    .cloned()
4015            })
4016            .collect();
4017
4018        let expected_values: Vec<_> = [0..100, 0..50, 200..500].into_iter().flatten().collect();
4019        assert_eq!(&values, &expected_values)
4020    }
4021
4022    #[test]
4023    fn complex_aggregate() {
4024        // Tests aggregating nested data
4025        let field_a = Arc::new(Field::new("leaf_a", DataType::Int32, false));
4026        let field_b = Arc::new(Field::new("leaf_b", DataType::Int32, true));
4027        let struct_a = Arc::new(Field::new(
4028            "struct_a",
4029            DataType::Struct(vec![field_a.clone(), field_b.clone()].into()),
4030            true,
4031        ));
4032
4033        let list_a = Arc::new(Field::new("list", DataType::List(struct_a), true));
4034        let struct_b = Arc::new(Field::new(
4035            "struct_b",
4036            DataType::Struct(vec![list_a.clone()].into()),
4037            false,
4038        ));
4039
4040        let schema = Arc::new(Schema::new(vec![struct_b]));
4041
4042        // create nested data
4043        let field_a_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
4044        let field_b_array =
4045            Int32Array::from_iter(vec![Some(1), None, Some(2), None, None, Some(6)]);
4046
4047        let struct_a_array = StructArray::from(vec![
4048            (field_a.clone(), Arc::new(field_a_array) as ArrayRef),
4049            (field_b.clone(), Arc::new(field_b_array) as ArrayRef),
4050        ]);
4051
4052        let list_data = ArrayDataBuilder::new(list_a.data_type().clone())
4053            .len(5)
4054            .add_buffer(Buffer::from_iter(vec![
4055                0_i32, 1_i32, 1_i32, 3_i32, 3_i32, 5_i32,
4056            ]))
4057            .null_bit_buffer(Some(Buffer::from_iter(vec![
4058                true, false, true, false, true,
4059            ])))
4060            .child_data(vec![struct_a_array.into_data()])
4061            .build()
4062            .unwrap();
4063
4064        let list_a_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
4065        let struct_b_array = StructArray::from(vec![(list_a.clone(), list_a_array)]);
4066
4067        let batch1 =
4068            RecordBatch::try_from_iter(vec![("struct_b", Arc::new(struct_b_array) as ArrayRef)])
4069                .unwrap();
4070
4071        let field_a_array = Int32Array::from(vec![6, 7, 8, 9, 10]);
4072        let field_b_array = Int32Array::from_iter(vec![None, None, None, Some(1), None]);
4073
4074        let struct_a_array = StructArray::from(vec![
4075            (field_a, Arc::new(field_a_array) as ArrayRef),
4076            (field_b, Arc::new(field_b_array) as ArrayRef),
4077        ]);
4078
4079        let list_data = ArrayDataBuilder::new(list_a.data_type().clone())
4080            .len(2)
4081            .add_buffer(Buffer::from_iter(vec![0_i32, 4_i32, 5_i32]))
4082            .child_data(vec![struct_a_array.into_data()])
4083            .build()
4084            .unwrap();
4085
4086        let list_a_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
4087        let struct_b_array = StructArray::from(vec![(list_a, list_a_array)]);
4088
4089        let batch2 =
4090            RecordBatch::try_from_iter(vec![("struct_b", Arc::new(struct_b_array) as ArrayRef)])
4091                .unwrap();
4092
4093        let batches = &[batch1, batch2];
4094
4095        // Verify data is as expected
4096
4097        let expected = r#"
4098            +-------------------------------------------------------------------------------------------------------+
4099            | struct_b                                                                                              |
4100            +-------------------------------------------------------------------------------------------------------+
4101            | {list: [{leaf_a: 1, leaf_b: 1}]}                                                                      |
4102            | {list: }                                                                                              |
4103            | {list: [{leaf_a: 2, leaf_b: }, {leaf_a: 3, leaf_b: 2}]}                                               |
4104            | {list: }                                                                                              |
4105            | {list: [{leaf_a: 4, leaf_b: }, {leaf_a: 5, leaf_b: }]}                                                |
4106            | {list: [{leaf_a: 6, leaf_b: }, {leaf_a: 7, leaf_b: }, {leaf_a: 8, leaf_b: }, {leaf_a: 9, leaf_b: 1}]} |
4107            | {list: [{leaf_a: 10, leaf_b: }]}                                                                      |
4108            +-------------------------------------------------------------------------------------------------------+
4109        "#.trim().split('\n').map(|x| x.trim()).collect::<Vec<_>>().join("\n");
4110
4111        let actual = pretty_format_batches(batches).unwrap().to_string();
4112        assert_eq!(actual, expected);
4113
4114        // Write data
4115        let file = tempfile::tempfile().unwrap();
4116        let props = WriterProperties::builder()
4117            .set_max_row_group_row_count(Some(6))
4118            .build();
4119
4120        let mut writer =
4121            ArrowWriter::try_new(file.try_clone().unwrap(), schema, Some(props)).unwrap();
4122
4123        for batch in batches {
4124            writer.write(batch).unwrap();
4125        }
4126        writer.close().unwrap();
4127
4128        // Read Data
4129        // Should have written entire first batch and first row of second to the first row group
4130        // leaving a single row in the second row group
4131
4132        let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();
4133        assert_eq!(&row_group_sizes(builder.metadata()), &[6, 1]);
4134
4135        let batches = builder
4136            .with_batch_size(2)
4137            .build()
4138            .unwrap()
4139            .collect::<ArrowResult<Vec<_>>>()
4140            .unwrap();
4141
4142        assert_eq!(batches.len(), 4);
4143        let batch_counts: Vec<_> = batches.iter().map(|x| x.num_rows()).collect();
4144        assert_eq!(&batch_counts, &[2, 2, 2, 1]);
4145
4146        let actual = pretty_format_batches(&batches).unwrap().to_string();
4147        assert_eq!(actual, expected);
4148    }
4149
4150    #[test]
4151    fn test_arrow_writer_metadata() {
4152        let batch_schema = Schema::new(vec![Field::new("int32", DataType::Int32, false)]);
4153        let file_schema = batch_schema.clone().with_metadata(
4154            vec![("foo".to_string(), "bar".to_string())]
4155                .into_iter()
4156                .collect(),
4157        );
4158
4159        let batch = RecordBatch::try_new(
4160            Arc::new(batch_schema),
4161            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _],
4162        )
4163        .unwrap();
4164
4165        let mut buf = Vec::with_capacity(1024);
4166        let mut writer = ArrowWriter::try_new(&mut buf, Arc::new(file_schema), None).unwrap();
4167        writer.write(&batch).unwrap();
4168        writer.close().unwrap();
4169    }
4170
4171    #[test]
4172    fn test_arrow_writer_nullable() {
4173        let batch_schema = Schema::new(vec![Field::new("int32", DataType::Int32, false)]);
4174        let file_schema = Schema::new(vec![Field::new("int32", DataType::Int32, true)]);
4175        let file_schema = Arc::new(file_schema);
4176
4177        let batch = RecordBatch::try_new(
4178            Arc::new(batch_schema),
4179            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _],
4180        )
4181        .unwrap();
4182
4183        let mut buf = Vec::with_capacity(1024);
4184        let mut writer = ArrowWriter::try_new(&mut buf, file_schema.clone(), None).unwrap();
4185        writer.write(&batch).unwrap();
4186        writer.close().unwrap();
4187
4188        let mut read = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024).unwrap();
4189        let back = read.next().unwrap().unwrap();
4190        assert_eq!(back.schema(), file_schema);
4191        assert_ne!(back.schema(), batch.schema());
4192        assert_eq!(back.column(0).as_ref(), batch.column(0).as_ref());
4193    }
4194
4195    #[test]
4196    fn in_progress_accounting() {
4197        // define schema
4198        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
4199
4200        // create some data
4201        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
4202
4203        // build a record batch
4204        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
4205
4206        let mut writer = ArrowWriter::try_new(vec![], batch.schema(), None).unwrap();
4207
4208        // starts empty
4209        assert_eq!(writer.in_progress_size(), 0);
4210        assert_eq!(writer.in_progress_rows(), 0);
4211        assert_eq!(writer.memory_size(), 0);
4212        assert_eq!(writer.bytes_written(), 4); // Initial header
4213        writer.write(&batch).unwrap();
4214
4215        // updated on write
4216        let initial_size = writer.in_progress_size();
4217        assert!(initial_size > 0);
4218        assert_eq!(writer.in_progress_rows(), 5);
4219        let initial_memory = writer.memory_size();
4220        assert!(initial_memory > 0);
4221        // memory estimate is larger than estimated encoded size
4222        assert!(
4223            initial_size <= initial_memory,
4224            "{initial_size} <= {initial_memory}"
4225        );
4226
4227        // updated on second write
4228        writer.write(&batch).unwrap();
4229        assert!(writer.in_progress_size() > initial_size);
4230        assert_eq!(writer.in_progress_rows(), 10);
4231        assert!(writer.memory_size() > initial_memory);
4232        assert!(
4233            writer.in_progress_size() <= writer.memory_size(),
4234            "in_progress_size {} <= memory_size {}",
4235            writer.in_progress_size(),
4236            writer.memory_size()
4237        );
4238
4239        // in progress tracking is cleared, but the overall data written is updated
4240        let pre_flush_bytes_written = writer.bytes_written();
4241        writer.flush().unwrap();
4242        assert_eq!(writer.in_progress_size(), 0);
4243        assert_eq!(writer.memory_size(), 0);
4244        assert!(writer.bytes_written() > pre_flush_bytes_written);
4245
4246        writer.close().unwrap();
4247    }
4248
4249    #[test]
4250    fn test_writer_all_null() {
4251        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
4252        let b = Int32Array::new(vec![0; 5].into(), Some(NullBuffer::new_null(5)));
4253        let batch = RecordBatch::try_from_iter(vec![
4254            ("a", Arc::new(a) as ArrayRef),
4255            ("b", Arc::new(b) as ArrayRef),
4256        ])
4257        .unwrap();
4258
4259        let mut buf = Vec::with_capacity(1024);
4260        let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None).unwrap();
4261        writer.write(&batch).unwrap();
4262        writer.close().unwrap();
4263
4264        let bytes = Bytes::from(buf);
4265        let options = ReadOptionsBuilder::new().with_page_index().build();
4266        let reader = SerializedFileReader::new_with_options(bytes, options).unwrap();
4267        let index = reader.metadata().offset_index().unwrap();
4268
4269        assert_eq!(index.len(), 1);
4270        assert_eq!(index[0].len(), 2); // 2 columns
4271        assert_eq!(index[0][0].page_locations().len(), 1); // 1 page
4272        assert_eq!(index[0][1].page_locations().len(), 1); // 1 page
4273    }
4274
4275    #[test]
4276    fn test_disabled_statistics_with_page() {
4277        let file_schema = Schema::new(vec![
4278            Field::new("a", DataType::Utf8, true),
4279            Field::new("b", DataType::Utf8, true),
4280        ]);
4281        let file_schema = Arc::new(file_schema);
4282
4283        let batch = RecordBatch::try_new(
4284            file_schema.clone(),
4285            vec![
4286                Arc::new(StringArray::from(vec!["a", "b", "c", "d"])) as _,
4287                Arc::new(StringArray::from(vec!["w", "x", "y", "z"])) as _,
4288            ],
4289        )
4290        .unwrap();
4291
4292        let props = WriterProperties::builder()
4293            .set_statistics_enabled(EnabledStatistics::None)
4294            .set_column_statistics_enabled("a".into(), EnabledStatistics::Page)
4295            .build();
4296
4297        let mut buf = Vec::with_capacity(1024);
4298        let mut writer = ArrowWriter::try_new(&mut buf, file_schema.clone(), Some(props)).unwrap();
4299        writer.write(&batch).unwrap();
4300
4301        let metadata = writer.close().unwrap();
4302        assert_eq!(metadata.num_row_groups(), 1);
4303        let row_group = metadata.row_group(0);
4304        assert_eq!(row_group.num_columns(), 2);
4305        // Column "a" has both offset and column index, as requested
4306        assert!(row_group.column(0).offset_index_offset().is_some());
4307        assert!(row_group.column(0).column_index_offset().is_some());
4308        // Column "b" should only have offset index
4309        assert!(row_group.column(1).offset_index_offset().is_some());
4310        assert!(row_group.column(1).column_index_offset().is_none());
4311
4312        let options = ReadOptionsBuilder::new().with_page_index().build();
4313        let reader = SerializedFileReader::new_with_options(Bytes::from(buf), options).unwrap();
4314
4315        let row_group = reader.get_row_group(0).unwrap();
4316        let a_col = row_group.metadata().column(0);
4317        let b_col = row_group.metadata().column(1);
4318
4319        // Column chunk of column "a" should have chunk level statistics
4320        if let Statistics::ByteArray(byte_array_stats) = a_col.statistics().unwrap() {
4321            let min = byte_array_stats.min_opt().unwrap();
4322            let max = byte_array_stats.max_opt().unwrap();
4323
4324            assert_eq!(min.as_bytes(), b"a");
4325            assert_eq!(max.as_bytes(), b"d");
4326        } else {
4327            panic!("expecting Statistics::ByteArray");
4328        }
4329
4330        // The column chunk for column "b" shouldn't have statistics
4331        assert!(b_col.statistics().is_none());
4332
4333        let offset_index = reader.metadata().offset_index().unwrap();
4334        assert_eq!(offset_index.len(), 1); // 1 row group
4335        assert_eq!(offset_index[0].len(), 2); // 2 columns
4336
4337        let column_index = reader.metadata().column_index().unwrap();
4338        assert_eq!(column_index.len(), 1); // 1 row group
4339        assert_eq!(column_index[0].len(), 2); // 2 columns
4340
4341        let a_idx = &column_index[0][0];
4342        assert!(
4343            matches!(a_idx, ColumnIndexMetaData::BYTE_ARRAY(_)),
4344            "{a_idx:?}"
4345        );
4346        let b_idx = &column_index[0][1];
4347        assert!(matches!(b_idx, ColumnIndexMetaData::NONE), "{b_idx:?}");
4348    }
4349
4350    #[test]
4351    fn test_disabled_statistics_with_chunk() {
4352        let file_schema = Schema::new(vec![
4353            Field::new("a", DataType::Utf8, true),
4354            Field::new("b", DataType::Utf8, true),
4355        ]);
4356        let file_schema = Arc::new(file_schema);
4357
4358        let batch = RecordBatch::try_new(
4359            file_schema.clone(),
4360            vec![
4361                Arc::new(StringArray::from(vec!["a", "b", "c", "d"])) as _,
4362                Arc::new(StringArray::from(vec!["w", "x", "y", "z"])) as _,
4363            ],
4364        )
4365        .unwrap();
4366
4367        let props = WriterProperties::builder()
4368            .set_statistics_enabled(EnabledStatistics::None)
4369            .set_column_statistics_enabled("a".into(), EnabledStatistics::Chunk)
4370            .build();
4371
4372        let mut buf = Vec::with_capacity(1024);
4373        let mut writer = ArrowWriter::try_new(&mut buf, file_schema.clone(), Some(props)).unwrap();
4374        writer.write(&batch).unwrap();
4375
4376        let metadata = writer.close().unwrap();
4377        assert_eq!(metadata.num_row_groups(), 1);
4378        let row_group = metadata.row_group(0);
4379        assert_eq!(row_group.num_columns(), 2);
4380        // Column "a" should only have offset index
4381        assert!(row_group.column(0).offset_index_offset().is_some());
4382        assert!(row_group.column(0).column_index_offset().is_none());
4383        // Column "b" should only have offset index
4384        assert!(row_group.column(1).offset_index_offset().is_some());
4385        assert!(row_group.column(1).column_index_offset().is_none());
4386
4387        let options = ReadOptionsBuilder::new().with_page_index().build();
4388        let reader = SerializedFileReader::new_with_options(Bytes::from(buf), options).unwrap();
4389
4390        let row_group = reader.get_row_group(0).unwrap();
4391        let a_col = row_group.metadata().column(0);
4392        let b_col = row_group.metadata().column(1);
4393
4394        // Column chunk of column "a" should have chunk level statistics
4395        if let Statistics::ByteArray(byte_array_stats) = a_col.statistics().unwrap() {
4396            let min = byte_array_stats.min_opt().unwrap();
4397            let max = byte_array_stats.max_opt().unwrap();
4398
4399            assert_eq!(min.as_bytes(), b"a");
4400            assert_eq!(max.as_bytes(), b"d");
4401        } else {
4402            panic!("expecting Statistics::ByteArray");
4403        }
4404
4405        // The column chunk for column "b"  shouldn't have statistics
4406        assert!(b_col.statistics().is_none());
4407
4408        let column_index = reader.metadata().column_index().unwrap();
4409        assert_eq!(column_index.len(), 1); // 1 row group
4410        assert_eq!(column_index[0].len(), 2); // 2 columns
4411
4412        let a_idx = &column_index[0][0];
4413        assert!(matches!(a_idx, ColumnIndexMetaData::NONE), "{a_idx:?}");
4414        let b_idx = &column_index[0][1];
4415        assert!(matches!(b_idx, ColumnIndexMetaData::NONE), "{b_idx:?}");
4416    }
4417
4418    #[test]
4419    fn test_arrow_writer_skip_metadata() {
4420        let batch_schema = Schema::new(vec![Field::new("int32", DataType::Int32, false)]);
4421        let file_schema = Arc::new(batch_schema.clone());
4422
4423        let batch = RecordBatch::try_new(
4424            Arc::new(batch_schema),
4425            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _],
4426        )
4427        .unwrap();
4428        let skip_options = ArrowWriterOptions::new().with_skip_arrow_metadata(true);
4429
4430        let mut buf = Vec::with_capacity(1024);
4431        let mut writer =
4432            ArrowWriter::try_new_with_options(&mut buf, file_schema.clone(), skip_options).unwrap();
4433        writer.write(&batch).unwrap();
4434        writer.close().unwrap();
4435
4436        let bytes = Bytes::from(buf);
4437        let reader_builder = ParquetRecordBatchReaderBuilder::try_new(bytes).unwrap();
4438        assert_eq!(file_schema, *reader_builder.schema());
4439        if let Some(key_value_metadata) = reader_builder
4440            .metadata()
4441            .file_metadata()
4442            .key_value_metadata()
4443        {
4444            assert!(
4445                !key_value_metadata
4446                    .iter()
4447                    .any(|kv| kv.key.as_str() == ARROW_SCHEMA_META_KEY)
4448            );
4449        }
4450    }
4451
4452    #[test]
4453    fn mismatched_schemas() {
4454        let batch_schema = Schema::new(vec![Field::new("count", DataType::Int32, false)]);
4455        let file_schema = Arc::new(Schema::new(vec![Field::new(
4456            "temperature",
4457            DataType::Float64,
4458            false,
4459        )]));
4460
4461        let batch = RecordBatch::try_new(
4462            Arc::new(batch_schema),
4463            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _],
4464        )
4465        .unwrap();
4466
4467        let mut buf = Vec::with_capacity(1024);
4468        let mut writer = ArrowWriter::try_new(&mut buf, file_schema.clone(), None).unwrap();
4469
4470        let err = writer.write(&batch).unwrap_err().to_string();
4471        assert_eq!(
4472            err,
4473            "Arrow: Incompatible type. Field 'temperature' has type Float64, array has type Int32"
4474        );
4475    }
4476
4477    #[test]
4478    // https://github.com/apache/arrow-rs/issues/6988
4479    fn test_roundtrip_empty_schema() {
4480        // create empty record batch with empty schema
4481        let empty_batch = RecordBatch::try_new_with_options(
4482            Arc::new(Schema::empty()),
4483            vec![],
4484            &RecordBatchOptions::default().with_row_count(Some(0)),
4485        )
4486        .unwrap();
4487
4488        // write to parquet
4489        let mut parquet_bytes: Vec<u8> = Vec::new();
4490        let mut writer =
4491            ArrowWriter::try_new(&mut parquet_bytes, empty_batch.schema(), None).unwrap();
4492        writer.write(&empty_batch).unwrap();
4493        writer.close().unwrap();
4494
4495        // read from parquet
4496        let bytes = Bytes::from(parquet_bytes);
4497        let reader = ParquetRecordBatchReaderBuilder::try_new(bytes).unwrap();
4498        assert_eq!(reader.schema(), &empty_batch.schema());
4499        let batches: Vec<_> = reader
4500            .build()
4501            .unwrap()
4502            .collect::<ArrowResult<Vec<_>>>()
4503            .unwrap();
4504        assert_eq!(batches.len(), 0);
4505    }
4506
4507    #[test]
4508    fn test_page_stats_not_written_by_default() {
4509        let string_field = Field::new("a", DataType::Utf8, false);
4510        let schema = Schema::new(vec![string_field]);
4511        let raw_string_values = vec!["Blart Versenwald III"];
4512        let string_values = StringArray::from(raw_string_values.clone());
4513        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(string_values)]).unwrap();
4514
4515        let props = WriterProperties::builder()
4516            .set_statistics_enabled(EnabledStatistics::Page)
4517            .set_dictionary_enabled(false)
4518            .set_encoding(Encoding::PLAIN)
4519            .set_compression(crate::basic::Compression::UNCOMPRESSED)
4520            .build();
4521
4522        let file = roundtrip_opts(&batch, props);
4523
4524        // read file and decode page headers
4525        // Note: use the thrift API as there is no Rust API to access the statistics in the page headers
4526
4527        // decode first page header
4528        let first_page = &file[4..];
4529        let mut prot = ThriftSliceInputProtocol::new(first_page);
4530        let hdr = PageHeader::read_thrift(&mut prot).unwrap();
4531        let stats = hdr.data_page_header.unwrap().statistics;
4532
4533        assert!(stats.is_none());
4534    }
4535
4536    #[test]
4537    fn test_page_stats_when_enabled() {
4538        let string_field = Field::new("a", DataType::Utf8, false);
4539        let schema = Schema::new(vec![string_field]);
4540        let raw_string_values = vec!["Blart Versenwald III", "Andrew Lamb"];
4541        let string_values = StringArray::from(raw_string_values.clone());
4542        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(string_values)]).unwrap();
4543
4544        let props = WriterProperties::builder()
4545            .set_statistics_enabled(EnabledStatistics::Page)
4546            .set_dictionary_enabled(false)
4547            .set_encoding(Encoding::PLAIN)
4548            .set_write_page_header_statistics(true)
4549            .set_compression(crate::basic::Compression::UNCOMPRESSED)
4550            .build();
4551
4552        let file = roundtrip_opts(&batch, props);
4553
4554        // read file and decode page headers
4555        // Note: use the thrift API as there is no Rust API to access the statistics in the page headers
4556
4557        // decode first page header
4558        let first_page = &file[4..];
4559        let mut prot = ThriftSliceInputProtocol::new(first_page);
4560        let hdr = PageHeader::read_thrift(&mut prot).unwrap();
4561        let stats = hdr.data_page_header.unwrap().statistics;
4562
4563        let stats = stats.unwrap();
4564        // check that min/max were actually written to the page
4565        assert!(stats.is_max_value_exact.unwrap());
4566        assert!(stats.is_min_value_exact.unwrap());
4567        assert_eq!(stats.max_value.unwrap(), "Blart Versenwald III".as_bytes());
4568        assert_eq!(stats.min_value.unwrap(), "Andrew Lamb".as_bytes());
4569    }
4570
4571    #[test]
4572    fn test_page_stats_truncation() {
4573        let string_field = Field::new("a", DataType::Utf8, false);
4574        let binary_field = Field::new("b", DataType::Binary, false);
4575        let schema = Schema::new(vec![string_field, binary_field]);
4576
4577        let raw_string_values = vec!["Blart Versenwald III"];
4578        let raw_binary_values = [b"Blart Versenwald III".to_vec()];
4579        let raw_binary_value_refs = raw_binary_values
4580            .iter()
4581            .map(|x| x.as_slice())
4582            .collect::<Vec<_>>();
4583
4584        let string_values = StringArray::from(raw_string_values.clone());
4585        let binary_values = BinaryArray::from(raw_binary_value_refs);
4586        let batch = RecordBatch::try_new(
4587            Arc::new(schema),
4588            vec![Arc::new(string_values), Arc::new(binary_values)],
4589        )
4590        .unwrap();
4591
4592        let props = WriterProperties::builder()
4593            .set_statistics_truncate_length(Some(2))
4594            .set_dictionary_enabled(false)
4595            .set_encoding(Encoding::PLAIN)
4596            .set_write_page_header_statistics(true)
4597            .set_compression(crate::basic::Compression::UNCOMPRESSED)
4598            .build();
4599
4600        let file = roundtrip_opts(&batch, props);
4601
4602        // read file and decode page headers
4603        // Note: use the thrift API as there is no Rust API to access the statistics in the page headers
4604
4605        // decode first page header
4606        let first_page = &file[4..];
4607        let mut prot = ThriftSliceInputProtocol::new(first_page);
4608        let hdr = PageHeader::read_thrift(&mut prot).unwrap();
4609        let stats = hdr.data_page_header.unwrap().statistics;
4610        assert!(stats.is_some());
4611        let stats = stats.unwrap();
4612        // check that min/max were properly truncated
4613        assert!(!stats.is_max_value_exact.unwrap());
4614        assert!(!stats.is_min_value_exact.unwrap());
4615        assert_eq!(stats.max_value.unwrap(), "Bm".as_bytes());
4616        assert_eq!(stats.min_value.unwrap(), "Bl".as_bytes());
4617
4618        // check second page now
4619        let second_page = &prot.as_slice()[hdr.compressed_page_size as usize..];
4620        let mut prot = ThriftSliceInputProtocol::new(second_page);
4621        let hdr = PageHeader::read_thrift(&mut prot).unwrap();
4622        let stats = hdr.data_page_header.unwrap().statistics;
4623        assert!(stats.is_some());
4624        let stats = stats.unwrap();
4625        // check that min/max were properly truncated
4626        assert!(!stats.is_max_value_exact.unwrap());
4627        assert!(!stats.is_min_value_exact.unwrap());
4628        assert_eq!(stats.max_value.unwrap(), "Bm".as_bytes());
4629        assert_eq!(stats.min_value.unwrap(), "Bl".as_bytes());
4630    }
4631
4632    #[test]
4633    fn test_page_encoding_statistics_roundtrip() {
4634        let batch_schema = Schema::new(vec![Field::new(
4635            "int32",
4636            arrow_schema::DataType::Int32,
4637            false,
4638        )]);
4639
4640        let batch = RecordBatch::try_new(
4641            Arc::new(batch_schema.clone()),
4642            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _],
4643        )
4644        .unwrap();
4645
4646        let mut file: File = tempfile::tempfile().unwrap();
4647        let mut writer = ArrowWriter::try_new(&mut file, Arc::new(batch_schema), None).unwrap();
4648        writer.write(&batch).unwrap();
4649        let file_metadata = writer.close().unwrap();
4650
4651        assert_eq!(file_metadata.num_row_groups(), 1);
4652        assert_eq!(file_metadata.row_group(0).num_columns(), 1);
4653        assert!(
4654            file_metadata
4655                .row_group(0)
4656                .column(0)
4657                .page_encoding_stats()
4658                .is_some()
4659        );
4660        let chunk_page_stats = file_metadata
4661            .row_group(0)
4662            .column(0)
4663            .page_encoding_stats()
4664            .unwrap();
4665
4666        // check that the read metadata is also correct
4667        let options = ReadOptionsBuilder::new()
4668            .with_page_index()
4669            .with_encoding_stats_as_mask(false)
4670            .build();
4671        let reader = SerializedFileReader::new_with_options(file, options).unwrap();
4672
4673        let rowgroup = reader.get_row_group(0).expect("row group missing");
4674        assert_eq!(rowgroup.num_columns(), 1);
4675        let column = rowgroup.metadata().column(0);
4676        assert!(column.page_encoding_stats().is_some());
4677        let file_page_stats = column.page_encoding_stats().unwrap();
4678        assert_eq!(chunk_page_stats, file_page_stats);
4679    }
4680
4681    #[test]
4682    fn test_different_dict_page_size_limit() {
4683        let array = Arc::new(Int64Array::from_iter(0..1024 * 1024));
4684        let schema = Arc::new(Schema::new(vec![
4685            Field::new("col0", arrow_schema::DataType::Int64, false),
4686            Field::new("col1", arrow_schema::DataType::Int64, false),
4687        ]));
4688        let batch =
4689            arrow_array::RecordBatch::try_new(schema.clone(), vec![array.clone(), array]).unwrap();
4690
4691        let props = WriterProperties::builder()
4692            .set_dictionary_page_size_limit(1024 * 1024)
4693            .set_column_dictionary_page_size_limit(ColumnPath::from("col1"), 1024 * 1024 * 4)
4694            .build();
4695        let mut writer = ArrowWriter::try_new(Vec::new(), schema, Some(props)).unwrap();
4696        writer.write(&batch).unwrap();
4697        let data = Bytes::from(writer.into_inner().unwrap());
4698
4699        let mut metadata = ParquetMetaDataReader::new();
4700        metadata.try_parse(&data).unwrap();
4701        let metadata = metadata.finish().unwrap();
4702        let col0_meta = metadata.row_group(0).column(0);
4703        let col1_meta = metadata.row_group(0).column(1);
4704
4705        let get_dict_page_size = move |meta: &ColumnChunkMetaData| {
4706            let mut reader =
4707                SerializedPageReader::new(Arc::new(data.clone()), meta, 0, None).unwrap();
4708            let page = reader.get_next_page().unwrap().unwrap();
4709            match page {
4710                Page::DictionaryPage { buf, .. } => buf.len(),
4711                _ => panic!("expected DictionaryPage"),
4712            }
4713        };
4714
4715        assert_eq!(get_dict_page_size(col0_meta), 1024 * 1024);
4716        assert_eq!(get_dict_page_size(col1_meta), 1024 * 1024 * 4);
4717    }
4718
4719    struct WriteBatchesShape {
4720        num_batches: usize,
4721        rows_per_batch: usize,
4722        row_size: usize,
4723    }
4724
4725    /// Helper function to write batches with the provided `WriteBatchesShape` into an `ArrowWriter`
4726    fn write_batches(
4727        WriteBatchesShape {
4728            num_batches,
4729            rows_per_batch,
4730            row_size,
4731        }: WriteBatchesShape,
4732        props: WriterProperties,
4733    ) -> ParquetRecordBatchReaderBuilder<File> {
4734        let schema = Arc::new(Schema::new(vec![Field::new(
4735            "str",
4736            ArrowDataType::Utf8,
4737            false,
4738        )]));
4739        let file = tempfile::tempfile().unwrap();
4740        let mut writer =
4741            ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), Some(props)).unwrap();
4742
4743        for batch_idx in 0..num_batches {
4744            let strings: Vec<String> = (0..rows_per_batch)
4745                .map(|i| format!("{:0>width$}", batch_idx * 10 + i, width = row_size))
4746                .collect();
4747            let array = StringArray::from(strings);
4748            let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
4749            writer.write(&batch).unwrap();
4750        }
4751        writer.close().unwrap();
4752        ParquetRecordBatchReaderBuilder::try_new(file).unwrap()
4753    }
4754
4755    #[test]
4756    // When both limits are None, all data should go into a single row group
4757    fn test_row_group_limit_none_writes_single_row_group() {
4758        let props = WriterProperties::builder()
4759            .set_max_row_group_row_count(None)
4760            .set_max_row_group_bytes(None)
4761            .build();
4762
4763        let builder = write_batches(
4764            WriteBatchesShape {
4765                num_batches: 1,
4766                rows_per_batch: 1000,
4767                row_size: 4,
4768            },
4769            props,
4770        );
4771
4772        assert_eq!(
4773            &row_group_sizes(builder.metadata()),
4774            &[1000],
4775            "With no limits, all rows should be in a single row group"
4776        );
4777    }
4778
4779    #[test]
4780    // When only max_row_group_size is set, respect the row limit
4781    fn test_row_group_limit_rows_only() {
4782        let props = WriterProperties::builder()
4783            .set_max_row_group_row_count(Some(300))
4784            .set_max_row_group_bytes(None)
4785            .build();
4786
4787        let builder = write_batches(
4788            WriteBatchesShape {
4789                num_batches: 1,
4790                rows_per_batch: 1000,
4791                row_size: 4,
4792            },
4793            props,
4794        );
4795
4796        assert_eq!(
4797            &row_group_sizes(builder.metadata()),
4798            &[300, 300, 300, 100],
4799            "Row groups should be split by row count"
4800        );
4801    }
4802
4803    #[test]
4804    // When only max_row_group_bytes is set, respect the byte limit
4805    fn test_row_group_limit_bytes_only() {
4806        let props = WriterProperties::builder()
4807            .set_max_row_group_row_count(None)
4808            // Set byte limit to approximately fit ~30 rows worth of data (~100 bytes each)
4809            .set_max_row_group_bytes(Some(3500))
4810            .build();
4811
4812        let builder = write_batches(
4813            WriteBatchesShape {
4814                num_batches: 10,
4815                rows_per_batch: 10,
4816                row_size: 100,
4817            },
4818            props,
4819        );
4820
4821        let sizes = row_group_sizes(builder.metadata());
4822
4823        assert!(
4824            sizes.len() > 1,
4825            "Should have multiple row groups due to byte limit, got {sizes:?}",
4826        );
4827
4828        let total_rows: i64 = sizes.iter().sum();
4829        assert_eq!(total_rows, 100, "Total rows should be preserved");
4830    }
4831
4832    #[test]
4833    // If an in-progress row group is already oversized, it should be flushed before writing more.
4834    fn test_row_group_limit_bytes_flushes_when_current_group_already_too_large() {
4835        let schema = Arc::new(Schema::new(vec![Field::new(
4836            "str",
4837            ArrowDataType::Utf8,
4838            false,
4839        )]));
4840        let file = tempfile::tempfile().unwrap();
4841
4842        // Start with no byte limit so we can intentionally build an oversized in-progress row group.
4843        let props = WriterProperties::builder()
4844            .set_max_row_group_row_count(None)
4845            .set_max_row_group_bytes(None)
4846            .build();
4847        let mut writer =
4848            ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), Some(props)).unwrap();
4849
4850        let first_array = StringArray::from(
4851            (0..10)
4852                .map(|i| format!("{:0>100}", i))
4853                .collect::<Vec<String>>(),
4854        );
4855        let first_batch =
4856            RecordBatch::try_new(schema.clone(), vec![Arc::new(first_array)]).unwrap();
4857        writer.write(&first_batch).unwrap();
4858        assert_eq!(writer.in_progress_rows(), 10);
4859
4860        // Tighten the limit below the current in-progress bytes to exercise:
4861        // `if current_bytes >= max_bytes { self.flush()?; ... }`
4862        writer.max_row_group_bytes = Some(1);
4863
4864        let second_array = StringArray::from(vec!["x".to_string()]);
4865        let second_batch =
4866            RecordBatch::try_new(schema.clone(), vec![Arc::new(second_array)]).unwrap();
4867        writer.write(&second_batch).unwrap();
4868        writer.close().unwrap();
4869        let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();
4870
4871        assert_eq!(
4872            &row_group_sizes(builder.metadata()),
4873            &[10, 1],
4874            "The second write should flush an oversized in-progress row group first",
4875        );
4876    }
4877
4878    #[test]
4879    // When both limits are set, the row limit triggers first
4880    fn test_row_group_limit_both_row_wins_single_batch() {
4881        let props = WriterProperties::builder()
4882            .set_max_row_group_row_count(Some(200)) // Will trigger at 200 rows
4883            .set_max_row_group_bytes(Some(1024 * 1024)) // 1MB - won't trigger for small int data
4884            .build();
4885
4886        let builder = write_batches(
4887            WriteBatchesShape {
4888                num_batches: 1,
4889                row_size: 4,
4890                rows_per_batch: 1000,
4891            },
4892            props,
4893        );
4894
4895        assert_eq!(
4896            &row_group_sizes(builder.metadata()),
4897            &[200, 200, 200, 200, 200],
4898            "Row limit should trigger before byte limit"
4899        );
4900    }
4901
4902    #[test]
4903    // When both limits are set, the row limit triggers first
4904    fn test_row_group_limit_both_row_wins_multiple_batches() {
4905        let props = WriterProperties::builder()
4906            .set_max_row_group_row_count(Some(5)) // Will trigger every 5 rows
4907            .set_max_row_group_bytes(Some(9999)) // Won't trigger
4908            .build();
4909
4910        let builder = write_batches(
4911            WriteBatchesShape {
4912                num_batches: 10,
4913                rows_per_batch: 10,
4914                row_size: 100,
4915            },
4916            props,
4917        );
4918
4919        assert_eq!(
4920            &row_group_sizes(builder.metadata()),
4921            &[5; 20],
4922            "Row limit should trigger before byte limit"
4923        );
4924    }
4925
4926    #[test]
4927    // When both limits are set, the byte limit triggers first
4928    fn test_row_group_limit_both_bytes_wins() {
4929        let props = WriterProperties::builder()
4930            .set_max_row_group_row_count(Some(1000)) // Won't trigger for 100 rows
4931            .set_max_row_group_bytes(Some(3500)) // Will trigger at ~30-35 rows
4932            .build();
4933
4934        let builder = write_batches(
4935            WriteBatchesShape {
4936                num_batches: 10,
4937                rows_per_batch: 10,
4938                row_size: 100,
4939            },
4940            props,
4941        );
4942
4943        let sizes = row_group_sizes(builder.metadata());
4944
4945        assert!(
4946            sizes.len() > 1,
4947            "Byte limit should trigger before row limit, got {sizes:?}",
4948        );
4949
4950        assert!(
4951            sizes.iter().all(|&s| s < 1000),
4952            "No row group should hit the row limit"
4953        );
4954
4955        let total_rows: i64 = sizes.iter().sum();
4956        assert_eq!(total_rows, 100, "Total rows should be preserved");
4957    }
4958}