arrow_schema/
datatype_parse.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc};
19
20use crate::{ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit};
21
22/// Parses a DataType from a string representation
23///
24/// For example, the string "Int32" would be parsed into [`DataType::Int32`]
25pub(crate) fn parse_data_type(val: &str) -> ArrowResult<DataType> {
26    Parser::new(val).parse()
27}
28
29type ArrowResult<T> = Result<T, ArrowError>;
30
31fn make_error(val: &str, msg: &str) -> ArrowError {
32    let msg = format!(
33        "Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(ns)'. Error {msg}"
34    );
35    ArrowError::ParseError(msg)
36}
37
38fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> ArrowError {
39    make_error(val, &format!("Expected '{expected}', got '{actual}'"))
40}
41
42/// Implementation of `parse_data_type`, modeled after <https://github.com/sqlparser-rs/sqlparser-rs>
43#[derive(Debug)]
44struct Parser<'a> {
45    val: &'a str,
46    tokenizer: Peekable<Tokenizer<'a>>,
47}
48
49impl<'a> Parser<'a> {
50    fn new(val: &'a str) -> Self {
51        Self {
52            val,
53            tokenizer: Tokenizer::new(val).peekable(),
54        }
55    }
56
57    fn parse(mut self) -> ArrowResult<DataType> {
58        let data_type = self.parse_next_type()?;
59        // ensure that there is no trailing content
60        if self.tokenizer.next().is_some() {
61            Err(make_error(
62                self.val,
63                &format!("checking trailing content after parsing '{data_type}'"),
64            ))
65        } else {
66            Ok(data_type)
67        }
68    }
69
70    /// parses the next full DataType
71    fn parse_next_type(&mut self) -> ArrowResult<DataType> {
72        match self.next_token()? {
73            Token::SimpleType(data_type) => Ok(data_type),
74            Token::Timestamp => self.parse_timestamp(),
75            Token::Time32 => self.parse_time32(),
76            Token::Time64 => self.parse_time64(),
77            Token::Duration => self.parse_duration(),
78            Token::Interval => self.parse_interval(),
79            Token::FixedSizeBinary => self.parse_fixed_size_binary(),
80            Token::Decimal32 => self.parse_decimal_32(),
81            Token::Decimal64 => self.parse_decimal_64(),
82            Token::Decimal128 => self.parse_decimal_128(),
83            Token::Decimal256 => self.parse_decimal_256(),
84            Token::Dictionary => self.parse_dictionary(),
85            Token::List => self.parse_list(),
86            Token::LargeList => self.parse_large_list(),
87            Token::FixedSizeList => self.parse_fixed_size_list(),
88            Token::Struct => self.parse_struct(),
89            tok => Err(make_error(
90                self.val,
91                &format!("finding next type, got unexpected '{tok}'"),
92            )),
93        }
94    }
95
96    /// Parses the List type
97    fn parse_list(&mut self) -> ArrowResult<DataType> {
98        self.expect_token(Token::LParen)?;
99        let data_type = self.parse_next_type()?;
100        self.expect_token(Token::RParen)?;
101        Ok(DataType::List(Arc::new(Field::new_list_field(
102            data_type, true,
103        ))))
104    }
105
106    /// Parses the LargeList type
107    fn parse_large_list(&mut self) -> ArrowResult<DataType> {
108        self.expect_token(Token::LParen)?;
109        let data_type = self.parse_next_type()?;
110        self.expect_token(Token::RParen)?;
111        Ok(DataType::LargeList(Arc::new(Field::new_list_field(
112            data_type, true,
113        ))))
114    }
115
116    /// Parses the FixedSizeList type
117    fn parse_fixed_size_list(&mut self) -> ArrowResult<DataType> {
118        self.expect_token(Token::LParen)?;
119        let length = self.parse_i32("FixedSizeList")?;
120        self.expect_token(Token::Comma)?;
121        let data_type = self.parse_next_type()?;
122        self.expect_token(Token::RParen)?;
123        Ok(DataType::FixedSizeList(
124            Arc::new(Field::new_list_field(data_type, true)),
125            length,
126        ))
127    }
128
129    /// Parses the next timeunit
130    fn parse_time_unit(&mut self, context: &str) -> ArrowResult<TimeUnit> {
131        match self.next_token()? {
132            Token::TimeUnit(time_unit) => Ok(time_unit),
133            tok => Err(make_error(
134                self.val,
135                &format!("finding TimeUnit for {context}, got {tok}"),
136            )),
137        }
138    }
139
140    /// Parses the next double quoted string
141    fn parse_double_quoted_string(&mut self, context: &str) -> ArrowResult<String> {
142        let token = self.next_token()?;
143        if let Token::DoubleQuotedString(string) = token {
144            Ok(string)
145        } else {
146            Err(make_error(
147                self.val,
148                &format!("expected double quoted string for {context}, got '{token}'"),
149            ))
150        }
151    }
152
153    /// Parses the next integer value
154    fn parse_i64(&mut self, context: &str) -> ArrowResult<i64> {
155        match self.next_token()? {
156            Token::Integer(v) => Ok(v),
157            tok => Err(make_error(
158                self.val,
159                &format!("finding i64 for {context}, got '{tok}'"),
160            )),
161        }
162    }
163
164    /// Parses the next i32 integer value
165    fn parse_i32(&mut self, context: &str) -> ArrowResult<i32> {
166        let length = self.parse_i64(context)?;
167        length.try_into().map_err(|e| {
168            make_error(
169                self.val,
170                &format!("converting {length} into i32 for {context}: {e}"),
171            )
172        })
173    }
174
175    /// Parses the next i8 integer value
176    fn parse_i8(&mut self, context: &str) -> ArrowResult<i8> {
177        let length = self.parse_i64(context)?;
178        length.try_into().map_err(|e| {
179            make_error(
180                self.val,
181                &format!("converting {length} into i8 for {context}: {e}"),
182            )
183        })
184    }
185
186    /// Parses the next u8 integer value
187    fn parse_u8(&mut self, context: &str) -> ArrowResult<u8> {
188        let length = self.parse_i64(context)?;
189        length.try_into().map_err(|e| {
190            make_error(
191                self.val,
192                &format!("converting {length} into u8 for {context}: {e}"),
193            )
194        })
195    }
196
197    /// Parses the next timestamp (called after `Timestamp` has been consumed)
198    fn parse_timestamp(&mut self) -> ArrowResult<DataType> {
199        self.expect_token(Token::LParen)?;
200        let time_unit = self.parse_time_unit("Timestamp")?;
201
202        let timezone;
203        match self.next_token()? {
204            Token::Comma => {
205                match self.next_token()? {
206                    // Support old style `Timestamp(Nanosecond, None)`
207                    Token::None => {
208                        timezone = None;
209                    }
210                    // Support old style `Timestamp(Nanosecond, Some("Timezone"))`
211                    Token::Some => {
212                        self.expect_token(Token::LParen)?;
213                        timezone = Some(self.parse_double_quoted_string("Timezone")?);
214                        self.expect_token(Token::RParen)?;
215                    }
216                    Token::DoubleQuotedString(tz) => {
217                        // Support new style `Timestamp(Nanosecond, "Timezone")`
218                        timezone = Some(tz);
219                    }
220                    tok => {
221                        return Err(make_error(
222                            self.val,
223                            &format!("Expected None, Some, or a timezone string, got {tok:?}"),
224                        ));
225                    }
226                };
227                self.expect_token(Token::RParen)?;
228            }
229            // No timezone (e.g `Timestamp(ns)`)
230            Token::RParen => {
231                timezone = None;
232            }
233            next_token => {
234                return Err(make_error(
235                    self.val,
236                    &format!("Expected comma followed by a timezone, or an ), got {next_token:?}"),
237                ));
238            }
239        }
240        Ok(DataType::Timestamp(time_unit, timezone.map(Into::into)))
241    }
242
243    /// Parses the next Time32 (called after `Time32` has been consumed)
244    fn parse_time32(&mut self) -> ArrowResult<DataType> {
245        self.expect_token(Token::LParen)?;
246        let time_unit = self.parse_time_unit("Time32")?;
247        self.expect_token(Token::RParen)?;
248        Ok(DataType::Time32(time_unit))
249    }
250
251    /// Parses the next Time64 (called after `Time64` has been consumed)
252    fn parse_time64(&mut self) -> ArrowResult<DataType> {
253        self.expect_token(Token::LParen)?;
254        let time_unit = self.parse_time_unit("Time64")?;
255        self.expect_token(Token::RParen)?;
256        Ok(DataType::Time64(time_unit))
257    }
258
259    /// Parses the next Duration (called after `Duration` has been consumed)
260    fn parse_duration(&mut self) -> ArrowResult<DataType> {
261        self.expect_token(Token::LParen)?;
262        let time_unit = self.parse_time_unit("Duration")?;
263        self.expect_token(Token::RParen)?;
264        Ok(DataType::Duration(time_unit))
265    }
266
267    /// Parses the next Interval (called after `Interval` has been consumed)
268    fn parse_interval(&mut self) -> ArrowResult<DataType> {
269        self.expect_token(Token::LParen)?;
270        let interval_unit = match self.next_token()? {
271            Token::IntervalUnit(interval_unit) => interval_unit,
272            tok => {
273                return Err(make_error(
274                    self.val,
275                    &format!("finding IntervalUnit for Interval, got {tok}"),
276                ));
277            }
278        };
279        self.expect_token(Token::RParen)?;
280        Ok(DataType::Interval(interval_unit))
281    }
282
283    /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has been consumed)
284    fn parse_fixed_size_binary(&mut self) -> ArrowResult<DataType> {
285        self.expect_token(Token::LParen)?;
286        let length = self.parse_i32("FixedSizeBinary")?;
287        self.expect_token(Token::RParen)?;
288        Ok(DataType::FixedSizeBinary(length))
289    }
290
291    /// Parses the next Decimal32 (called after `Decimal32` has been consumed)
292    fn parse_decimal_32(&mut self) -> ArrowResult<DataType> {
293        self.expect_token(Token::LParen)?;
294        let precision = self.parse_u8("Decimal32")?;
295        self.expect_token(Token::Comma)?;
296        let scale = self.parse_i8("Decimal32")?;
297        self.expect_token(Token::RParen)?;
298        Ok(DataType::Decimal32(precision, scale))
299    }
300
301    /// Parses the next Decimal64 (called after `Decimal64` has been consumed)
302    fn parse_decimal_64(&mut self) -> ArrowResult<DataType> {
303        self.expect_token(Token::LParen)?;
304        let precision = self.parse_u8("Decimal64")?;
305        self.expect_token(Token::Comma)?;
306        let scale = self.parse_i8("Decimal64")?;
307        self.expect_token(Token::RParen)?;
308        Ok(DataType::Decimal64(precision, scale))
309    }
310
311    /// Parses the next Decimal128 (called after `Decimal128` has been consumed)
312    fn parse_decimal_128(&mut self) -> ArrowResult<DataType> {
313        self.expect_token(Token::LParen)?;
314        let precision = self.parse_u8("Decimal128")?;
315        self.expect_token(Token::Comma)?;
316        let scale = self.parse_i8("Decimal128")?;
317        self.expect_token(Token::RParen)?;
318        Ok(DataType::Decimal128(precision, scale))
319    }
320
321    /// Parses the next Decimal256 (called after `Decimal256` has been consumed)
322    fn parse_decimal_256(&mut self) -> ArrowResult<DataType> {
323        self.expect_token(Token::LParen)?;
324        let precision = self.parse_u8("Decimal256")?;
325        self.expect_token(Token::Comma)?;
326        let scale = self.parse_i8("Decimal256")?;
327        self.expect_token(Token::RParen)?;
328        Ok(DataType::Decimal256(precision, scale))
329    }
330
331    /// Parses the next Dictionary (called after `Dictionary` has been consumed)
332    fn parse_dictionary(&mut self) -> ArrowResult<DataType> {
333        self.expect_token(Token::LParen)?;
334        let key_type = self.parse_next_type()?;
335        self.expect_token(Token::Comma)?;
336        let value_type = self.parse_next_type()?;
337        self.expect_token(Token::RParen)?;
338        Ok(DataType::Dictionary(
339            Box::new(key_type),
340            Box::new(value_type),
341        ))
342    }
343    fn parse_struct(&mut self) -> ArrowResult<DataType> {
344        self.expect_token(Token::LParen)?;
345        let mut fields = Vec::new();
346        loop {
347            // expects:   "field name": [nullable] #datatype
348
349            let field_name = match self.next_token()? {
350                Token::RParen => {
351                    break;
352                }
353                Token::DoubleQuotedString(field_name) => field_name,
354                tok => {
355                    return Err(make_error(
356                        self.val,
357                        &format!("Expected a quoted string for a field name; got {tok:?}"),
358                    ));
359                }
360            };
361            self.expect_token(Token::Colon)?;
362
363            let nullable = self
364                .tokenizer
365                .next_if(|next| matches!(next, Ok(Token::Nullable)))
366                .is_some();
367            let field_type = self.parse_next_type()?;
368            fields.push(Arc::new(Field::new(field_name, field_type, nullable)));
369            match self.next_token()? {
370                Token::Comma => continue,
371                Token::RParen => break,
372                tok => {
373                    return Err(make_error(
374                        self.val,
375                        &format!(
376                            "Unexpected token while parsing Struct fields. Expected ',' or ')', but got '{tok}'"
377                        ),
378                    ));
379                }
380            }
381        }
382        Ok(DataType::Struct(Fields::from(fields)))
383    }
384
385    /// return the next token, or an error if there are none left
386    fn next_token(&mut self) -> ArrowResult<Token> {
387        match self.tokenizer.next() {
388            None => Err(make_error(self.val, "finding next token")),
389            Some(token) => token,
390        }
391    }
392
393    /// consume the next token, returning OK(()) if it matches tok, and Err if not
394    fn expect_token(&mut self, tok: Token) -> ArrowResult<()> {
395        let next_token = self.next_token()?;
396        if next_token == tok {
397            Ok(())
398        } else {
399            Err(make_error_expected(self.val, &tok, &next_token))
400        }
401    }
402}
403
404/// returns true if this character is a separator
405fn is_separator(c: char) -> bool {
406    c == '(' || c == ')' || c == ',' || c == ':' || c == ' '
407}
408
409#[derive(Debug)]
410/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing
411///
412/// For example the string "Timestamp(ns)" would be parsed into:
413///
414/// * Token::Timestamp
415/// * Token::Lparen
416/// * Token::IntervalUnit(IntervalUnit::Nanosecond)
417/// * Token::Rparen,
418struct Tokenizer<'a> {
419    val: &'a str,
420    chars: Peekable<Chars<'a>>,
421    // temporary buffer for parsing words
422    word: String,
423}
424
425impl<'a> Tokenizer<'a> {
426    fn new(val: &'a str) -> Self {
427        Self {
428            val,
429            chars: val.chars().peekable(),
430            word: String::new(),
431        }
432    }
433
434    /// returns the next char, without consuming it
435    fn peek_next_char(&mut self) -> Option<char> {
436        self.chars.peek().copied()
437    }
438
439    /// returns the next char, and consuming it
440    fn next_char(&mut self) -> Option<char> {
441        self.chars.next()
442    }
443
444    /// parse the characters in val starting at pos, until the next
445    /// `,`, `(`, or `)` or end of line
446    fn parse_word(&mut self) -> ArrowResult<Token> {
447        // reset temp space
448        self.word.clear();
449        loop {
450            match self.peek_next_char() {
451                None => break,
452                Some(c) if is_separator(c) => break,
453                Some(c) => {
454                    self.next_char();
455                    self.word.push(c);
456                }
457            }
458        }
459
460        if let Some(c) = self.word.chars().next() {
461            // if it started with a number, try parsing it as an integer
462            if c == '-' || c.is_numeric() {
463                let val: i64 = self.word.parse().map_err(|e| {
464                    make_error(self.val, &format!("parsing {} as integer: {e}", self.word))
465                })?;
466                return Ok(Token::Integer(val));
467            }
468        }
469
470        // figure out what the word was
471        let token = match self.word.as_str() {
472            "Null" => Token::SimpleType(DataType::Null),
473            "Boolean" => Token::SimpleType(DataType::Boolean),
474
475            "Int8" => Token::SimpleType(DataType::Int8),
476            "Int16" => Token::SimpleType(DataType::Int16),
477            "Int32" => Token::SimpleType(DataType::Int32),
478            "Int64" => Token::SimpleType(DataType::Int64),
479
480            "UInt8" => Token::SimpleType(DataType::UInt8),
481            "UInt16" => Token::SimpleType(DataType::UInt16),
482            "UInt32" => Token::SimpleType(DataType::UInt32),
483            "UInt64" => Token::SimpleType(DataType::UInt64),
484
485            "Utf8" => Token::SimpleType(DataType::Utf8),
486            "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8),
487            "Utf8View" => Token::SimpleType(DataType::Utf8View),
488            "Binary" => Token::SimpleType(DataType::Binary),
489            "BinaryView" => Token::SimpleType(DataType::BinaryView),
490            "LargeBinary" => Token::SimpleType(DataType::LargeBinary),
491
492            "Float16" => Token::SimpleType(DataType::Float16),
493            "Float32" => Token::SimpleType(DataType::Float32),
494            "Float64" => Token::SimpleType(DataType::Float64),
495
496            "Date32" => Token::SimpleType(DataType::Date32),
497            "Date64" => Token::SimpleType(DataType::Date64),
498
499            "List" => Token::List,
500            "LargeList" => Token::LargeList,
501            "FixedSizeList" => Token::FixedSizeList,
502
503            "s" | "Second" => Token::TimeUnit(TimeUnit::Second),
504            "ms" | "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
505            "µs" | "us" | "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond),
506            "ns" | "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond),
507
508            "Timestamp" => Token::Timestamp,
509            "Time32" => Token::Time32,
510            "Time64" => Token::Time64,
511            "Duration" => Token::Duration,
512            "Interval" => Token::Interval,
513            "Dictionary" => Token::Dictionary,
514
515            "FixedSizeBinary" => Token::FixedSizeBinary,
516
517            "Decimal32" => Token::Decimal32,
518            "Decimal64" => Token::Decimal64,
519            "Decimal128" => Token::Decimal128,
520            "Decimal256" => Token::Decimal256,
521
522            "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth),
523            "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime),
524            "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano),
525
526            "Some" => Token::Some,
527            "None" => Token::None,
528
529            "nullable" => Token::Nullable,
530
531            "Struct" => Token::Struct,
532
533            token => {
534                return Err(make_error(self.val, &format!("unknown token: {token}")));
535            }
536        };
537        Ok(token)
538    }
539
540    /// Parses e.g. `"foo bar"`
541    fn parse_quoted_string(&mut self) -> ArrowResult<Token> {
542        if self.next_char() != Some('\"') {
543            return Err(make_error(self.val, "Expected \""));
544        }
545
546        // reset temp space
547        self.word.clear();
548
549        let mut is_escaped = false;
550
551        loop {
552            match self.next_char() {
553                None => {
554                    return Err(ArrowError::ParseError(format!(
555                        "Unterminated string at: \"{}",
556                        self.word
557                    )));
558                }
559                Some(c) => match c {
560                    '\\' => {
561                        is_escaped = true;
562                        self.word.push(c);
563                    }
564                    '"' => {
565                        if is_escaped {
566                            self.word.push(c);
567                            is_escaped = false;
568                        } else {
569                            break;
570                        }
571                    }
572                    c => {
573                        self.word.push(c);
574                    }
575                },
576            }
577        }
578
579        let val: String = self.word.parse().map_err(|err| {
580            ArrowError::ParseError(format!("Failed to parse string: \"{}\": {err}", self.word))
581        })?;
582
583        if val.is_empty() {
584            // Using empty strings as field names is just asking for trouble
585            return Err(make_error(self.val, "empty strings aren't allowed"));
586        }
587
588        Ok(Token::DoubleQuotedString(val))
589    }
590}
591
592impl Iterator for Tokenizer<'_> {
593    type Item = ArrowResult<Token>;
594
595    fn next(&mut self) -> Option<Self::Item> {
596        loop {
597            match self.peek_next_char()? {
598                ' ' => {
599                    // skip whitespace
600                    self.next_char();
601                    continue;
602                }
603                '"' => {
604                    return Some(self.parse_quoted_string());
605                }
606                '(' => {
607                    self.next_char();
608                    return Some(Ok(Token::LParen));
609                }
610                ')' => {
611                    self.next_char();
612                    return Some(Ok(Token::RParen));
613                }
614                ',' => {
615                    self.next_char();
616                    return Some(Ok(Token::Comma));
617                }
618                ':' => {
619                    self.next_char();
620                    return Some(Ok(Token::Colon));
621                }
622                _ => return Some(self.parse_word()),
623            }
624        }
625    }
626}
627
628/// Grammar is
629///
630#[derive(Debug, PartialEq)]
631enum Token {
632    // Null, or Int32
633    SimpleType(DataType),
634    Timestamp,
635    Time32,
636    Time64,
637    Duration,
638    Interval,
639    FixedSizeBinary,
640    Decimal32,
641    Decimal64,
642    Decimal128,
643    Decimal256,
644    Dictionary,
645    TimeUnit(TimeUnit),
646    IntervalUnit(IntervalUnit),
647    LParen,
648    RParen,
649    Comma,
650    Colon,
651    Some,
652    None,
653    Integer(i64),
654    DoubleQuotedString(String),
655    List,
656    LargeList,
657    FixedSizeList,
658    Struct,
659    Nullable,
660}
661
662impl Display for Token {
663    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664        match self {
665            Token::SimpleType(t) => write!(f, "{t}"),
666            Token::List => write!(f, "List"),
667            Token::LargeList => write!(f, "LargeList"),
668            Token::FixedSizeList => write!(f, "FixedSizeList"),
669            Token::Timestamp => write!(f, "Timestamp"),
670            Token::Time32 => write!(f, "Time32"),
671            Token::Time64 => write!(f, "Time64"),
672            Token::Duration => write!(f, "Duration"),
673            Token::Interval => write!(f, "Interval"),
674            Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"),
675            Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"),
676            Token::LParen => write!(f, "("),
677            Token::RParen => write!(f, ")"),
678            Token::Comma => write!(f, ","),
679            Token::Colon => write!(f, ":"),
680            Token::Some => write!(f, "Some"),
681            Token::None => write!(f, "None"),
682            Token::FixedSizeBinary => write!(f, "FixedSizeBinary"),
683            Token::Decimal32 => write!(f, "Decimal32"),
684            Token::Decimal64 => write!(f, "Decimal64"),
685            Token::Decimal128 => write!(f, "Decimal128"),
686            Token::Decimal256 => write!(f, "Decimal256"),
687            Token::Dictionary => write!(f, "Dictionary"),
688            Token::Integer(v) => write!(f, "Integer({v})"),
689            Token::DoubleQuotedString(s) => write!(f, "DoubleQuotedString({s})"),
690            Token::Struct => write!(f, "Struct"),
691            Token::Nullable => write!(f, "nullable"),
692        }
693    }
694}
695
696#[cfg(test)]
697mod test {
698    use super::*;
699
700    #[test]
701    fn test_parse_data_type() {
702        // this ensures types can be parsed correctly from their string representations
703        for dt in list_datatypes() {
704            round_trip(dt)
705        }
706    }
707
708    /// Ensure we converting data_type to a string, and then parse it as a type
709    /// verifying it is the same
710    fn round_trip(data_type: DataType) {
711        let data_type_string = data_type.to_string();
712        println!("Input '{data_type_string}' ({data_type:?})");
713        let parsed_type = parse_data_type(&data_type_string).unwrap();
714        assert_eq!(
715            data_type, parsed_type,
716            "Mismatch parsing {data_type_string}"
717        );
718    }
719
720    fn list_datatypes() -> Vec<DataType> {
721        vec![
722            // ---------
723            // Non Nested types
724            // ---------
725            DataType::Null,
726            DataType::Boolean,
727            DataType::Int8,
728            DataType::Int16,
729            DataType::Int32,
730            DataType::Int64,
731            DataType::UInt8,
732            DataType::UInt16,
733            DataType::UInt32,
734            DataType::UInt64,
735            DataType::Float16,
736            DataType::Float32,
737            DataType::Float64,
738            DataType::Timestamp(TimeUnit::Second, None),
739            DataType::Timestamp(TimeUnit::Millisecond, None),
740            DataType::Timestamp(TimeUnit::Microsecond, None),
741            DataType::Timestamp(TimeUnit::Nanosecond, None),
742            // we can't cover all possible timezones, here we only test utc and +08:00
743            DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())),
744            DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())),
745            DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())),
746            DataType::Timestamp(TimeUnit::Second, Some("+00:00".into())),
747            DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())),
748            DataType::Timestamp(TimeUnit::Microsecond, Some("+08:00".into())),
749            DataType::Timestamp(TimeUnit::Millisecond, Some("+08:00".into())),
750            DataType::Timestamp(TimeUnit::Second, Some("+08:00".into())),
751            DataType::Date32,
752            DataType::Date64,
753            DataType::Time32(TimeUnit::Second),
754            DataType::Time32(TimeUnit::Millisecond),
755            DataType::Time32(TimeUnit::Microsecond),
756            DataType::Time32(TimeUnit::Nanosecond),
757            DataType::Time64(TimeUnit::Second),
758            DataType::Time64(TimeUnit::Millisecond),
759            DataType::Time64(TimeUnit::Microsecond),
760            DataType::Time64(TimeUnit::Nanosecond),
761            DataType::Duration(TimeUnit::Second),
762            DataType::Duration(TimeUnit::Millisecond),
763            DataType::Duration(TimeUnit::Microsecond),
764            DataType::Duration(TimeUnit::Nanosecond),
765            DataType::Interval(IntervalUnit::YearMonth),
766            DataType::Interval(IntervalUnit::DayTime),
767            DataType::Interval(IntervalUnit::MonthDayNano),
768            DataType::Binary,
769            DataType::BinaryView,
770            DataType::FixedSizeBinary(0),
771            DataType::FixedSizeBinary(1234),
772            DataType::FixedSizeBinary(-432),
773            DataType::LargeBinary,
774            DataType::Utf8,
775            DataType::Utf8View,
776            DataType::LargeUtf8,
777            DataType::Decimal32(7, 8),
778            DataType::Decimal64(6, 9),
779            DataType::Decimal128(7, 12),
780            DataType::Decimal256(6, 13),
781            // ---------
782            // Nested types
783            // ---------
784            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
785            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
786            DataType::Dictionary(
787                Box::new(DataType::Int8),
788                Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)),
789            ),
790            DataType::Dictionary(
791                Box::new(DataType::Int8),
792                Box::new(DataType::FixedSizeBinary(23)),
793            ),
794            DataType::Dictionary(
795                Box::new(DataType::Int8),
796                Box::new(
797                    // nested dictionaries are probably a bad idea but they are possible
798                    DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
799                ),
800            ),
801            DataType::Struct(Fields::from(vec![
802                Field::new("f1", DataType::Int64, true),
803                Field::new("f2", DataType::Float64, true),
804                Field::new(
805                    "f3",
806                    DataType::Timestamp(TimeUnit::Second, Some("+08:00".into())),
807                    true,
808                ),
809                Field::new(
810                    "f4",
811                    DataType::Dictionary(
812                        Box::new(DataType::Int8),
813                        Box::new(DataType::FixedSizeBinary(23)),
814                    ),
815                    true,
816                ),
817            ])),
818            DataType::Struct(Fields::from(vec![
819                Field::new("Int64", DataType::Int64, true),
820                Field::new("Float64", DataType::Float64, true),
821            ])),
822            DataType::Struct(Fields::from(vec![
823                Field::new("f1", DataType::Int64, true),
824                Field::new(
825                    "nested_struct",
826                    DataType::Struct(Fields::from(vec![Field::new("n1", DataType::Int64, true)])),
827                    true,
828                ),
829            ])),
830            DataType::Struct(Fields::empty()),
831            // TODO support more structured types (List, LargeList, Union, Map, RunEndEncoded, etc)
832        ]
833    }
834
835    #[test]
836    fn test_parse_data_type_whitespace_tolerance() {
837        // (string to parse, expected DataType)
838        let cases = [
839            ("Int8", DataType::Int8),
840            (
841                "Timestamp        (ns)",
842                DataType::Timestamp(TimeUnit::Nanosecond, None),
843            ),
844            (
845                "Timestamp        (ns)  ",
846                DataType::Timestamp(TimeUnit::Nanosecond, None),
847            ),
848            (
849                "          Timestamp        (ns               )",
850                DataType::Timestamp(TimeUnit::Nanosecond, None),
851            ),
852            (
853                "Timestamp        (ns               )  ",
854                DataType::Timestamp(TimeUnit::Nanosecond, None),
855            ),
856        ];
857
858        for (data_type_string, expected_data_type) in cases {
859            let parsed_data_type = parse_data_type(data_type_string).unwrap();
860            assert_eq!(
861                parsed_data_type, expected_data_type,
862                "Parsing '{data_type_string}', expecting '{expected_data_type}'"
863            );
864        }
865    }
866
867    /// Ensure that old style types can still be parsed
868    #[test]
869    fn test_parse_data_type_backwards_compatibility() {
870        use DataType::*;
871        use IntervalUnit::*;
872        use TimeUnit::*;
873        // List below created with:
874        // for t in list_datatypes() {
875        // println!(r#"("{t}", {t:?}),"#)
876        // }
877        // (string to parse, expected DataType)
878        let cases = [
879            ("Timestamp(Nanosecond, None)", Timestamp(Nanosecond, None)),
880            ("Timestamp(Microsecond, None)", Timestamp(Microsecond, None)),
881            ("Timestamp(Millisecond, None)", Timestamp(Millisecond, None)),
882            ("Timestamp(Second, None)", Timestamp(Second, None)),
883            ("Timestamp(Nanosecond, None)", Timestamp(Nanosecond, None)),
884            // Timezones
885            (
886                r#"Timestamp(Nanosecond, Some("+00:00"))"#,
887                Timestamp(Nanosecond, Some("+00:00".into())),
888            ),
889            (
890                r#"Timestamp(Microsecond, Some("+00:00"))"#,
891                Timestamp(Microsecond, Some("+00:00".into())),
892            ),
893            (
894                r#"Timestamp(Millisecond, Some("+00:00"))"#,
895                Timestamp(Millisecond, Some("+00:00".into())),
896            ),
897            (
898                r#"Timestamp(Second, Some("+00:00"))"#,
899                Timestamp(Second, Some("+00:00".into())),
900            ),
901            ("Null", Null),
902            ("Boolean", Boolean),
903            ("Int8", Int8),
904            ("Int16", Int16),
905            ("Int32", Int32),
906            ("Int64", Int64),
907            ("UInt8", UInt8),
908            ("UInt16", UInt16),
909            ("UInt32", UInt32),
910            ("UInt64", UInt64),
911            ("Float16", Float16),
912            ("Float32", Float32),
913            ("Float64", Float64),
914            ("Timestamp(s)", Timestamp(Second, None)),
915            ("Timestamp(ms)", Timestamp(Millisecond, None)),
916            ("Timestamp(µs)", Timestamp(Microsecond, None)),
917            ("Timestamp(ns)", Timestamp(Nanosecond, None)),
918            (
919                r#"Timestamp(ns, "+00:00")"#,
920                Timestamp(Nanosecond, Some("+00:00".into())),
921            ),
922            (
923                r#"Timestamp(µs, "+00:00")"#,
924                Timestamp(Microsecond, Some("+00:00".into())),
925            ),
926            (
927                r#"Timestamp(ms, "+00:00")"#,
928                Timestamp(Millisecond, Some("+00:00".into())),
929            ),
930            (
931                r#"Timestamp(s, "+00:00")"#,
932                Timestamp(Second, Some("+00:00".into())),
933            ),
934            (
935                r#"Timestamp(ns, "+08:00")"#,
936                Timestamp(Nanosecond, Some("+08:00".into())),
937            ),
938            (
939                r#"Timestamp(µs, "+08:00")"#,
940                Timestamp(Microsecond, Some("+08:00".into())),
941            ),
942            (
943                r#"Timestamp(ms, "+08:00")"#,
944                Timestamp(Millisecond, Some("+08:00".into())),
945            ),
946            (
947                r#"Timestamp(s, "+08:00")"#,
948                Timestamp(Second, Some("+08:00".into())),
949            ),
950            ("Date32", Date32),
951            ("Date64", Date64),
952            ("Time32(s)", Time32(Second)),
953            ("Time32(ms)", Time32(Millisecond)),
954            ("Time32(µs)", Time32(Microsecond)),
955            ("Time32(ns)", Time32(Nanosecond)),
956            ("Time64(s)", Time64(Second)),
957            ("Time64(ms)", Time64(Millisecond)),
958            ("Time64(µs)", Time64(Microsecond)),
959            ("Time64(ns)", Time64(Nanosecond)),
960            ("Duration(s)", Duration(Second)),
961            ("Duration(ms)", Duration(Millisecond)),
962            ("Duration(µs)", Duration(Microsecond)),
963            ("Duration(ns)", Duration(Nanosecond)),
964            ("Interval(YearMonth)", Interval(YearMonth)),
965            ("Interval(DayTime)", Interval(DayTime)),
966            ("Interval(MonthDayNano)", Interval(MonthDayNano)),
967            ("Binary", Binary),
968            ("BinaryView", BinaryView),
969            ("FixedSizeBinary(0)", FixedSizeBinary(0)),
970            ("FixedSizeBinary(1234)", FixedSizeBinary(1234)),
971            ("FixedSizeBinary(-432)", FixedSizeBinary(-432)),
972            ("LargeBinary", LargeBinary),
973            ("Utf8", Utf8),
974            ("Utf8View", Utf8View),
975            ("LargeUtf8", LargeUtf8),
976            ("Decimal32(7, 8)", Decimal32(7, 8)),
977            ("Decimal64(6, 9)", Decimal64(6, 9)),
978            ("Decimal128(7, 12)", Decimal128(7, 12)),
979            ("Decimal256(6, 13)", Decimal256(6, 13)),
980            (
981                "Dictionary(Int32, Utf8)",
982                Dictionary(Box::new(Int32), Box::new(Utf8)),
983            ),
984            (
985                "Dictionary(Int8, Utf8)",
986                Dictionary(Box::new(Int8), Box::new(Utf8)),
987            ),
988            (
989                "Dictionary(Int8, Timestamp(ns))",
990                Dictionary(Box::new(Int8), Box::new(Timestamp(Nanosecond, None))),
991            ),
992            (
993                "Dictionary(Int8, FixedSizeBinary(23))",
994                Dictionary(Box::new(Int8), Box::new(FixedSizeBinary(23))),
995            ),
996            (
997                "Dictionary(Int8, Dictionary(Int8, Utf8))",
998                Dictionary(
999                    Box::new(Int8),
1000                    Box::new(Dictionary(Box::new(Int8), Box::new(Utf8))),
1001                ),
1002            ),
1003            (
1004                r#"Struct("f1": nullable Int64, "f2": nullable Float64, "f3": nullable Timestamp(s, "+08:00"), "f4": nullable Dictionary(Int8, FixedSizeBinary(23)))"#,
1005                Struct(Fields::from(vec![
1006                    Field::new("f1", Int64, true),
1007                    Field::new("f2", Float64, true),
1008                    Field::new("f3", Timestamp(Second, Some("+08:00".into())), true),
1009                    Field::new(
1010                        "f4",
1011                        Dictionary(Box::new(Int8), Box::new(FixedSizeBinary(23))),
1012                        true,
1013                    ),
1014                ])),
1015            ),
1016            (
1017                r#"Struct("Int64": nullable Int64, "Float64": nullable Float64)"#,
1018                Struct(Fields::from(vec![
1019                    Field::new("Int64", Int64, true),
1020                    Field::new("Float64", Float64, true),
1021                ])),
1022            ),
1023            (
1024                r#"Struct("f1": nullable Int64, "nested_struct": nullable Struct("n1": nullable Int64))"#,
1025                Struct(Fields::from(vec![
1026                    Field::new("f1", Int64, true),
1027                    Field::new(
1028                        "nested_struct",
1029                        Struct(Fields::from(vec![Field::new("n1", Int64, true)])),
1030                        true,
1031                    ),
1032                ])),
1033            ),
1034            (r#"Struct()"#, Struct(Fields::empty())),
1035        ];
1036
1037        for (data_type_string, expected_data_type) in cases {
1038            let parsed_data_type = parse_data_type(data_type_string).unwrap();
1039            assert_eq!(
1040                parsed_data_type, expected_data_type,
1041                "Parsing '{data_type_string}', expecting '{expected_data_type}'"
1042            );
1043        }
1044    }
1045
1046    #[test]
1047    fn parse_data_type_errors() {
1048        // (string to parse, expected error message)
1049        let cases = [
1050            ("", "Unsupported type ''"),
1051            ("", "Error finding next token"),
1052            ("null", "Unsupported type 'null'"),
1053            ("Nu", "Unsupported type 'Nu'"),
1054            (r#"Timestamp(ns, +00:00)"#, "Error unknown token: +00"),
1055            (
1056                r#"Timestamp(ns, "+00:00)"#,
1057                r#"Unterminated string at: "+00:00)"#,
1058            ),
1059            (r#"Timestamp(ns, "")"#, r#"empty strings aren't allowed"#),
1060            (
1061                r#"Timestamp(ns, "+00:00"")"#,
1062                r#"Parser error: Unterminated string at: ")"#,
1063            ),
1064            ("Timestamp(ns, ", "Error finding next token"),
1065            (
1066                "Float32 Float32",
1067                "trailing content after parsing 'Float32'",
1068            ),
1069            ("Int32, ", "trailing content after parsing 'Int32'"),
1070            ("Int32(3), ", "trailing content after parsing 'Int32'"),
1071            (
1072                "FixedSizeBinary(Int32), ",
1073                "Error finding i64 for FixedSizeBinary, got 'Int32'",
1074            ),
1075            (
1076                "FixedSizeBinary(3.0), ",
1077                "Error parsing 3.0 as integer: invalid digit found in string",
1078            ),
1079            // too large for i32
1080            (
1081                "FixedSizeBinary(4000000000), ",
1082                "Error converting 4000000000 into i32 for FixedSizeBinary: out of range integral type conversion attempted",
1083            ),
1084            // can't have negative precision
1085            (
1086                "Decimal32(-3, 5)",
1087                "Error converting -3 into u8 for Decimal32: out of range integral type conversion attempted",
1088            ),
1089            (
1090                "Decimal64(-3, 5)",
1091                "Error converting -3 into u8 for Decimal64: out of range integral type conversion attempted",
1092            ),
1093            (
1094                "Decimal128(-3, 5)",
1095                "Error converting -3 into u8 for Decimal128: out of range integral type conversion attempted",
1096            ),
1097            (
1098                "Decimal256(-3, 5)",
1099                "Error converting -3 into u8 for Decimal256: out of range integral type conversion attempted",
1100            ),
1101            (
1102                "Decimal32(3, 500)",
1103                "Error converting 500 into i8 for Decimal32: out of range integral type conversion attempted",
1104            ),
1105            (
1106                "Decimal64(3, 500)",
1107                "Error converting 500 into i8 for Decimal64: out of range integral type conversion attempted",
1108            ),
1109            (
1110                "Decimal128(3, 500)",
1111                "Error converting 500 into i8 for Decimal128: out of range integral type conversion attempted",
1112            ),
1113            (
1114                "Decimal256(3, 500)",
1115                "Error converting 500 into i8 for Decimal256: out of range integral type conversion attempted",
1116            ),
1117            ("Struct(f1 Int64)", "Error unknown token: f1"),
1118            ("Struct(\"f1\" Int64)", "Expected ':'"),
1119            (
1120                "Struct(\"f1\": )",
1121                "Error finding next type, got unexpected ')'",
1122            ),
1123        ];
1124
1125        for (data_type_string, expected_message) in cases {
1126            println!("Parsing '{data_type_string}', expecting '{expected_message}'");
1127            match parse_data_type(data_type_string) {
1128                Ok(d) => panic!("Expected error while parsing '{data_type_string}', but got '{d}'"),
1129                Err(e) => {
1130                    let message = e.to_string();
1131                    assert!(
1132                        message.contains(expected_message),
1133                        "\n\ndid not find expected in actual.\n\nexpected: {expected_message}\nactual: {message}\n"
1134                    );
1135
1136                    if !message.contains("Unterminated string") {
1137                        // errors should also contain a help message
1138                        assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(ns)'"), "message: {message}");
1139                    }
1140                }
1141            }
1142        }
1143    }
1144
1145    #[test]
1146    fn parse_error_type() {
1147        let err = parse_data_type("foobar").unwrap_err();
1148        assert!(matches!(err, ArrowError::ParseError(_)));
1149        assert_eq!(
1150            err.to_string(),
1151            "Parser error: Unsupported type 'foobar'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(ns)'. Error unknown token: foobar"
1152        );
1153    }
1154}