Skip to main content

parquet/arrow/schema/
primitive.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::basic::{
19    ConvertedType, IntType, LogicalType, TimeUnit as ParquetTimeUnit, Type as PhysicalType,
20};
21use crate::errors::{ParquetError, Result};
22use crate::schema::types::{BasicTypeInfo, Type};
23use arrow_schema::{DECIMAL128_MAX_PRECISION, DataType, IntervalUnit, TimeUnit};
24
25/// Converts [`Type`] to [`DataType`] with an optional `arrow_type_hint`
26/// provided by the arrow schema
27///
28/// Note: the values embedded in the schema are advisory,
29pub fn convert_primitive(
30    parquet_type: &Type,
31    arrow_type_hint: Option<DataType>,
32) -> Result<DataType> {
33    let physical_type = from_parquet(parquet_type)?;
34    Ok(match arrow_type_hint {
35        Some(hint) => apply_hint(physical_type, hint),
36        None => physical_type,
37    })
38}
39
40/// Uses an type hint from the embedded arrow schema to aid in faithfully
41/// reproducing the data as it was written into parquet
42fn apply_hint(parquet: DataType, hint: DataType) -> DataType {
43    match (&parquet, &hint) {
44        // Not all time units can be represented as LogicalType / ConvertedType
45        (DataType::Int32 | DataType::Int64, DataType::Timestamp(_, _)) => hint,
46        (DataType::Int32, DataType::Time32(_)) => hint,
47        (DataType::Int64, DataType::Time64(_)) => hint,
48        (DataType::Int64, DataType::Duration(_)) => hint,
49
50        // Date64 doesn't have a corresponding LogicalType / ConvertedType
51        (DataType::Int64, DataType::Date64) => hint,
52
53        // Coerce Date32 back to Date64 (#1666)
54        (DataType::Date32, DataType::Date64) => hint,
55
56        // Timestamps of the same resolution can be converted to a a different timezone.
57        (DataType::Timestamp(p, _), DataType::Timestamp(h, Some(_))) if p == h => hint,
58
59        // INT96 default to Timestamp(TimeUnit::Nanosecond, None) (see from_parquet below).
60        // Allow different resolutions to support larger date ranges.
61        (
62            DataType::Timestamp(TimeUnit::Nanosecond, None),
63            DataType::Timestamp(TimeUnit::Second, _),
64        ) => hint,
65        (
66            DataType::Timestamp(TimeUnit::Nanosecond, None),
67            DataType::Timestamp(TimeUnit::Millisecond, _),
68        ) => hint,
69        (
70            DataType::Timestamp(TimeUnit::Nanosecond, None),
71            DataType::Timestamp(TimeUnit::Microsecond, _),
72        ) => hint,
73
74        // Determine offset size
75        (DataType::Utf8, DataType::LargeUtf8) => hint,
76        (DataType::Binary, DataType::LargeBinary) => hint,
77
78        // Read as Utf8
79        (DataType::Binary, DataType::Utf8) => hint,
80        (DataType::Binary, DataType::LargeUtf8) => hint,
81        (DataType::Binary, DataType::Utf8View) => hint,
82
83        // Determine view type
84        (DataType::Utf8, DataType::Utf8View) => hint,
85        (DataType::Binary, DataType::BinaryView) => hint,
86
87        // Determine interval time unit (#1666)
88        (DataType::Interval(_), DataType::Interval(_)) => hint,
89
90        // Promote to Decimal256 or narrow to Decimal32 or Decimal64
91        (DataType::Decimal128(_, _), DataType::Decimal32(_, _)) => hint,
92        (DataType::Decimal128(_, _), DataType::Decimal64(_, _)) => hint,
93        (DataType::Decimal128(_, _), DataType::Decimal256(_, _)) => hint,
94
95        // Potentially preserve dictionary encoding
96        (_, DataType::Dictionary(_, value)) => {
97            // Apply hint to inner type
98            let hinted = apply_hint(parquet, value.as_ref().clone());
99
100            // If matches dictionary value - preserve dictionary
101            // otherwise use hinted inner type
102            match &hinted == value.as_ref() {
103                true => hint,
104                false => hinted,
105            }
106        }
107        _ => parquet,
108    }
109}
110
111fn from_parquet(parquet_type: &Type) -> Result<DataType> {
112    match parquet_type {
113        Type::PrimitiveType {
114            physical_type,
115            basic_info,
116            type_length,
117            scale,
118            precision,
119            ..
120        } => match basic_info.logical_type_ref() {
121            // Any physical type can have the UNKNOWN logical type annotation. Check for that first.
122            // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#unknown-always-null
123            Some(&LogicalType::Unknown) => Ok(DataType::Null),
124            _ => match physical_type {
125                PhysicalType::BOOLEAN => Ok(DataType::Boolean),
126                PhysicalType::INT32 => from_int32(basic_info, *scale, *precision),
127                PhysicalType::INT64 => from_int64(basic_info, *scale, *precision),
128                PhysicalType::INT96 => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
129                PhysicalType::FLOAT => Ok(DataType::Float32),
130                PhysicalType::DOUBLE => Ok(DataType::Float64),
131                PhysicalType::BYTE_ARRAY => from_byte_array(basic_info, *precision, *scale),
132                PhysicalType::FIXED_LEN_BYTE_ARRAY => {
133                    from_fixed_len_byte_array(basic_info, *scale, *precision, *type_length)
134                }
135            },
136        },
137        Type::GroupType { .. } => unreachable!(),
138    }
139}
140
141fn decimal_type(scale: i32, precision: i32) -> Result<DataType> {
142    if precision <= DECIMAL128_MAX_PRECISION as i32 {
143        decimal_128_type(scale, precision)
144    } else {
145        decimal_256_type(scale, precision)
146    }
147}
148
149fn decimal_128_type(scale: i32, precision: i32) -> Result<DataType> {
150    let scale = scale
151        .try_into()
152        .map_err(|_| arrow_err!("scale cannot be negative: {}", scale))?;
153
154    let precision = precision
155        .try_into()
156        .map_err(|_| arrow_err!("precision cannot be negative: {}", precision))?;
157
158    Ok(DataType::Decimal128(precision, scale))
159}
160
161fn decimal_256_type(scale: i32, precision: i32) -> Result<DataType> {
162    let scale = scale
163        .try_into()
164        .map_err(|_| arrow_err!("scale cannot be negative: {}", scale))?;
165
166    let precision = precision
167        .try_into()
168        .map_err(|_| arrow_err!("precision cannot be negative: {}", precision))?;
169
170    Ok(DataType::Decimal256(precision, scale))
171}
172
173#[allow(clippy::manual_range_contains)]
174fn check_decimal_length(type_length: i32) -> Result<()> {
175    if type_length < 1 || type_length > 32 {
176        return Err(ParquetError::General(format!(
177            "DECIMAL must be a Fixed Length Byte Array with length 1 to 32, got {type_length}"
178        )));
179    }
180    Ok(())
181}
182
183fn from_int32(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result<DataType> {
184    match (info.logical_type_ref(), info.converted_type()) {
185        (None, ConvertedType::NONE) => Ok(DataType::Int32),
186        (Some(ref t @ LogicalType::Integer(int)), _) => match (int.bit_width, int.is_signed) {
187            (8, true) => Ok(DataType::Int8),
188            (16, true) => Ok(DataType::Int16),
189            (32, true) => Ok(DataType::Int32),
190            (8, false) => Ok(DataType::UInt8),
191            (16, false) => Ok(DataType::UInt16),
192            (32, false) => Ok(DataType::UInt32),
193            _ => Err(arrow_err!("Cannot create INT32 physical type from {:?}", t)),
194        },
195        (Some(LogicalType::Decimal(decimal)), _) => {
196            decimal_128_type(decimal.scale, decimal.precision)
197        }
198        (Some(LogicalType::Date), _) => Ok(DataType::Date32),
199        (Some(LogicalType::Time(time)), _) => match time.unit {
200            ParquetTimeUnit::MILLIS => Ok(DataType::Time32(TimeUnit::Millisecond)),
201            _ => Err(arrow_err!(
202                "Cannot create INT32 physical type from {:?}",
203                time.unit
204            )),
205        },
206        (None, ConvertedType::UINT_8) => Ok(DataType::UInt8),
207        (None, ConvertedType::UINT_16) => Ok(DataType::UInt16),
208        (None, ConvertedType::UINT_32) => Ok(DataType::UInt32),
209        (None, ConvertedType::INT_8) => Ok(DataType::Int8),
210        (None, ConvertedType::INT_16) => Ok(DataType::Int16),
211        (None, ConvertedType::INT_32) => Ok(DataType::Int32),
212        (None, ConvertedType::DATE) => Ok(DataType::Date32),
213        (None, ConvertedType::TIME_MILLIS) => Ok(DataType::Time32(TimeUnit::Millisecond)),
214        (None, ConvertedType::DECIMAL) => decimal_128_type(scale, precision),
215        (logical, converted) => Err(arrow_err!(
216            "Unable to convert parquet INT32 logical type {:?} or converted type {}",
217            logical,
218            converted
219        )),
220    }
221}
222
223fn from_int64(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result<DataType> {
224    match (info.logical_type_ref(), info.converted_type()) {
225        (None, ConvertedType::NONE) => Ok(DataType::Int64),
226        (
227            Some(LogicalType::Integer(IntType {
228                bit_width: 64,
229                is_signed,
230            })),
231            _,
232        ) => match is_signed {
233            true => Ok(DataType::Int64),
234            false => Ok(DataType::UInt64),
235        },
236        (Some(LogicalType::Time(time)), _) => match time.unit {
237            ParquetTimeUnit::MILLIS => {
238                Err(arrow_err!("Cannot create INT64 from MILLIS time unit",))
239            }
240            ParquetTimeUnit::MICROS => Ok(DataType::Time64(TimeUnit::Microsecond)),
241            ParquetTimeUnit::NANOS => Ok(DataType::Time64(TimeUnit::Nanosecond)),
242        },
243        (Some(LogicalType::Timestamp(timestamp)), _) => Ok(DataType::Timestamp(
244            match timestamp.unit {
245                ParquetTimeUnit::MILLIS => TimeUnit::Millisecond,
246                ParquetTimeUnit::MICROS => TimeUnit::Microsecond,
247                ParquetTimeUnit::NANOS => TimeUnit::Nanosecond,
248            },
249            if timestamp.is_adjusted_to_u_t_c {
250                Some("UTC".into())
251            } else {
252                None
253            },
254        )),
255        (None, ConvertedType::INT_64) => Ok(DataType::Int64),
256        (None, ConvertedType::UINT_64) => Ok(DataType::UInt64),
257        (None, ConvertedType::TIME_MICROS) => Ok(DataType::Time64(TimeUnit::Microsecond)),
258        (None, ConvertedType::TIMESTAMP_MILLIS) => Ok(DataType::Timestamp(
259            TimeUnit::Millisecond,
260            Some("UTC".into()),
261        )),
262        (None, ConvertedType::TIMESTAMP_MICROS) => Ok(DataType::Timestamp(
263            TimeUnit::Microsecond,
264            Some("UTC".into()),
265        )),
266        (Some(LogicalType::Decimal(dec)), _) => decimal_128_type(dec.scale, dec.precision),
267        (None, ConvertedType::DECIMAL) => decimal_128_type(scale, precision),
268        (logical, converted) => Err(arrow_err!(
269            "Unable to convert parquet INT64 logical type {:?} or converted type {}",
270            logical,
271            converted
272        )),
273    }
274}
275
276fn from_byte_array(info: &BasicTypeInfo, precision: i32, scale: i32) -> Result<DataType> {
277    match (info.logical_type_ref(), info.converted_type()) {
278        (Some(LogicalType::String), _) => Ok(DataType::Utf8),
279        (Some(LogicalType::Json), _) => Ok(DataType::Utf8),
280        (Some(LogicalType::Bson), _) => Ok(DataType::Binary),
281        (Some(LogicalType::Enum), _) => Ok(DataType::Binary),
282        (Some(LogicalType::Geometry { .. }), _) => Ok(DataType::Binary),
283        (Some(LogicalType::Geography { .. }), _) => Ok(DataType::Binary),
284        (Some(LogicalType::_Unknown { .. }), _) => Ok(DataType::Binary),
285        (None, ConvertedType::NONE) => Ok(DataType::Binary),
286        (None, ConvertedType::JSON) => Ok(DataType::Utf8),
287        (None, ConvertedType::BSON) => Ok(DataType::Binary),
288        (None, ConvertedType::ENUM) => Ok(DataType::Binary),
289        (None, ConvertedType::UTF8) => Ok(DataType::Utf8),
290        (Some(LogicalType::Decimal(decimal)), _) => decimal_type(decimal.scale, decimal.precision),
291        (None, ConvertedType::DECIMAL) => decimal_type(scale, precision),
292        (logical, converted) => Err(arrow_err!(
293            "Unable to convert parquet BYTE_ARRAY logical type {:?} or converted type {}",
294            logical,
295            converted
296        )),
297    }
298}
299
300fn from_fixed_len_byte_array(
301    info: &BasicTypeInfo,
302    scale: i32,
303    precision: i32,
304    type_length: i32,
305) -> Result<DataType> {
306    match (info.logical_type_ref(), info.converted_type()) {
307        (Some(LogicalType::Decimal(decimal)), _) => {
308            check_decimal_length(type_length)?;
309            // lengths 1..=16 map to Decimal128, 17..=32 to Decimal256
310            if type_length <= 16 {
311                decimal_128_type(decimal.scale, decimal.precision)
312            } else {
313                decimal_256_type(decimal.scale, decimal.precision)
314            }
315        }
316        (None, ConvertedType::DECIMAL) => {
317            check_decimal_length(type_length)?;
318            if type_length <= 16 {
319                decimal_128_type(scale, precision)
320            } else {
321                decimal_256_type(scale, precision)
322            }
323        }
324        (None, ConvertedType::INTERVAL) => {
325            if type_length != 12 {
326                return Err(ParquetError::General(format!(
327                    "INTERVAL must be a Fixed Length Byte Array with length 12, got {type_length}"
328                )));
329            }
330            // There is currently no reliable way of determining which IntervalUnit
331            // to return. Thus without the original Arrow schema, the results
332            // would be incorrect if all 12 bytes of the interval are populated
333            Ok(DataType::Interval(IntervalUnit::DayTime))
334        }
335        (Some(LogicalType::Float16), _) => {
336            if type_length == 2 {
337                Ok(DataType::Float16)
338            } else {
339                Err(ParquetError::General(
340                    "FLOAT16 logical type must be Fixed Length Byte Array with length 2"
341                        .to_string(),
342                ))
343            }
344        }
345        _ => Ok(DataType::FixedSizeBinary(type_length)),
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::basic::{DecimalType, Repetition};
353    use crate::schema::types::Type;
354
355    // The PrimitiveTypeBuilder rejects bad lengths at construction. To exercise
356    // the reader-side checks, build a valid type then overwrite its type_length,
357    // simulating a schema decoded from a file that wasn't produced via the builder.
358    fn with_type_length(ty: Type, type_length: i32) -> Type {
359        match ty {
360            Type::PrimitiveType {
361                basic_info,
362                physical_type,
363                precision,
364                scale,
365                ..
366            } => Type::PrimitiveType {
367                basic_info,
368                physical_type,
369                type_length,
370                precision,
371                scale,
372            },
373            _ => unreachable!(),
374        }
375    }
376
377    fn flba_decimal_logical(type_length: i32) -> Type {
378        let valid = Type::primitive_type_builder("c", PhysicalType::FIXED_LEN_BYTE_ARRAY)
379            .with_repetition(Repetition::REQUIRED)
380            .with_logical_type(Some(LogicalType::Decimal(DecimalType {
381                precision: 5,
382                scale: 2,
383            })))
384            .with_length(16)
385            .with_precision(5)
386            .with_scale(2)
387            .build()
388            .unwrap();
389        with_type_length(valid, type_length)
390    }
391
392    fn flba_decimal_converted(type_length: i32) -> Type {
393        let valid = Type::primitive_type_builder("c", PhysicalType::FIXED_LEN_BYTE_ARRAY)
394            .with_repetition(Repetition::REQUIRED)
395            .with_converted_type(ConvertedType::DECIMAL)
396            .with_length(16)
397            .with_precision(5)
398            .with_scale(2)
399            .build()
400            .unwrap();
401        with_type_length(valid, type_length)
402    }
403
404    fn flba_interval(type_length: i32) -> Type {
405        let valid = Type::primitive_type_builder("c", PhysicalType::FIXED_LEN_BYTE_ARRAY)
406            .with_repetition(Repetition::REQUIRED)
407            .with_converted_type(ConvertedType::INTERVAL)
408            .with_length(12)
409            .build()
410            .unwrap();
411        with_type_length(valid, type_length)
412    }
413
414    fn assert_err_contains(ty: &Type, needle: &str) {
415        let err = convert_primitive(ty, None).expect_err("expected an error");
416        let msg = err.to_string();
417        assert!(msg.contains(needle), "expected {needle:?} in error: {msg}");
418    }
419
420    #[test]
421    fn decimal_logical_rejects_invalid_length() {
422        for bad in [-1, 0, 33] {
423            assert_err_contains(&flba_decimal_logical(bad), "DECIMAL");
424        }
425    }
426
427    #[test]
428    fn decimal_converted_rejects_invalid_length() {
429        for bad in [-1, 0, 33] {
430            assert_err_contains(&flba_decimal_converted(bad), "DECIMAL");
431        }
432    }
433
434    #[test]
435    fn decimal_accepts_valid_lengths() {
436        assert!(matches!(
437            convert_primitive(&flba_decimal_logical(16), None).unwrap(),
438            DataType::Decimal128(_, _)
439        ));
440        assert!(matches!(
441            convert_primitive(&flba_decimal_logical(32), None).unwrap(),
442            DataType::Decimal256(_, _)
443        ));
444    }
445
446    #[test]
447    fn interval_rejects_wrong_length() {
448        for bad in [0, 11, 13] {
449            assert_err_contains(&flba_interval(bad), "INTERVAL");
450        }
451    }
452
453    #[test]
454    fn interval_accepts_length_12() {
455        assert_eq!(
456            convert_primitive(&flba_interval(12), None).unwrap(),
457            DataType::Interval(IntervalUnit::DayTime)
458        );
459    }
460}