parquet/schema/
types.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains structs and methods to build Parquet schema and schema descriptors.
19
20use std::vec::IntoIter;
21use std::{collections::HashMap, fmt, sync::Arc};
22
23use crate::file::metadata::HeapSize;
24use crate::file::metadata::thrift::SchemaElement;
25
26use crate::basic::{
27    ColumnOrder, ConvertedType, LogicalType, Repetition, SortOrder, TimeUnit, Type as PhysicalType,
28};
29use crate::errors::{ParquetError, Result};
30
31// ----------------------------------------------------------------------
32// Parquet Type definitions
33
34/// Type alias for `Arc<Type>`.
35pub type TypePtr = Arc<Type>;
36/// Type alias for `Arc<SchemaDescriptor>`.
37pub type SchemaDescPtr = Arc<SchemaDescriptor>;
38/// Type alias for `Arc<ColumnDescriptor>`.
39pub type ColumnDescPtr = Arc<ColumnDescriptor>;
40
41/// Representation of a Parquet type.
42///
43/// Used to describe primitive leaf fields and structs, including top-level schema.
44///
45/// Note that the top-level schema is represented using [`Type::GroupType`] whose
46/// repetition is `None`.
47#[derive(Clone, Debug, PartialEq)]
48pub enum Type {
49    /// Represents a primitive leaf field.
50    PrimitiveType {
51        /// Basic information about the type.
52        basic_info: BasicTypeInfo,
53        /// Physical type of this primitive type.
54        physical_type: PhysicalType,
55        /// Length of this type.
56        type_length: i32,
57        /// Scale of this type.
58        scale: i32,
59        /// Precision of this type.
60        precision: i32,
61    },
62    /// Represents a group of fields (similar to struct).
63    GroupType {
64        /// Basic information about the type.
65        basic_info: BasicTypeInfo,
66        /// Fields of this group type.
67        fields: Vec<TypePtr>,
68    },
69}
70
71impl HeapSize for Type {
72    fn heap_size(&self) -> usize {
73        match self {
74            Type::PrimitiveType { basic_info, .. } => basic_info.heap_size(),
75            Type::GroupType { basic_info, fields } => basic_info.heap_size() + fields.heap_size(),
76        }
77    }
78}
79
80impl Type {
81    /// Creates primitive type builder with provided field name and physical type.
82    pub fn primitive_type_builder(
83        name: &str,
84        physical_type: PhysicalType,
85    ) -> PrimitiveTypeBuilder<'_> {
86        PrimitiveTypeBuilder::new(name, physical_type)
87    }
88
89    /// Creates group type builder with provided column name.
90    pub fn group_type_builder(name: &str) -> GroupTypeBuilder<'_> {
91        GroupTypeBuilder::new(name)
92    }
93
94    /// Returns [`BasicTypeInfo`] information about the type.
95    pub fn get_basic_info(&self) -> &BasicTypeInfo {
96        match *self {
97            Type::PrimitiveType { ref basic_info, .. } => basic_info,
98            Type::GroupType { ref basic_info, .. } => basic_info,
99        }
100    }
101
102    /// Returns this type's field name.
103    pub fn name(&self) -> &str {
104        self.get_basic_info().name()
105    }
106
107    /// Gets the fields from this group type.
108    /// Note that this will panic if called on a non-group type.
109    // TODO: should we return `&[&Type]` here?
110    pub fn get_fields(&self) -> &[TypePtr] {
111        match *self {
112            Type::GroupType { ref fields, .. } => &fields[..],
113            _ => panic!("Cannot call get_fields() on a non-group type"),
114        }
115    }
116
117    /// Gets physical type of this primitive type.
118    /// Note that this will panic if called on a non-primitive type.
119    pub fn get_physical_type(&self) -> PhysicalType {
120        match *self {
121            Type::PrimitiveType {
122                basic_info: _,
123                physical_type,
124                ..
125            } => physical_type,
126            _ => panic!("Cannot call get_physical_type() on a non-primitive type"),
127        }
128    }
129
130    /// Gets precision of this primitive type.
131    /// Note that this will panic if called on a non-primitive type.
132    pub fn get_precision(&self) -> i32 {
133        match *self {
134            Type::PrimitiveType { precision, .. } => precision,
135            _ => panic!("Cannot call get_precision() on non-primitive type"),
136        }
137    }
138
139    /// Gets scale of this primitive type.
140    /// Note that this will panic if called on a non-primitive type.
141    pub fn get_scale(&self) -> i32 {
142        match *self {
143            Type::PrimitiveType { scale, .. } => scale,
144            _ => panic!("Cannot call get_scale() on non-primitive type"),
145        }
146    }
147
148    /// Checks if `sub_type` schema is part of current schema.
149    /// This method can be used to check if projected columns are part of the root schema.
150    pub fn check_contains(&self, sub_type: &Type) -> bool {
151        // Names match, and repetitions match or not set for both
152        let basic_match = self.get_basic_info().name() == sub_type.get_basic_info().name()
153            && (self.is_schema() && sub_type.is_schema()
154                || !self.is_schema()
155                    && !sub_type.is_schema()
156                    && self.get_basic_info().repetition()
157                        == sub_type.get_basic_info().repetition());
158
159        match *self {
160            Type::PrimitiveType { .. } if basic_match && sub_type.is_primitive() => {
161                self.get_physical_type() == sub_type.get_physical_type()
162            }
163            Type::GroupType { .. } if basic_match && sub_type.is_group() => {
164                // build hashmap of name -> TypePtr
165                let mut field_map = HashMap::new();
166                for field in self.get_fields() {
167                    field_map.insert(field.name(), field);
168                }
169
170                for field in sub_type.get_fields() {
171                    if !field_map
172                        .get(field.name())
173                        .map(|tpe| tpe.check_contains(field))
174                        .unwrap_or(false)
175                    {
176                        return false;
177                    }
178                }
179                true
180            }
181            _ => false,
182        }
183    }
184
185    /// Returns `true` if this type is a primitive type, `false` otherwise.
186    pub fn is_primitive(&self) -> bool {
187        matches!(*self, Type::PrimitiveType { .. })
188    }
189
190    /// Returns `true` if this type is a group type, `false` otherwise.
191    pub fn is_group(&self) -> bool {
192        matches!(*self, Type::GroupType { .. })
193    }
194
195    /// Returns `true` if this type is the top-level schema type (message type).
196    pub fn is_schema(&self) -> bool {
197        match *self {
198            Type::GroupType { ref basic_info, .. } => !basic_info.has_repetition(),
199            _ => false,
200        }
201    }
202
203    /// Returns `true` if this type is repeated or optional.
204    /// If this type doesn't have repetition defined, we treat it as required.
205    pub fn is_optional(&self) -> bool {
206        self.get_basic_info().has_repetition()
207            && self.get_basic_info().repetition() != Repetition::REQUIRED
208    }
209
210    /// Returns `true` if this type is annotated as a list.
211    pub(crate) fn is_list(&self) -> bool {
212        if self.is_group() {
213            let basic_info = self.get_basic_info();
214            if let Some(logical_type) = basic_info.logical_type() {
215                return logical_type == LogicalType::List;
216            }
217            return basic_info.converted_type() == ConvertedType::LIST;
218        }
219        false
220    }
221
222    /// Returns `true` if this type is a group with a single child field that is `repeated`.
223    pub(crate) fn has_single_repeated_child(&self) -> bool {
224        if self.is_group() {
225            let children = self.get_fields();
226            return children.len() == 1
227                && children[0].get_basic_info().has_repetition()
228                && children[0].get_basic_info().repetition() == Repetition::REPEATED;
229        }
230        false
231    }
232}
233
234/// A builder for primitive types. All attributes are optional
235/// except the name and physical type.
236/// Note that if not specified explicitly, `Repetition::OPTIONAL` is used.
237pub struct PrimitiveTypeBuilder<'a> {
238    name: &'a str,
239    repetition: Repetition,
240    physical_type: PhysicalType,
241    converted_type: ConvertedType,
242    logical_type: Option<LogicalType>,
243    length: i32,
244    precision: i32,
245    scale: i32,
246    id: Option<i32>,
247}
248
249impl<'a> PrimitiveTypeBuilder<'a> {
250    /// Creates new primitive type builder with provided field name and physical type.
251    pub fn new(name: &'a str, physical_type: PhysicalType) -> Self {
252        Self {
253            name,
254            repetition: Repetition::OPTIONAL,
255            physical_type,
256            converted_type: ConvertedType::NONE,
257            logical_type: None,
258            length: -1,
259            precision: -1,
260            scale: -1,
261            id: None,
262        }
263    }
264
265    /// Sets [`Repetition`] for this field and returns itself.
266    pub fn with_repetition(self, repetition: Repetition) -> Self {
267        Self { repetition, ..self }
268    }
269
270    /// Sets [`ConvertedType`] for this field and returns itself.
271    pub fn with_converted_type(self, converted_type: ConvertedType) -> Self {
272        Self {
273            converted_type,
274            ..self
275        }
276    }
277
278    /// Sets [`LogicalType`] for this field and returns itself.
279    /// If only the logical type is populated for a primitive type, the converted type
280    /// will be automatically populated, and can thus be omitted.
281    pub fn with_logical_type(self, logical_type: Option<LogicalType>) -> Self {
282        Self {
283            logical_type,
284            ..self
285        }
286    }
287
288    /// Sets type length and returns itself.
289    /// This is only applied to FIXED_LEN_BYTE_ARRAY and INT96 (INTERVAL) types, because
290    /// they maintain fixed size underlying byte array.
291    /// By default, value is `0`.
292    pub fn with_length(self, length: i32) -> Self {
293        Self { length, ..self }
294    }
295
296    /// Sets precision for Parquet DECIMAL physical type and returns itself.
297    /// By default, it equals to `0` and used only for decimal context.
298    pub fn with_precision(self, precision: i32) -> Self {
299        Self { precision, ..self }
300    }
301
302    /// Sets scale for Parquet DECIMAL physical type and returns itself.
303    /// By default, it equals to `0` and used only for decimal context.
304    pub fn with_scale(self, scale: i32) -> Self {
305        Self { scale, ..self }
306    }
307
308    /// Sets optional field id and returns itself.
309    pub fn with_id(self, id: Option<i32>) -> Self {
310        Self { id, ..self }
311    }
312
313    /// Creates a new `PrimitiveType` instance from the collected attributes.
314    /// Returns `Err` in case of any building conditions are not met.
315    pub fn build(self) -> Result<Type> {
316        let mut basic_info = BasicTypeInfo {
317            name: String::from(self.name),
318            repetition: Some(self.repetition),
319            converted_type: self.converted_type,
320            logical_type: self.logical_type.clone(),
321            id: self.id,
322        };
323
324        // Check length before logical type, since it is used for logical type validation.
325        if self.physical_type == PhysicalType::FIXED_LEN_BYTE_ARRAY && self.length < 0 {
326            return Err(general_err!(
327                "Invalid FIXED_LEN_BYTE_ARRAY length: {} for field '{}'",
328                self.length,
329                self.name
330            ));
331        }
332
333        if let Some(logical_type) = &self.logical_type {
334            // If a converted type is populated, check that it is consistent with
335            // its logical type
336            if self.converted_type != ConvertedType::NONE {
337                if ConvertedType::from(self.logical_type.clone()) != self.converted_type {
338                    return Err(general_err!(
339                        "Logical type {:?} is incompatible with converted type {} for field '{}'",
340                        logical_type,
341                        self.converted_type,
342                        self.name
343                    ));
344                }
345            } else {
346                // Populate the converted type for backwards compatibility
347                basic_info.converted_type = self.logical_type.clone().into();
348            }
349            // Check that logical type and physical type are compatible
350            match (logical_type, self.physical_type) {
351                (LogicalType::Map, _) | (LogicalType::List, _) => {
352                    return Err(general_err!(
353                        "{:?} cannot be applied to a primitive type for field '{}'",
354                        logical_type,
355                        self.name
356                    ));
357                }
358                (LogicalType::Enum, PhysicalType::BYTE_ARRAY) => {}
359                (LogicalType::Decimal { scale, precision }, _) => {
360                    // Check that scale and precision are consistent with legacy values
361                    if *scale != self.scale {
362                        return Err(general_err!(
363                            "DECIMAL logical type scale {} must match self.scale {} for field '{}'",
364                            scale,
365                            self.scale,
366                            self.name
367                        ));
368                    }
369                    if *precision != self.precision {
370                        return Err(general_err!(
371                            "DECIMAL logical type precision {} must match self.precision {} for field '{}'",
372                            precision,
373                            self.precision,
374                            self.name
375                        ));
376                    }
377                    self.check_decimal_precision_scale()?;
378                }
379                (LogicalType::Date, PhysicalType::INT32) => {}
380                (
381                    LogicalType::Time {
382                        unit: TimeUnit::MILLIS,
383                        ..
384                    },
385                    PhysicalType::INT32,
386                ) => {}
387                (LogicalType::Time { unit, .. }, PhysicalType::INT64) => {
388                    if *unit == TimeUnit::MILLIS {
389                        return Err(general_err!(
390                            "Cannot use millisecond unit on INT64 type for field '{}'",
391                            self.name
392                        ));
393                    }
394                }
395                (LogicalType::Timestamp { .. }, PhysicalType::INT64) => {}
396                (LogicalType::Integer { bit_width, .. }, PhysicalType::INT32)
397                    if *bit_width <= 32 => {}
398                (LogicalType::Integer { bit_width, .. }, PhysicalType::INT64)
399                    if *bit_width == 64 => {}
400                // Null type
401                (LogicalType::Unknown, PhysicalType::INT32) => {}
402                (LogicalType::String, PhysicalType::BYTE_ARRAY) => {}
403                (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {}
404                (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {}
405                (LogicalType::Geometry { .. }, PhysicalType::BYTE_ARRAY) => {}
406                (LogicalType::Geography { .. }, PhysicalType::BYTE_ARRAY) => {}
407                (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) if self.length == 16 => {}
408                (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {
409                    return Err(general_err!(
410                        "UUID cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(16) field",
411                        self.name
412                    ));
413                }
414                (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) if self.length == 2 => {}
415                (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {
416                    return Err(general_err!(
417                        "FLOAT16 cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(2) field",
418                        self.name
419                    ));
420                }
421                (a, b) => {
422                    return Err(general_err!(
423                        "Cannot annotate {:?} from {} for field '{}'",
424                        a,
425                        b,
426                        self.name
427                    ));
428                }
429            }
430        }
431
432        match self.converted_type {
433            ConvertedType::NONE => {}
434            ConvertedType::UTF8 | ConvertedType::BSON | ConvertedType::JSON => {
435                if self.physical_type != PhysicalType::BYTE_ARRAY {
436                    return Err(general_err!(
437                        "{} cannot annotate field '{}' because it is not a BYTE_ARRAY field",
438                        self.converted_type,
439                        self.name
440                    ));
441                }
442            }
443            ConvertedType::DECIMAL => {
444                self.check_decimal_precision_scale()?;
445            }
446            ConvertedType::DATE
447            | ConvertedType::TIME_MILLIS
448            | ConvertedType::UINT_8
449            | ConvertedType::UINT_16
450            | ConvertedType::UINT_32
451            | ConvertedType::INT_8
452            | ConvertedType::INT_16
453            | ConvertedType::INT_32 => {
454                if self.physical_type != PhysicalType::INT32 {
455                    return Err(general_err!(
456                        "{} cannot annotate field '{}' because it is not a INT32 field",
457                        self.converted_type,
458                        self.name
459                    ));
460                }
461            }
462            ConvertedType::TIME_MICROS
463            | ConvertedType::TIMESTAMP_MILLIS
464            | ConvertedType::TIMESTAMP_MICROS
465            | ConvertedType::UINT_64
466            | ConvertedType::INT_64 => {
467                if self.physical_type != PhysicalType::INT64 {
468                    return Err(general_err!(
469                        "{} cannot annotate field '{}' because it is not a INT64 field",
470                        self.converted_type,
471                        self.name
472                    ));
473                }
474            }
475            ConvertedType::INTERVAL => {
476                if self.physical_type != PhysicalType::FIXED_LEN_BYTE_ARRAY || self.length != 12 {
477                    return Err(general_err!(
478                        "INTERVAL cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(12) field",
479                        self.name
480                    ));
481                }
482            }
483            ConvertedType::ENUM => {
484                if self.physical_type != PhysicalType::BYTE_ARRAY {
485                    return Err(general_err!(
486                        "ENUM cannot annotate field '{}' because it is not a BYTE_ARRAY field",
487                        self.name
488                    ));
489                }
490            }
491            _ => {
492                return Err(general_err!(
493                    "{} cannot be applied to primitive field '{}'",
494                    self.converted_type,
495                    self.name
496                ));
497            }
498        }
499
500        Ok(Type::PrimitiveType {
501            basic_info,
502            physical_type: self.physical_type,
503            type_length: self.length,
504            scale: self.scale,
505            precision: self.precision,
506        })
507    }
508
509    #[inline]
510    fn check_decimal_precision_scale(&self) -> Result<()> {
511        match self.physical_type {
512            PhysicalType::INT32
513            | PhysicalType::INT64
514            | PhysicalType::BYTE_ARRAY
515            | PhysicalType::FIXED_LEN_BYTE_ARRAY => (),
516            _ => {
517                return Err(general_err!(
518                    "DECIMAL can only annotate INT32, INT64, BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY"
519                ));
520            }
521        }
522
523        // Precision is required and must be a non-zero positive integer.
524        if self.precision < 1 {
525            return Err(general_err!(
526                "Invalid DECIMAL precision: {}",
527                self.precision
528            ));
529        }
530
531        // Scale must be zero or a positive integer less than the precision.
532        if self.scale < 0 {
533            return Err(general_err!("Invalid DECIMAL scale: {}", self.scale));
534        }
535
536        if self.scale > self.precision {
537            return Err(general_err!(
538                "Invalid DECIMAL: scale ({}) cannot be greater than precision \
539             ({})",
540                self.scale,
541                self.precision
542            ));
543        }
544
545        // Check precision and scale based on physical type limitations.
546        match self.physical_type {
547            PhysicalType::INT32 => {
548                if self.precision > 9 {
549                    return Err(general_err!(
550                        "Cannot represent INT32 as DECIMAL with precision {}",
551                        self.precision
552                    ));
553                }
554            }
555            PhysicalType::INT64 => {
556                if self.precision > 18 {
557                    return Err(general_err!(
558                        "Cannot represent INT64 as DECIMAL with precision {}",
559                        self.precision
560                    ));
561                }
562            }
563            PhysicalType::FIXED_LEN_BYTE_ARRAY => {
564                let length = self
565                    .length
566                    .checked_mul(8)
567                    .ok_or(general_err!("Invalid length {} for Decimal", self.length))?;
568                let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32;
569
570                if self.precision > max_precision {
571                    return Err(general_err!(
572                        "Cannot represent FIXED_LEN_BYTE_ARRAY as DECIMAL with length {} and \
573                        precision {}. The max precision can only be {}",
574                        self.length,
575                        self.precision,
576                        max_precision
577                    ));
578                }
579            }
580            _ => (), // For BYTE_ARRAY precision is not limited
581        }
582
583        Ok(())
584    }
585}
586
587/// A builder for group types. All attributes are optional except the name.
588/// Note that if not specified explicitly, `None` is used as the repetition of the group,
589/// which means it is a root (message) type.
590pub struct GroupTypeBuilder<'a> {
591    name: &'a str,
592    repetition: Option<Repetition>,
593    converted_type: ConvertedType,
594    logical_type: Option<LogicalType>,
595    fields: Vec<TypePtr>,
596    id: Option<i32>,
597}
598
599impl<'a> GroupTypeBuilder<'a> {
600    /// Creates new group type builder with provided field name.
601    pub fn new(name: &'a str) -> Self {
602        Self {
603            name,
604            repetition: None,
605            converted_type: ConvertedType::NONE,
606            logical_type: None,
607            fields: Vec::new(),
608            id: None,
609        }
610    }
611
612    /// Sets [`Repetition`] for this field and returns itself.
613    pub fn with_repetition(mut self, repetition: Repetition) -> Self {
614        self.repetition = Some(repetition);
615        self
616    }
617
618    /// Sets [`ConvertedType`] for this field and returns itself.
619    pub fn with_converted_type(self, converted_type: ConvertedType) -> Self {
620        Self {
621            converted_type,
622            ..self
623        }
624    }
625
626    /// Sets [`LogicalType`] for this field and returns itself.
627    pub fn with_logical_type(self, logical_type: Option<LogicalType>) -> Self {
628        Self {
629            logical_type,
630            ..self
631        }
632    }
633
634    /// Sets a list of fields that should be child nodes of this field.
635    /// Returns updated self.
636    pub fn with_fields(self, fields: Vec<TypePtr>) -> Self {
637        Self { fields, ..self }
638    }
639
640    /// Sets optional field id and returns itself.
641    pub fn with_id(self, id: Option<i32>) -> Self {
642        Self { id, ..self }
643    }
644
645    /// Creates a new `GroupType` instance from the gathered attributes.
646    pub fn build(self) -> Result<Type> {
647        let mut basic_info = BasicTypeInfo {
648            name: String::from(self.name),
649            repetition: self.repetition,
650            converted_type: self.converted_type,
651            logical_type: self.logical_type.clone(),
652            id: self.id,
653        };
654        // Populate the converted type if only the logical type is populated
655        if self.logical_type.is_some() && self.converted_type == ConvertedType::NONE {
656            basic_info.converted_type = self.logical_type.into();
657        }
658        Ok(Type::GroupType {
659            basic_info,
660            fields: self.fields,
661        })
662    }
663}
664
665/// Basic type info. This contains information such as the name of the type,
666/// the repetition level, the logical type and the kind of the type (group, primitive).
667#[derive(Clone, Debug, PartialEq, Eq)]
668pub struct BasicTypeInfo {
669    name: String,
670    repetition: Option<Repetition>,
671    converted_type: ConvertedType,
672    logical_type: Option<LogicalType>,
673    id: Option<i32>,
674}
675
676impl HeapSize for BasicTypeInfo {
677    fn heap_size(&self) -> usize {
678        // no heap allocations in any other subfield
679        self.name.heap_size()
680    }
681}
682
683impl BasicTypeInfo {
684    /// Returns field name.
685    pub fn name(&self) -> &str {
686        &self.name
687    }
688
689    /// Returns `true` if type has repetition field set, `false` otherwise.
690    /// This is mostly applied to group type, because primitive type always has
691    /// repetition set.
692    pub fn has_repetition(&self) -> bool {
693        self.repetition.is_some()
694    }
695
696    /// Returns [`Repetition`] value for the type.
697    pub fn repetition(&self) -> Repetition {
698        assert!(self.repetition.is_some());
699        self.repetition.unwrap()
700    }
701
702    /// Returns [`ConvertedType`] value for the type.
703    pub fn converted_type(&self) -> ConvertedType {
704        self.converted_type
705    }
706
707    /// Returns [`LogicalType`] value for the type.
708    pub fn logical_type(&self) -> Option<LogicalType> {
709        // Unlike ConvertedType, LogicalType cannot implement Copy, thus we clone it
710        self.logical_type.clone()
711    }
712
713    /// Returns `true` if id is set, `false` otherwise.
714    pub fn has_id(&self) -> bool {
715        self.id.is_some()
716    }
717
718    /// Returns id value for the type.
719    pub fn id(&self) -> i32 {
720        assert!(self.id.is_some());
721        self.id.unwrap()
722    }
723}
724
725// ----------------------------------------------------------------------
726// Parquet descriptor definitions
727
728/// Represents the location of a column in a Parquet schema
729///
730/// # Example: refer to column named `'my_column'`
731/// ```
732/// # use parquet::schema::types::ColumnPath;
733/// let column_path = ColumnPath::from("my_column");
734/// ```
735///
736/// # Example: refer to column named `c` in a nested struct `{a: {b: {c: ...}}}`
737/// ```
738/// # use parquet::schema::types::ColumnPath;
739/// // form path 'a.b.c'
740/// let column_path = ColumnPath::from(vec![
741///   String::from("a"),
742///   String::from("b"),
743///   String::from("c")
744/// ]);
745/// ```
746#[derive(Clone, PartialEq, Debug, Eq, Hash)]
747pub struct ColumnPath {
748    parts: Vec<String>,
749}
750
751impl HeapSize for ColumnPath {
752    fn heap_size(&self) -> usize {
753        self.parts.heap_size()
754    }
755}
756
757impl ColumnPath {
758    /// Creates new column path from vector of field names.
759    pub fn new(parts: Vec<String>) -> Self {
760        ColumnPath { parts }
761    }
762
763    /// Returns string representation of this column path.
764    /// ```rust
765    /// use parquet::schema::types::ColumnPath;
766    ///
767    /// let path = ColumnPath::new(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
768    /// assert_eq!(&path.string(), "a.b.c");
769    /// ```
770    pub fn string(&self) -> String {
771        self.parts.join(".")
772    }
773
774    /// Appends more components to end of column path.
775    /// ```rust
776    /// use parquet::schema::types::ColumnPath;
777    ///
778    /// let mut path = ColumnPath::new(vec!["a".to_string(), "b".to_string(), "c"
779    /// .to_string()]);
780    /// assert_eq!(&path.string(), "a.b.c");
781    ///
782    /// path.append(vec!["d".to_string(), "e".to_string()]);
783    /// assert_eq!(&path.string(), "a.b.c.d.e");
784    /// ```
785    pub fn append(&mut self, mut tail: Vec<String>) {
786        self.parts.append(&mut tail);
787    }
788
789    /// Returns a slice of path components.
790    pub fn parts(&self) -> &[String] {
791        &self.parts
792    }
793}
794
795impl fmt::Display for ColumnPath {
796    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
797        write!(f, "{:?}", self.string())
798    }
799}
800
801impl From<Vec<String>> for ColumnPath {
802    fn from(parts: Vec<String>) -> Self {
803        ColumnPath { parts }
804    }
805}
806
807impl From<&str> for ColumnPath {
808    fn from(single_path: &str) -> Self {
809        let s = String::from(single_path);
810        ColumnPath::from(s)
811    }
812}
813
814impl From<String> for ColumnPath {
815    fn from(single_path: String) -> Self {
816        let v = vec![single_path];
817        ColumnPath { parts: v }
818    }
819}
820
821impl AsRef<[String]> for ColumnPath {
822    fn as_ref(&self) -> &[String] {
823        &self.parts
824    }
825}
826
827/// Physical type for leaf-level primitive columns.
828///
829/// Also includes the maximum definition and repetition levels required to
830/// re-assemble nested data.
831#[derive(Debug, PartialEq)]
832pub struct ColumnDescriptor {
833    /// The "leaf" primitive type of this column
834    primitive_type: TypePtr,
835
836    /// The maximum definition level for this column
837    max_def_level: i16,
838
839    /// The maximum repetition level for this column
840    max_rep_level: i16,
841
842    /// The path of this column. For instance, "a.b.c.d".
843    path: ColumnPath,
844}
845
846impl HeapSize for ColumnDescriptor {
847    fn heap_size(&self) -> usize {
848        // Don't include the heap size of primitive_type, this is already
849        // accounted for via SchemaDescriptor::schema
850        self.path.heap_size()
851    }
852}
853
854impl ColumnDescriptor {
855    /// Creates new descriptor for leaf-level column.
856    pub fn new(
857        primitive_type: TypePtr,
858        max_def_level: i16,
859        max_rep_level: i16,
860        path: ColumnPath,
861    ) -> Self {
862        Self {
863            primitive_type,
864            max_def_level,
865            max_rep_level,
866            path,
867        }
868    }
869
870    /// Returns maximum definition level for this column.
871    #[inline]
872    pub fn max_def_level(&self) -> i16 {
873        self.max_def_level
874    }
875
876    /// Returns maximum repetition level for this column.
877    #[inline]
878    pub fn max_rep_level(&self) -> i16 {
879        self.max_rep_level
880    }
881
882    /// Returns [`ColumnPath`] for this column.
883    pub fn path(&self) -> &ColumnPath {
884        &self.path
885    }
886
887    /// Returns self type [`Type`] for this leaf column.
888    pub fn self_type(&self) -> &Type {
889        self.primitive_type.as_ref()
890    }
891
892    /// Returns self type [`TypePtr`]  for this leaf
893    /// column.
894    pub fn self_type_ptr(&self) -> TypePtr {
895        self.primitive_type.clone()
896    }
897
898    /// Returns column name.
899    pub fn name(&self) -> &str {
900        self.primitive_type.name()
901    }
902
903    /// Returns [`ConvertedType`] for this column.
904    pub fn converted_type(&self) -> ConvertedType {
905        self.primitive_type.get_basic_info().converted_type()
906    }
907
908    /// Returns [`LogicalType`] for this column.
909    pub fn logical_type(&self) -> Option<LogicalType> {
910        self.primitive_type.get_basic_info().logical_type()
911    }
912
913    /// Returns physical type for this column.
914    /// Note that it will panic if called on a non-primitive type.
915    pub fn physical_type(&self) -> PhysicalType {
916        match self.primitive_type.as_ref() {
917            Type::PrimitiveType { physical_type, .. } => *physical_type,
918            _ => panic!("Expected primitive type!"),
919        }
920    }
921
922    /// Returns type length for this column.
923    /// Note that it will panic if called on a non-primitive type.
924    pub fn type_length(&self) -> i32 {
925        match self.primitive_type.as_ref() {
926            Type::PrimitiveType { type_length, .. } => *type_length,
927            _ => panic!("Expected primitive type!"),
928        }
929    }
930
931    /// Returns type precision for this column.
932    /// Note that it will panic if called on a non-primitive type.
933    pub fn type_precision(&self) -> i32 {
934        match self.primitive_type.as_ref() {
935            Type::PrimitiveType { precision, .. } => *precision,
936            _ => panic!("Expected primitive type!"),
937        }
938    }
939
940    /// Returns type scale for this column.
941    /// Note that it will panic if called on a non-primitive type.
942    pub fn type_scale(&self) -> i32 {
943        match self.primitive_type.as_ref() {
944            Type::PrimitiveType { scale, .. } => *scale,
945            _ => panic!("Expected primitive type!"),
946        }
947    }
948
949    /// Returns the sort order for this column
950    pub fn sort_order(&self) -> SortOrder {
951        ColumnOrder::get_sort_order(
952            self.logical_type(),
953            self.converted_type(),
954            self.physical_type(),
955        )
956    }
957}
958
959/// Schema of a Parquet file.
960///
961/// Encapsulates the file's schema ([`Type`]) and [`ColumnDescriptor`]s for
962/// each primitive (leaf) column.
963///
964/// # Example
965/// ```
966/// # use std::sync::Arc;
967/// use parquet::schema::types::{SchemaDescriptor, Type};
968/// use parquet::basic; // note there are two `Type`s that are different
969/// // Schema for a table with two columns: "a" (int64) and "b" (int32, stored as a date)
970/// let descriptor = SchemaDescriptor::new(
971///   Arc::new(
972///     Type::group_type_builder("my_schema")
973///       .with_fields(vec![
974///         Arc::new(
975///          Type::primitive_type_builder("a", basic::Type::INT64)
976///           .build().unwrap()
977///         ),
978///         Arc::new(
979///          Type::primitive_type_builder("b", basic::Type::INT32)
980///           .with_converted_type(basic::ConvertedType::DATE)
981///           .with_logical_type(Some(basic::LogicalType::Date))
982///           .build().unwrap()
983///         ),
984///      ])
985///      .build().unwrap()
986///   )
987/// );
988/// ```
989#[derive(PartialEq, Clone)]
990pub struct SchemaDescriptor {
991    /// The top-level logical schema (the "message" type).
992    ///
993    /// This must be a [`Type::GroupType`] where each field is a root
994    /// column type in the schema.
995    schema: TypePtr,
996
997    /// The descriptors for the physical type of each leaf column in this schema
998    ///
999    /// Constructed from `schema` in DFS order.
1000    leaves: Vec<ColumnDescPtr>,
1001
1002    /// Mapping from a leaf column's index to the root column index that it
1003    /// comes from.
1004    ///
1005    /// For instance: the leaf `a.b.c.d` would have a link back to `a`:
1006    /// ```text
1007    /// -- a  <-----+
1008    /// -- -- b     |
1009    /// -- -- -- c  |
1010    /// -- -- -- -- d
1011    /// ```
1012    leaf_to_base: Vec<usize>,
1013}
1014
1015impl fmt::Debug for SchemaDescriptor {
1016    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1017        // Skip leaves and leaf_to_base as they only a cache information already found in `schema`
1018        f.debug_struct("SchemaDescriptor")
1019            .field("schema", &self.schema)
1020            .finish()
1021    }
1022}
1023
1024// Need to implement HeapSize in this module as the fields are private
1025impl HeapSize for SchemaDescriptor {
1026    fn heap_size(&self) -> usize {
1027        self.schema.heap_size() + self.leaves.heap_size() + self.leaf_to_base.heap_size()
1028    }
1029}
1030
1031impl SchemaDescriptor {
1032    /// Creates new schema descriptor from Parquet schema.
1033    pub fn new(tp: TypePtr) -> Self {
1034        const INIT_SCHEMA_DEPTH: usize = 16;
1035        assert!(tp.is_group(), "SchemaDescriptor should take a GroupType");
1036        // unwrap should be safe since we just asserted tp is a group
1037        let n_leaves = num_leaves(&tp).unwrap();
1038        let mut leaves = Vec::with_capacity(n_leaves);
1039        let mut leaf_to_base = Vec::with_capacity(n_leaves);
1040        let mut path = Vec::with_capacity(INIT_SCHEMA_DEPTH);
1041        for (root_idx, f) in tp.get_fields().iter().enumerate() {
1042            path.clear();
1043            build_tree(f, root_idx, 0, 0, &mut leaves, &mut leaf_to_base, &mut path);
1044        }
1045
1046        Self {
1047            schema: tp,
1048            leaves,
1049            leaf_to_base,
1050        }
1051    }
1052
1053    /// Returns [`ColumnDescriptor`] for a field position.
1054    pub fn column(&self, i: usize) -> ColumnDescPtr {
1055        assert!(
1056            i < self.leaves.len(),
1057            "Index out of bound: {} not in [0, {})",
1058            i,
1059            self.leaves.len()
1060        );
1061        self.leaves[i].clone()
1062    }
1063
1064    /// Returns slice of [`ColumnDescriptor`].
1065    pub fn columns(&self) -> &[ColumnDescPtr] {
1066        &self.leaves
1067    }
1068
1069    /// Returns number of leaf-level columns.
1070    pub fn num_columns(&self) -> usize {
1071        self.leaves.len()
1072    }
1073
1074    /// Returns column root [`Type`] for a leaf position.
1075    pub fn get_column_root(&self, i: usize) -> &Type {
1076        let result = self.column_root_of(i);
1077        result.as_ref()
1078    }
1079
1080    /// Returns column root [`Type`] pointer for a leaf position.
1081    pub fn get_column_root_ptr(&self, i: usize) -> TypePtr {
1082        let result = self.column_root_of(i);
1083        result.clone()
1084    }
1085
1086    /// Returns the index of the root column for a field position
1087    pub fn get_column_root_idx(&self, leaf: usize) -> usize {
1088        assert!(
1089            leaf < self.leaves.len(),
1090            "Index out of bound: {} not in [0, {})",
1091            leaf,
1092            self.leaves.len()
1093        );
1094
1095        *self
1096            .leaf_to_base
1097            .get(leaf)
1098            .unwrap_or_else(|| panic!("Expected a value for index {leaf} but found None"))
1099    }
1100
1101    fn column_root_of(&self, i: usize) -> &TypePtr {
1102        &self.schema.get_fields()[self.get_column_root_idx(i)]
1103    }
1104
1105    /// Returns schema as [`Type`].
1106    pub fn root_schema(&self) -> &Type {
1107        self.schema.as_ref()
1108    }
1109
1110    /// Returns schema as [`TypePtr`] for cheap cloning.
1111    pub fn root_schema_ptr(&self) -> TypePtr {
1112        self.schema.clone()
1113    }
1114
1115    /// Returns schema name.
1116    pub fn name(&self) -> &str {
1117        self.schema.name()
1118    }
1119}
1120
1121// walk tree and count nodes
1122pub(crate) fn num_nodes(tp: &TypePtr) -> Result<usize> {
1123    if !tp.is_group() {
1124        return Err(general_err!("Root schema must be Group type"));
1125    }
1126    let mut n_nodes = 1usize; // count root
1127    for f in tp.get_fields().iter() {
1128        count_nodes(f, &mut n_nodes);
1129    }
1130    Ok(n_nodes)
1131}
1132
1133pub(crate) fn count_nodes(tp: &TypePtr, n_nodes: &mut usize) {
1134    *n_nodes += 1;
1135    if let Type::GroupType { fields, .. } = tp.as_ref() {
1136        for f in fields {
1137            count_nodes(f, n_nodes);
1138        }
1139    }
1140}
1141
1142// do a quick walk of the tree to get proper sizing for SchemaDescriptor arrays
1143fn num_leaves(tp: &TypePtr) -> Result<usize> {
1144    if !tp.is_group() {
1145        return Err(general_err!("Root schema must be Group type"));
1146    }
1147    let mut n_leaves = 0usize;
1148    for f in tp.get_fields().iter() {
1149        count_leaves(f, &mut n_leaves);
1150    }
1151    Ok(n_leaves)
1152}
1153
1154fn count_leaves(tp: &TypePtr, n_leaves: &mut usize) {
1155    match tp.as_ref() {
1156        Type::PrimitiveType { .. } => *n_leaves += 1,
1157        Type::GroupType { fields, .. } => {
1158            for f in fields {
1159                count_leaves(f, n_leaves);
1160            }
1161        }
1162    }
1163}
1164
1165fn build_tree<'a>(
1166    tp: &'a TypePtr,
1167    root_idx: usize,
1168    mut max_rep_level: i16,
1169    mut max_def_level: i16,
1170    leaves: &mut Vec<ColumnDescPtr>,
1171    leaf_to_base: &mut Vec<usize>,
1172    path_so_far: &mut Vec<&'a str>,
1173) {
1174    assert!(tp.get_basic_info().has_repetition());
1175
1176    path_so_far.push(tp.name());
1177    match tp.get_basic_info().repetition() {
1178        Repetition::OPTIONAL => {
1179            max_def_level += 1;
1180        }
1181        Repetition::REPEATED => {
1182            max_def_level += 1;
1183            max_rep_level += 1;
1184        }
1185        _ => {}
1186    }
1187
1188    match tp.as_ref() {
1189        Type::PrimitiveType { .. } => {
1190            let mut path: Vec<String> = vec![];
1191            path.extend(path_so_far.iter().copied().map(String::from));
1192            leaves.push(Arc::new(ColumnDescriptor::new(
1193                tp.clone(),
1194                max_def_level,
1195                max_rep_level,
1196                ColumnPath::new(path),
1197            )));
1198            leaf_to_base.push(root_idx);
1199        }
1200        Type::GroupType { fields, .. } => {
1201            for f in fields {
1202                build_tree(
1203                    f,
1204                    root_idx,
1205                    max_rep_level,
1206                    max_def_level,
1207                    leaves,
1208                    leaf_to_base,
1209                    path_so_far,
1210                );
1211                path_so_far.pop();
1212            }
1213        }
1214    }
1215}
1216
1217/// Checks if the logical type is valid.
1218fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
1219    if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
1220        if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 {
1221            return Err(general_err!(
1222                "Bit width must be 8, 16, 32, or 64 for Integer logical type"
1223            ));
1224        }
1225    }
1226    Ok(())
1227}
1228
1229// convert thrift decoded array of `SchemaElement` into this crate's representation of
1230// parquet types. this function consumes `elements`.
1231pub(crate) fn parquet_schema_from_array<'a>(elements: Vec<SchemaElement<'a>>) -> Result<TypePtr> {
1232    let mut index = 0;
1233    let num_elements = elements.len();
1234    let mut schema_nodes = Vec::with_capacity(1); // there should only be one element when done
1235
1236    // turn into iterator so we can take ownership of elements of the vector
1237    let mut elements = elements.into_iter();
1238
1239    while index < num_elements {
1240        let t = schema_from_array_helper(&mut elements, num_elements, index)?;
1241        index = t.0;
1242        schema_nodes.push(t.1);
1243    }
1244    if schema_nodes.len() != 1 {
1245        return Err(general_err!(
1246            "Expected exactly one root node, but found {}",
1247            schema_nodes.len()
1248        ));
1249    }
1250
1251    if !schema_nodes[0].is_group() {
1252        return Err(general_err!("Expected root node to be a group type"));
1253    }
1254
1255    Ok(schema_nodes.remove(0))
1256}
1257
1258// recursive helper function for schema conversion
1259fn schema_from_array_helper<'a>(
1260    elements: &mut IntoIter<SchemaElement<'a>>,
1261    num_elements: usize,
1262    index: usize,
1263) -> Result<(usize, TypePtr)> {
1264    // Whether or not the current node is root (message type).
1265    // There is only one message type node in the schema tree.
1266    let is_root_node = index == 0;
1267
1268    if index >= num_elements {
1269        return Err(general_err!(
1270            "Index out of bound, index = {}, len = {}",
1271            index,
1272            num_elements
1273        ));
1274    }
1275    let element = elements.next().expect("schema vector should not be empty");
1276
1277    // Check for empty schema
1278    if let (true, None | Some(0)) = (is_root_node, element.num_children) {
1279        let builder = Type::group_type_builder(element.name);
1280        return Ok((index + 1, Arc::new(builder.build().unwrap())));
1281    }
1282
1283    let converted_type = element.converted_type.unwrap_or(ConvertedType::NONE);
1284
1285    // LogicalType is prefered to ConvertedType, but both may be present.
1286    let logical_type = element.logical_type;
1287
1288    check_logical_type(&logical_type)?;
1289
1290    let field_id = element.field_id;
1291    match element.num_children {
1292        // From parquet-format:
1293        //   The children count is used to construct the nested relationship.
1294        //   This field is not set when the element is a primitive type
1295        // Sometimes parquet-cpp sets num_children field to 0 for primitive types, so we
1296        // have to handle this case too.
1297        None | Some(0) => {
1298            // primitive type
1299            if element.repetition_type.is_none() {
1300                return Err(general_err!(
1301                    "Repetition level must be defined for a primitive type"
1302                ));
1303            }
1304            let repetition = element.repetition_type.unwrap();
1305            if let Some(physical_type) = element.r#type {
1306                let length = element.type_length.unwrap_or(-1);
1307                let scale = element.scale.unwrap_or(-1);
1308                let precision = element.precision.unwrap_or(-1);
1309                let name = element.name;
1310                let builder = Type::primitive_type_builder(name, physical_type)
1311                    .with_repetition(repetition)
1312                    .with_converted_type(converted_type)
1313                    .with_logical_type(logical_type)
1314                    .with_length(length)
1315                    .with_precision(precision)
1316                    .with_scale(scale)
1317                    .with_id(field_id);
1318                Ok((index + 1, Arc::new(builder.build()?)))
1319            } else {
1320                let mut builder = Type::group_type_builder(element.name)
1321                    .with_converted_type(converted_type)
1322                    .with_logical_type(logical_type)
1323                    .with_id(field_id);
1324                if !is_root_node {
1325                    // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or
1326                    // REPEATED for root node.
1327                    //
1328                    // We only set repetition for group types that are not top-level message
1329                    // type. According to parquet-format:
1330                    //   Root of the schema does not have a repetition_type.
1331                    //   All other types must have one.
1332                    builder = builder.with_repetition(repetition);
1333                }
1334                Ok((index + 1, Arc::new(builder.build().unwrap())))
1335            }
1336        }
1337        Some(n) => {
1338            let repetition = element.repetition_type;
1339
1340            let mut fields = Vec::with_capacity(n as usize);
1341            let mut next_index = index + 1;
1342            for _ in 0..n {
1343                let child_result = schema_from_array_helper(elements, num_elements, next_index)?;
1344                next_index = child_result.0;
1345                fields.push(child_result.1);
1346            }
1347
1348            let mut builder = Type::group_type_builder(element.name)
1349                .with_converted_type(converted_type)
1350                .with_logical_type(logical_type)
1351                .with_fields(fields)
1352                .with_id(field_id);
1353
1354            // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or
1355            // REPEATED for root node.
1356            //
1357            // We only set repetition for group types that are not top-level message
1358            // type. According to parquet-format:
1359            //   Root of the schema does not have a repetition_type.
1360            //   All other types must have one.
1361            if !is_root_node {
1362                let Some(rep) = repetition else {
1363                    return Err(general_err!(
1364                        "Repetition level must be defined for non-root types"
1365                    ));
1366                };
1367                builder = builder.with_repetition(rep);
1368            }
1369            Ok((next_index, Arc::new(builder.build()?)))
1370        }
1371    }
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376    use super::*;
1377
1378    use crate::{
1379        file::metadata::thrift::tests::{buf_to_schema_list, roundtrip_schema, schema_to_buf},
1380        schema::parser::parse_message_type,
1381    };
1382
1383    // TODO: add tests for v2 types
1384
1385    #[test]
1386    fn test_primitive_type() {
1387        let mut result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1388            .with_logical_type(Some(LogicalType::Integer {
1389                bit_width: 32,
1390                is_signed: true,
1391            }))
1392            .with_id(Some(0))
1393            .build();
1394        assert!(result.is_ok());
1395
1396        if let Ok(tp) = result {
1397            assert!(tp.is_primitive());
1398            assert!(!tp.is_group());
1399            let basic_info = tp.get_basic_info();
1400            assert_eq!(basic_info.repetition(), Repetition::OPTIONAL);
1401            assert_eq!(
1402                basic_info.logical_type(),
1403                Some(LogicalType::Integer {
1404                    bit_width: 32,
1405                    is_signed: true
1406                })
1407            );
1408            assert_eq!(basic_info.converted_type(), ConvertedType::INT_32);
1409            assert_eq!(basic_info.id(), 0);
1410            match tp {
1411                Type::PrimitiveType { physical_type, .. } => {
1412                    assert_eq!(physical_type, PhysicalType::INT32);
1413                }
1414                _ => panic!(),
1415            }
1416        }
1417
1418        // Test illegal inputs with logical type
1419        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1420            .with_repetition(Repetition::REPEATED)
1421            .with_logical_type(Some(LogicalType::Integer {
1422                is_signed: true,
1423                bit_width: 8,
1424            }))
1425            .build();
1426        assert!(result.is_err());
1427        if let Err(e) = result {
1428            assert_eq!(
1429                format!("{e}"),
1430                "Parquet error: Cannot annotate Integer { bit_width: 8, is_signed: true } from INT64 for field 'foo'"
1431            );
1432        }
1433
1434        // Test illegal inputs with converted type
1435        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1436            .with_repetition(Repetition::REPEATED)
1437            .with_converted_type(ConvertedType::BSON)
1438            .build();
1439        assert!(result.is_err());
1440        if let Err(e) = result {
1441            assert_eq!(
1442                format!("{e}"),
1443                "Parquet error: BSON cannot annotate field 'foo' because it is not a BYTE_ARRAY field"
1444            );
1445        }
1446
1447        result = Type::primitive_type_builder("foo", PhysicalType::INT96)
1448            .with_repetition(Repetition::REQUIRED)
1449            .with_converted_type(ConvertedType::DECIMAL)
1450            .with_precision(-1)
1451            .with_scale(-1)
1452            .build();
1453        assert!(result.is_err());
1454        if let Err(e) = result {
1455            assert_eq!(
1456                format!("{e}"),
1457                "Parquet error: DECIMAL can only annotate INT32, INT64, BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY"
1458            );
1459        }
1460
1461        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1462            .with_repetition(Repetition::REQUIRED)
1463            .with_logical_type(Some(LogicalType::Decimal {
1464                scale: 32,
1465                precision: 12,
1466            }))
1467            .with_precision(-1)
1468            .with_scale(-1)
1469            .build();
1470        assert!(result.is_err());
1471        if let Err(e) = result {
1472            assert_eq!(
1473                format!("{e}"),
1474                "Parquet error: DECIMAL logical type scale 32 must match self.scale -1 for field 'foo'"
1475            );
1476        }
1477
1478        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1479            .with_repetition(Repetition::REQUIRED)
1480            .with_converted_type(ConvertedType::DECIMAL)
1481            .with_precision(-1)
1482            .with_scale(-1)
1483            .build();
1484        assert!(result.is_err());
1485        if let Err(e) = result {
1486            assert_eq!(
1487                format!("{e}"),
1488                "Parquet error: Invalid DECIMAL precision: -1"
1489            );
1490        }
1491
1492        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1493            .with_repetition(Repetition::REQUIRED)
1494            .with_converted_type(ConvertedType::DECIMAL)
1495            .with_precision(0)
1496            .with_scale(-1)
1497            .build();
1498        assert!(result.is_err());
1499        if let Err(e) = result {
1500            assert_eq!(
1501                format!("{e}"),
1502                "Parquet error: Invalid DECIMAL precision: 0"
1503            );
1504        }
1505
1506        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1507            .with_repetition(Repetition::REQUIRED)
1508            .with_converted_type(ConvertedType::DECIMAL)
1509            .with_precision(1)
1510            .with_scale(-1)
1511            .build();
1512        assert!(result.is_err());
1513        if let Err(e) = result {
1514            assert_eq!(format!("{e}"), "Parquet error: Invalid DECIMAL scale: -1");
1515        }
1516
1517        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1518            .with_repetition(Repetition::REQUIRED)
1519            .with_converted_type(ConvertedType::DECIMAL)
1520            .with_precision(1)
1521            .with_scale(2)
1522            .build();
1523        assert!(result.is_err());
1524        if let Err(e) = result {
1525            assert_eq!(
1526                format!("{e}"),
1527                "Parquet error: Invalid DECIMAL: scale (2) cannot be greater than precision (1)"
1528            );
1529        }
1530
1531        // It is OK if precision == scale
1532        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1533            .with_repetition(Repetition::REQUIRED)
1534            .with_converted_type(ConvertedType::DECIMAL)
1535            .with_precision(1)
1536            .with_scale(1)
1537            .build();
1538        assert!(result.is_ok());
1539
1540        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1541            .with_repetition(Repetition::REQUIRED)
1542            .with_converted_type(ConvertedType::DECIMAL)
1543            .with_precision(18)
1544            .with_scale(2)
1545            .build();
1546        assert!(result.is_err());
1547        if let Err(e) = result {
1548            assert_eq!(
1549                format!("{e}"),
1550                "Parquet error: Cannot represent INT32 as DECIMAL with precision 18"
1551            );
1552        }
1553
1554        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1555            .with_repetition(Repetition::REQUIRED)
1556            .with_converted_type(ConvertedType::DECIMAL)
1557            .with_precision(32)
1558            .with_scale(2)
1559            .build();
1560        assert!(result.is_err());
1561        if let Err(e) = result {
1562            assert_eq!(
1563                format!("{e}"),
1564                "Parquet error: Cannot represent INT64 as DECIMAL with precision 32"
1565            );
1566        }
1567
1568        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1569            .with_repetition(Repetition::REQUIRED)
1570            .with_converted_type(ConvertedType::DECIMAL)
1571            .with_length(5)
1572            .with_precision(12)
1573            .with_scale(2)
1574            .build();
1575        assert!(result.is_err());
1576        if let Err(e) = result {
1577            assert_eq!(
1578                format!("{e}"),
1579                "Parquet error: Cannot represent FIXED_LEN_BYTE_ARRAY as DECIMAL with length 5 and precision 12. The max precision can only be 11"
1580            );
1581        }
1582
1583        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1584            .with_repetition(Repetition::REQUIRED)
1585            .with_converted_type(ConvertedType::UINT_8)
1586            .build();
1587        assert!(result.is_err());
1588        if let Err(e) = result {
1589            assert_eq!(
1590                format!("{e}"),
1591                "Parquet error: UINT_8 cannot annotate field 'foo' because it is not a INT32 field"
1592            );
1593        }
1594
1595        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1596            .with_repetition(Repetition::REQUIRED)
1597            .with_converted_type(ConvertedType::TIME_MICROS)
1598            .build();
1599        assert!(result.is_err());
1600        if let Err(e) = result {
1601            assert_eq!(
1602                format!("{e}"),
1603                "Parquet error: TIME_MICROS cannot annotate field 'foo' because it is not a INT64 field"
1604            );
1605        }
1606
1607        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1608            .with_repetition(Repetition::REQUIRED)
1609            .with_converted_type(ConvertedType::INTERVAL)
1610            .build();
1611        assert!(result.is_err());
1612        if let Err(e) = result {
1613            assert_eq!(
1614                format!("{e}"),
1615                "Parquet error: INTERVAL cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(12) field"
1616            );
1617        }
1618
1619        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1620            .with_repetition(Repetition::REQUIRED)
1621            .with_converted_type(ConvertedType::INTERVAL)
1622            .with_length(1)
1623            .build();
1624        assert!(result.is_err());
1625        if let Err(e) = result {
1626            assert_eq!(
1627                format!("{e}"),
1628                "Parquet error: INTERVAL cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(12) field"
1629            );
1630        }
1631
1632        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1633            .with_repetition(Repetition::REQUIRED)
1634            .with_converted_type(ConvertedType::ENUM)
1635            .build();
1636        assert!(result.is_err());
1637        if let Err(e) = result {
1638            assert_eq!(
1639                format!("{e}"),
1640                "Parquet error: ENUM cannot annotate field 'foo' because it is not a BYTE_ARRAY field"
1641            );
1642        }
1643
1644        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1645            .with_repetition(Repetition::REQUIRED)
1646            .with_converted_type(ConvertedType::MAP)
1647            .build();
1648        assert!(result.is_err());
1649        if let Err(e) = result {
1650            assert_eq!(
1651                format!("{e}"),
1652                "Parquet error: MAP cannot be applied to primitive field 'foo'"
1653            );
1654        }
1655
1656        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1657            .with_repetition(Repetition::REQUIRED)
1658            .with_converted_type(ConvertedType::DECIMAL)
1659            .with_length(-1)
1660            .build();
1661        assert!(result.is_err());
1662        if let Err(e) = result {
1663            assert_eq!(
1664                format!("{e}"),
1665                "Parquet error: Invalid FIXED_LEN_BYTE_ARRAY length: -1 for field 'foo'"
1666            );
1667        }
1668
1669        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1670            .with_repetition(Repetition::REQUIRED)
1671            .with_logical_type(Some(LogicalType::Float16))
1672            .with_length(2)
1673            .build();
1674        assert!(result.is_ok());
1675
1676        // Can't be other than FIXED_LEN_BYTE_ARRAY for physical type
1677        result = Type::primitive_type_builder("foo", PhysicalType::FLOAT)
1678            .with_repetition(Repetition::REQUIRED)
1679            .with_logical_type(Some(LogicalType::Float16))
1680            .with_length(2)
1681            .build();
1682        assert!(result.is_err());
1683        if let Err(e) = result {
1684            assert_eq!(
1685                format!("{e}"),
1686                "Parquet error: Cannot annotate Float16 from FLOAT for field 'foo'"
1687            );
1688        }
1689
1690        // Must have length 2
1691        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1692            .with_repetition(Repetition::REQUIRED)
1693            .with_logical_type(Some(LogicalType::Float16))
1694            .with_length(4)
1695            .build();
1696        assert!(result.is_err());
1697        if let Err(e) = result {
1698            assert_eq!(
1699                format!("{e}"),
1700                "Parquet error: FLOAT16 cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(2) field"
1701            );
1702        }
1703
1704        // Must have length 16
1705        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1706            .with_repetition(Repetition::REQUIRED)
1707            .with_logical_type(Some(LogicalType::Uuid))
1708            .with_length(15)
1709            .build();
1710        assert!(result.is_err());
1711        if let Err(e) = result {
1712            assert_eq!(
1713                format!("{e}"),
1714                "Parquet error: UUID cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(16) field"
1715            );
1716        }
1717    }
1718
1719    #[test]
1720    fn test_group_type() {
1721        let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32)
1722            .with_converted_type(ConvertedType::INT_32)
1723            .with_id(Some(0))
1724            .build();
1725        assert!(f1.is_ok());
1726        let f2 = Type::primitive_type_builder("f2", PhysicalType::BYTE_ARRAY)
1727            .with_converted_type(ConvertedType::UTF8)
1728            .with_id(Some(1))
1729            .build();
1730        assert!(f2.is_ok());
1731
1732        let fields = vec![Arc::new(f1.unwrap()), Arc::new(f2.unwrap())];
1733
1734        let result = Type::group_type_builder("foo")
1735            .with_repetition(Repetition::REPEATED)
1736            .with_logical_type(Some(LogicalType::List))
1737            .with_fields(fields)
1738            .with_id(Some(1))
1739            .build();
1740        assert!(result.is_ok());
1741
1742        let tp = result.unwrap();
1743        let basic_info = tp.get_basic_info();
1744        assert!(tp.is_group());
1745        assert!(!tp.is_primitive());
1746        assert_eq!(basic_info.repetition(), Repetition::REPEATED);
1747        assert_eq!(basic_info.logical_type(), Some(LogicalType::List));
1748        assert_eq!(basic_info.converted_type(), ConvertedType::LIST);
1749        assert_eq!(basic_info.id(), 1);
1750        assert_eq!(tp.get_fields().len(), 2);
1751        assert_eq!(tp.get_fields()[0].name(), "f1");
1752        assert_eq!(tp.get_fields()[1].name(), "f2");
1753    }
1754
1755    #[test]
1756    fn test_column_descriptor() {
1757        let result = test_column_descriptor_helper();
1758        assert!(
1759            result.is_ok(),
1760            "Expected result to be OK but got err:\n {}",
1761            result.unwrap_err()
1762        );
1763    }
1764
1765    fn test_column_descriptor_helper() -> Result<()> {
1766        let tp = Type::primitive_type_builder("name", PhysicalType::BYTE_ARRAY)
1767            .with_converted_type(ConvertedType::UTF8)
1768            .build()?;
1769
1770        let descr = ColumnDescriptor::new(Arc::new(tp), 4, 1, ColumnPath::from("name"));
1771
1772        assert_eq!(descr.path(), &ColumnPath::from("name"));
1773        assert_eq!(descr.converted_type(), ConvertedType::UTF8);
1774        assert_eq!(descr.physical_type(), PhysicalType::BYTE_ARRAY);
1775        assert_eq!(descr.max_def_level(), 4);
1776        assert_eq!(descr.max_rep_level(), 1);
1777        assert_eq!(descr.name(), "name");
1778        assert_eq!(descr.type_length(), -1);
1779        assert_eq!(descr.type_precision(), -1);
1780        assert_eq!(descr.type_scale(), -1);
1781
1782        Ok(())
1783    }
1784
1785    #[test]
1786    fn test_schema_descriptor() {
1787        let result = test_schema_descriptor_helper();
1788        assert!(
1789            result.is_ok(),
1790            "Expected result to be OK but got err:\n {}",
1791            result.unwrap_err()
1792        );
1793    }
1794
1795    // A helper fn to avoid handling the results from type creation
1796    fn test_schema_descriptor_helper() -> Result<()> {
1797        let mut fields = vec![];
1798
1799        let inta = Type::primitive_type_builder("a", PhysicalType::INT32)
1800            .with_repetition(Repetition::REQUIRED)
1801            .with_converted_type(ConvertedType::INT_32)
1802            .build()?;
1803        fields.push(Arc::new(inta));
1804        let intb = Type::primitive_type_builder("b", PhysicalType::INT64)
1805            .with_converted_type(ConvertedType::INT_64)
1806            .build()?;
1807        fields.push(Arc::new(intb));
1808        let intc = Type::primitive_type_builder("c", PhysicalType::BYTE_ARRAY)
1809            .with_repetition(Repetition::REPEATED)
1810            .with_converted_type(ConvertedType::UTF8)
1811            .build()?;
1812        fields.push(Arc::new(intc));
1813
1814        // 3-level list encoding
1815        let item1 = Type::primitive_type_builder("item1", PhysicalType::INT64)
1816            .with_repetition(Repetition::REQUIRED)
1817            .with_converted_type(ConvertedType::INT_64)
1818            .build()?;
1819        let item2 = Type::primitive_type_builder("item2", PhysicalType::BOOLEAN).build()?;
1820        let item3 = Type::primitive_type_builder("item3", PhysicalType::INT32)
1821            .with_repetition(Repetition::REPEATED)
1822            .with_converted_type(ConvertedType::INT_32)
1823            .build()?;
1824        let list = Type::group_type_builder("records")
1825            .with_repetition(Repetition::REPEATED)
1826            .with_converted_type(ConvertedType::LIST)
1827            .with_fields(vec![Arc::new(item1), Arc::new(item2), Arc::new(item3)])
1828            .build()?;
1829        let bag = Type::group_type_builder("bag")
1830            .with_repetition(Repetition::OPTIONAL)
1831            .with_fields(vec![Arc::new(list)])
1832            .build()?;
1833        fields.push(Arc::new(bag));
1834
1835        let schema = Type::group_type_builder("schema")
1836            .with_repetition(Repetition::REPEATED)
1837            .with_fields(fields)
1838            .build()?;
1839        let descr = SchemaDescriptor::new(Arc::new(schema));
1840
1841        let nleaves = 6;
1842        assert_eq!(descr.num_columns(), nleaves);
1843
1844        //                             mdef mrep
1845        // required int32 a            0    0
1846        // optional int64 b            1    0
1847        // repeated byte_array c       1    1
1848        // optional group bag          1    0
1849        //   repeated group records    2    1
1850        //     required int64 item1    2    1
1851        //     optional boolean item2  3    1
1852        //     repeated int32 item3    3    2
1853        let ex_max_def_levels = [0, 1, 1, 2, 3, 3];
1854        let ex_max_rep_levels = [0, 0, 1, 1, 1, 2];
1855
1856        for i in 0..nleaves {
1857            let col = descr.column(i);
1858            assert_eq!(col.max_def_level(), ex_max_def_levels[i], "{i}");
1859            assert_eq!(col.max_rep_level(), ex_max_rep_levels[i], "{i}");
1860        }
1861
1862        assert_eq!(descr.column(0).path().string(), "a");
1863        assert_eq!(descr.column(1).path().string(), "b");
1864        assert_eq!(descr.column(2).path().string(), "c");
1865        assert_eq!(descr.column(3).path().string(), "bag.records.item1");
1866        assert_eq!(descr.column(4).path().string(), "bag.records.item2");
1867        assert_eq!(descr.column(5).path().string(), "bag.records.item3");
1868
1869        assert_eq!(descr.get_column_root(0).name(), "a");
1870        assert_eq!(descr.get_column_root(3).name(), "bag");
1871        assert_eq!(descr.get_column_root(4).name(), "bag");
1872        assert_eq!(descr.get_column_root(5).name(), "bag");
1873
1874        Ok(())
1875    }
1876
1877    #[test]
1878    fn test_schema_build_tree_def_rep_levels() {
1879        let message_type = "
1880    message spark_schema {
1881      REQUIRED INT32 a;
1882      OPTIONAL group b {
1883        OPTIONAL INT32 _1;
1884        OPTIONAL INT32 _2;
1885      }
1886      OPTIONAL group c (LIST) {
1887        REPEATED group list {
1888          OPTIONAL INT32 element;
1889        }
1890      }
1891    }
1892    ";
1893        let schema = parse_message_type(message_type).expect("should parse schema");
1894        let descr = SchemaDescriptor::new(Arc::new(schema));
1895        // required int32 a
1896        assert_eq!(descr.column(0).max_def_level(), 0);
1897        assert_eq!(descr.column(0).max_rep_level(), 0);
1898        // optional int32 b._1
1899        assert_eq!(descr.column(1).max_def_level(), 2);
1900        assert_eq!(descr.column(1).max_rep_level(), 0);
1901        // optional int32 b._2
1902        assert_eq!(descr.column(2).max_def_level(), 2);
1903        assert_eq!(descr.column(2).max_rep_level(), 0);
1904        // repeated optional int32 c.list.element
1905        assert_eq!(descr.column(3).max_def_level(), 3);
1906        assert_eq!(descr.column(3).max_rep_level(), 1);
1907    }
1908
1909    #[test]
1910    #[should_panic(expected = "Cannot call get_physical_type() on a non-primitive type")]
1911    fn test_get_physical_type_panic() {
1912        let list = Type::group_type_builder("records")
1913            .with_repetition(Repetition::REPEATED)
1914            .build()
1915            .unwrap();
1916        list.get_physical_type();
1917    }
1918
1919    #[test]
1920    fn test_get_physical_type_primitive() {
1921        let f = Type::primitive_type_builder("f", PhysicalType::INT64)
1922            .build()
1923            .unwrap();
1924        assert_eq!(f.get_physical_type(), PhysicalType::INT64);
1925
1926        let f = Type::primitive_type_builder("f", PhysicalType::BYTE_ARRAY)
1927            .build()
1928            .unwrap();
1929        assert_eq!(f.get_physical_type(), PhysicalType::BYTE_ARRAY);
1930    }
1931
1932    #[test]
1933    fn test_check_contains_primitive_primitive() {
1934        // OK
1935        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1936            .build()
1937            .unwrap();
1938        let f2 = Type::primitive_type_builder("f", PhysicalType::INT32)
1939            .build()
1940            .unwrap();
1941        assert!(f1.check_contains(&f2));
1942
1943        // OK: different logical type does not affect check_contains
1944        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1945            .with_converted_type(ConvertedType::UINT_8)
1946            .build()
1947            .unwrap();
1948        let f2 = Type::primitive_type_builder("f", PhysicalType::INT32)
1949            .with_converted_type(ConvertedType::UINT_16)
1950            .build()
1951            .unwrap();
1952        assert!(f1.check_contains(&f2));
1953
1954        // KO: different name
1955        let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32)
1956            .build()
1957            .unwrap();
1958        let f2 = Type::primitive_type_builder("f2", PhysicalType::INT32)
1959            .build()
1960            .unwrap();
1961        assert!(!f1.check_contains(&f2));
1962
1963        // KO: different type
1964        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1965            .build()
1966            .unwrap();
1967        let f2 = Type::primitive_type_builder("f", PhysicalType::INT64)
1968            .build()
1969            .unwrap();
1970        assert!(!f1.check_contains(&f2));
1971
1972        // KO: different repetition
1973        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1974            .with_repetition(Repetition::REQUIRED)
1975            .build()
1976            .unwrap();
1977        let f2 = Type::primitive_type_builder("f", PhysicalType::INT32)
1978            .with_repetition(Repetition::OPTIONAL)
1979            .build()
1980            .unwrap();
1981        assert!(!f1.check_contains(&f2));
1982    }
1983
1984    // function to create a new group type for testing
1985    fn test_new_group_type(name: &str, repetition: Repetition, types: Vec<Type>) -> Type {
1986        Type::group_type_builder(name)
1987            .with_repetition(repetition)
1988            .with_fields(types.into_iter().map(Arc::new).collect())
1989            .build()
1990            .unwrap()
1991    }
1992
1993    #[test]
1994    fn test_check_contains_group_group() {
1995        // OK: should match okay with empty fields
1996        let f1 = Type::group_type_builder("f").build().unwrap();
1997        let f2 = Type::group_type_builder("f").build().unwrap();
1998        assert!(f1.check_contains(&f2));
1999        assert!(!f1.is_optional());
2000
2001        // OK: fields match
2002        let f1 = test_new_group_type(
2003            "f",
2004            Repetition::REPEATED,
2005            vec![
2006                Type::primitive_type_builder("f1", PhysicalType::INT32)
2007                    .build()
2008                    .unwrap(),
2009                Type::primitive_type_builder("f2", PhysicalType::INT64)
2010                    .build()
2011                    .unwrap(),
2012            ],
2013        );
2014        let f2 = test_new_group_type(
2015            "f",
2016            Repetition::REPEATED,
2017            vec![
2018                Type::primitive_type_builder("f1", PhysicalType::INT32)
2019                    .build()
2020                    .unwrap(),
2021                Type::primitive_type_builder("f2", PhysicalType::INT64)
2022                    .build()
2023                    .unwrap(),
2024            ],
2025        );
2026        assert!(f1.check_contains(&f2));
2027
2028        // OK: subset of fields
2029        let f1 = test_new_group_type(
2030            "f",
2031            Repetition::REPEATED,
2032            vec![
2033                Type::primitive_type_builder("f1", PhysicalType::INT32)
2034                    .build()
2035                    .unwrap(),
2036                Type::primitive_type_builder("f2", PhysicalType::INT64)
2037                    .build()
2038                    .unwrap(),
2039            ],
2040        );
2041        let f2 = test_new_group_type(
2042            "f",
2043            Repetition::REPEATED,
2044            vec![
2045                Type::primitive_type_builder("f2", PhysicalType::INT64)
2046                    .build()
2047                    .unwrap(),
2048            ],
2049        );
2050        assert!(f1.check_contains(&f2));
2051
2052        // KO: different name
2053        let f1 = Type::group_type_builder("f1").build().unwrap();
2054        let f2 = Type::group_type_builder("f2").build().unwrap();
2055        assert!(!f1.check_contains(&f2));
2056
2057        // KO: different repetition
2058        let f1 = Type::group_type_builder("f")
2059            .with_repetition(Repetition::OPTIONAL)
2060            .build()
2061            .unwrap();
2062        let f2 = Type::group_type_builder("f")
2063            .with_repetition(Repetition::REPEATED)
2064            .build()
2065            .unwrap();
2066        assert!(!f1.check_contains(&f2));
2067
2068        // KO: different fields
2069        let f1 = test_new_group_type(
2070            "f",
2071            Repetition::REPEATED,
2072            vec![
2073                Type::primitive_type_builder("f1", PhysicalType::INT32)
2074                    .build()
2075                    .unwrap(),
2076                Type::primitive_type_builder("f2", PhysicalType::INT64)
2077                    .build()
2078                    .unwrap(),
2079            ],
2080        );
2081        let f2 = test_new_group_type(
2082            "f",
2083            Repetition::REPEATED,
2084            vec![
2085                Type::primitive_type_builder("f1", PhysicalType::INT32)
2086                    .build()
2087                    .unwrap(),
2088                Type::primitive_type_builder("f2", PhysicalType::BOOLEAN)
2089                    .build()
2090                    .unwrap(),
2091            ],
2092        );
2093        assert!(!f1.check_contains(&f2));
2094
2095        // KO: different fields
2096        let f1 = test_new_group_type(
2097            "f",
2098            Repetition::REPEATED,
2099            vec![
2100                Type::primitive_type_builder("f1", PhysicalType::INT32)
2101                    .build()
2102                    .unwrap(),
2103                Type::primitive_type_builder("f2", PhysicalType::INT64)
2104                    .build()
2105                    .unwrap(),
2106            ],
2107        );
2108        let f2 = test_new_group_type(
2109            "f",
2110            Repetition::REPEATED,
2111            vec![
2112                Type::primitive_type_builder("f3", PhysicalType::INT32)
2113                    .build()
2114                    .unwrap(),
2115            ],
2116        );
2117        assert!(!f1.check_contains(&f2));
2118    }
2119
2120    #[test]
2121    fn test_check_contains_group_primitive() {
2122        // KO: should not match
2123        let f1 = Type::group_type_builder("f").build().unwrap();
2124        let f2 = Type::primitive_type_builder("f", PhysicalType::INT64)
2125            .build()
2126            .unwrap();
2127        assert!(!f1.check_contains(&f2));
2128        assert!(!f2.check_contains(&f1));
2129
2130        // KO: should not match when primitive field is part of group type
2131        let f1 = test_new_group_type(
2132            "f",
2133            Repetition::REPEATED,
2134            vec![
2135                Type::primitive_type_builder("f1", PhysicalType::INT32)
2136                    .build()
2137                    .unwrap(),
2138            ],
2139        );
2140        let f2 = Type::primitive_type_builder("f1", PhysicalType::INT32)
2141            .build()
2142            .unwrap();
2143        assert!(!f1.check_contains(&f2));
2144        assert!(!f2.check_contains(&f1));
2145
2146        // OK: match nested types
2147        let f1 = test_new_group_type(
2148            "a",
2149            Repetition::REPEATED,
2150            vec![
2151                test_new_group_type(
2152                    "b",
2153                    Repetition::REPEATED,
2154                    vec![
2155                        Type::primitive_type_builder("c", PhysicalType::INT32)
2156                            .build()
2157                            .unwrap(),
2158                    ],
2159                ),
2160                Type::primitive_type_builder("d", PhysicalType::INT64)
2161                    .build()
2162                    .unwrap(),
2163                Type::primitive_type_builder("e", PhysicalType::BOOLEAN)
2164                    .build()
2165                    .unwrap(),
2166            ],
2167        );
2168        let f2 = test_new_group_type(
2169            "a",
2170            Repetition::REPEATED,
2171            vec![test_new_group_type(
2172                "b",
2173                Repetition::REPEATED,
2174                vec![
2175                    Type::primitive_type_builder("c", PhysicalType::INT32)
2176                        .build()
2177                        .unwrap(),
2178                ],
2179            )],
2180        );
2181        assert!(f1.check_contains(&f2)); // should match
2182        assert!(!f2.check_contains(&f1)); // should fail
2183    }
2184
2185    #[test]
2186    fn test_schema_type_thrift_conversion_err() {
2187        let schema = Type::primitive_type_builder("col", PhysicalType::INT32)
2188            .build()
2189            .unwrap();
2190        let schema = Arc::new(schema);
2191        let thrift_schema = schema_to_buf(&schema);
2192        assert!(thrift_schema.is_err());
2193        if let Err(e) = thrift_schema {
2194            assert_eq!(
2195                format!("{e}"),
2196                "Parquet error: Root schema must be Group type"
2197            );
2198        }
2199    }
2200
2201    #[test]
2202    fn test_schema_type_thrift_conversion() {
2203        let message_type = "
2204    message conversions {
2205      REQUIRED INT64 id;
2206      OPTIONAL FIXED_LEN_BYTE_ARRAY (2) f16 (FLOAT16);
2207      OPTIONAL group int_array_Array (LIST) {
2208        REPEATED group list {
2209          OPTIONAL group element (LIST) {
2210            REPEATED group list {
2211              OPTIONAL INT32 element;
2212            }
2213          }
2214        }
2215      }
2216      OPTIONAL group int_map (MAP) {
2217        REPEATED group map (MAP_KEY_VALUE) {
2218          REQUIRED BYTE_ARRAY key (UTF8);
2219          OPTIONAL INT32 value;
2220        }
2221      }
2222      OPTIONAL group int_Map_Array (LIST) {
2223        REPEATED group list {
2224          OPTIONAL group g (MAP) {
2225            REPEATED group map (MAP_KEY_VALUE) {
2226              REQUIRED BYTE_ARRAY key (UTF8);
2227              OPTIONAL group value {
2228                OPTIONAL group H {
2229                  OPTIONAL group i (LIST) {
2230                    REPEATED group list {
2231                      OPTIONAL DOUBLE element;
2232                    }
2233                  }
2234                }
2235              }
2236            }
2237          }
2238        }
2239      }
2240      OPTIONAL group nested_struct {
2241        OPTIONAL INT32 A;
2242        OPTIONAL group b (LIST) {
2243          REPEATED group list {
2244            REQUIRED FIXED_LEN_BYTE_ARRAY (16) element;
2245          }
2246        }
2247      }
2248    }
2249    ";
2250        let expected_schema = parse_message_type(message_type).unwrap();
2251        let result_schema = roundtrip_schema(Arc::new(expected_schema.clone())).unwrap();
2252        assert_eq!(result_schema, Arc::new(expected_schema));
2253    }
2254
2255    #[test]
2256    fn test_schema_type_thrift_conversion_decimal() {
2257        let message_type = "
2258    message decimals {
2259      OPTIONAL INT32 field0;
2260      OPTIONAL INT64 field1 (DECIMAL (18, 2));
2261      OPTIONAL FIXED_LEN_BYTE_ARRAY (16) field2 (DECIMAL (38, 18));
2262      OPTIONAL BYTE_ARRAY field3 (DECIMAL (9));
2263    }
2264    ";
2265        let expected_schema = parse_message_type(message_type).unwrap();
2266        let result_schema = roundtrip_schema(Arc::new(expected_schema.clone())).unwrap();
2267        assert_eq!(result_schema, Arc::new(expected_schema));
2268    }
2269
2270    // Tests schema conversion from thrift, when num_children is set to Some(0) for a
2271    // primitive type.
2272    #[test]
2273    fn test_schema_from_thrift_with_num_children_set() {
2274        // schema definition written by parquet-cpp version 1.3.2-SNAPSHOT
2275        let message_type = "
2276    message schema {
2277      OPTIONAL BYTE_ARRAY id (UTF8);
2278      OPTIONAL BYTE_ARRAY name (UTF8);
2279      OPTIONAL BYTE_ARRAY message (UTF8);
2280      OPTIONAL INT32 type (UINT_8);
2281      OPTIONAL INT64 author_time (TIMESTAMP_MILLIS);
2282      OPTIONAL INT64 __index_level_0__;
2283    }
2284    ";
2285
2286        let expected_schema = Arc::new(parse_message_type(message_type).unwrap());
2287        let mut buf = schema_to_buf(&expected_schema).unwrap();
2288        let mut thrift_schema = buf_to_schema_list(&mut buf).unwrap();
2289
2290        // Change all of None to Some(0)
2291        for elem in &mut thrift_schema[..] {
2292            if elem.num_children.is_none() {
2293                elem.num_children = Some(0);
2294            }
2295        }
2296
2297        let result_schema = parquet_schema_from_array(thrift_schema).unwrap();
2298        assert_eq!(result_schema, expected_schema);
2299    }
2300
2301    // Sometimes parquet-cpp sets repetition level for the root node, which is against
2302    // the format definition, but we need to handle it by setting it back to None.
2303    #[test]
2304    fn test_schema_from_thrift_root_has_repetition() {
2305        // schema definition written by parquet-cpp version 1.3.2-SNAPSHOT
2306        let message_type = "
2307    message schema {
2308      OPTIONAL BYTE_ARRAY a (UTF8);
2309      OPTIONAL INT32 b (UINT_8);
2310    }
2311    ";
2312
2313        let expected_schema = Arc::new(parse_message_type(message_type).unwrap());
2314        let mut buf = schema_to_buf(&expected_schema).unwrap();
2315        let mut thrift_schema = buf_to_schema_list(&mut buf).unwrap();
2316        thrift_schema[0].repetition_type = Some(Repetition::REQUIRED);
2317
2318        let result_schema = parquet_schema_from_array(thrift_schema).unwrap();
2319        assert_eq!(result_schema, expected_schema);
2320    }
2321
2322    #[test]
2323    fn test_schema_from_thrift_group_has_no_child() {
2324        let message_type = "message schema {}";
2325
2326        let expected_schema = Arc::new(parse_message_type(message_type).unwrap());
2327        let mut buf = schema_to_buf(&expected_schema).unwrap();
2328        let mut thrift_schema = buf_to_schema_list(&mut buf).unwrap();
2329        thrift_schema[0].repetition_type = Some(Repetition::REQUIRED);
2330
2331        let result_schema = parquet_schema_from_array(thrift_schema).unwrap();
2332        assert_eq!(result_schema, expected_schema);
2333    }
2334}