parquet/schema/
types.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains structs and methods to build Parquet schema and schema descriptors.
19
20use std::{collections::HashMap, fmt, sync::Arc};
21
22use crate::file::metadata::HeapSize;
23use crate::format::SchemaElement;
24
25use crate::basic::{
26    ColumnOrder, ConvertedType, LogicalType, Repetition, SortOrder, TimeUnit, Type as PhysicalType,
27};
28use crate::errors::{ParquetError, Result};
29
30// ----------------------------------------------------------------------
31// Parquet Type definitions
32
33/// Type alias for `Arc<Type>`.
34pub type TypePtr = Arc<Type>;
35/// Type alias for `Arc<SchemaDescriptor>`.
36pub type SchemaDescPtr = Arc<SchemaDescriptor>;
37/// Type alias for `Arc<ColumnDescriptor>`.
38pub type ColumnDescPtr = Arc<ColumnDescriptor>;
39
40/// Representation of a Parquet type.
41///
42/// Used to describe primitive leaf fields and structs, including top-level schema.
43///
44/// Note that the top-level schema is represented using [`Type::GroupType`] whose
45/// repetition is `None`.
46#[derive(Clone, Debug, PartialEq)]
47pub enum Type {
48    /// Represents a primitive leaf field.
49    PrimitiveType {
50        /// Basic information about the type.
51        basic_info: BasicTypeInfo,
52        /// Physical type of this primitive type.
53        physical_type: PhysicalType,
54        /// Length of this type.
55        type_length: i32,
56        /// Scale of this type.
57        scale: i32,
58        /// Precision of this type.
59        precision: i32,
60    },
61    /// Represents a group of fields (similar to struct).
62    GroupType {
63        /// Basic information about the type.
64        basic_info: BasicTypeInfo,
65        /// Fields of this group type.
66        fields: Vec<TypePtr>,
67    },
68}
69
70impl HeapSize for Type {
71    fn heap_size(&self) -> usize {
72        match self {
73            Type::PrimitiveType { basic_info, .. } => basic_info.heap_size(),
74            Type::GroupType { basic_info, fields } => basic_info.heap_size() + fields.heap_size(),
75        }
76    }
77}
78
79impl Type {
80    /// Creates primitive type builder with provided field name and physical type.
81    pub fn primitive_type_builder(name: &str, physical_type: PhysicalType) -> PrimitiveTypeBuilder {
82        PrimitiveTypeBuilder::new(name, physical_type)
83    }
84
85    /// Creates group type builder with provided column name.
86    pub fn group_type_builder(name: &str) -> GroupTypeBuilder {
87        GroupTypeBuilder::new(name)
88    }
89
90    /// Returns [`BasicTypeInfo`] information about the type.
91    pub fn get_basic_info(&self) -> &BasicTypeInfo {
92        match *self {
93            Type::PrimitiveType { ref basic_info, .. } => basic_info,
94            Type::GroupType { ref basic_info, .. } => basic_info,
95        }
96    }
97
98    /// Returns this type's field name.
99    pub fn name(&self) -> &str {
100        self.get_basic_info().name()
101    }
102
103    /// Gets the fields from this group type.
104    /// Note that this will panic if called on a non-group type.
105    // TODO: should we return `&[&Type]` here?
106    pub fn get_fields(&self) -> &[TypePtr] {
107        match *self {
108            Type::GroupType { ref fields, .. } => &fields[..],
109            _ => panic!("Cannot call get_fields() on a non-group type"),
110        }
111    }
112
113    /// Gets physical type of this primitive type.
114    /// Note that this will panic if called on a non-primitive type.
115    pub fn get_physical_type(&self) -> PhysicalType {
116        match *self {
117            Type::PrimitiveType {
118                basic_info: _,
119                physical_type,
120                ..
121            } => physical_type,
122            _ => panic!("Cannot call get_physical_type() on a non-primitive type"),
123        }
124    }
125
126    /// Gets precision of this primitive type.
127    /// Note that this will panic if called on a non-primitive type.
128    pub fn get_precision(&self) -> i32 {
129        match *self {
130            Type::PrimitiveType { precision, .. } => precision,
131            _ => panic!("Cannot call get_precision() on non-primitive type"),
132        }
133    }
134
135    /// Gets scale of this primitive type.
136    /// Note that this will panic if called on a non-primitive type.
137    pub fn get_scale(&self) -> i32 {
138        match *self {
139            Type::PrimitiveType { scale, .. } => scale,
140            _ => panic!("Cannot call get_scale() on non-primitive type"),
141        }
142    }
143
144    /// Checks if `sub_type` schema is part of current schema.
145    /// This method can be used to check if projected columns are part of the root schema.
146    pub fn check_contains(&self, sub_type: &Type) -> bool {
147        // Names match, and repetitions match or not set for both
148        let basic_match = self.get_basic_info().name() == sub_type.get_basic_info().name()
149            && (self.is_schema() && sub_type.is_schema()
150                || !self.is_schema()
151                    && !sub_type.is_schema()
152                    && self.get_basic_info().repetition()
153                        == sub_type.get_basic_info().repetition());
154
155        match *self {
156            Type::PrimitiveType { .. } if basic_match && sub_type.is_primitive() => {
157                self.get_physical_type() == sub_type.get_physical_type()
158            }
159            Type::GroupType { .. } if basic_match && sub_type.is_group() => {
160                // build hashmap of name -> TypePtr
161                let mut field_map = HashMap::new();
162                for field in self.get_fields() {
163                    field_map.insert(field.name(), field);
164                }
165
166                for field in sub_type.get_fields() {
167                    if !field_map
168                        .get(field.name())
169                        .map(|tpe| tpe.check_contains(field))
170                        .unwrap_or(false)
171                    {
172                        return false;
173                    }
174                }
175                true
176            }
177            _ => false,
178        }
179    }
180
181    /// Returns `true` if this type is a primitive type, `false` otherwise.
182    pub fn is_primitive(&self) -> bool {
183        matches!(*self, Type::PrimitiveType { .. })
184    }
185
186    /// Returns `true` if this type is a group type, `false` otherwise.
187    pub fn is_group(&self) -> bool {
188        matches!(*self, Type::GroupType { .. })
189    }
190
191    /// Returns `true` if this type is the top-level schema type (message type).
192    pub fn is_schema(&self) -> bool {
193        match *self {
194            Type::GroupType { ref basic_info, .. } => !basic_info.has_repetition(),
195            _ => false,
196        }
197    }
198
199    /// Returns `true` if this type is repeated or optional.
200    /// If this type doesn't have repetition defined, we treat it as required.
201    pub fn is_optional(&self) -> bool {
202        self.get_basic_info().has_repetition()
203            && self.get_basic_info().repetition() != Repetition::REQUIRED
204    }
205
206    /// Returns `true` if this type is annotated as a list.
207    pub(crate) fn is_list(&self) -> bool {
208        if self.is_group() {
209            let basic_info = self.get_basic_info();
210            if let Some(logical_type) = basic_info.logical_type() {
211                return logical_type == LogicalType::List;
212            }
213            return basic_info.converted_type() == ConvertedType::LIST;
214        }
215        false
216    }
217
218    /// Returns `true` if this type is a group with a single child field that is `repeated`.
219    pub(crate) fn has_single_repeated_child(&self) -> bool {
220        if self.is_group() {
221            let children = self.get_fields();
222            return children.len() == 1
223                && children[0].get_basic_info().has_repetition()
224                && children[0].get_basic_info().repetition() == Repetition::REPEATED;
225        }
226        false
227    }
228}
229
230/// A builder for primitive types. All attributes are optional
231/// except the name and physical type.
232/// Note that if not specified explicitly, `Repetition::OPTIONAL` is used.
233pub struct PrimitiveTypeBuilder<'a> {
234    name: &'a str,
235    repetition: Repetition,
236    physical_type: PhysicalType,
237    converted_type: ConvertedType,
238    logical_type: Option<LogicalType>,
239    length: i32,
240    precision: i32,
241    scale: i32,
242    id: Option<i32>,
243}
244
245impl<'a> PrimitiveTypeBuilder<'a> {
246    /// Creates new primitive type builder with provided field name and physical type.
247    pub fn new(name: &'a str, physical_type: PhysicalType) -> Self {
248        Self {
249            name,
250            repetition: Repetition::OPTIONAL,
251            physical_type,
252            converted_type: ConvertedType::NONE,
253            logical_type: None,
254            length: -1,
255            precision: -1,
256            scale: -1,
257            id: None,
258        }
259    }
260
261    /// Sets [`Repetition`] for this field and returns itself.
262    pub fn with_repetition(self, repetition: Repetition) -> Self {
263        Self { repetition, ..self }
264    }
265
266    /// Sets [`ConvertedType`] for this field and returns itself.
267    pub fn with_converted_type(self, converted_type: ConvertedType) -> Self {
268        Self {
269            converted_type,
270            ..self
271        }
272    }
273
274    /// Sets [`LogicalType`] for this field and returns itself.
275    /// If only the logical type is populated for a primitive type, the converted type
276    /// will be automatically populated, and can thus be omitted.
277    pub fn with_logical_type(self, logical_type: Option<LogicalType>) -> Self {
278        Self {
279            logical_type,
280            ..self
281        }
282    }
283
284    /// Sets type length and returns itself.
285    /// This is only applied to FIXED_LEN_BYTE_ARRAY and INT96 (INTERVAL) types, because
286    /// they maintain fixed size underlying byte array.
287    /// By default, value is `0`.
288    pub fn with_length(self, length: i32) -> Self {
289        Self { length, ..self }
290    }
291
292    /// Sets precision for Parquet DECIMAL physical type and returns itself.
293    /// By default, it equals to `0` and used only for decimal context.
294    pub fn with_precision(self, precision: i32) -> Self {
295        Self { precision, ..self }
296    }
297
298    /// Sets scale for Parquet DECIMAL physical type and returns itself.
299    /// By default, it equals to `0` and used only for decimal context.
300    pub fn with_scale(self, scale: i32) -> Self {
301        Self { scale, ..self }
302    }
303
304    /// Sets optional field id and returns itself.
305    pub fn with_id(self, id: Option<i32>) -> Self {
306        Self { id, ..self }
307    }
308
309    /// Creates a new `PrimitiveType` instance from the collected attributes.
310    /// Returns `Err` in case of any building conditions are not met.
311    pub fn build(self) -> Result<Type> {
312        let mut basic_info = BasicTypeInfo {
313            name: String::from(self.name),
314            repetition: Some(self.repetition),
315            converted_type: self.converted_type,
316            logical_type: self.logical_type.clone(),
317            id: self.id,
318        };
319
320        // Check length before logical type, since it is used for logical type validation.
321        if self.physical_type == PhysicalType::FIXED_LEN_BYTE_ARRAY && self.length < 0 {
322            return Err(general_err!(
323                "Invalid FIXED_LEN_BYTE_ARRAY length: {} for field '{}'",
324                self.length,
325                self.name
326            ));
327        }
328
329        if let Some(logical_type) = &self.logical_type {
330            // If a converted type is populated, check that it is consistent with
331            // its logical type
332            if self.converted_type != ConvertedType::NONE {
333                if ConvertedType::from(self.logical_type.clone()) != self.converted_type {
334                    return Err(general_err!(
335                        "Logical type {:?} is incompatible with converted type {} for field '{}'",
336                        logical_type,
337                        self.converted_type,
338                        self.name
339                    ));
340                }
341            } else {
342                // Populate the converted type for backwards compatibility
343                basic_info.converted_type = self.logical_type.clone().into();
344            }
345            // Check that logical type and physical type are compatible
346            match (logical_type, self.physical_type) {
347                (LogicalType::Map, _) | (LogicalType::List, _) => {
348                    return Err(general_err!(
349                        "{:?} cannot be applied to a primitive type for field '{}'",
350                        logical_type,
351                        self.name
352                    ));
353                }
354                (LogicalType::Enum, PhysicalType::BYTE_ARRAY) => {}
355                (LogicalType::Decimal { scale, precision }, _) => {
356                    // Check that scale and precision are consistent with legacy values
357                    if *scale != self.scale {
358                        return Err(general_err!(
359                            "DECIMAL logical type scale {} must match self.scale {} for field '{}'",
360                            scale,
361                            self.scale,
362                            self.name
363                        ));
364                    }
365                    if *precision != self.precision {
366                        return Err(general_err!(
367                            "DECIMAL logical type precision {} must match self.precision {} for field '{}'",
368                            precision,
369                            self.precision,
370                            self.name
371                        ));
372                    }
373                    self.check_decimal_precision_scale()?;
374                }
375                (LogicalType::Date, PhysicalType::INT32) => {}
376                (
377                    LogicalType::Time {
378                        unit: TimeUnit::MILLIS(_),
379                        ..
380                    },
381                    PhysicalType::INT32,
382                ) => {}
383                (LogicalType::Time { unit, .. }, PhysicalType::INT64) => {
384                    if *unit == TimeUnit::MILLIS(Default::default()) {
385                        return Err(general_err!(
386                            "Cannot use millisecond unit on INT64 type for field '{}'",
387                            self.name
388                        ));
389                    }
390                }
391                (LogicalType::Timestamp { .. }, PhysicalType::INT64) => {}
392                (LogicalType::Integer { bit_width, .. }, PhysicalType::INT32)
393                    if *bit_width <= 32 => {}
394                (LogicalType::Integer { bit_width, .. }, PhysicalType::INT64)
395                    if *bit_width == 64 => {}
396                // Null type
397                (LogicalType::Unknown, PhysicalType::INT32) => {}
398                (LogicalType::String, PhysicalType::BYTE_ARRAY) => {}
399                (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {}
400                (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {}
401                (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) if self.length == 16 => {}
402                (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {
403                    return Err(general_err!(
404                        "UUID cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(16) field",
405                        self.name
406                    ))
407                }
408                (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY)
409                    if self.length == 2 => {}
410                (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {
411                    return Err(general_err!(
412                        "FLOAT16 cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(2) field",
413                        self.name
414                    ))
415                }
416                (a, b) => {
417                    return Err(general_err!(
418                        "Cannot annotate {:?} from {} for field '{}'",
419                        a,
420                        b,
421                        self.name
422                    ))
423                }
424            }
425        }
426
427        match self.converted_type {
428            ConvertedType::NONE => {}
429            ConvertedType::UTF8 | ConvertedType::BSON | ConvertedType::JSON => {
430                if self.physical_type != PhysicalType::BYTE_ARRAY {
431                    return Err(general_err!(
432                        "{} cannot annotate field '{}' because it is not a BYTE_ARRAY field",
433                        self.converted_type,
434                        self.name
435                    ));
436                }
437            }
438            ConvertedType::DECIMAL => {
439                self.check_decimal_precision_scale()?;
440            }
441            ConvertedType::DATE
442            | ConvertedType::TIME_MILLIS
443            | ConvertedType::UINT_8
444            | ConvertedType::UINT_16
445            | ConvertedType::UINT_32
446            | ConvertedType::INT_8
447            | ConvertedType::INT_16
448            | ConvertedType::INT_32 => {
449                if self.physical_type != PhysicalType::INT32 {
450                    return Err(general_err!(
451                        "{} cannot annotate field '{}' because it is not a INT32 field",
452                        self.converted_type,
453                        self.name
454                    ));
455                }
456            }
457            ConvertedType::TIME_MICROS
458            | ConvertedType::TIMESTAMP_MILLIS
459            | ConvertedType::TIMESTAMP_MICROS
460            | ConvertedType::UINT_64
461            | ConvertedType::INT_64 => {
462                if self.physical_type != PhysicalType::INT64 {
463                    return Err(general_err!(
464                        "{} cannot annotate field '{}' because it is not a INT64 field",
465                        self.converted_type,
466                        self.name
467                    ));
468                }
469            }
470            ConvertedType::INTERVAL => {
471                if self.physical_type != PhysicalType::FIXED_LEN_BYTE_ARRAY || self.length != 12 {
472                    return Err(general_err!(
473                        "INTERVAL cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(12) field",
474                        self.name
475                    ));
476                }
477            }
478            ConvertedType::ENUM => {
479                if self.physical_type != PhysicalType::BYTE_ARRAY {
480                    return Err(general_err!(
481                        "ENUM cannot annotate field '{}' because it is not a BYTE_ARRAY field",
482                        self.name
483                    ));
484                }
485            }
486            _ => {
487                return Err(general_err!(
488                    "{} cannot be applied to primitive field '{}'",
489                    self.converted_type,
490                    self.name
491                ));
492            }
493        }
494
495        Ok(Type::PrimitiveType {
496            basic_info,
497            physical_type: self.physical_type,
498            type_length: self.length,
499            scale: self.scale,
500            precision: self.precision,
501        })
502    }
503
504    #[inline]
505    fn check_decimal_precision_scale(&self) -> Result<()> {
506        match self.physical_type {
507            PhysicalType::INT32
508            | PhysicalType::INT64
509            | PhysicalType::BYTE_ARRAY
510            | PhysicalType::FIXED_LEN_BYTE_ARRAY => (),
511            _ => {
512                return Err(general_err!(
513                    "DECIMAL can only annotate INT32, INT64, BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY"
514                ));
515            }
516        }
517
518        // Precision is required and must be a non-zero positive integer.
519        if self.precision < 1 {
520            return Err(general_err!(
521                "Invalid DECIMAL precision: {}",
522                self.precision
523            ));
524        }
525
526        // Scale must be zero or a positive integer less than the precision.
527        if self.scale < 0 {
528            return Err(general_err!("Invalid DECIMAL scale: {}", self.scale));
529        }
530
531        if self.scale > self.precision {
532            return Err(general_err!(
533                "Invalid DECIMAL: scale ({}) cannot be greater than precision \
534             ({})",
535                self.scale,
536                self.precision
537            ));
538        }
539
540        // Check precision and scale based on physical type limitations.
541        match self.physical_type {
542            PhysicalType::INT32 => {
543                if self.precision > 9 {
544                    return Err(general_err!(
545                        "Cannot represent INT32 as DECIMAL with precision {}",
546                        self.precision
547                    ));
548                }
549            }
550            PhysicalType::INT64 => {
551                if self.precision > 18 {
552                    return Err(general_err!(
553                        "Cannot represent INT64 as DECIMAL with precision {}",
554                        self.precision
555                    ));
556                }
557            }
558            PhysicalType::FIXED_LEN_BYTE_ARRAY => {
559                let length = self
560                    .length
561                    .checked_mul(8)
562                    .ok_or(general_err!("Invalid length {} for Decimal", self.length))?;
563                let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32;
564
565                if self.precision > max_precision {
566                    return Err(general_err!(
567                        "Cannot represent FIXED_LEN_BYTE_ARRAY as DECIMAL with length {} and \
568                        precision {}. The max precision can only be {}",
569                        self.length,
570                        self.precision,
571                        max_precision
572                    ));
573                }
574            }
575            _ => (), // For BYTE_ARRAY precision is not limited
576        }
577
578        Ok(())
579    }
580}
581
582/// A builder for group types. All attributes are optional except the name.
583/// Note that if not specified explicitly, `None` is used as the repetition of the group,
584/// which means it is a root (message) type.
585pub struct GroupTypeBuilder<'a> {
586    name: &'a str,
587    repetition: Option<Repetition>,
588    converted_type: ConvertedType,
589    logical_type: Option<LogicalType>,
590    fields: Vec<TypePtr>,
591    id: Option<i32>,
592}
593
594impl<'a> GroupTypeBuilder<'a> {
595    /// Creates new group type builder with provided field name.
596    pub fn new(name: &'a str) -> Self {
597        Self {
598            name,
599            repetition: None,
600            converted_type: ConvertedType::NONE,
601            logical_type: None,
602            fields: Vec::new(),
603            id: None,
604        }
605    }
606
607    /// Sets [`Repetition`] for this field and returns itself.
608    pub fn with_repetition(mut self, repetition: Repetition) -> Self {
609        self.repetition = Some(repetition);
610        self
611    }
612
613    /// Sets [`ConvertedType`] for this field and returns itself.
614    pub fn with_converted_type(self, converted_type: ConvertedType) -> Self {
615        Self {
616            converted_type,
617            ..self
618        }
619    }
620
621    /// Sets [`LogicalType`] for this field and returns itself.
622    pub fn with_logical_type(self, logical_type: Option<LogicalType>) -> Self {
623        Self {
624            logical_type,
625            ..self
626        }
627    }
628
629    /// Sets a list of fields that should be child nodes of this field.
630    /// Returns updated self.
631    pub fn with_fields(self, fields: Vec<TypePtr>) -> Self {
632        Self { fields, ..self }
633    }
634
635    /// Sets optional field id and returns itself.
636    pub fn with_id(self, id: Option<i32>) -> Self {
637        Self { id, ..self }
638    }
639
640    /// Creates a new `GroupType` instance from the gathered attributes.
641    pub fn build(self) -> Result<Type> {
642        let mut basic_info = BasicTypeInfo {
643            name: String::from(self.name),
644            repetition: self.repetition,
645            converted_type: self.converted_type,
646            logical_type: self.logical_type.clone(),
647            id: self.id,
648        };
649        // Populate the converted type if only the logical type is populated
650        if self.logical_type.is_some() && self.converted_type == ConvertedType::NONE {
651            basic_info.converted_type = self.logical_type.into();
652        }
653        Ok(Type::GroupType {
654            basic_info,
655            fields: self.fields,
656        })
657    }
658}
659
660/// Basic type info. This contains information such as the name of the type,
661/// the repetition level, the logical type and the kind of the type (group, primitive).
662#[derive(Clone, Debug, PartialEq, Eq)]
663pub struct BasicTypeInfo {
664    name: String,
665    repetition: Option<Repetition>,
666    converted_type: ConvertedType,
667    logical_type: Option<LogicalType>,
668    id: Option<i32>,
669}
670
671impl HeapSize for BasicTypeInfo {
672    fn heap_size(&self) -> usize {
673        // no heap allocations in any other subfield
674        self.name.heap_size()
675    }
676}
677
678impl BasicTypeInfo {
679    /// Returns field name.
680    pub fn name(&self) -> &str {
681        &self.name
682    }
683
684    /// Returns `true` if type has repetition field set, `false` otherwise.
685    /// This is mostly applied to group type, because primitive type always has
686    /// repetition set.
687    pub fn has_repetition(&self) -> bool {
688        self.repetition.is_some()
689    }
690
691    /// Returns [`Repetition`] value for the type.
692    pub fn repetition(&self) -> Repetition {
693        assert!(self.repetition.is_some());
694        self.repetition.unwrap()
695    }
696
697    /// Returns [`ConvertedType`] value for the type.
698    pub fn converted_type(&self) -> ConvertedType {
699        self.converted_type
700    }
701
702    /// Returns [`LogicalType`] value for the type.
703    pub fn logical_type(&self) -> Option<LogicalType> {
704        // Unlike ConvertedType, LogicalType cannot implement Copy, thus we clone it
705        self.logical_type.clone()
706    }
707
708    /// Returns `true` if id is set, `false` otherwise.
709    pub fn has_id(&self) -> bool {
710        self.id.is_some()
711    }
712
713    /// Returns id value for the type.
714    pub fn id(&self) -> i32 {
715        assert!(self.id.is_some());
716        self.id.unwrap()
717    }
718}
719
720// ----------------------------------------------------------------------
721// Parquet descriptor definitions
722
723/// Represents the location of a column in a Parquet schema
724///
725/// # Example: refer to column named `'my_column'`
726/// ```
727/// # use parquet::schema::types::ColumnPath;
728/// let column_path = ColumnPath::from("my_column");
729/// ```
730///
731/// # Example: refer to column named `c` in a nested struct `{a: {b: {c: ...}}}`
732/// ```
733/// # use parquet::schema::types::ColumnPath;
734/// // form path 'a.b.c'
735/// let column_path = ColumnPath::from(vec![
736///   String::from("a"),
737///   String::from("b"),
738///   String::from("c")
739/// ]);
740/// ```
741#[derive(Clone, PartialEq, Debug, Eq, Hash)]
742pub struct ColumnPath {
743    parts: Vec<String>,
744}
745
746impl HeapSize for ColumnPath {
747    fn heap_size(&self) -> usize {
748        self.parts.heap_size()
749    }
750}
751
752impl ColumnPath {
753    /// Creates new column path from vector of field names.
754    pub fn new(parts: Vec<String>) -> Self {
755        ColumnPath { parts }
756    }
757
758    /// Returns string representation of this column path.
759    /// ```rust
760    /// use parquet::schema::types::ColumnPath;
761    ///
762    /// let path = ColumnPath::new(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
763    /// assert_eq!(&path.string(), "a.b.c");
764    /// ```
765    pub fn string(&self) -> String {
766        self.parts.join(".")
767    }
768
769    /// Appends more components to end of column path.
770    /// ```rust
771    /// use parquet::schema::types::ColumnPath;
772    ///
773    /// let mut path = ColumnPath::new(vec!["a".to_string(), "b".to_string(), "c"
774    /// .to_string()]);
775    /// assert_eq!(&path.string(), "a.b.c");
776    ///
777    /// path.append(vec!["d".to_string(), "e".to_string()]);
778    /// assert_eq!(&path.string(), "a.b.c.d.e");
779    /// ```
780    pub fn append(&mut self, mut tail: Vec<String>) {
781        self.parts.append(&mut tail);
782    }
783
784    /// Returns a slice of path components.
785    pub fn parts(&self) -> &[String] {
786        &self.parts
787    }
788}
789
790impl fmt::Display for ColumnPath {
791    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
792        write!(f, "{:?}", self.string())
793    }
794}
795
796impl From<Vec<String>> for ColumnPath {
797    fn from(parts: Vec<String>) -> Self {
798        ColumnPath { parts }
799    }
800}
801
802impl From<&str> for ColumnPath {
803    fn from(single_path: &str) -> Self {
804        let s = String::from(single_path);
805        ColumnPath::from(s)
806    }
807}
808
809impl From<String> for ColumnPath {
810    fn from(single_path: String) -> Self {
811        let v = vec![single_path];
812        ColumnPath { parts: v }
813    }
814}
815
816impl AsRef<[String]> for ColumnPath {
817    fn as_ref(&self) -> &[String] {
818        &self.parts
819    }
820}
821
822/// Physical type for leaf-level primitive columns.
823///
824/// Also includes the maximum definition and repetition levels required to
825/// re-assemble nested data.
826#[derive(Debug, PartialEq)]
827pub struct ColumnDescriptor {
828    /// The "leaf" primitive type of this column
829    primitive_type: TypePtr,
830
831    /// The maximum definition level for this column
832    max_def_level: i16,
833
834    /// The maximum repetition level for this column
835    max_rep_level: i16,
836
837    /// The path of this column. For instance, "a.b.c.d".
838    path: ColumnPath,
839}
840
841impl HeapSize for ColumnDescriptor {
842    fn heap_size(&self) -> usize {
843        self.primitive_type.heap_size() + self.path.heap_size()
844    }
845}
846
847impl ColumnDescriptor {
848    /// Creates new descriptor for leaf-level column.
849    pub fn new(
850        primitive_type: TypePtr,
851        max_def_level: i16,
852        max_rep_level: i16,
853        path: ColumnPath,
854    ) -> Self {
855        Self {
856            primitive_type,
857            max_def_level,
858            max_rep_level,
859            path,
860        }
861    }
862
863    /// Returns maximum definition level for this column.
864    #[inline]
865    pub fn max_def_level(&self) -> i16 {
866        self.max_def_level
867    }
868
869    /// Returns maximum repetition level for this column.
870    #[inline]
871    pub fn max_rep_level(&self) -> i16 {
872        self.max_rep_level
873    }
874
875    /// Returns [`ColumnPath`] for this column.
876    pub fn path(&self) -> &ColumnPath {
877        &self.path
878    }
879
880    /// Returns self type [`Type`] for this leaf column.
881    pub fn self_type(&self) -> &Type {
882        self.primitive_type.as_ref()
883    }
884
885    /// Returns self type [`TypePtr`]  for this leaf
886    /// column.
887    pub fn self_type_ptr(&self) -> TypePtr {
888        self.primitive_type.clone()
889    }
890
891    /// Returns column name.
892    pub fn name(&self) -> &str {
893        self.primitive_type.name()
894    }
895
896    /// Returns [`ConvertedType`] for this column.
897    pub fn converted_type(&self) -> ConvertedType {
898        self.primitive_type.get_basic_info().converted_type()
899    }
900
901    /// Returns [`LogicalType`] for this column.
902    pub fn logical_type(&self) -> Option<LogicalType> {
903        self.primitive_type.get_basic_info().logical_type()
904    }
905
906    /// Returns physical type for this column.
907    /// Note that it will panic if called on a non-primitive type.
908    pub fn physical_type(&self) -> PhysicalType {
909        match self.primitive_type.as_ref() {
910            Type::PrimitiveType { physical_type, .. } => *physical_type,
911            _ => panic!("Expected primitive type!"),
912        }
913    }
914
915    /// Returns type length for this column.
916    /// Note that it will panic if called on a non-primitive type.
917    pub fn type_length(&self) -> i32 {
918        match self.primitive_type.as_ref() {
919            Type::PrimitiveType { type_length, .. } => *type_length,
920            _ => panic!("Expected primitive type!"),
921        }
922    }
923
924    /// Returns type precision for this column.
925    /// Note that it will panic if called on a non-primitive type.
926    pub fn type_precision(&self) -> i32 {
927        match self.primitive_type.as_ref() {
928            Type::PrimitiveType { precision, .. } => *precision,
929            _ => panic!("Expected primitive type!"),
930        }
931    }
932
933    /// Returns type scale for this column.
934    /// Note that it will panic if called on a non-primitive type.
935    pub fn type_scale(&self) -> i32 {
936        match self.primitive_type.as_ref() {
937            Type::PrimitiveType { scale, .. } => *scale,
938            _ => panic!("Expected primitive type!"),
939        }
940    }
941
942    /// Returns the sort order for this column
943    pub fn sort_order(&self) -> SortOrder {
944        ColumnOrder::get_sort_order(
945            self.logical_type(),
946            self.converted_type(),
947            self.physical_type(),
948        )
949    }
950}
951
952/// Schema of a Parquet file.
953///
954/// Encapsulates the file's schema ([`Type`]) and [`ColumnDescriptor`]s for
955/// each primitive (leaf) column.
956///
957/// # Example
958/// ```
959/// # use std::sync::Arc;
960/// use parquet::schema::types::{SchemaDescriptor, Type};
961/// use parquet::basic; // note there are two `Type`s that are different
962/// // Schema for a table with two columns: "a" (int64) and "b" (int32, stored as a date)
963/// let descriptor = SchemaDescriptor::new(
964///   Arc::new(
965///     Type::group_type_builder("my_schema")
966///       .with_fields(vec![
967///         Arc::new(
968///          Type::primitive_type_builder("a", basic::Type::INT64)
969///           .build().unwrap()
970///         ),
971///         Arc::new(
972///          Type::primitive_type_builder("b", basic::Type::INT32)
973///           .with_converted_type(basic::ConvertedType::DATE)
974///           .with_logical_type(Some(basic::LogicalType::Date))
975///           .build().unwrap()
976///         ),
977///      ])
978///      .build().unwrap()
979///   )
980/// );
981/// ```
982#[derive(PartialEq)]
983pub struct SchemaDescriptor {
984    /// The top-level logical schema (the "message" type).
985    ///
986    /// This must be a [`Type::GroupType`] where each field is a root
987    /// column type in the schema.
988    schema: TypePtr,
989
990    /// The descriptors for the physical type of each leaf column in this schema
991    ///
992    /// Constructed from `schema` in DFS order.
993    leaves: Vec<ColumnDescPtr>,
994
995    /// Mapping from a leaf column's index to the root column index that it
996    /// comes from.
997    ///
998    /// For instance: the leaf `a.b.c.d` would have a link back to `a`:
999    /// ```text
1000    /// -- a  <-----+
1001    /// -- -- b     |
1002    /// -- -- -- c  |
1003    /// -- -- -- -- d
1004    /// ```
1005    leaf_to_base: Vec<usize>,
1006}
1007
1008impl fmt::Debug for SchemaDescriptor {
1009    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1010        // Skip leaves and leaf_to_base as they only a cache information already found in `schema`
1011        f.debug_struct("SchemaDescriptor")
1012            .field("schema", &self.schema)
1013            .finish()
1014    }
1015}
1016
1017// Need to implement HeapSize in this module as the fields are private
1018impl HeapSize for SchemaDescriptor {
1019    fn heap_size(&self) -> usize {
1020        self.schema.heap_size() + self.leaves.heap_size() + self.leaf_to_base.heap_size()
1021    }
1022}
1023
1024impl SchemaDescriptor {
1025    /// Creates new schema descriptor from Parquet schema.
1026    pub fn new(tp: TypePtr) -> Self {
1027        assert!(tp.is_group(), "SchemaDescriptor should take a GroupType");
1028        let mut leaves = vec![];
1029        let mut leaf_to_base = Vec::new();
1030        for (root_idx, f) in tp.get_fields().iter().enumerate() {
1031            let mut path = vec![];
1032            build_tree(f, root_idx, 0, 0, &mut leaves, &mut leaf_to_base, &mut path);
1033        }
1034
1035        Self {
1036            schema: tp,
1037            leaves,
1038            leaf_to_base,
1039        }
1040    }
1041
1042    /// Returns [`ColumnDescriptor`] for a field position.
1043    pub fn column(&self, i: usize) -> ColumnDescPtr {
1044        assert!(
1045            i < self.leaves.len(),
1046            "Index out of bound: {} not in [0, {})",
1047            i,
1048            self.leaves.len()
1049        );
1050        self.leaves[i].clone()
1051    }
1052
1053    /// Returns slice of [`ColumnDescriptor`].
1054    pub fn columns(&self) -> &[ColumnDescPtr] {
1055        &self.leaves
1056    }
1057
1058    /// Returns number of leaf-level columns.
1059    pub fn num_columns(&self) -> usize {
1060        self.leaves.len()
1061    }
1062
1063    /// Returns column root [`Type`] for a leaf position.
1064    pub fn get_column_root(&self, i: usize) -> &Type {
1065        let result = self.column_root_of(i);
1066        result.as_ref()
1067    }
1068
1069    /// Returns column root [`Type`] pointer for a leaf position.
1070    pub fn get_column_root_ptr(&self, i: usize) -> TypePtr {
1071        let result = self.column_root_of(i);
1072        result.clone()
1073    }
1074
1075    /// Returns the index of the root column for a field position
1076    pub fn get_column_root_idx(&self, leaf: usize) -> usize {
1077        assert!(
1078            leaf < self.leaves.len(),
1079            "Index out of bound: {} not in [0, {})",
1080            leaf,
1081            self.leaves.len()
1082        );
1083
1084        *self
1085            .leaf_to_base
1086            .get(leaf)
1087            .unwrap_or_else(|| panic!("Expected a value for index {leaf} but found None"))
1088    }
1089
1090    fn column_root_of(&self, i: usize) -> &TypePtr {
1091        &self.schema.get_fields()[self.get_column_root_idx(i)]
1092    }
1093
1094    /// Returns schema as [`Type`].
1095    pub fn root_schema(&self) -> &Type {
1096        self.schema.as_ref()
1097    }
1098
1099    /// Returns schema as [`TypePtr`] for cheap cloning.
1100    pub fn root_schema_ptr(&self) -> TypePtr {
1101        self.schema.clone()
1102    }
1103
1104    /// Returns schema name.
1105    pub fn name(&self) -> &str {
1106        self.schema.name()
1107    }
1108}
1109
1110fn build_tree<'a>(
1111    tp: &'a TypePtr,
1112    root_idx: usize,
1113    mut max_rep_level: i16,
1114    mut max_def_level: i16,
1115    leaves: &mut Vec<ColumnDescPtr>,
1116    leaf_to_base: &mut Vec<usize>,
1117    path_so_far: &mut Vec<&'a str>,
1118) {
1119    assert!(tp.get_basic_info().has_repetition());
1120
1121    path_so_far.push(tp.name());
1122    match tp.get_basic_info().repetition() {
1123        Repetition::OPTIONAL => {
1124            max_def_level += 1;
1125        }
1126        Repetition::REPEATED => {
1127            max_def_level += 1;
1128            max_rep_level += 1;
1129        }
1130        _ => {}
1131    }
1132
1133    match tp.as_ref() {
1134        Type::PrimitiveType { .. } => {
1135            let mut path: Vec<String> = vec![];
1136            path.extend(path_so_far.iter().copied().map(String::from));
1137            leaves.push(Arc::new(ColumnDescriptor::new(
1138                tp.clone(),
1139                max_def_level,
1140                max_rep_level,
1141                ColumnPath::new(path),
1142            )));
1143            leaf_to_base.push(root_idx);
1144        }
1145        Type::GroupType { ref fields, .. } => {
1146            for f in fields {
1147                build_tree(
1148                    f,
1149                    root_idx,
1150                    max_rep_level,
1151                    max_def_level,
1152                    leaves,
1153                    leaf_to_base,
1154                    path_so_far,
1155                );
1156                path_so_far.pop();
1157            }
1158        }
1159    }
1160}
1161
1162/// Method to convert from Thrift.
1163pub fn from_thrift(elements: &[SchemaElement]) -> Result<TypePtr> {
1164    let mut index = 0;
1165    let mut schema_nodes = Vec::new();
1166    while index < elements.len() {
1167        let t = from_thrift_helper(elements, index)?;
1168        index = t.0;
1169        schema_nodes.push(t.1);
1170    }
1171    if schema_nodes.len() != 1 {
1172        return Err(general_err!(
1173            "Expected exactly one root node, but found {}",
1174            schema_nodes.len()
1175        ));
1176    }
1177
1178    if !schema_nodes[0].is_group() {
1179        return Err(general_err!("Expected root node to be a group type"));
1180    }
1181
1182    Ok(schema_nodes.remove(0))
1183}
1184
1185/// Checks if the logical type is valid.
1186fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
1187    if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
1188        if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 {
1189            return Err(general_err!(
1190                "Bit width must be 8, 16, 32, or 64 for Integer logical type"
1191            ));
1192        }
1193    }
1194    Ok(())
1195}
1196
1197/// Constructs a new Type from the `elements`, starting at index `index`.
1198/// The first result is the starting index for the next Type after this one. If it is
1199/// equal to `elements.len()`, then this Type is the last one.
1200/// The second result is the result Type.
1201fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize, TypePtr)> {
1202    // Whether or not the current node is root (message type).
1203    // There is only one message type node in the schema tree.
1204    let is_root_node = index == 0;
1205
1206    if index >= elements.len() {
1207        return Err(general_err!(
1208            "Index out of bound, index = {}, len = {}",
1209            index,
1210            elements.len()
1211        ));
1212    }
1213    let element = &elements[index];
1214
1215    // Check for empty schema
1216    if let (true, None | Some(0)) = (is_root_node, element.num_children) {
1217        let builder = Type::group_type_builder(&element.name);
1218        return Ok((index + 1, Arc::new(builder.build().unwrap())));
1219    }
1220
1221    let converted_type = ConvertedType::try_from(element.converted_type)?;
1222    // LogicalType is only present in v2 Parquet files. ConvertedType is always
1223    // populated, regardless of the version of the file (v1 or v2).
1224    let logical_type = element
1225        .logical_type
1226        .as_ref()
1227        .map(|value| LogicalType::from(value.clone()));
1228
1229    check_logical_type(&logical_type)?;
1230
1231    let field_id = elements[index].field_id;
1232    match elements[index].num_children {
1233        // From parquet-format:
1234        //   The children count is used to construct the nested relationship.
1235        //   This field is not set when the element is a primitive type
1236        // Sometimes parquet-cpp sets num_children field to 0 for primitive types, so we
1237        // have to handle this case too.
1238        None | Some(0) => {
1239            // primitive type
1240            if elements[index].repetition_type.is_none() {
1241                return Err(general_err!(
1242                    "Repetition level must be defined for a primitive type"
1243                ));
1244            }
1245            let repetition = Repetition::try_from(elements[index].repetition_type.unwrap())?;
1246            if let Some(type_) = elements[index].type_ {
1247                let physical_type = PhysicalType::try_from(type_)?;
1248                let length = elements[index].type_length.unwrap_or(-1);
1249                let scale = elements[index].scale.unwrap_or(-1);
1250                let precision = elements[index].precision.unwrap_or(-1);
1251                let name = &elements[index].name;
1252                let builder = Type::primitive_type_builder(name, physical_type)
1253                    .with_repetition(repetition)
1254                    .with_converted_type(converted_type)
1255                    .with_logical_type(logical_type)
1256                    .with_length(length)
1257                    .with_precision(precision)
1258                    .with_scale(scale)
1259                    .with_id(field_id);
1260                Ok((index + 1, Arc::new(builder.build()?)))
1261            } else {
1262                let mut builder = Type::group_type_builder(&elements[index].name)
1263                    .with_converted_type(converted_type)
1264                    .with_logical_type(logical_type)
1265                    .with_id(field_id);
1266                if !is_root_node {
1267                    // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or
1268                    // REPEATED for root node.
1269                    //
1270                    // We only set repetition for group types that are not top-level message
1271                    // type. According to parquet-format:
1272                    //   Root of the schema does not have a repetition_type.
1273                    //   All other types must have one.
1274                    builder = builder.with_repetition(repetition);
1275                }
1276                Ok((index + 1, Arc::new(builder.build().unwrap())))
1277            }
1278        }
1279        Some(n) => {
1280            let repetition = elements[index]
1281                .repetition_type
1282                .map(Repetition::try_from)
1283                .transpose()?;
1284
1285            let mut fields = vec![];
1286            let mut next_index = index + 1;
1287            for _ in 0..n {
1288                let child_result = from_thrift_helper(elements, next_index)?;
1289                next_index = child_result.0;
1290                fields.push(child_result.1);
1291            }
1292
1293            let mut builder = Type::group_type_builder(&elements[index].name)
1294                .with_converted_type(converted_type)
1295                .with_logical_type(logical_type)
1296                .with_fields(fields)
1297                .with_id(field_id);
1298            if let Some(rep) = repetition {
1299                // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or
1300                // REPEATED for root node.
1301                //
1302                // We only set repetition for group types that are not top-level message
1303                // type. According to parquet-format:
1304                //   Root of the schema does not have a repetition_type.
1305                //   All other types must have one.
1306                if !is_root_node {
1307                    builder = builder.with_repetition(rep);
1308                }
1309            }
1310            Ok((next_index, Arc::new(builder.build().unwrap())))
1311        }
1312    }
1313}
1314
1315/// Method to convert to Thrift.
1316pub fn to_thrift(schema: &Type) -> Result<Vec<SchemaElement>> {
1317    if !schema.is_group() {
1318        return Err(general_err!("Root schema must be Group type"));
1319    }
1320    let mut elements: Vec<SchemaElement> = Vec::new();
1321    to_thrift_helper(schema, &mut elements);
1322    Ok(elements)
1323}
1324
1325/// Constructs list of `SchemaElement` from the schema using depth-first traversal.
1326/// Here we assume that schema is always valid and starts with group type.
1327fn to_thrift_helper(schema: &Type, elements: &mut Vec<SchemaElement>) {
1328    match *schema {
1329        Type::PrimitiveType {
1330            ref basic_info,
1331            physical_type,
1332            type_length,
1333            scale,
1334            precision,
1335        } => {
1336            let element = SchemaElement {
1337                type_: Some(physical_type.into()),
1338                type_length: if type_length >= 0 {
1339                    Some(type_length)
1340                } else {
1341                    None
1342                },
1343                repetition_type: Some(basic_info.repetition().into()),
1344                name: basic_info.name().to_owned(),
1345                num_children: None,
1346                converted_type: basic_info.converted_type().into(),
1347                scale: if scale >= 0 { Some(scale) } else { None },
1348                precision: if precision >= 0 {
1349                    Some(precision)
1350                } else {
1351                    None
1352                },
1353                field_id: if basic_info.has_id() {
1354                    Some(basic_info.id())
1355                } else {
1356                    None
1357                },
1358                logical_type: basic_info.logical_type().map(|value| value.into()),
1359            };
1360
1361            elements.push(element);
1362        }
1363        Type::GroupType {
1364            ref basic_info,
1365            ref fields,
1366        } => {
1367            let repetition = if basic_info.has_repetition() {
1368                Some(basic_info.repetition().into())
1369            } else {
1370                None
1371            };
1372
1373            let element = SchemaElement {
1374                type_: None,
1375                type_length: None,
1376                repetition_type: repetition,
1377                name: basic_info.name().to_owned(),
1378                num_children: Some(fields.len() as i32),
1379                converted_type: basic_info.converted_type().into(),
1380                scale: None,
1381                precision: None,
1382                field_id: if basic_info.has_id() {
1383                    Some(basic_info.id())
1384                } else {
1385                    None
1386                },
1387                logical_type: basic_info.logical_type().map(|value| value.into()),
1388            };
1389
1390            elements.push(element);
1391
1392            // Add child elements for a group
1393            for field in fields {
1394                to_thrift_helper(field, elements);
1395            }
1396        }
1397    }
1398}
1399
1400#[cfg(test)]
1401mod tests {
1402    use super::*;
1403
1404    use crate::schema::parser::parse_message_type;
1405
1406    // TODO: add tests for v2 types
1407
1408    #[test]
1409    fn test_primitive_type() {
1410        let mut result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1411            .with_logical_type(Some(LogicalType::Integer {
1412                bit_width: 32,
1413                is_signed: true,
1414            }))
1415            .with_id(Some(0))
1416            .build();
1417        assert!(result.is_ok());
1418
1419        if let Ok(tp) = result {
1420            assert!(tp.is_primitive());
1421            assert!(!tp.is_group());
1422            let basic_info = tp.get_basic_info();
1423            assert_eq!(basic_info.repetition(), Repetition::OPTIONAL);
1424            assert_eq!(
1425                basic_info.logical_type(),
1426                Some(LogicalType::Integer {
1427                    bit_width: 32,
1428                    is_signed: true
1429                })
1430            );
1431            assert_eq!(basic_info.converted_type(), ConvertedType::INT_32);
1432            assert_eq!(basic_info.id(), 0);
1433            match tp {
1434                Type::PrimitiveType { physical_type, .. } => {
1435                    assert_eq!(physical_type, PhysicalType::INT32);
1436                }
1437                _ => panic!(),
1438            }
1439        }
1440
1441        // Test illegal inputs with logical type
1442        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1443            .with_repetition(Repetition::REPEATED)
1444            .with_logical_type(Some(LogicalType::Integer {
1445                is_signed: true,
1446                bit_width: 8,
1447            }))
1448            .build();
1449        assert!(result.is_err());
1450        if let Err(e) = result {
1451            assert_eq!(
1452                format!("{e}"),
1453                "Parquet error: Cannot annotate Integer { bit_width: 8, is_signed: true } from INT64 for field 'foo'"
1454            );
1455        }
1456
1457        // Test illegal inputs with converted type
1458        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1459            .with_repetition(Repetition::REPEATED)
1460            .with_converted_type(ConvertedType::BSON)
1461            .build();
1462        assert!(result.is_err());
1463        if let Err(e) = result {
1464            assert_eq!(
1465                format!("{e}"),
1466                "Parquet error: BSON cannot annotate field 'foo' because it is not a BYTE_ARRAY field"
1467            );
1468        }
1469
1470        result = Type::primitive_type_builder("foo", PhysicalType::INT96)
1471            .with_repetition(Repetition::REQUIRED)
1472            .with_converted_type(ConvertedType::DECIMAL)
1473            .with_precision(-1)
1474            .with_scale(-1)
1475            .build();
1476        assert!(result.is_err());
1477        if let Err(e) = result {
1478            assert_eq!(
1479                format!("{e}"),
1480                "Parquet error: DECIMAL can only annotate INT32, INT64, BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY"
1481            );
1482        }
1483
1484        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1485            .with_repetition(Repetition::REQUIRED)
1486            .with_logical_type(Some(LogicalType::Decimal {
1487                scale: 32,
1488                precision: 12,
1489            }))
1490            .with_precision(-1)
1491            .with_scale(-1)
1492            .build();
1493        assert!(result.is_err());
1494        if let Err(e) = result {
1495            assert_eq!(
1496                format!("{e}"),
1497                "Parquet error: DECIMAL logical type scale 32 must match self.scale -1 for field 'foo'"
1498            );
1499        }
1500
1501        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1502            .with_repetition(Repetition::REQUIRED)
1503            .with_converted_type(ConvertedType::DECIMAL)
1504            .with_precision(-1)
1505            .with_scale(-1)
1506            .build();
1507        assert!(result.is_err());
1508        if let Err(e) = result {
1509            assert_eq!(
1510                format!("{e}"),
1511                "Parquet error: Invalid DECIMAL precision: -1"
1512            );
1513        }
1514
1515        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1516            .with_repetition(Repetition::REQUIRED)
1517            .with_converted_type(ConvertedType::DECIMAL)
1518            .with_precision(0)
1519            .with_scale(-1)
1520            .build();
1521        assert!(result.is_err());
1522        if let Err(e) = result {
1523            assert_eq!(
1524                format!("{e}"),
1525                "Parquet error: Invalid DECIMAL precision: 0"
1526            );
1527        }
1528
1529        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1530            .with_repetition(Repetition::REQUIRED)
1531            .with_converted_type(ConvertedType::DECIMAL)
1532            .with_precision(1)
1533            .with_scale(-1)
1534            .build();
1535        assert!(result.is_err());
1536        if let Err(e) = result {
1537            assert_eq!(format!("{e}"), "Parquet error: Invalid DECIMAL scale: -1");
1538        }
1539
1540        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1541            .with_repetition(Repetition::REQUIRED)
1542            .with_converted_type(ConvertedType::DECIMAL)
1543            .with_precision(1)
1544            .with_scale(2)
1545            .build();
1546        assert!(result.is_err());
1547        if let Err(e) = result {
1548            assert_eq!(
1549                format!("{e}"),
1550                "Parquet error: Invalid DECIMAL: scale (2) cannot be greater than precision (1)"
1551            );
1552        }
1553
1554        // It is OK if precision == scale
1555        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1556            .with_repetition(Repetition::REQUIRED)
1557            .with_converted_type(ConvertedType::DECIMAL)
1558            .with_precision(1)
1559            .with_scale(1)
1560            .build();
1561        assert!(result.is_ok());
1562
1563        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1564            .with_repetition(Repetition::REQUIRED)
1565            .with_converted_type(ConvertedType::DECIMAL)
1566            .with_precision(18)
1567            .with_scale(2)
1568            .build();
1569        assert!(result.is_err());
1570        if let Err(e) = result {
1571            assert_eq!(
1572                format!("{e}"),
1573                "Parquet error: Cannot represent INT32 as DECIMAL with precision 18"
1574            );
1575        }
1576
1577        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1578            .with_repetition(Repetition::REQUIRED)
1579            .with_converted_type(ConvertedType::DECIMAL)
1580            .with_precision(32)
1581            .with_scale(2)
1582            .build();
1583        assert!(result.is_err());
1584        if let Err(e) = result {
1585            assert_eq!(
1586                format!("{e}"),
1587                "Parquet error: Cannot represent INT64 as DECIMAL with precision 32"
1588            );
1589        }
1590
1591        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1592            .with_repetition(Repetition::REQUIRED)
1593            .with_converted_type(ConvertedType::DECIMAL)
1594            .with_length(5)
1595            .with_precision(12)
1596            .with_scale(2)
1597            .build();
1598        assert!(result.is_err());
1599        if let Err(e) = result {
1600            assert_eq!(
1601                format!("{e}"),
1602                "Parquet error: Cannot represent FIXED_LEN_BYTE_ARRAY as DECIMAL with length 5 and precision 12. The max precision can only be 11"
1603            );
1604        }
1605
1606        result = Type::primitive_type_builder("foo", PhysicalType::INT64)
1607            .with_repetition(Repetition::REQUIRED)
1608            .with_converted_type(ConvertedType::UINT_8)
1609            .build();
1610        assert!(result.is_err());
1611        if let Err(e) = result {
1612            assert_eq!(
1613                format!("{e}"),
1614                "Parquet error: UINT_8 cannot annotate field 'foo' because it is not a INT32 field"
1615            );
1616        }
1617
1618        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1619            .with_repetition(Repetition::REQUIRED)
1620            .with_converted_type(ConvertedType::TIME_MICROS)
1621            .build();
1622        assert!(result.is_err());
1623        if let Err(e) = result {
1624            assert_eq!(
1625                format!("{e}"),
1626                "Parquet error: TIME_MICROS cannot annotate field 'foo' because it is not a INT64 field"
1627            );
1628        }
1629
1630        result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY)
1631            .with_repetition(Repetition::REQUIRED)
1632            .with_converted_type(ConvertedType::INTERVAL)
1633            .build();
1634        assert!(result.is_err());
1635        if let Err(e) = result {
1636            assert_eq!(
1637                format!("{e}"),
1638                "Parquet error: INTERVAL cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(12) field"
1639            );
1640        }
1641
1642        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1643            .with_repetition(Repetition::REQUIRED)
1644            .with_converted_type(ConvertedType::INTERVAL)
1645            .with_length(1)
1646            .build();
1647        assert!(result.is_err());
1648        if let Err(e) = result {
1649            assert_eq!(
1650                format!("{e}"),
1651                "Parquet error: INTERVAL cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(12) field"
1652            );
1653        }
1654
1655        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1656            .with_repetition(Repetition::REQUIRED)
1657            .with_converted_type(ConvertedType::ENUM)
1658            .build();
1659        assert!(result.is_err());
1660        if let Err(e) = result {
1661            assert_eq!(
1662                format!("{e}"),
1663                "Parquet error: ENUM cannot annotate field 'foo' because it is not a BYTE_ARRAY field"
1664            );
1665        }
1666
1667        result = Type::primitive_type_builder("foo", PhysicalType::INT32)
1668            .with_repetition(Repetition::REQUIRED)
1669            .with_converted_type(ConvertedType::MAP)
1670            .build();
1671        assert!(result.is_err());
1672        if let Err(e) = result {
1673            assert_eq!(
1674                format!("{e}"),
1675                "Parquet error: MAP cannot be applied to primitive field 'foo'"
1676            );
1677        }
1678
1679        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1680            .with_repetition(Repetition::REQUIRED)
1681            .with_converted_type(ConvertedType::DECIMAL)
1682            .with_length(-1)
1683            .build();
1684        assert!(result.is_err());
1685        if let Err(e) = result {
1686            assert_eq!(
1687                format!("{e}"),
1688                "Parquet error: Invalid FIXED_LEN_BYTE_ARRAY length: -1 for field 'foo'"
1689            );
1690        }
1691
1692        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1693            .with_repetition(Repetition::REQUIRED)
1694            .with_logical_type(Some(LogicalType::Float16))
1695            .with_length(2)
1696            .build();
1697        assert!(result.is_ok());
1698
1699        // Can't be other than FIXED_LEN_BYTE_ARRAY for physical type
1700        result = Type::primitive_type_builder("foo", PhysicalType::FLOAT)
1701            .with_repetition(Repetition::REQUIRED)
1702            .with_logical_type(Some(LogicalType::Float16))
1703            .with_length(2)
1704            .build();
1705        assert!(result.is_err());
1706        if let Err(e) = result {
1707            assert_eq!(
1708                format!("{e}"),
1709                "Parquet error: Cannot annotate Float16 from FLOAT for field 'foo'"
1710            );
1711        }
1712
1713        // Must have length 2
1714        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1715            .with_repetition(Repetition::REQUIRED)
1716            .with_logical_type(Some(LogicalType::Float16))
1717            .with_length(4)
1718            .build();
1719        assert!(result.is_err());
1720        if let Err(e) = result {
1721            assert_eq!(
1722                format!("{e}"),
1723                "Parquet error: FLOAT16 cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(2) field"
1724            );
1725        }
1726
1727        // Must have length 16
1728        result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY)
1729            .with_repetition(Repetition::REQUIRED)
1730            .with_logical_type(Some(LogicalType::Uuid))
1731            .with_length(15)
1732            .build();
1733        assert!(result.is_err());
1734        if let Err(e) = result {
1735            assert_eq!(
1736                format!("{e}"),
1737                "Parquet error: UUID cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(16) field"
1738            );
1739        }
1740    }
1741
1742    #[test]
1743    fn test_group_type() {
1744        let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32)
1745            .with_converted_type(ConvertedType::INT_32)
1746            .with_id(Some(0))
1747            .build();
1748        assert!(f1.is_ok());
1749        let f2 = Type::primitive_type_builder("f2", PhysicalType::BYTE_ARRAY)
1750            .with_converted_type(ConvertedType::UTF8)
1751            .with_id(Some(1))
1752            .build();
1753        assert!(f2.is_ok());
1754
1755        let fields = vec![Arc::new(f1.unwrap()), Arc::new(f2.unwrap())];
1756
1757        let result = Type::group_type_builder("foo")
1758            .with_repetition(Repetition::REPEATED)
1759            .with_logical_type(Some(LogicalType::List))
1760            .with_fields(fields)
1761            .with_id(Some(1))
1762            .build();
1763        assert!(result.is_ok());
1764
1765        let tp = result.unwrap();
1766        let basic_info = tp.get_basic_info();
1767        assert!(tp.is_group());
1768        assert!(!tp.is_primitive());
1769        assert_eq!(basic_info.repetition(), Repetition::REPEATED);
1770        assert_eq!(basic_info.logical_type(), Some(LogicalType::List));
1771        assert_eq!(basic_info.converted_type(), ConvertedType::LIST);
1772        assert_eq!(basic_info.id(), 1);
1773        assert_eq!(tp.get_fields().len(), 2);
1774        assert_eq!(tp.get_fields()[0].name(), "f1");
1775        assert_eq!(tp.get_fields()[1].name(), "f2");
1776    }
1777
1778    #[test]
1779    fn test_column_descriptor() {
1780        let result = test_column_descriptor_helper();
1781        assert!(
1782            result.is_ok(),
1783            "Expected result to be OK but got err:\n {}",
1784            result.unwrap_err()
1785        );
1786    }
1787
1788    fn test_column_descriptor_helper() -> Result<()> {
1789        let tp = Type::primitive_type_builder("name", PhysicalType::BYTE_ARRAY)
1790            .with_converted_type(ConvertedType::UTF8)
1791            .build()?;
1792
1793        let descr = ColumnDescriptor::new(Arc::new(tp), 4, 1, ColumnPath::from("name"));
1794
1795        assert_eq!(descr.path(), &ColumnPath::from("name"));
1796        assert_eq!(descr.converted_type(), ConvertedType::UTF8);
1797        assert_eq!(descr.physical_type(), PhysicalType::BYTE_ARRAY);
1798        assert_eq!(descr.max_def_level(), 4);
1799        assert_eq!(descr.max_rep_level(), 1);
1800        assert_eq!(descr.name(), "name");
1801        assert_eq!(descr.type_length(), -1);
1802        assert_eq!(descr.type_precision(), -1);
1803        assert_eq!(descr.type_scale(), -1);
1804
1805        Ok(())
1806    }
1807
1808    #[test]
1809    fn test_schema_descriptor() {
1810        let result = test_schema_descriptor_helper();
1811        assert!(
1812            result.is_ok(),
1813            "Expected result to be OK but got err:\n {}",
1814            result.unwrap_err()
1815        );
1816    }
1817
1818    // A helper fn to avoid handling the results from type creation
1819    fn test_schema_descriptor_helper() -> Result<()> {
1820        let mut fields = vec![];
1821
1822        let inta = Type::primitive_type_builder("a", PhysicalType::INT32)
1823            .with_repetition(Repetition::REQUIRED)
1824            .with_converted_type(ConvertedType::INT_32)
1825            .build()?;
1826        fields.push(Arc::new(inta));
1827        let intb = Type::primitive_type_builder("b", PhysicalType::INT64)
1828            .with_converted_type(ConvertedType::INT_64)
1829            .build()?;
1830        fields.push(Arc::new(intb));
1831        let intc = Type::primitive_type_builder("c", PhysicalType::BYTE_ARRAY)
1832            .with_repetition(Repetition::REPEATED)
1833            .with_converted_type(ConvertedType::UTF8)
1834            .build()?;
1835        fields.push(Arc::new(intc));
1836
1837        // 3-level list encoding
1838        let item1 = Type::primitive_type_builder("item1", PhysicalType::INT64)
1839            .with_repetition(Repetition::REQUIRED)
1840            .with_converted_type(ConvertedType::INT_64)
1841            .build()?;
1842        let item2 = Type::primitive_type_builder("item2", PhysicalType::BOOLEAN).build()?;
1843        let item3 = Type::primitive_type_builder("item3", PhysicalType::INT32)
1844            .with_repetition(Repetition::REPEATED)
1845            .with_converted_type(ConvertedType::INT_32)
1846            .build()?;
1847        let list = Type::group_type_builder("records")
1848            .with_repetition(Repetition::REPEATED)
1849            .with_converted_type(ConvertedType::LIST)
1850            .with_fields(vec![Arc::new(item1), Arc::new(item2), Arc::new(item3)])
1851            .build()?;
1852        let bag = Type::group_type_builder("bag")
1853            .with_repetition(Repetition::OPTIONAL)
1854            .with_fields(vec![Arc::new(list)])
1855            .build()?;
1856        fields.push(Arc::new(bag));
1857
1858        let schema = Type::group_type_builder("schema")
1859            .with_repetition(Repetition::REPEATED)
1860            .with_fields(fields)
1861            .build()?;
1862        let descr = SchemaDescriptor::new(Arc::new(schema));
1863
1864        let nleaves = 6;
1865        assert_eq!(descr.num_columns(), nleaves);
1866
1867        //                             mdef mrep
1868        // required int32 a            0    0
1869        // optional int64 b            1    0
1870        // repeated byte_array c       1    1
1871        // optional group bag          1    0
1872        //   repeated group records    2    1
1873        //     required int64 item1    2    1
1874        //     optional boolean item2  3    1
1875        //     repeated int32 item3    3    2
1876        let ex_max_def_levels = [0, 1, 1, 2, 3, 3];
1877        let ex_max_rep_levels = [0, 0, 1, 1, 1, 2];
1878
1879        for i in 0..nleaves {
1880            let col = descr.column(i);
1881            assert_eq!(col.max_def_level(), ex_max_def_levels[i], "{i}");
1882            assert_eq!(col.max_rep_level(), ex_max_rep_levels[i], "{i}");
1883        }
1884
1885        assert_eq!(descr.column(0).path().string(), "a");
1886        assert_eq!(descr.column(1).path().string(), "b");
1887        assert_eq!(descr.column(2).path().string(), "c");
1888        assert_eq!(descr.column(3).path().string(), "bag.records.item1");
1889        assert_eq!(descr.column(4).path().string(), "bag.records.item2");
1890        assert_eq!(descr.column(5).path().string(), "bag.records.item3");
1891
1892        assert_eq!(descr.get_column_root(0).name(), "a");
1893        assert_eq!(descr.get_column_root(3).name(), "bag");
1894        assert_eq!(descr.get_column_root(4).name(), "bag");
1895        assert_eq!(descr.get_column_root(5).name(), "bag");
1896
1897        Ok(())
1898    }
1899
1900    #[test]
1901    fn test_schema_build_tree_def_rep_levels() {
1902        let message_type = "
1903    message spark_schema {
1904      REQUIRED INT32 a;
1905      OPTIONAL group b {
1906        OPTIONAL INT32 _1;
1907        OPTIONAL INT32 _2;
1908      }
1909      OPTIONAL group c (LIST) {
1910        REPEATED group list {
1911          OPTIONAL INT32 element;
1912        }
1913      }
1914    }
1915    ";
1916        let schema = parse_message_type(message_type).expect("should parse schema");
1917        let descr = SchemaDescriptor::new(Arc::new(schema));
1918        // required int32 a
1919        assert_eq!(descr.column(0).max_def_level(), 0);
1920        assert_eq!(descr.column(0).max_rep_level(), 0);
1921        // optional int32 b._1
1922        assert_eq!(descr.column(1).max_def_level(), 2);
1923        assert_eq!(descr.column(1).max_rep_level(), 0);
1924        // optional int32 b._2
1925        assert_eq!(descr.column(2).max_def_level(), 2);
1926        assert_eq!(descr.column(2).max_rep_level(), 0);
1927        // repeated optional int32 c.list.element
1928        assert_eq!(descr.column(3).max_def_level(), 3);
1929        assert_eq!(descr.column(3).max_rep_level(), 1);
1930    }
1931
1932    #[test]
1933    #[should_panic(expected = "Cannot call get_physical_type() on a non-primitive type")]
1934    fn test_get_physical_type_panic() {
1935        let list = Type::group_type_builder("records")
1936            .with_repetition(Repetition::REPEATED)
1937            .build()
1938            .unwrap();
1939        list.get_physical_type();
1940    }
1941
1942    #[test]
1943    fn test_get_physical_type_primitive() {
1944        let f = Type::primitive_type_builder("f", PhysicalType::INT64)
1945            .build()
1946            .unwrap();
1947        assert_eq!(f.get_physical_type(), PhysicalType::INT64);
1948
1949        let f = Type::primitive_type_builder("f", PhysicalType::BYTE_ARRAY)
1950            .build()
1951            .unwrap();
1952        assert_eq!(f.get_physical_type(), PhysicalType::BYTE_ARRAY);
1953    }
1954
1955    #[test]
1956    fn test_check_contains_primitive_primitive() {
1957        // OK
1958        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1959            .build()
1960            .unwrap();
1961        let f2 = Type::primitive_type_builder("f", PhysicalType::INT32)
1962            .build()
1963            .unwrap();
1964        assert!(f1.check_contains(&f2));
1965
1966        // OK: different logical type does not affect check_contains
1967        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1968            .with_converted_type(ConvertedType::UINT_8)
1969            .build()
1970            .unwrap();
1971        let f2 = Type::primitive_type_builder("f", PhysicalType::INT32)
1972            .with_converted_type(ConvertedType::UINT_16)
1973            .build()
1974            .unwrap();
1975        assert!(f1.check_contains(&f2));
1976
1977        // KO: different name
1978        let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32)
1979            .build()
1980            .unwrap();
1981        let f2 = Type::primitive_type_builder("f2", PhysicalType::INT32)
1982            .build()
1983            .unwrap();
1984        assert!(!f1.check_contains(&f2));
1985
1986        // KO: different type
1987        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1988            .build()
1989            .unwrap();
1990        let f2 = Type::primitive_type_builder("f", PhysicalType::INT64)
1991            .build()
1992            .unwrap();
1993        assert!(!f1.check_contains(&f2));
1994
1995        // KO: different repetition
1996        let f1 = Type::primitive_type_builder("f", PhysicalType::INT32)
1997            .with_repetition(Repetition::REQUIRED)
1998            .build()
1999            .unwrap();
2000        let f2 = Type::primitive_type_builder("f", PhysicalType::INT32)
2001            .with_repetition(Repetition::OPTIONAL)
2002            .build()
2003            .unwrap();
2004        assert!(!f1.check_contains(&f2));
2005    }
2006
2007    // function to create a new group type for testing
2008    fn test_new_group_type(name: &str, repetition: Repetition, types: Vec<Type>) -> Type {
2009        Type::group_type_builder(name)
2010            .with_repetition(repetition)
2011            .with_fields(types.into_iter().map(Arc::new).collect())
2012            .build()
2013            .unwrap()
2014    }
2015
2016    #[test]
2017    fn test_check_contains_group_group() {
2018        // OK: should match okay with empty fields
2019        let f1 = Type::group_type_builder("f").build().unwrap();
2020        let f2 = Type::group_type_builder("f").build().unwrap();
2021        assert!(f1.check_contains(&f2));
2022        assert!(!f1.is_optional());
2023
2024        // OK: fields match
2025        let f1 = test_new_group_type(
2026            "f",
2027            Repetition::REPEATED,
2028            vec![
2029                Type::primitive_type_builder("f1", PhysicalType::INT32)
2030                    .build()
2031                    .unwrap(),
2032                Type::primitive_type_builder("f2", PhysicalType::INT64)
2033                    .build()
2034                    .unwrap(),
2035            ],
2036        );
2037        let f2 = test_new_group_type(
2038            "f",
2039            Repetition::REPEATED,
2040            vec![
2041                Type::primitive_type_builder("f1", PhysicalType::INT32)
2042                    .build()
2043                    .unwrap(),
2044                Type::primitive_type_builder("f2", PhysicalType::INT64)
2045                    .build()
2046                    .unwrap(),
2047            ],
2048        );
2049        assert!(f1.check_contains(&f2));
2050
2051        // OK: subset of fields
2052        let f1 = test_new_group_type(
2053            "f",
2054            Repetition::REPEATED,
2055            vec![
2056                Type::primitive_type_builder("f1", PhysicalType::INT32)
2057                    .build()
2058                    .unwrap(),
2059                Type::primitive_type_builder("f2", PhysicalType::INT64)
2060                    .build()
2061                    .unwrap(),
2062            ],
2063        );
2064        let f2 = test_new_group_type(
2065            "f",
2066            Repetition::REPEATED,
2067            vec![Type::primitive_type_builder("f2", PhysicalType::INT64)
2068                .build()
2069                .unwrap()],
2070        );
2071        assert!(f1.check_contains(&f2));
2072
2073        // KO: different name
2074        let f1 = Type::group_type_builder("f1").build().unwrap();
2075        let f2 = Type::group_type_builder("f2").build().unwrap();
2076        assert!(!f1.check_contains(&f2));
2077
2078        // KO: different repetition
2079        let f1 = Type::group_type_builder("f")
2080            .with_repetition(Repetition::OPTIONAL)
2081            .build()
2082            .unwrap();
2083        let f2 = Type::group_type_builder("f")
2084            .with_repetition(Repetition::REPEATED)
2085            .build()
2086            .unwrap();
2087        assert!(!f1.check_contains(&f2));
2088
2089        // KO: different fields
2090        let f1 = test_new_group_type(
2091            "f",
2092            Repetition::REPEATED,
2093            vec![
2094                Type::primitive_type_builder("f1", PhysicalType::INT32)
2095                    .build()
2096                    .unwrap(),
2097                Type::primitive_type_builder("f2", PhysicalType::INT64)
2098                    .build()
2099                    .unwrap(),
2100            ],
2101        );
2102        let f2 = test_new_group_type(
2103            "f",
2104            Repetition::REPEATED,
2105            vec![
2106                Type::primitive_type_builder("f1", PhysicalType::INT32)
2107                    .build()
2108                    .unwrap(),
2109                Type::primitive_type_builder("f2", PhysicalType::BOOLEAN)
2110                    .build()
2111                    .unwrap(),
2112            ],
2113        );
2114        assert!(!f1.check_contains(&f2));
2115
2116        // KO: different fields
2117        let f1 = test_new_group_type(
2118            "f",
2119            Repetition::REPEATED,
2120            vec![
2121                Type::primitive_type_builder("f1", PhysicalType::INT32)
2122                    .build()
2123                    .unwrap(),
2124                Type::primitive_type_builder("f2", PhysicalType::INT64)
2125                    .build()
2126                    .unwrap(),
2127            ],
2128        );
2129        let f2 = test_new_group_type(
2130            "f",
2131            Repetition::REPEATED,
2132            vec![Type::primitive_type_builder("f3", PhysicalType::INT32)
2133                .build()
2134                .unwrap()],
2135        );
2136        assert!(!f1.check_contains(&f2));
2137    }
2138
2139    #[test]
2140    fn test_check_contains_group_primitive() {
2141        // KO: should not match
2142        let f1 = Type::group_type_builder("f").build().unwrap();
2143        let f2 = Type::primitive_type_builder("f", PhysicalType::INT64)
2144            .build()
2145            .unwrap();
2146        assert!(!f1.check_contains(&f2));
2147        assert!(!f2.check_contains(&f1));
2148
2149        // KO: should not match when primitive field is part of group type
2150        let f1 = test_new_group_type(
2151            "f",
2152            Repetition::REPEATED,
2153            vec![Type::primitive_type_builder("f1", PhysicalType::INT32)
2154                .build()
2155                .unwrap()],
2156        );
2157        let f2 = Type::primitive_type_builder("f1", PhysicalType::INT32)
2158            .build()
2159            .unwrap();
2160        assert!(!f1.check_contains(&f2));
2161        assert!(!f2.check_contains(&f1));
2162
2163        // OK: match nested types
2164        let f1 = test_new_group_type(
2165            "a",
2166            Repetition::REPEATED,
2167            vec![
2168                test_new_group_type(
2169                    "b",
2170                    Repetition::REPEATED,
2171                    vec![Type::primitive_type_builder("c", PhysicalType::INT32)
2172                        .build()
2173                        .unwrap()],
2174                ),
2175                Type::primitive_type_builder("d", PhysicalType::INT64)
2176                    .build()
2177                    .unwrap(),
2178                Type::primitive_type_builder("e", PhysicalType::BOOLEAN)
2179                    .build()
2180                    .unwrap(),
2181            ],
2182        );
2183        let f2 = test_new_group_type(
2184            "a",
2185            Repetition::REPEATED,
2186            vec![test_new_group_type(
2187                "b",
2188                Repetition::REPEATED,
2189                vec![Type::primitive_type_builder("c", PhysicalType::INT32)
2190                    .build()
2191                    .unwrap()],
2192            )],
2193        );
2194        assert!(f1.check_contains(&f2)); // should match
2195        assert!(!f2.check_contains(&f1)); // should fail
2196    }
2197
2198    #[test]
2199    fn test_schema_type_thrift_conversion_err() {
2200        let schema = Type::primitive_type_builder("col", PhysicalType::INT32)
2201            .build()
2202            .unwrap();
2203        let thrift_schema = to_thrift(&schema);
2204        assert!(thrift_schema.is_err());
2205        if let Err(e) = thrift_schema {
2206            assert_eq!(
2207                format!("{e}"),
2208                "Parquet error: Root schema must be Group type"
2209            );
2210        }
2211    }
2212
2213    #[test]
2214    fn test_schema_type_thrift_conversion() {
2215        let message_type = "
2216    message conversions {
2217      REQUIRED INT64 id;
2218      OPTIONAL FIXED_LEN_BYTE_ARRAY (2) f16 (FLOAT16);
2219      OPTIONAL group int_array_Array (LIST) {
2220        REPEATED group list {
2221          OPTIONAL group element (LIST) {
2222            REPEATED group list {
2223              OPTIONAL INT32 element;
2224            }
2225          }
2226        }
2227      }
2228      OPTIONAL group int_map (MAP) {
2229        REPEATED group map (MAP_KEY_VALUE) {
2230          REQUIRED BYTE_ARRAY key (UTF8);
2231          OPTIONAL INT32 value;
2232        }
2233      }
2234      OPTIONAL group int_Map_Array (LIST) {
2235        REPEATED group list {
2236          OPTIONAL group g (MAP) {
2237            REPEATED group map (MAP_KEY_VALUE) {
2238              REQUIRED BYTE_ARRAY key (UTF8);
2239              OPTIONAL group value {
2240                OPTIONAL group H {
2241                  OPTIONAL group i (LIST) {
2242                    REPEATED group list {
2243                      OPTIONAL DOUBLE element;
2244                    }
2245                  }
2246                }
2247              }
2248            }
2249          }
2250        }
2251      }
2252      OPTIONAL group nested_struct {
2253        OPTIONAL INT32 A;
2254        OPTIONAL group b (LIST) {
2255          REPEATED group list {
2256            REQUIRED FIXED_LEN_BYTE_ARRAY (16) element;
2257          }
2258        }
2259      }
2260    }
2261    ";
2262        let expected_schema = parse_message_type(message_type).unwrap();
2263        let thrift_schema = to_thrift(&expected_schema).unwrap();
2264        let result_schema = from_thrift(&thrift_schema).unwrap();
2265        assert_eq!(result_schema, Arc::new(expected_schema));
2266    }
2267
2268    #[test]
2269    fn test_schema_type_thrift_conversion_decimal() {
2270        let message_type = "
2271    message decimals {
2272      OPTIONAL INT32 field0;
2273      OPTIONAL INT64 field1 (DECIMAL (18, 2));
2274      OPTIONAL FIXED_LEN_BYTE_ARRAY (16) field2 (DECIMAL (38, 18));
2275      OPTIONAL BYTE_ARRAY field3 (DECIMAL (9));
2276    }
2277    ";
2278        let expected_schema = parse_message_type(message_type).unwrap();
2279        let thrift_schema = to_thrift(&expected_schema).unwrap();
2280        let result_schema = from_thrift(&thrift_schema).unwrap();
2281        assert_eq!(result_schema, Arc::new(expected_schema));
2282    }
2283
2284    // Tests schema conversion from thrift, when num_children is set to Some(0) for a
2285    // primitive type.
2286    #[test]
2287    fn test_schema_from_thrift_with_num_children_set() {
2288        // schema definition written by parquet-cpp version 1.3.2-SNAPSHOT
2289        let message_type = "
2290    message schema {
2291      OPTIONAL BYTE_ARRAY id (UTF8);
2292      OPTIONAL BYTE_ARRAY name (UTF8);
2293      OPTIONAL BYTE_ARRAY message (UTF8);
2294      OPTIONAL INT32 type (UINT_8);
2295      OPTIONAL INT64 author_time (TIMESTAMP_MILLIS);
2296      OPTIONAL INT64 __index_level_0__;
2297    }
2298    ";
2299
2300        let expected_schema = parse_message_type(message_type).unwrap();
2301        let mut thrift_schema = to_thrift(&expected_schema).unwrap();
2302        // Change all of None to Some(0)
2303        for elem in &mut thrift_schema[..] {
2304            if elem.num_children.is_none() {
2305                elem.num_children = Some(0);
2306            }
2307        }
2308
2309        let result_schema = from_thrift(&thrift_schema).unwrap();
2310        assert_eq!(result_schema, Arc::new(expected_schema));
2311    }
2312
2313    // Sometimes parquet-cpp sets repetition level for the root node, which is against
2314    // the format definition, but we need to handle it by setting it back to None.
2315    #[test]
2316    fn test_schema_from_thrift_root_has_repetition() {
2317        // schema definition written by parquet-cpp version 1.3.2-SNAPSHOT
2318        let message_type = "
2319    message schema {
2320      OPTIONAL BYTE_ARRAY a (UTF8);
2321      OPTIONAL INT32 b (UINT_8);
2322    }
2323    ";
2324
2325        let expected_schema = parse_message_type(message_type).unwrap();
2326        let mut thrift_schema = to_thrift(&expected_schema).unwrap();
2327        thrift_schema[0].repetition_type = Some(Repetition::REQUIRED.into());
2328
2329        let result_schema = from_thrift(&thrift_schema).unwrap();
2330        assert_eq!(result_schema, Arc::new(expected_schema));
2331    }
2332
2333    #[test]
2334    fn test_schema_from_thrift_group_has_no_child() {
2335        let message_type = "message schema {}";
2336
2337        let expected_schema = parse_message_type(message_type).unwrap();
2338        let mut thrift_schema = to_thrift(&expected_schema).unwrap();
2339        thrift_schema[0].repetition_type = Some(Repetition::REQUIRED.into());
2340
2341        let result_schema = from_thrift(&thrift_schema).unwrap();
2342        assert_eq!(result_schema, Arc::new(expected_schema));
2343    }
2344}