Skip to main content

parquet/
parquet_thrift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Structs used for encoding and decoding Parquet Thrift objects.
19//!
20//! These include:
21//! * [`ThriftCompactInputProtocol`]: Trait implemented by Thrift decoders.
22//!     * [`ThriftSliceInputProtocol`]: Thrift decoder that takes a slice of bytes as input.
23//!     * [`ThriftReadInputProtocol`]: Thrift decoder that takes a [`Read`] as input.
24//! * [`ReadThrift`]: Trait implemented by serializable objects.
25//! * [`ThriftCompactOutputProtocol`]: Thrift encoder.
26//! * [`WriteThrift`]: Trait implemented by serializable objects.
27//! * [`WriteThriftField`]: Trait implemented by serializable objects that are fields in Thrift structs.
28
29use std::{
30    cmp::Ordering,
31    io::{Read, Write},
32};
33
34use crate::{
35    errors::{ParquetError, Result},
36    write_thrift_field,
37};
38use std::io::Error;
39use std::num::TryFromIntError;
40use std::str::Utf8Error;
41
42#[derive(Debug)]
43pub(crate) enum ThriftProtocolError {
44    Eof,
45    IO(Error),
46    InvalidFieldType(u8),
47    InvalidElementType(u8),
48    FieldDeltaOverflow { field_delta: u8, last_field_id: i16 },
49    InvalidBoolean(u8),
50    IntegerOverflow,
51    Utf8Error,
52    SkipDepth(FieldType),
53    SkipUnsupportedType(FieldType),
54}
55
56impl From<ThriftProtocolError> for ParquetError {
57    #[inline(never)]
58    fn from(e: ThriftProtocolError) -> Self {
59        match e {
60            ThriftProtocolError::Eof => eof_err!("Unexpected EOF"),
61            ThriftProtocolError::IO(e) => e.into(),
62            ThriftProtocolError::InvalidFieldType(value) => match FieldType::try_from(value) {
63                Ok(fld_type) => general_err!("Unexpected struct field type {:?}", fld_type),
64                Err(_) => general_err!("Unexpected struct field type {}", value),
65            },
66            ThriftProtocolError::InvalidElementType(value) => {
67                general_err!("Unexpected list/set element type {}", value)
68            }
69            ThriftProtocolError::FieldDeltaOverflow {
70                field_delta,
71                last_field_id,
72            } => general_err!("cannot add {} to {}", field_delta, last_field_id),
73            ThriftProtocolError::InvalidBoolean(value) => {
74                general_err!("cannot convert {} into bool", value)
75            }
76            ThriftProtocolError::IntegerOverflow => {
77                general_err!("integer overflow decoding thrift value")
78            }
79            ThriftProtocolError::Utf8Error => general_err!("invalid utf8"),
80            ThriftProtocolError::SkipDepth(field_type) => {
81                general_err!("cannot parse past {:?}", field_type)
82            }
83            ThriftProtocolError::SkipUnsupportedType(field_type) => {
84                general_err!("cannot skip field type {:?}", field_type)
85            }
86        }
87    }
88}
89
90impl From<Utf8Error> for ThriftProtocolError {
91    fn from(_: Utf8Error) -> Self {
92        // ignore error payload to reduce the size of ThriftProtocolError
93        Self::Utf8Error
94    }
95}
96
97impl From<Error> for ThriftProtocolError {
98    fn from(e: Error) -> Self {
99        Self::IO(e)
100    }
101}
102
103impl From<TryFromIntError> for ThriftProtocolError {
104    fn from(_: TryFromIntError) -> Self {
105        // ignore error payload to reduce the size of ThriftProtocolError
106        Self::IntegerOverflow
107    }
108}
109
110pub type ThriftProtocolResult<T> = Result<T, ThriftProtocolError>;
111
112/// Wrapper for thrift `double` fields. This is used to provide
113/// an implementation of `Eq` for floats. This implementation
114/// uses IEEE 754 total order.
115#[derive(Debug, Clone, Copy, PartialEq)]
116pub struct OrderedF64(f64);
117
118impl From<f64> for OrderedF64 {
119    fn from(value: f64) -> Self {
120        Self(value)
121    }
122}
123
124impl From<OrderedF64> for f64 {
125    fn from(value: OrderedF64) -> Self {
126        value.0
127    }
128}
129
130impl Eq for OrderedF64 {} // Marker trait, requires PartialEq
131
132impl Ord for OrderedF64 {
133    fn cmp(&self, other: &Self) -> Ordering {
134        self.0.total_cmp(&other.0)
135    }
136}
137
138impl PartialOrd for OrderedF64 {
139    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140        Some(self.cmp(other))
141    }
142}
143
144// Thrift compact protocol types for struct fields.
145#[derive(Clone, Copy, Debug, Eq, PartialEq)]
146pub(crate) enum FieldType {
147    Stop = 0,
148    BooleanTrue = 1,
149    BooleanFalse = 2,
150    Byte = 3,
151    I16 = 4,
152    I32 = 5,
153    I64 = 6,
154    Double = 7,
155    Binary = 8,
156    List = 9,
157    Set = 10,
158    Map = 11,
159    Struct = 12,
160    Uuid = 13,
161}
162
163impl TryFrom<u8> for FieldType {
164    type Error = ThriftProtocolError;
165    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
166        match value {
167            0 => Ok(Self::Stop),
168            1 => Ok(Self::BooleanTrue),
169            2 => Ok(Self::BooleanFalse),
170            3 => Ok(Self::Byte),
171            4 => Ok(Self::I16),
172            5 => Ok(Self::I32),
173            6 => Ok(Self::I64),
174            7 => Ok(Self::Double),
175            8 => Ok(Self::Binary),
176            9 => Ok(Self::List),
177            10 => Ok(Self::Set),
178            11 => Ok(Self::Map),
179            12 => Ok(Self::Struct),
180            13 => Ok(Self::Uuid),
181            _ => Err(ThriftProtocolError::InvalidFieldType(value)),
182        }
183    }
184}
185
186impl From<ElementType> for FieldType {
187    fn from(value: ElementType) -> Self {
188        match value {
189            ElementType::Bool => Self::BooleanTrue,
190            ElementType::Byte => Self::Byte,
191            ElementType::I16 => Self::I16,
192            ElementType::I32 => Self::I32,
193            ElementType::I64 => Self::I64,
194            ElementType::Double => Self::Double,
195            ElementType::Binary => Self::Binary,
196            ElementType::List => Self::List,
197            ElementType::Set => Self::Set,
198            ElementType::Map => Self::Map,
199            ElementType::Struct => Self::Struct,
200            ElementType::Uuid => Self::Uuid,
201        }
202    }
203}
204
205// Thrift compact protocol types for list elements
206#[derive(Clone, Copy, Debug, Eq, PartialEq)]
207pub(crate) enum ElementType {
208    Bool = 2,
209    Byte = 3,
210    I16 = 4,
211    I32 = 5,
212    I64 = 6,
213    Double = 7,
214    Binary = 8,
215    List = 9,
216    Set = 10,
217    Map = 11,
218    Struct = 12,
219    Uuid = 13,
220}
221
222impl TryFrom<u8> for ElementType {
223    type Error = ThriftProtocolError;
224    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
225        match value {
226            // For historical and compatibility reasons, a reader should be capable to deal with both cases.
227            // The only valid value in the original spec was 2, but due to an widespread implementation bug
228            // the defacto standard across large parts of the library became 1 instead.
229            // As a result, both values are now allowed.
230            // https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
231            1 | 2 => Ok(Self::Bool),
232            3 => Ok(Self::Byte),
233            4 => Ok(Self::I16),
234            5 => Ok(Self::I32),
235            6 => Ok(Self::I64),
236            7 => Ok(Self::Double),
237            8 => Ok(Self::Binary),
238            9 => Ok(Self::List),
239            10 => Ok(Self::Set),
240            11 => Ok(Self::Map),
241            12 => Ok(Self::Struct),
242            13 => Ok(Self::Uuid),
243            _ => Err(ThriftProtocolError::InvalidElementType(value)),
244        }
245    }
246}
247
248/// Struct used to describe a [thrift struct] field during decoding.
249///
250/// [thrift struct]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
251pub(crate) struct FieldIdentifier {
252    /// The type for the field.
253    pub(crate) field_type: FieldType,
254    /// The field's `id`. May be computed from delta or directly decoded.
255    pub(crate) id: i16,
256}
257
258impl FieldIdentifier {
259    pub(crate) fn bool_val(&self) -> ThriftProtocolResult<bool> {
260        match self.field_type {
261            FieldType::BooleanTrue => Ok(true),
262            FieldType::BooleanFalse => Ok(false),
263            _ => Err(ThriftProtocolError::InvalidFieldType(self.field_type as u8)),
264        }
265    }
266}
267
268/// Struct used to describe a [thrift list].
269///
270/// [thrift list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
271#[derive(Clone, Debug, Eq, PartialEq)]
272pub(crate) struct ListIdentifier {
273    /// The type for each element in the list.
274    pub(crate) element_type: ElementType,
275    /// Number of elements contained in the list.
276    pub(crate) size: i32,
277}
278
279/// Low-level object used to deserialize structs encoded with the Thrift [compact] protocol.
280///
281/// Implementation of this trait must provide the low-level functions `read_byte`, `read_bytes`,
282/// `skip_bytes`, and `read_double`. These primitives are used by the default functions provided
283/// here to perform deserialization.
284///
285/// [compact]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
286pub(crate) trait ThriftCompactInputProtocol<'a> {
287    /// Read a single byte from the input.
288    fn read_byte(&mut self) -> ThriftProtocolResult<u8>;
289
290    /// Read a Thrift encoded [binary] from the input.
291    ///
292    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
293    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]>;
294
295    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>>;
296
297    /// Skip the next `n` bytes of input.
298    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()>;
299
300    /// Read a ULEB128 encoded unsigned varint from the input.
301    fn read_vlq(&mut self) -> ThriftProtocolResult<u64> {
302        // try the happy path first
303        let byte = self.read_byte()?;
304        if byte & 0x80 == 0 {
305            return Ok(byte as u64);
306        }
307        let mut in_progress = (byte & 0x7f) as u64;
308        let mut shift = 7;
309        loop {
310            let byte = self.read_byte()?;
311            in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
312            if byte & 0x80 == 0 {
313                return Ok(in_progress);
314            }
315            shift += 7;
316        }
317    }
318
319    /// Read a zig-zag encoded signed varint from the input.
320    fn read_zig_zag(&mut self) -> ThriftProtocolResult<i64> {
321        let val = self.read_vlq()?;
322        Ok((val >> 1) as i64 ^ -((val & 1) as i64))
323    }
324
325    /// Read the [`ListIdentifier`] for a Thrift encoded list.
326    fn read_list_begin(&mut self) -> ThriftProtocolResult<ListIdentifier> {
327        let header = self.read_byte()?;
328        // some parquet writers will have an element_type of 0 for an empty list.
329        // account for that and return a bogus but valid element_type.
330        if header == 0 {
331            return Ok(ListIdentifier {
332                element_type: ElementType::Byte,
333                size: 0,
334            });
335        }
336        let element_type = ElementType::try_from(header & 0x0f)?;
337
338        let possible_element_count = (header & 0xF0) >> 4;
339        let element_count = if possible_element_count != 15 {
340            // high bits set high if count and type encoded separately
341            possible_element_count as i32
342        } else {
343            // The list size on the wire is an unsigned varint, but we represent
344            // it as `i32` (matching Java's `int` and the Thrift schema).
345            // A varint that decodes above `i32::MAX` is malformed input — reject
346            // it here at the protocol layer rather than letting the cast wrap
347            // into a negative size that downstream allocation code has to
348            // re-validate.
349            i32::try_from(self.read_vlq()?)?
350        };
351
352        Ok(ListIdentifier {
353            element_type,
354            size: element_count,
355        })
356    }
357
358    // Full field ids are uncommon.
359    // Not inlining this method reduces the code size of `read_field_begin`, which then ideally gets
360    // inlined everywhere.
361    #[cold]
362    fn read_full_field_id(&mut self) -> ThriftProtocolResult<i16> {
363        self.read_i16()
364    }
365
366    /// Read the [`FieldIdentifier`] for a field in a Thrift encoded struct.
367    fn read_field_begin(&mut self, last_field_id: i16) -> ThriftProtocolResult<FieldIdentifier> {
368        // we can read at least one byte, which is:
369        // - the type
370        // - the field delta and the type
371        let field_type = self.read_byte()?;
372        if field_type & 0xf == 0 {
373            return Ok(FieldIdentifier {
374                field_type: FieldType::Stop,
375                id: 0,
376            });
377        }
378
379        let field_delta = (field_type & 0xf0) >> 4;
380        let field_type = FieldType::try_from(field_type & 0xf)?;
381
382        let id = if field_delta != 0 {
383            last_field_id.checked_add(field_delta as i16).ok_or(
384                ThriftProtocolError::FieldDeltaOverflow {
385                    field_delta,
386                    last_field_id,
387                },
388            )?
389        } else {
390            self.read_full_field_id()?
391        };
392
393        Ok(FieldIdentifier { field_type, id })
394    }
395
396    /// This is a specialized version of [`Self::read_field_begin`], solely for use in parsing
397    /// simple structs. This function assumes that the delta field will always be less than 0xf,
398    /// fields will be in order, and no boolean fields will be read.
399    /// This also skips validation of the field type.
400    ///
401    /// Returns a tuple of `(field_type, field_delta)`.
402    fn read_field_header(&mut self) -> ThriftProtocolResult<(u8, u8)> {
403        let field_type = self.read_byte()?;
404        let field_delta = (field_type & 0xf0) >> 4;
405        let field_type = field_type & 0xf;
406        Ok((field_type, field_delta))
407    }
408
409    /// Read a boolean list element. This should not be used for struct fields. For the latter,
410    /// use the [`FieldIdentifier::bool_val`] field.
411    fn read_bool(&mut self) -> ThriftProtocolResult<bool> {
412        let b = self.read_byte()?;
413        // Previous versions of the thrift specification said to use 0 and 1 inside collections,
414        // but that differed from existing implementations.
415        // The specification was updated in https://github.com/apache/thrift/commit/2c29c5665bc442e703480bb0ee60fe925ffe02e8.
416        // At least the go implementation seems to have followed the previously documented values.
417        match b {
418            0x01 => Ok(true),
419            0x00 | 0x02 => Ok(false),
420            _ => Err(ThriftProtocolError::InvalidBoolean(b)),
421        }
422    }
423
424    /// Read a Thrift [binary] as a UTF-8 encoded string.
425    ///
426    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
427    fn read_string(&mut self) -> ThriftProtocolResult<&'a str> {
428        let slice = self.read_bytes()?;
429        Ok(std::str::from_utf8(slice)?)
430    }
431
432    /// Read an `i8`.
433    fn read_i8(&mut self) -> ThriftProtocolResult<i8> {
434        Ok(self.read_byte()? as _)
435    }
436
437    /// Read an `i16`.
438    fn read_i16(&mut self) -> ThriftProtocolResult<i16> {
439        Ok(self.read_zig_zag()? as _)
440    }
441
442    /// Read an `i32`.
443    fn read_i32(&mut self) -> ThriftProtocolResult<i32> {
444        Ok(self.read_zig_zag()? as _)
445    }
446
447    /// Read an `i64`.
448    fn read_i64(&mut self) -> ThriftProtocolResult<i64> {
449        self.read_zig_zag()
450    }
451
452    /// Read a Thrift `double` as `f64`.
453    fn read_double(&mut self) -> ThriftProtocolResult<f64>;
454
455    /// Skip a ULEB128 encoded varint.
456    fn skip_vlq(&mut self) -> ThriftProtocolResult<()> {
457        loop {
458            let byte = self.read_byte()?;
459            if byte & 0x80 == 0 {
460                return Ok(());
461            }
462        }
463    }
464
465    /// Skip a thrift [binary].
466    ///
467    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
468    fn skip_binary(&mut self) -> ThriftProtocolResult<()> {
469        let len = self.read_vlq()? as usize;
470        self.skip_bytes(len)
471    }
472
473    /// Skip a field with type `field_type` recursively until the default
474    /// maximum skip depth (currently 64) is reached.
475    fn skip(&mut self, field_type: FieldType) -> ThriftProtocolResult<()> {
476        const DEFAULT_SKIP_DEPTH: i8 = 64;
477        self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH)
478    }
479
480    /// Empty structs in unions consist of a single byte of 0 for the field stop record.
481    /// This skips that byte without encuring the cost of processing the [`FieldIdentifier`].
482    /// Will return an error if the struct is not actually empty.
483    fn skip_empty_struct(&mut self) -> Result<()> {
484        let b = self.read_byte()?;
485        if b != 0 {
486            Err(general_err!("Empty struct has fields"))
487        } else {
488            Ok(())
489        }
490    }
491
492    /// Skip a field with type `field_type` recursively up to `depth` levels.
493    fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> ThriftProtocolResult<()> {
494        if depth == 0 {
495            return Err(ThriftProtocolError::SkipDepth(field_type));
496        }
497
498        match field_type {
499            // boolean field has no data
500            FieldType::BooleanFalse | FieldType::BooleanTrue => Ok(()),
501            FieldType::Byte => self.read_i8().map(|_| ()),
502            FieldType::I16 => self.skip_vlq().map(|_| ()),
503            FieldType::I32 => self.skip_vlq().map(|_| ()),
504            FieldType::I64 => self.skip_vlq().map(|_| ()),
505            FieldType::Double => self.skip_bytes(8).map(|_| ()),
506            FieldType::Binary => self.skip_binary().map(|_| ()),
507            // see https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct
508            FieldType::Struct => {
509                loop {
510                    // we don't need field id for skipping, so always pass 0 for last id
511                    let field_ident = self.read_field_begin(0)?;
512                    if field_ident.field_type == FieldType::Stop {
513                        break;
514                    }
515                    self.skip_till_depth(field_ident.field_type, depth - 1)?;
516                }
517                Ok(())
518            }
519            // lists and sets are encoded the same
520            // see https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
521            FieldType::List | FieldType::Set => {
522                let list_ident = self.read_list_begin()?;
523                let element_type = FieldType::from(list_ident.element_type);
524                for _ in 0..list_ident.size {
525                    self.skip_till_depth(element_type, depth - 1)?;
526                }
527                Ok(())
528            }
529            // see https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#map
530            FieldType::Map => {
531                let size = i32::try_from(self.read_vlq()?)?;
532                if size > 0 {
533                    let kv = self.read_byte()?;
534                    let key_type = FieldType::from(ElementType::try_from(kv >> 4)?);
535                    let val_type = FieldType::from(ElementType::try_from(kv & 0xf)?);
536                    for _ in 0..size {
537                        self.skip_till_depth(key_type, depth - 1)?;
538                        self.skip_till_depth(val_type, depth - 1)?;
539                    }
540                }
541                Ok(())
542            }
543            // see https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#universal-unique-identifier-encoding
544            FieldType::Uuid => self.skip_bytes(16).map(|_| ()),
545            _ => Err(ThriftProtocolError::SkipUnsupportedType(field_type)),
546        }
547    }
548}
549
550/// A high performance Thrift reader that reads from a slice of bytes.
551pub(crate) struct ThriftSliceInputProtocol<'a> {
552    buf: &'a [u8],
553}
554
555impl<'a> ThriftSliceInputProtocol<'a> {
556    /// Create a new `ThriftSliceInputProtocol` using the bytes in `buf`.
557    pub fn new(buf: &'a [u8]) -> Self {
558        Self { buf }
559    }
560
561    /// Return the current buffer as a slice.
562    pub fn as_slice(&self) -> &'a [u8] {
563        self.buf
564    }
565}
566
567impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for ThriftSliceInputProtocol<'a> {
568    #[inline]
569    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
570        let ret = *self.buf.first().ok_or(ThriftProtocolError::Eof)?;
571        self.buf = &self.buf[1..];
572        Ok(ret)
573    }
574
575    fn read_bytes(&mut self) -> ThriftProtocolResult<&'b [u8]> {
576        let len = self.read_vlq()? as usize;
577        let ret = self.buf.get(..len).ok_or(ThriftProtocolError::Eof)?;
578        self.buf = &self.buf[len..];
579        Ok(ret)
580    }
581
582    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
583        Ok(self.read_bytes()?.to_vec())
584    }
585
586    #[inline]
587    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
588        self.buf.get(..n).ok_or(ThriftProtocolError::Eof)?;
589        self.buf = &self.buf[n..];
590        Ok(())
591    }
592
593    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
594        let slice = self.buf.get(..8).ok_or(ThriftProtocolError::Eof)?;
595        self.buf = &self.buf[8..];
596        match slice.try_into() {
597            Ok(slice) => Ok(f64::from_le_bytes(slice)),
598            Err(_) => unreachable!(),
599        }
600    }
601}
602
603/// A Thrift input protocol that wraps a [`Read`] object.
604///
605/// Note that this is only intended for use in reading Parquet page headers. This will panic
606/// if Thrift `binary` data is encountered because a slice of that data cannot be returned.
607pub(crate) struct ThriftReadInputProtocol<R: Read> {
608    reader: R,
609}
610
611impl<R: Read> ThriftReadInputProtocol<R> {
612    pub(crate) fn new(reader: R) -> Self {
613        Self { reader }
614    }
615}
616
617impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol<R> {
618    #[inline]
619    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
620        let mut buf = [0_u8; 1];
621        self.reader.read_exact(&mut buf)?;
622        Ok(buf[0])
623    }
624
625    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]> {
626        unimplemented!()
627    }
628
629    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
630        let len = self.read_vlq()? as usize;
631        let mut v = Vec::with_capacity(len);
632        std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?;
633        Ok(v)
634    }
635
636    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
637        std::io::copy(
638            &mut self.reader.by_ref().take(n as u64),
639            &mut std::io::sink(),
640        )?;
641        Ok(())
642    }
643
644    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
645        let mut buf = [0_u8; 8];
646        self.reader.read_exact(&mut buf)?;
647        Ok(f64::from_le_bytes(buf))
648    }
649}
650
651/// Trait implemented for objects that can be deserialized from a Thrift input stream.
652/// Implementations are provided for Thrift primitive types.
653pub(crate) trait ReadThrift<'a, R: ThriftCompactInputProtocol<'a>> {
654    /// Read an object of type `Self` from the input protocol object.
655    fn read_thrift(prot: &mut R) -> Result<Self>
656    where
657        Self: Sized;
658}
659
660impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool {
661    fn read_thrift(prot: &mut R) -> Result<Self> {
662        Ok(prot.read_bool()?)
663    }
664}
665
666impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 {
667    fn read_thrift(prot: &mut R) -> Result<Self> {
668        Ok(prot.read_i8()?)
669    }
670}
671
672impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 {
673    fn read_thrift(prot: &mut R) -> Result<Self> {
674        Ok(prot.read_i16()?)
675    }
676}
677
678impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 {
679    fn read_thrift(prot: &mut R) -> Result<Self> {
680        Ok(prot.read_i32()?)
681    }
682}
683
684impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 {
685    fn read_thrift(prot: &mut R) -> Result<Self> {
686        Ok(prot.read_i64()?)
687    }
688}
689
690impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for OrderedF64 {
691    fn read_thrift(prot: &mut R) -> Result<Self> {
692        Ok(OrderedF64(prot.read_double()?))
693    }
694}
695
696impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str {
697    fn read_thrift(prot: &mut R) -> Result<Self> {
698        Ok(prot.read_string()?)
699    }
700}
701
702impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for String {
703    fn read_thrift(prot: &mut R) -> Result<Self> {
704        Ok(String::from_utf8(prot.read_bytes_owned()?)?)
705    }
706}
707
708impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] {
709    fn read_thrift(prot: &mut R) -> Result<Self> {
710        Ok(prot.read_bytes()?)
711    }
712}
713
714/// Read a Thrift encoded [list] from the input protocol object.
715///
716/// [list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
717pub(crate) fn read_thrift_vec<'a, T, R>(prot: &mut R) -> Result<Vec<T>>
718where
719    R: ThriftCompactInputProtocol<'a>,
720    T: ReadThrift<'a, R> + WriteThrift,
721{
722    let list_ident = prot.read_list_begin()?;
723    validate_list_type(T::ELEMENT_TYPE, &list_ident)?;
724    let mut res = Vec::with_capacity(list_ident.size as usize);
725    for _ in 0..list_ident.size {
726        let val = T::read_thrift(prot)?;
727        res.push(val);
728    }
729    Ok(res)
730}
731
732pub(crate) fn validate_list_type(expected: ElementType, got: &ListIdentifier) -> Result<()> {
733    if got.element_type != expected {
734        return Err(general_err!(
735            "Expected list element type of {:?} but got {:?}",
736            expected,
737            got.element_type
738        ));
739    }
740    Ok(())
741}
742
743/////////////////////////
744// thrift compact output
745
746/// Low-level object used to serialize structs to the Thrift [compact output] protocol.
747///
748/// This struct serves as a wrapper around a [`Write`] object, to which thrift encoded data
749/// will written. The implementation provides functions to write Thrift primitive types, as well
750/// as functions used in the encoding of lists and structs. This should rarely be used directly,
751/// but is instead intended for use by implementers of [`WriteThrift`] and [`WriteThriftField`].
752///
753/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
754pub(crate) struct ThriftCompactOutputProtocol<W: Write> {
755    writer: W,
756    write_path_in_schema: bool,
757}
758
759impl<W: Write> ThriftCompactOutputProtocol<W> {
760    /// Create a new `ThriftCompactOutputProtocol` wrapping the byte sink `writer`.
761    pub(crate) fn new(writer: W) -> Self {
762        Self {
763            writer,
764            write_path_in_schema: true,
765        }
766    }
767
768    // TODO(ets): at some point there should probably be a properties object
769    // to control aspects of thrift output. But since this is the only option to date
770    // I'm choosing a simpler API.
771    /// Control the writing of the `path_in_schema` element of the `ColumnMetaData`
772    pub(crate) fn set_write_path_in_schema(&mut self, val: bool) {
773        self.write_path_in_schema = val;
774    }
775
776    /// Indicate whether or not to emit `path_in_schema`.
777    pub(crate) fn write_path_in_schema(&self) -> bool {
778        self.write_path_in_schema
779    }
780
781    /// Write a single byte to the output stream.
782    fn write_byte(&mut self, b: u8) -> Result<()> {
783        self.writer.write_all(&[b])?;
784        Ok(())
785    }
786
787    /// Write the given `u64` as a ULEB128 encoded varint.
788    fn write_vlq(&mut self, val: u64) -> Result<()> {
789        let mut v = val;
790        while v > 0x7f {
791            self.write_byte(v as u8 | 0x80)?;
792            v >>= 7;
793        }
794        self.write_byte(v as u8)
795    }
796
797    /// Write the given `i64` as a zig-zag encoded varint.
798    fn write_zig_zag(&mut self, val: i64) -> Result<()> {
799        let s = (val < 0) as i64;
800        self.write_vlq((((val ^ -s) << 1) + s) as u64)
801    }
802
803    /// Used to mark the start of a Thrift struct field of type `field_type`. `last_field_id`
804    /// is used to compute a delta to the given `field_id` per the compact protocol [spec].
805    ///
806    /// [spec]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
807    pub(crate) fn write_field_begin(
808        &mut self,
809        field_type: FieldType,
810        field_id: i16,
811        last_field_id: i16,
812    ) -> Result<()> {
813        let delta = field_id.wrapping_sub(last_field_id);
814        if delta > 0 && delta <= 0xf {
815            self.write_byte((delta as u8) << 4 | field_type as u8)
816        } else {
817            self.write_byte(field_type as u8)?;
818            self.write_i16(field_id)
819        }
820    }
821
822    /// Used to indicate the start of a list of `element_type` elements.
823    pub(crate) fn write_list_begin(&mut self, element_type: ElementType, len: usize) -> Result<()> {
824        if len < 15 {
825            self.write_byte((len as u8) << 4 | element_type as u8)
826        } else {
827            self.write_byte(0xf0u8 | element_type as u8)?;
828            self.write_vlq(len as _)
829        }
830    }
831
832    /// Used to mark the end of a struct. This must be called after all fields of the struct have
833    /// been written.
834    pub(crate) fn write_struct_end(&mut self) -> Result<()> {
835        self.write_byte(0)
836    }
837
838    /// Serialize a slice of `u8`s. This will encode a length, and then write the bytes without
839    /// further encoding.
840    pub(crate) fn write_bytes(&mut self, val: &[u8]) -> Result<()> {
841        self.write_vlq(val.len() as u64)?;
842        self.writer.write_all(val)?;
843        Ok(())
844    }
845
846    /// Short-cut method used to encode structs that have no fields (often used in Thrift unions).
847    /// This simply encodes the field id and then immediately writes the end-of-struct marker.
848    pub(crate) fn write_empty_struct(&mut self, field_id: i16, last_field_id: i16) -> Result<i16> {
849        self.write_field_begin(FieldType::Struct, field_id, last_field_id)?;
850        self.write_struct_end()?;
851        Ok(last_field_id)
852    }
853
854    /// Write a boolean value.
855    pub(crate) fn write_bool(&mut self, val: bool) -> Result<()> {
856        match val {
857            true => self.write_byte(1),
858            false => self.write_byte(2),
859        }
860    }
861
862    /// Write a zig-zag encoded `i8` value.
863    pub(crate) fn write_i8(&mut self, val: i8) -> Result<()> {
864        self.write_byte(val as u8)
865    }
866
867    /// Write a zig-zag encoded `i16` value.
868    pub(crate) fn write_i16(&mut self, val: i16) -> Result<()> {
869        self.write_zig_zag(val as _)
870    }
871
872    /// Write a zig-zag encoded `i32` value.
873    pub(crate) fn write_i32(&mut self, val: i32) -> Result<()> {
874        self.write_zig_zag(val as _)
875    }
876
877    /// Write a zig-zag encoded `i64` value.
878    pub(crate) fn write_i64(&mut self, val: i64) -> Result<()> {
879        self.write_zig_zag(val as _)
880    }
881
882    /// Write a double value.
883    pub(crate) fn write_double(&mut self, val: f64) -> Result<()> {
884        self.writer.write_all(&val.to_le_bytes())?;
885        Ok(())
886    }
887}
888
889/// Trait implemented by objects that are to be serialized to a Thrift [compact output] protocol
890/// stream. Implementations are also provided for primitive Thrift types.
891///
892/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
893pub(crate) trait WriteThrift {
894    /// The [`ElementType`] to use when a list of this object is written.
895    const ELEMENT_TYPE: ElementType;
896
897    /// Serialize this object to the given `writer`.
898    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()>;
899}
900
901/// Implementation for a vector of thrift serializable objects that implement [`WriteThrift`].
902/// This will write the necessary list header and then serialize the elements one-at-a-time.
903impl<T> WriteThrift for Vec<T>
904where
905    T: WriteThrift,
906{
907    const ELEMENT_TYPE: ElementType = ElementType::List;
908
909    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
910        writer.write_list_begin(T::ELEMENT_TYPE, self.len())?;
911        for item in self {
912            item.write_thrift(writer)?;
913        }
914        Ok(())
915    }
916}
917
918impl WriteThrift for bool {
919    const ELEMENT_TYPE: ElementType = ElementType::Bool;
920
921    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
922        writer.write_bool(*self)
923    }
924}
925
926impl WriteThrift for i8 {
927    const ELEMENT_TYPE: ElementType = ElementType::Byte;
928
929    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
930        writer.write_i8(*self)
931    }
932}
933
934impl WriteThrift for i16 {
935    const ELEMENT_TYPE: ElementType = ElementType::I16;
936
937    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
938        writer.write_i16(*self)
939    }
940}
941
942impl WriteThrift for i32 {
943    const ELEMENT_TYPE: ElementType = ElementType::I32;
944
945    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
946        writer.write_i32(*self)
947    }
948}
949
950impl WriteThrift for i64 {
951    const ELEMENT_TYPE: ElementType = ElementType::I64;
952
953    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
954        writer.write_i64(*self)
955    }
956}
957
958impl WriteThrift for OrderedF64 {
959    const ELEMENT_TYPE: ElementType = ElementType::Double;
960
961    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
962        writer.write_double(self.0)
963    }
964}
965
966impl WriteThrift for f64 {
967    const ELEMENT_TYPE: ElementType = ElementType::Double;
968
969    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
970        writer.write_double(*self)
971    }
972}
973
974impl WriteThrift for &[u8] {
975    const ELEMENT_TYPE: ElementType = ElementType::Binary;
976
977    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
978        writer.write_bytes(self)
979    }
980}
981
982impl WriteThrift for &str {
983    const ELEMENT_TYPE: ElementType = ElementType::Binary;
984
985    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
986        writer.write_bytes(self.as_bytes())
987    }
988}
989
990impl WriteThrift for String {
991    const ELEMENT_TYPE: ElementType = ElementType::Binary;
992
993    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
994        writer.write_bytes(self.as_bytes())
995    }
996}
997
998/// Trait implemented by objects that are fields of Thrift structs.
999///
1000/// For example, given the Thrift struct definition
1001/// ```ignore
1002/// struct MyStruct {
1003///   1: required i32 field1
1004///   2: optional bool field2
1005///   3: optional OtherStruct field3
1006/// }
1007/// ```
1008///
1009/// which becomes in Rust
1010/// ```no_run
1011/// # struct OtherStruct {}
1012/// struct MyStruct {
1013///   field1: i32,
1014///   field2: Option<bool>,
1015///   field3: Option<OtherStruct>,
1016/// }
1017/// ```
1018/// the impl of `WriteThrift` for `MyStruct` will use the `WriteThriftField` impls for `i32`,
1019/// `bool`, and `OtherStruct`.
1020///
1021/// ```ignore
1022/// impl WriteThrift for MyStruct {
1023///   fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
1024///     let mut last_field_id = 0i16;
1025///     last_field_id = self.field1.write_thrift_field(writer, 1, last_field_id)?;
1026///     if self.field2.is_some() {
1027///       // if field2 is `None` then this assignment won't happen and last_field_id will remain
1028///       // `1` when writing `field3`
1029///       last_field_id = self.field2.write_thrift_field(writer, 2, last_field_id)?;
1030///     }
1031///     if self.field3.is_some() {
1032///       // no need to assign last_field_id since this is the final field.
1033///       self.field3.write_thrift_field(writer, 3, last_field_id)?;
1034///     }
1035///     writer.write_struct_end()
1036///   }
1037/// }
1038/// ```
1039///
1040pub(crate) trait WriteThriftField {
1041    /// Used to write struct fields (which may be primitive or IDL defined types). This will
1042    /// write the field marker for the given `field_id`, using `last_field_id` to compute the
1043    /// field delta used by the Thrift [compact protocol]. On success this will return `field_id`
1044    /// to be used in chaining.
1045    ///
1046    /// [compact protocol]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
1047    fn write_thrift_field<W: Write>(
1048        &self,
1049        writer: &mut ThriftCompactOutputProtocol<W>,
1050        field_id: i16,
1051        last_field_id: i16,
1052    ) -> Result<i16>;
1053}
1054
1055// bool struct fields are written differently to bool values
1056impl WriteThriftField for bool {
1057    fn write_thrift_field<W: Write>(
1058        &self,
1059        writer: &mut ThriftCompactOutputProtocol<W>,
1060        field_id: i16,
1061        last_field_id: i16,
1062    ) -> Result<i16> {
1063        // boolean only writes the field header
1064        match *self {
1065            true => writer.write_field_begin(FieldType::BooleanTrue, field_id, last_field_id)?,
1066            false => writer.write_field_begin(FieldType::BooleanFalse, field_id, last_field_id)?,
1067        }
1068        Ok(field_id)
1069    }
1070}
1071
1072write_thrift_field!(i8, FieldType::Byte);
1073write_thrift_field!(i16, FieldType::I16);
1074write_thrift_field!(i32, FieldType::I32);
1075write_thrift_field!(i64, FieldType::I64);
1076write_thrift_field!(OrderedF64, FieldType::Double);
1077write_thrift_field!(f64, FieldType::Double);
1078write_thrift_field!(String, FieldType::Binary);
1079
1080impl WriteThriftField for &[u8] {
1081    fn write_thrift_field<W: Write>(
1082        &self,
1083        writer: &mut ThriftCompactOutputProtocol<W>,
1084        field_id: i16,
1085        last_field_id: i16,
1086    ) -> Result<i16> {
1087        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1088        writer.write_bytes(self)?;
1089        Ok(field_id)
1090    }
1091}
1092
1093impl WriteThriftField for &str {
1094    fn write_thrift_field<W: Write>(
1095        &self,
1096        writer: &mut ThriftCompactOutputProtocol<W>,
1097        field_id: i16,
1098        last_field_id: i16,
1099    ) -> Result<i16> {
1100        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1101        writer.write_bytes(self.as_bytes())?;
1102        Ok(field_id)
1103    }
1104}
1105
1106impl<T> WriteThriftField for Vec<T>
1107where
1108    T: WriteThrift,
1109{
1110    fn write_thrift_field<W: Write>(
1111        &self,
1112        writer: &mut ThriftCompactOutputProtocol<W>,
1113        field_id: i16,
1114        last_field_id: i16,
1115    ) -> Result<i16> {
1116        writer.write_field_begin(FieldType::List, field_id, last_field_id)?;
1117        self.write_thrift(writer)?;
1118        Ok(field_id)
1119    }
1120}
1121
1122#[cfg(test)]
1123pub(crate) mod tests {
1124    use crate::basic::{TimeUnit, Type};
1125
1126    use super::*;
1127    use std::fmt::Debug;
1128
1129    pub(crate) fn test_roundtrip<T>(val: T)
1130    where
1131        T: for<'a> ReadThrift<'a, ThriftSliceInputProtocol<'a>> + WriteThrift + PartialEq + Debug,
1132    {
1133        let mut buf = Vec::<u8>::new();
1134        {
1135            let mut writer = ThriftCompactOutputProtocol::new(&mut buf);
1136            val.write_thrift(&mut writer).unwrap();
1137        }
1138
1139        let mut prot = ThriftSliceInputProtocol::new(&buf);
1140        let read_val = T::read_thrift(&mut prot).unwrap();
1141        assert_eq!(val, read_val);
1142    }
1143
1144    #[test]
1145    fn test_enum_roundtrip() {
1146        test_roundtrip(Type::BOOLEAN);
1147        test_roundtrip(Type::INT32);
1148        test_roundtrip(Type::INT64);
1149        test_roundtrip(Type::INT96);
1150        test_roundtrip(Type::FLOAT);
1151        test_roundtrip(Type::DOUBLE);
1152        test_roundtrip(Type::BYTE_ARRAY);
1153        test_roundtrip(Type::FIXED_LEN_BYTE_ARRAY);
1154    }
1155
1156    #[test]
1157    fn test_union_all_empty_roundtrip() {
1158        test_roundtrip(TimeUnit::MILLIS);
1159        test_roundtrip(TimeUnit::MICROS);
1160        test_roundtrip(TimeUnit::NANOS);
1161    }
1162
1163    #[test]
1164    fn test_decode_empty_list() {
1165        let data = vec![0u8; 1];
1166        let mut prot = ThriftSliceInputProtocol::new(&data);
1167        let header = prot.read_list_begin().expect("error reading list header");
1168        assert_eq!(header.size, 0);
1169        assert_eq!(header.element_type, ElementType::Byte);
1170    }
1171
1172    /// A Thrift list header whose `size` varint decodes above `i32::MAX`
1173    /// must be rejected at the protocol layer rather than wrapping into a
1174    /// negative `i32` and being smuggled into downstream allocation code.
1175    #[test]
1176    fn test_read_list_begin_size_above_i32_max_returns_err() {
1177        // List header: element_type=8 (Binary), 0xF=follow-up varint.
1178        // Varint 80 80 80 80 08 decodes to 0x8000_0000 = i32::MAX + 1.
1179        let mut data: Vec<u8> = vec![0xF8];
1180        data.extend_from_slice(&[0x80, 0x80, 0x80, 0x80, 0x08]);
1181        let mut prot = ThriftSliceInputProtocol::new(&data);
1182        let result = prot.read_list_begin();
1183        assert!(result.is_err(), "expected error, got {result:?}");
1184    }
1185
1186    #[test]
1187    fn test_read_list_wrong_type() {
1188        // list header: 4 elements of `Boolean`
1189        let data = [0x42, 0x01];
1190        let mut prot = ThriftSliceInputProtocol::new(&data);
1191        // try to read as list<i32>
1192        let result = read_thrift_vec::<i32, ThriftSliceInputProtocol>(&mut prot);
1193        println!("{result:?}");
1194        assert!(
1195            result
1196                .unwrap_err()
1197                .to_string()
1198                .contains("Expected list element type of I32 but got Bool")
1199        );
1200    }
1201}