parquet/
parquet_thrift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Structs used for encoding and decoding Parquet Thrift objects.
19//!
20//! These include:
21//! * [`ThriftCompactInputProtocol`]: Trait implemented by Thrift decoders.
22//!     * [`ThriftSliceInputProtocol`]: Thrift decoder that takes a slice of bytes as input.
23//!     * [`ThriftReadInputProtocol`]: Thrift decoder that takes a [`Read`] as input.
24//! * [`ReadThrift`]: Trait implemented by serializable objects.
25//! * [`ThriftCompactOutputProtocol`]: Thrift encoder.
26//! * [`WriteThrift`]: Trait implemented by serializable objects.
27//! * [`WriteThriftField`]: Trait implemented by serializable objects that are fields in Thrift structs.
28
29use std::{
30    cmp::Ordering,
31    io::{Read, Write},
32};
33
34use crate::{
35    errors::{ParquetError, Result},
36    write_thrift_field,
37};
38use std::io::Error;
39use std::str::Utf8Error;
40
41#[derive(Debug)]
42pub(crate) enum ThriftProtocolError {
43    Eof,
44    IO(Error),
45    InvalidFieldType(u8),
46    InvalidElementType(u8),
47    FieldDeltaOverflow { field_delta: u8, last_field_id: i16 },
48    InvalidBoolean(u8),
49    Utf8Error,
50    SkipDepth(FieldType),
51    SkipUnsupportedType(FieldType),
52}
53
54impl From<ThriftProtocolError> for ParquetError {
55    #[inline(never)]
56    fn from(e: ThriftProtocolError) -> Self {
57        match e {
58            ThriftProtocolError::Eof => eof_err!("Unexpected EOF"),
59            ThriftProtocolError::IO(e) => e.into(),
60            ThriftProtocolError::InvalidFieldType(value) => {
61                general_err!("Unexpected struct field type {}", value)
62            }
63            ThriftProtocolError::InvalidElementType(value) => {
64                general_err!("Unexpected list/set element type {}", value)
65            }
66            ThriftProtocolError::FieldDeltaOverflow {
67                field_delta,
68                last_field_id,
69            } => general_err!("cannot add {} to {}", field_delta, last_field_id),
70            ThriftProtocolError::InvalidBoolean(value) => {
71                general_err!("cannot convert {} into bool", value)
72            }
73            ThriftProtocolError::Utf8Error => general_err!("invalid utf8"),
74            ThriftProtocolError::SkipDepth(field_type) => {
75                general_err!("cannot parse past {:?}", field_type)
76            }
77            ThriftProtocolError::SkipUnsupportedType(field_type) => {
78                general_err!("cannot skip field type {:?}", field_type)
79            }
80        }
81    }
82}
83
84impl From<Utf8Error> for ThriftProtocolError {
85    fn from(_: Utf8Error) -> Self {
86        // ignore error payload to reduce the size of ThriftProtocolError
87        Self::Utf8Error
88    }
89}
90
91impl From<Error> for ThriftProtocolError {
92    fn from(e: Error) -> Self {
93        Self::IO(e)
94    }
95}
96
97pub type ThriftProtocolResult<T> = Result<T, ThriftProtocolError>;
98
99/// Wrapper for thrift `double` fields. This is used to provide
100/// an implementation of `Eq` for floats. This implementation
101/// uses IEEE 754 total order.
102#[derive(Debug, Clone, Copy, PartialEq)]
103pub struct OrderedF64(f64);
104
105impl From<f64> for OrderedF64 {
106    fn from(value: f64) -> Self {
107        Self(value)
108    }
109}
110
111impl From<OrderedF64> for f64 {
112    fn from(value: OrderedF64) -> Self {
113        value.0
114    }
115}
116
117impl Eq for OrderedF64 {} // Marker trait, requires PartialEq
118
119impl Ord for OrderedF64 {
120    fn cmp(&self, other: &Self) -> Ordering {
121        self.0.total_cmp(&other.0)
122    }
123}
124
125impl PartialOrd for OrderedF64 {
126    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
127        Some(self.cmp(other))
128    }
129}
130
131// Thrift compact protocol types for struct fields.
132#[derive(Clone, Copy, Debug, Eq, PartialEq)]
133pub(crate) enum FieldType {
134    Stop = 0,
135    BooleanTrue = 1,
136    BooleanFalse = 2,
137    Byte = 3,
138    I16 = 4,
139    I32 = 5,
140    I64 = 6,
141    Double = 7,
142    Binary = 8,
143    List = 9,
144    Set = 10,
145    Map = 11,
146    Struct = 12,
147}
148
149impl TryFrom<u8> for FieldType {
150    type Error = ThriftProtocolError;
151    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
152        match value {
153            0 => Ok(Self::Stop),
154            1 => Ok(Self::BooleanTrue),
155            2 => Ok(Self::BooleanFalse),
156            3 => Ok(Self::Byte),
157            4 => Ok(Self::I16),
158            5 => Ok(Self::I32),
159            6 => Ok(Self::I64),
160            7 => Ok(Self::Double),
161            8 => Ok(Self::Binary),
162            9 => Ok(Self::List),
163            10 => Ok(Self::Set),
164            11 => Ok(Self::Map),
165            12 => Ok(Self::Struct),
166            _ => Err(ThriftProtocolError::InvalidFieldType(value)),
167        }
168    }
169}
170
171impl TryFrom<ElementType> for FieldType {
172    type Error = ThriftProtocolError;
173    fn try_from(value: ElementType) -> std::result::Result<Self, Self::Error> {
174        match value {
175            ElementType::Bool => Ok(Self::BooleanTrue),
176            ElementType::Byte => Ok(Self::Byte),
177            ElementType::I16 => Ok(Self::I16),
178            ElementType::I32 => Ok(Self::I32),
179            ElementType::I64 => Ok(Self::I64),
180            ElementType::Double => Ok(Self::Double),
181            ElementType::Binary => Ok(Self::Binary),
182            ElementType::List => Ok(Self::List),
183            ElementType::Struct => Ok(Self::Struct),
184            _ => Err(ThriftProtocolError::InvalidFieldType(value as u8)),
185        }
186    }
187}
188
189// Thrift compact protocol types for list elements
190#[derive(Clone, Copy, Debug, Eq, PartialEq)]
191pub(crate) enum ElementType {
192    Bool = 2,
193    Byte = 3,
194    I16 = 4,
195    I32 = 5,
196    I64 = 6,
197    Double = 7,
198    Binary = 8,
199    List = 9,
200    Set = 10,
201    Map = 11,
202    Struct = 12,
203}
204
205impl TryFrom<u8> for ElementType {
206    type Error = ThriftProtocolError;
207    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
208        match value {
209            // For historical and compatibility reasons, a reader should be capable to deal with both cases.
210            // The only valid value in the original spec was 2, but due to an widespread implementation bug
211            // the defacto standard across large parts of the library became 1 instead.
212            // As a result, both values are now allowed.
213            // https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
214            1 | 2 => Ok(Self::Bool),
215            3 => Ok(Self::Byte),
216            4 => Ok(Self::I16),
217            5 => Ok(Self::I32),
218            6 => Ok(Self::I64),
219            7 => Ok(Self::Double),
220            8 => Ok(Self::Binary),
221            9 => Ok(Self::List),
222            10 => Ok(Self::Set),
223            11 => Ok(Self::Map),
224            12 => Ok(Self::Struct),
225            _ => Err(ThriftProtocolError::InvalidElementType(value)),
226        }
227    }
228}
229
230/// Struct used to describe a [thrift struct] field during decoding.
231///
232/// [thrift struct]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
233pub(crate) struct FieldIdentifier {
234    /// The type for the field.
235    pub(crate) field_type: FieldType,
236    /// The field's `id`. May be computed from delta or directly decoded.
237    pub(crate) id: i16,
238    /// Stores the value for booleans.
239    ///
240    /// Boolean fields store no data, instead the field type is either boolean true, or
241    /// boolean false.
242    pub(crate) bool_val: Option<bool>,
243}
244
245/// Struct used to describe a [thrift list].
246///
247/// [thrift list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
248#[derive(Clone, Debug, Eq, PartialEq)]
249pub(crate) struct ListIdentifier {
250    /// The type for each element in the list.
251    pub(crate) element_type: ElementType,
252    /// Number of elements contained in the list.
253    pub(crate) size: i32,
254}
255
256/// Low-level object used to deserialize structs encoded with the Thrift [compact] protocol.
257///
258/// Implementation of this trait must provide the low-level functions `read_byte`, `read_bytes`,
259/// `skip_bytes`, and `read_double`. These primitives are used by the default functions provided
260/// here to perform deserialization.
261///
262/// [compact]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
263pub(crate) trait ThriftCompactInputProtocol<'a> {
264    /// Read a single byte from the input.
265    fn read_byte(&mut self) -> ThriftProtocolResult<u8>;
266
267    /// Read a Thrift encoded [binary] from the input.
268    ///
269    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
270    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]>;
271
272    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>>;
273
274    /// Skip the next `n` bytes of input.
275    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()>;
276
277    /// Read a ULEB128 encoded unsigned varint from the input.
278    fn read_vlq(&mut self) -> ThriftProtocolResult<u64> {
279        // try the happy path first
280        let byte = self.read_byte()?;
281        if byte & 0x80 == 0 {
282            return Ok(byte as u64);
283        }
284        let mut in_progress = (byte & 0x7f) as u64;
285        let mut shift = 7;
286        loop {
287            let byte = self.read_byte()?;
288            in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
289            if byte & 0x80 == 0 {
290                return Ok(in_progress);
291            }
292            shift += 7;
293        }
294    }
295
296    /// Read a zig-zag encoded signed varint from the input.
297    fn read_zig_zag(&mut self) -> ThriftProtocolResult<i64> {
298        let val = self.read_vlq()?;
299        Ok((val >> 1) as i64 ^ -((val & 1) as i64))
300    }
301
302    /// Read the [`ListIdentifier`] for a Thrift encoded list.
303    fn read_list_begin(&mut self) -> ThriftProtocolResult<ListIdentifier> {
304        let header = self.read_byte()?;
305        // some parquet writers will have an element_type of 0 for an empty list.
306        // account for that and return a bogus but valid element_type.
307        if header == 0 {
308            return Ok(ListIdentifier {
309                element_type: ElementType::Byte,
310                size: 0,
311            });
312        }
313        let element_type = ElementType::try_from(header & 0x0f)?;
314
315        let possible_element_count = (header & 0xF0) >> 4;
316        let element_count = if possible_element_count != 15 {
317            // high bits set high if count and type encoded separately
318            possible_element_count as i32
319        } else {
320            self.read_vlq()? as _
321        };
322
323        Ok(ListIdentifier {
324            element_type,
325            size: element_count,
326        })
327    }
328
329    // Full field ids are uncommon.
330    // Not inlining this method reduces the code size of `read_field_begin`, which then ideally gets
331    // inlined everywhere.
332    #[cold]
333    fn read_full_field_id(&mut self) -> ThriftProtocolResult<i16> {
334        self.read_i16()
335    }
336
337    /// Read the [`FieldIdentifier`] for a field in a Thrift encoded struct.
338    fn read_field_begin(&mut self, last_field_id: i16) -> ThriftProtocolResult<FieldIdentifier> {
339        // we can read at least one byte, which is:
340        // - the type
341        // - the field delta and the type
342        let field_type = self.read_byte()?;
343        let field_delta = (field_type & 0xf0) >> 4;
344        let field_type = FieldType::try_from(field_type & 0xf)?;
345        let mut bool_val: Option<bool> = None;
346
347        match field_type {
348            FieldType::Stop => Ok(FieldIdentifier {
349                field_type: FieldType::Stop,
350                id: 0,
351                bool_val,
352            }),
353            _ => {
354                // special handling for bools
355                if field_type == FieldType::BooleanFalse {
356                    bool_val = Some(false);
357                } else if field_type == FieldType::BooleanTrue {
358                    bool_val = Some(true);
359                }
360                let field_id = if field_delta != 0 {
361                    last_field_id.checked_add(field_delta as i16).ok_or(
362                        ThriftProtocolError::FieldDeltaOverflow {
363                            field_delta,
364                            last_field_id,
365                        },
366                    )?
367                } else {
368                    self.read_full_field_id()?
369                };
370
371                Ok(FieldIdentifier {
372                    field_type,
373                    id: field_id,
374                    bool_val,
375                })
376            }
377        }
378    }
379
380    /// This is a specialized version of [`Self::read_field_begin`], solely for use in parsing
381    /// simple structs. This function assumes that the delta field will always be less than 0xf,
382    /// fields will be in order, and no boolean fields will be read.
383    /// This also skips validation of the field type.
384    ///
385    /// Returns a tuple of `(field_type, field_delta)`.
386    fn read_field_header(&mut self) -> ThriftProtocolResult<(u8, u8)> {
387        let field_type = self.read_byte()?;
388        let field_delta = (field_type & 0xf0) >> 4;
389        let field_type = field_type & 0xf;
390        Ok((field_type, field_delta))
391    }
392
393    /// Read a boolean list element. This should not be used for struct fields. For the latter,
394    /// use the [`FieldIdentifier::bool_val`] field.
395    fn read_bool(&mut self) -> ThriftProtocolResult<bool> {
396        let b = self.read_byte()?;
397        // Previous versions of the thrift specification said to use 0 and 1 inside collections,
398        // but that differed from existing implementations.
399        // The specification was updated in https://github.com/apache/thrift/commit/2c29c5665bc442e703480bb0ee60fe925ffe02e8.
400        // At least the go implementation seems to have followed the previously documented values.
401        match b {
402            0x01 => Ok(true),
403            0x00 | 0x02 => Ok(false),
404            _ => Err(ThriftProtocolError::InvalidBoolean(b)),
405        }
406    }
407
408    /// Read a Thrift [binary] as a UTF-8 encoded string.
409    ///
410    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
411    fn read_string(&mut self) -> ThriftProtocolResult<&'a str> {
412        let slice = self.read_bytes()?;
413        Ok(std::str::from_utf8(slice)?)
414    }
415
416    /// Read an `i8`.
417    fn read_i8(&mut self) -> ThriftProtocolResult<i8> {
418        Ok(self.read_byte()? as _)
419    }
420
421    /// Read an `i16`.
422    fn read_i16(&mut self) -> ThriftProtocolResult<i16> {
423        Ok(self.read_zig_zag()? as _)
424    }
425
426    /// Read an `i32`.
427    fn read_i32(&mut self) -> ThriftProtocolResult<i32> {
428        Ok(self.read_zig_zag()? as _)
429    }
430
431    /// Read an `i64`.
432    fn read_i64(&mut self) -> ThriftProtocolResult<i64> {
433        self.read_zig_zag()
434    }
435
436    /// Read a Thrift `double` as `f64`.
437    fn read_double(&mut self) -> ThriftProtocolResult<f64>;
438
439    /// Skip a ULEB128 encoded varint.
440    fn skip_vlq(&mut self) -> ThriftProtocolResult<()> {
441        loop {
442            let byte = self.read_byte()?;
443            if byte & 0x80 == 0 {
444                return Ok(());
445            }
446        }
447    }
448
449    /// Skip a thrift [binary].
450    ///
451    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
452    fn skip_binary(&mut self) -> ThriftProtocolResult<()> {
453        let len = self.read_vlq()? as usize;
454        self.skip_bytes(len)
455    }
456
457    /// Skip a field with type `field_type` recursively until the default
458    /// maximum skip depth (currently 64) is reached.
459    fn skip(&mut self, field_type: FieldType) -> ThriftProtocolResult<()> {
460        const DEFAULT_SKIP_DEPTH: i8 = 64;
461        self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH)
462    }
463
464    /// Empty structs in unions consist of a single byte of 0 for the field stop record.
465    /// This skips that byte without encuring the cost of processing the [`FieldIdentifier`].
466    /// Will return an error if the struct is not actually empty.
467    fn skip_empty_struct(&mut self) -> Result<()> {
468        let b = self.read_byte()?;
469        if b != 0 {
470            Err(general_err!("Empty struct has fields"))
471        } else {
472            Ok(())
473        }
474    }
475
476    /// Skip a field with type `field_type` recursively up to `depth` levels.
477    fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> ThriftProtocolResult<()> {
478        if depth == 0 {
479            return Err(ThriftProtocolError::SkipDepth(field_type));
480        }
481
482        match field_type {
483            // boolean field has no data
484            FieldType::BooleanFalse | FieldType::BooleanTrue => Ok(()),
485            FieldType::Byte => self.read_i8().map(|_| ()),
486            FieldType::I16 => self.skip_vlq().map(|_| ()),
487            FieldType::I32 => self.skip_vlq().map(|_| ()),
488            FieldType::I64 => self.skip_vlq().map(|_| ()),
489            FieldType::Double => self.skip_bytes(8).map(|_| ()),
490            FieldType::Binary => self.skip_binary().map(|_| ()),
491            FieldType::Struct => {
492                let mut last_field_id = 0i16;
493                loop {
494                    let field_ident = self.read_field_begin(last_field_id)?;
495                    if field_ident.field_type == FieldType::Stop {
496                        break;
497                    }
498                    self.skip_till_depth(field_ident.field_type, depth - 1)?;
499                    last_field_id = field_ident.id;
500                }
501                Ok(())
502            }
503            FieldType::List => {
504                let list_ident = self.read_list_begin()?;
505                for _ in 0..list_ident.size {
506                    let element_type = FieldType::try_from(list_ident.element_type)?;
507                    self.skip_till_depth(element_type, depth - 1)?;
508                }
509                Ok(())
510            }
511            // no list or map types in parquet format
512            _ => Err(ThriftProtocolError::SkipUnsupportedType(field_type)),
513        }
514    }
515}
516
517/// A high performance Thrift reader that reads from a slice of bytes.
518pub(crate) struct ThriftSliceInputProtocol<'a> {
519    buf: &'a [u8],
520}
521
522impl<'a> ThriftSliceInputProtocol<'a> {
523    /// Create a new `ThriftSliceInputProtocol` using the bytes in `buf`.
524    pub fn new(buf: &'a [u8]) -> Self {
525        Self { buf }
526    }
527
528    /// Return the current buffer as a slice.
529    pub fn as_slice(&self) -> &'a [u8] {
530        self.buf
531    }
532}
533
534impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for ThriftSliceInputProtocol<'a> {
535    #[inline]
536    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
537        let ret = *self.buf.first().ok_or(ThriftProtocolError::Eof)?;
538        self.buf = &self.buf[1..];
539        Ok(ret)
540    }
541
542    fn read_bytes(&mut self) -> ThriftProtocolResult<&'b [u8]> {
543        let len = self.read_vlq()? as usize;
544        let ret = self.buf.get(..len).ok_or(ThriftProtocolError::Eof)?;
545        self.buf = &self.buf[len..];
546        Ok(ret)
547    }
548
549    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
550        Ok(self.read_bytes()?.to_vec())
551    }
552
553    #[inline]
554    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
555        self.buf.get(..n).ok_or(ThriftProtocolError::Eof)?;
556        self.buf = &self.buf[n..];
557        Ok(())
558    }
559
560    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
561        let slice = self.buf.get(..8).ok_or(ThriftProtocolError::Eof)?;
562        self.buf = &self.buf[8..];
563        match slice.try_into() {
564            Ok(slice) => Ok(f64::from_le_bytes(slice)),
565            Err(_) => unreachable!(),
566        }
567    }
568}
569
570/// A Thrift input protocol that wraps a [`Read`] object.
571///
572/// Note that this is only intended for use in reading Parquet page headers. This will panic
573/// if Thrift `binary` data is encountered because a slice of that data cannot be returned.
574pub(crate) struct ThriftReadInputProtocol<R: Read> {
575    reader: R,
576}
577
578impl<R: Read> ThriftReadInputProtocol<R> {
579    pub(crate) fn new(reader: R) -> Self {
580        Self { reader }
581    }
582}
583
584impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol<R> {
585    #[inline]
586    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
587        let mut buf = [0_u8; 1];
588        self.reader.read_exact(&mut buf)?;
589        Ok(buf[0])
590    }
591
592    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]> {
593        unimplemented!()
594    }
595
596    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
597        let len = self.read_vlq()? as usize;
598        let mut v = Vec::with_capacity(len);
599        std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?;
600        Ok(v)
601    }
602
603    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
604        std::io::copy(
605            &mut self.reader.by_ref().take(n as u64),
606            &mut std::io::sink(),
607        )?;
608        Ok(())
609    }
610
611    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
612        let mut buf = [0_u8; 8];
613        self.reader.read_exact(&mut buf)?;
614        Ok(f64::from_le_bytes(buf))
615    }
616}
617
618/// Trait implemented for objects that can be deserialized from a Thrift input stream.
619/// Implementations are provided for Thrift primitive types.
620pub(crate) trait ReadThrift<'a, R: ThriftCompactInputProtocol<'a>> {
621    /// Read an object of type `Self` from the input protocol object.
622    fn read_thrift(prot: &mut R) -> Result<Self>
623    where
624        Self: Sized;
625}
626
627impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool {
628    fn read_thrift(prot: &mut R) -> Result<Self> {
629        Ok(prot.read_bool()?)
630    }
631}
632
633impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 {
634    fn read_thrift(prot: &mut R) -> Result<Self> {
635        Ok(prot.read_i8()?)
636    }
637}
638
639impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 {
640    fn read_thrift(prot: &mut R) -> Result<Self> {
641        Ok(prot.read_i16()?)
642    }
643}
644
645impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 {
646    fn read_thrift(prot: &mut R) -> Result<Self> {
647        Ok(prot.read_i32()?)
648    }
649}
650
651impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 {
652    fn read_thrift(prot: &mut R) -> Result<Self> {
653        Ok(prot.read_i64()?)
654    }
655}
656
657impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for OrderedF64 {
658    fn read_thrift(prot: &mut R) -> Result<Self> {
659        Ok(OrderedF64(prot.read_double()?))
660    }
661}
662
663impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str {
664    fn read_thrift(prot: &mut R) -> Result<Self> {
665        Ok(prot.read_string()?)
666    }
667}
668
669impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for String {
670    fn read_thrift(prot: &mut R) -> Result<Self> {
671        Ok(String::from_utf8(prot.read_bytes_owned()?)?)
672    }
673}
674
675impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] {
676    fn read_thrift(prot: &mut R) -> Result<Self> {
677        Ok(prot.read_bytes()?)
678    }
679}
680
681/// Read a Thrift encoded [list] from the input protocol object.
682///
683/// [list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
684pub(crate) fn read_thrift_vec<'a, T, R>(prot: &mut R) -> Result<Vec<T>>
685where
686    R: ThriftCompactInputProtocol<'a>,
687    T: ReadThrift<'a, R>,
688{
689    let list_ident = prot.read_list_begin()?;
690    let mut res = Vec::with_capacity(list_ident.size as usize);
691    for _ in 0..list_ident.size {
692        let val = T::read_thrift(prot)?;
693        res.push(val);
694    }
695    Ok(res)
696}
697
698/////////////////////////
699// thrift compact output
700
701/// Low-level object used to serialize structs to the Thrift [compact output] protocol.
702///
703/// This struct serves as a wrapper around a [`Write`] object, to which thrift encoded data
704/// will written. The implementation provides functions to write Thrift primitive types, as well
705/// as functions used in the encoding of lists and structs. This should rarely be used directly,
706/// but is instead intended for use by implementers of [`WriteThrift`] and [`WriteThriftField`].
707///
708/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
709pub(crate) struct ThriftCompactOutputProtocol<W: Write> {
710    writer: W,
711}
712
713impl<W: Write> ThriftCompactOutputProtocol<W> {
714    /// Create a new `ThriftCompactOutputProtocol` wrapping the byte sink `writer`.
715    pub(crate) fn new(writer: W) -> Self {
716        Self { writer }
717    }
718
719    /// Write a single byte to the output stream.
720    fn write_byte(&mut self, b: u8) -> Result<()> {
721        self.writer.write_all(&[b])?;
722        Ok(())
723    }
724
725    /// Write the given `u64` as a ULEB128 encoded varint.
726    fn write_vlq(&mut self, val: u64) -> Result<()> {
727        let mut v = val;
728        while v > 0x7f {
729            self.write_byte(v as u8 | 0x80)?;
730            v >>= 7;
731        }
732        self.write_byte(v as u8)
733    }
734
735    /// Write the given `i64` as a zig-zag encoded varint.
736    fn write_zig_zag(&mut self, val: i64) -> Result<()> {
737        let s = (val < 0) as i64;
738        self.write_vlq((((val ^ -s) << 1) + s) as u64)
739    }
740
741    /// Used to mark the start of a Thrift struct field of type `field_type`. `last_field_id`
742    /// is used to compute a delta to the given `field_id` per the compact protocol [spec].
743    ///
744    /// [spec]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
745    pub(crate) fn write_field_begin(
746        &mut self,
747        field_type: FieldType,
748        field_id: i16,
749        last_field_id: i16,
750    ) -> Result<()> {
751        let delta = field_id.wrapping_sub(last_field_id);
752        if delta > 0 && delta <= 0xf {
753            self.write_byte((delta as u8) << 4 | field_type as u8)
754        } else {
755            self.write_byte(field_type as u8)?;
756            self.write_i16(field_id)
757        }
758    }
759
760    /// Used to indicate the start of a list of `element_type` elements.
761    pub(crate) fn write_list_begin(&mut self, element_type: ElementType, len: usize) -> Result<()> {
762        if len < 15 {
763            self.write_byte((len as u8) << 4 | element_type as u8)
764        } else {
765            self.write_byte(0xf0u8 | element_type as u8)?;
766            self.write_vlq(len as _)
767        }
768    }
769
770    /// Used to mark the end of a struct. This must be called after all fields of the struct have
771    /// been written.
772    pub(crate) fn write_struct_end(&mut self) -> Result<()> {
773        self.write_byte(0)
774    }
775
776    /// Serialize a slice of `u8`s. This will encode a length, and then write the bytes without
777    /// further encoding.
778    pub(crate) fn write_bytes(&mut self, val: &[u8]) -> Result<()> {
779        self.write_vlq(val.len() as u64)?;
780        self.writer.write_all(val)?;
781        Ok(())
782    }
783
784    /// Short-cut method used to encode structs that have no fields (often used in Thrift unions).
785    /// This simply encodes the field id and then immediately writes the end-of-struct marker.
786    pub(crate) fn write_empty_struct(&mut self, field_id: i16, last_field_id: i16) -> Result<i16> {
787        self.write_field_begin(FieldType::Struct, field_id, last_field_id)?;
788        self.write_struct_end()?;
789        Ok(last_field_id)
790    }
791
792    /// Write a boolean value.
793    pub(crate) fn write_bool(&mut self, val: bool) -> Result<()> {
794        match val {
795            true => self.write_byte(1),
796            false => self.write_byte(2),
797        }
798    }
799
800    /// Write a zig-zag encoded `i8` value.
801    pub(crate) fn write_i8(&mut self, val: i8) -> Result<()> {
802        self.write_byte(val as u8)
803    }
804
805    /// Write a zig-zag encoded `i16` value.
806    pub(crate) fn write_i16(&mut self, val: i16) -> Result<()> {
807        self.write_zig_zag(val as _)
808    }
809
810    /// Write a zig-zag encoded `i32` value.
811    pub(crate) fn write_i32(&mut self, val: i32) -> Result<()> {
812        self.write_zig_zag(val as _)
813    }
814
815    /// Write a zig-zag encoded `i64` value.
816    pub(crate) fn write_i64(&mut self, val: i64) -> Result<()> {
817        self.write_zig_zag(val as _)
818    }
819
820    /// Write a double value.
821    pub(crate) fn write_double(&mut self, val: f64) -> Result<()> {
822        self.writer.write_all(&val.to_le_bytes())?;
823        Ok(())
824    }
825}
826
827/// Trait implemented by objects that are to be serialized to a Thrift [compact output] protocol
828/// stream. Implementations are also provided for primitive Thrift types.
829///
830/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
831pub(crate) trait WriteThrift {
832    /// The [`ElementType`] to use when a list of this object is written.
833    const ELEMENT_TYPE: ElementType;
834
835    /// Serialize this object to the given `writer`.
836    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()>;
837}
838
839/// Implementation for a vector of thrift serializable objects that implement [`WriteThrift`].
840/// This will write the necessary list header and then serialize the elements one-at-a-time.
841impl<T> WriteThrift for Vec<T>
842where
843    T: WriteThrift,
844{
845    const ELEMENT_TYPE: ElementType = ElementType::List;
846
847    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
848        writer.write_list_begin(T::ELEMENT_TYPE, self.len())?;
849        for item in self {
850            item.write_thrift(writer)?;
851        }
852        Ok(())
853    }
854}
855
856impl WriteThrift for bool {
857    const ELEMENT_TYPE: ElementType = ElementType::Bool;
858
859    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
860        writer.write_bool(*self)
861    }
862}
863
864impl WriteThrift for i8 {
865    const ELEMENT_TYPE: ElementType = ElementType::Byte;
866
867    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
868        writer.write_i8(*self)
869    }
870}
871
872impl WriteThrift for i16 {
873    const ELEMENT_TYPE: ElementType = ElementType::I16;
874
875    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
876        writer.write_i16(*self)
877    }
878}
879
880impl WriteThrift for i32 {
881    const ELEMENT_TYPE: ElementType = ElementType::I32;
882
883    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
884        writer.write_i32(*self)
885    }
886}
887
888impl WriteThrift for i64 {
889    const ELEMENT_TYPE: ElementType = ElementType::I64;
890
891    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
892        writer.write_i64(*self)
893    }
894}
895
896impl WriteThrift for OrderedF64 {
897    const ELEMENT_TYPE: ElementType = ElementType::Double;
898
899    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
900        writer.write_double(self.0)
901    }
902}
903
904impl WriteThrift for f64 {
905    const ELEMENT_TYPE: ElementType = ElementType::Double;
906
907    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
908        writer.write_double(*self)
909    }
910}
911
912impl WriteThrift for &[u8] {
913    const ELEMENT_TYPE: ElementType = ElementType::Binary;
914
915    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
916        writer.write_bytes(self)
917    }
918}
919
920impl WriteThrift for &str {
921    const ELEMENT_TYPE: ElementType = ElementType::Binary;
922
923    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
924        writer.write_bytes(self.as_bytes())
925    }
926}
927
928impl WriteThrift for String {
929    const ELEMENT_TYPE: ElementType = ElementType::Binary;
930
931    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
932        writer.write_bytes(self.as_bytes())
933    }
934}
935
936/// Trait implemented by objects that are fields of Thrift structs.
937///
938/// For example, given the Thrift struct definition
939/// ```ignore
940/// struct MyStruct {
941///   1: required i32 field1
942///   2: optional bool field2
943///   3: optional OtherStruct field3
944/// }
945/// ```
946///
947/// which becomes in Rust
948/// ```no_run
949/// # struct OtherStruct {}
950/// struct MyStruct {
951///   field1: i32,
952///   field2: Option<bool>,
953///   field3: Option<OtherStruct>,
954/// }
955/// ```
956/// the impl of `WriteThrift` for `MyStruct` will use the `WriteThriftField` impls for `i32`,
957/// `bool`, and `OtherStruct`.
958///
959/// ```ignore
960/// impl WriteThrift for MyStruct {
961///   fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
962///     let mut last_field_id = 0i16;
963///     last_field_id = self.field1.write_thrift_field(writer, 1, last_field_id)?;
964///     if self.field2.is_some() {
965///       // if field2 is `None` then this assignment won't happen and last_field_id will remain
966///       // `1` when writing `field3`
967///       last_field_id = self.field2.write_thrift_field(writer, 2, last_field_id)?;
968///     }
969///     if self.field3.is_some() {
970///       // no need to assign last_field_id since this is the final field.
971///       self.field3.write_thrift_field(writer, 3, last_field_id)?;
972///     }
973///     writer.write_struct_end()
974///   }
975/// }
976/// ```
977///
978pub(crate) trait WriteThriftField {
979    /// Used to write struct fields (which may be primitive or IDL defined types). This will
980    /// write the field marker for the given `field_id`, using `last_field_id` to compute the
981    /// field delta used by the Thrift [compact protocol]. On success this will return `field_id`
982    /// to be used in chaining.
983    ///
984    /// [compact protocol]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
985    fn write_thrift_field<W: Write>(
986        &self,
987        writer: &mut ThriftCompactOutputProtocol<W>,
988        field_id: i16,
989        last_field_id: i16,
990    ) -> Result<i16>;
991}
992
993// bool struct fields are written differently to bool values
994impl WriteThriftField for bool {
995    fn write_thrift_field<W: Write>(
996        &self,
997        writer: &mut ThriftCompactOutputProtocol<W>,
998        field_id: i16,
999        last_field_id: i16,
1000    ) -> Result<i16> {
1001        // boolean only writes the field header
1002        match *self {
1003            true => writer.write_field_begin(FieldType::BooleanTrue, field_id, last_field_id)?,
1004            false => writer.write_field_begin(FieldType::BooleanFalse, field_id, last_field_id)?,
1005        }
1006        Ok(field_id)
1007    }
1008}
1009
1010write_thrift_field!(i8, FieldType::Byte);
1011write_thrift_field!(i16, FieldType::I16);
1012write_thrift_field!(i32, FieldType::I32);
1013write_thrift_field!(i64, FieldType::I64);
1014write_thrift_field!(OrderedF64, FieldType::Double);
1015write_thrift_field!(f64, FieldType::Double);
1016write_thrift_field!(String, FieldType::Binary);
1017
1018impl WriteThriftField for &[u8] {
1019    fn write_thrift_field<W: Write>(
1020        &self,
1021        writer: &mut ThriftCompactOutputProtocol<W>,
1022        field_id: i16,
1023        last_field_id: i16,
1024    ) -> Result<i16> {
1025        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1026        writer.write_bytes(self)?;
1027        Ok(field_id)
1028    }
1029}
1030
1031impl WriteThriftField for &str {
1032    fn write_thrift_field<W: Write>(
1033        &self,
1034        writer: &mut ThriftCompactOutputProtocol<W>,
1035        field_id: i16,
1036        last_field_id: i16,
1037    ) -> Result<i16> {
1038        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1039        writer.write_bytes(self.as_bytes())?;
1040        Ok(field_id)
1041    }
1042}
1043
1044impl<T> WriteThriftField for Vec<T>
1045where
1046    T: WriteThrift,
1047{
1048    fn write_thrift_field<W: Write>(
1049        &self,
1050        writer: &mut ThriftCompactOutputProtocol<W>,
1051        field_id: i16,
1052        last_field_id: i16,
1053    ) -> Result<i16> {
1054        writer.write_field_begin(FieldType::List, field_id, last_field_id)?;
1055        self.write_thrift(writer)?;
1056        Ok(field_id)
1057    }
1058}
1059
1060#[cfg(test)]
1061pub(crate) mod tests {
1062    use crate::basic::{TimeUnit, Type};
1063
1064    use super::*;
1065    use std::fmt::Debug;
1066
1067    pub(crate) fn test_roundtrip<T>(val: T)
1068    where
1069        T: for<'a> ReadThrift<'a, ThriftSliceInputProtocol<'a>> + WriteThrift + PartialEq + Debug,
1070    {
1071        let mut buf = Vec::<u8>::new();
1072        {
1073            let mut writer = ThriftCompactOutputProtocol::new(&mut buf);
1074            val.write_thrift(&mut writer).unwrap();
1075        }
1076
1077        let mut prot = ThriftSliceInputProtocol::new(&buf);
1078        let read_val = T::read_thrift(&mut prot).unwrap();
1079        assert_eq!(val, read_val);
1080    }
1081
1082    #[test]
1083    fn test_enum_roundtrip() {
1084        test_roundtrip(Type::BOOLEAN);
1085        test_roundtrip(Type::INT32);
1086        test_roundtrip(Type::INT64);
1087        test_roundtrip(Type::INT96);
1088        test_roundtrip(Type::FLOAT);
1089        test_roundtrip(Type::DOUBLE);
1090        test_roundtrip(Type::BYTE_ARRAY);
1091        test_roundtrip(Type::FIXED_LEN_BYTE_ARRAY);
1092    }
1093
1094    #[test]
1095    fn test_union_all_empty_roundtrip() {
1096        test_roundtrip(TimeUnit::MILLIS);
1097        test_roundtrip(TimeUnit::MICROS);
1098        test_roundtrip(TimeUnit::NANOS);
1099    }
1100
1101    #[test]
1102    fn test_decode_empty_list() {
1103        let data = vec![0u8; 1];
1104        let mut prot = ThriftSliceInputProtocol::new(&data);
1105        let header = prot.read_list_begin().expect("error reading list header");
1106        assert_eq!(header.size, 0);
1107        assert_eq!(header.element_type, ElementType::Byte);
1108    }
1109}