parquet/
parquet_thrift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Structs used for encoding and decoding Parquet Thrift objects.
19//!
20//! These include:
21//! * [`ThriftCompactInputProtocol`]: Trait implemented by Thrift decoders.
22//!     * [`ThriftSliceInputProtocol`]: Thrift decoder that takes a slice of bytes as input.
23//!     * [`ThriftReadInputProtocol`]: Thrift decoder that takes a [`Read`] as input.
24//! * [`ReadThrift`]: Trait implemented by serializable objects.
25//! * [`ThriftCompactOutputProtocol`]: Thrift encoder.
26//! * [`WriteThrift`]: Trait implemented by serializable objects.
27//! * [`WriteThriftField`]: Trait implemented by serializable objects that are fields in Thrift structs.
28
29use std::{
30    cmp::Ordering,
31    io::{Read, Write},
32};
33
34use crate::{
35    errors::{ParquetError, Result},
36    write_thrift_field,
37};
38use std::io::Error;
39use std::str::Utf8Error;
40
41#[derive(Debug)]
42pub(crate) enum ThriftProtocolError {
43    Eof,
44    IO(Error),
45    InvalidFieldType(u8),
46    InvalidElementType(u8),
47    FieldDeltaOverflow { field_delta: u8, last_field_id: i16 },
48    InvalidBoolean(u8),
49    Utf8Error,
50    SkipDepth(FieldType),
51    SkipUnsupportedType(FieldType),
52}
53
54impl From<ThriftProtocolError> for ParquetError {
55    #[inline(never)]
56    fn from(e: ThriftProtocolError) -> Self {
57        match e {
58            ThriftProtocolError::Eof => eof_err!("Unexpected EOF"),
59            ThriftProtocolError::IO(e) => e.into(),
60            ThriftProtocolError::InvalidFieldType(value) => {
61                general_err!("Unexpected struct field type {}", value)
62            }
63            ThriftProtocolError::InvalidElementType(value) => {
64                general_err!("Unexpected list/set element type{}", value)
65            }
66            ThriftProtocolError::FieldDeltaOverflow {
67                field_delta,
68                last_field_id,
69            } => general_err!("cannot add {} to {}", field_delta, last_field_id),
70            ThriftProtocolError::InvalidBoolean(value) => {
71                general_err!("cannot convert {} into bool", value)
72            }
73            ThriftProtocolError::Utf8Error => general_err!("invalid utf8"),
74            ThriftProtocolError::SkipDepth(field_type) => {
75                general_err!("cannot parse past {:?}", field_type)
76            }
77            ThriftProtocolError::SkipUnsupportedType(field_type) => {
78                general_err!("cannot skip field type {:?}", field_type)
79            }
80        }
81    }
82}
83
84impl From<Utf8Error> for ThriftProtocolError {
85    fn from(_: Utf8Error) -> Self {
86        // ignore error payload to reduce the size of ThriftProtocolError
87        Self::Utf8Error
88    }
89}
90
91impl From<Error> for ThriftProtocolError {
92    fn from(e: Error) -> Self {
93        Self::IO(e)
94    }
95}
96
97pub type ThriftProtocolResult<T> = Result<T, ThriftProtocolError>;
98
99/// Wrapper for thrift `double` fields. This is used to provide
100/// an implementation of `Eq` for floats. This implementation
101/// uses IEEE 754 total order.
102#[derive(Debug, Clone, Copy, PartialEq)]
103pub struct OrderedF64(f64);
104
105impl From<f64> for OrderedF64 {
106    fn from(value: f64) -> Self {
107        Self(value)
108    }
109}
110
111impl From<OrderedF64> for f64 {
112    fn from(value: OrderedF64) -> Self {
113        value.0
114    }
115}
116
117impl Eq for OrderedF64 {} // Marker trait, requires PartialEq
118
119impl Ord for OrderedF64 {
120    fn cmp(&self, other: &Self) -> Ordering {
121        self.0.total_cmp(&other.0)
122    }
123}
124
125impl PartialOrd for OrderedF64 {
126    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
127        Some(self.cmp(other))
128    }
129}
130
131// Thrift compact protocol types for struct fields.
132#[derive(Clone, Copy, Debug, Eq, PartialEq)]
133pub(crate) enum FieldType {
134    Stop = 0,
135    BooleanTrue = 1,
136    BooleanFalse = 2,
137    Byte = 3,
138    I16 = 4,
139    I32 = 5,
140    I64 = 6,
141    Double = 7,
142    Binary = 8,
143    List = 9,
144    Set = 10,
145    Map = 11,
146    Struct = 12,
147}
148
149impl TryFrom<u8> for FieldType {
150    type Error = ThriftProtocolError;
151    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
152        match value {
153            0 => Ok(Self::Stop),
154            1 => Ok(Self::BooleanTrue),
155            2 => Ok(Self::BooleanFalse),
156            3 => Ok(Self::Byte),
157            4 => Ok(Self::I16),
158            5 => Ok(Self::I32),
159            6 => Ok(Self::I64),
160            7 => Ok(Self::Double),
161            8 => Ok(Self::Binary),
162            9 => Ok(Self::List),
163            10 => Ok(Self::Set),
164            11 => Ok(Self::Map),
165            12 => Ok(Self::Struct),
166            _ => Err(ThriftProtocolError::InvalidFieldType(value)),
167        }
168    }
169}
170
171impl TryFrom<ElementType> for FieldType {
172    type Error = ThriftProtocolError;
173    fn try_from(value: ElementType) -> std::result::Result<Self, Self::Error> {
174        match value {
175            ElementType::Bool => Ok(Self::BooleanTrue),
176            ElementType::Byte => Ok(Self::Byte),
177            ElementType::I16 => Ok(Self::I16),
178            ElementType::I32 => Ok(Self::I32),
179            ElementType::I64 => Ok(Self::I64),
180            ElementType::Double => Ok(Self::Double),
181            ElementType::Binary => Ok(Self::Binary),
182            ElementType::List => Ok(Self::List),
183            ElementType::Struct => Ok(Self::Struct),
184            _ => Err(ThriftProtocolError::InvalidFieldType(value as u8)),
185        }
186    }
187}
188
189// Thrift compact protocol types for list elements
190#[derive(Clone, Copy, Debug, Eq, PartialEq)]
191pub(crate) enum ElementType {
192    Bool = 2,
193    Byte = 3,
194    I16 = 4,
195    I32 = 5,
196    I64 = 6,
197    Double = 7,
198    Binary = 8,
199    List = 9,
200    Set = 10,
201    Map = 11,
202    Struct = 12,
203}
204
205impl TryFrom<u8> for ElementType {
206    type Error = ThriftProtocolError;
207    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
208        match value {
209            // For historical and compatibility reasons, a reader should be capable to deal with both cases.
210            // The only valid value in the original spec was 2, but due to an widespread implementation bug
211            // the defacto standard across large parts of the library became 1 instead.
212            // As a result, both values are now allowed.
213            // https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
214            1 | 2 => Ok(Self::Bool),
215            3 => Ok(Self::Byte),
216            4 => Ok(Self::I16),
217            5 => Ok(Self::I32),
218            6 => Ok(Self::I64),
219            7 => Ok(Self::Double),
220            8 => Ok(Self::Binary),
221            9 => Ok(Self::List),
222            10 => Ok(Self::Set),
223            11 => Ok(Self::Map),
224            12 => Ok(Self::Struct),
225            _ => Err(ThriftProtocolError::InvalidElementType(value)),
226        }
227    }
228}
229
230/// Struct used to describe a [thrift struct] field during decoding.
231///
232/// [thrift struct]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
233pub(crate) struct FieldIdentifier {
234    /// The type for the field.
235    pub(crate) field_type: FieldType,
236    /// The field's `id`. May be computed from delta or directly decoded.
237    pub(crate) id: i16,
238    /// Stores the value for booleans.
239    ///
240    /// Boolean fields store no data, instead the field type is either boolean true, or
241    /// boolean false.
242    pub(crate) bool_val: Option<bool>,
243}
244
245/// Struct used to describe a [thrift list].
246///
247/// [thrift list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
248#[derive(Clone, Debug, Eq, PartialEq)]
249pub(crate) struct ListIdentifier {
250    /// The type for each element in the list.
251    pub(crate) element_type: ElementType,
252    /// Number of elements contained in the list.
253    pub(crate) size: i32,
254}
255
256/// Low-level object used to deserialize structs encoded with the Thrift [compact] protocol.
257///
258/// Implementation of this trait must provide the low-level functions `read_byte`, `read_bytes`,
259/// `skip_bytes`, and `read_double`. These primitives are used by the default functions provided
260/// here to perform deserialization.
261///
262/// [compact]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
263pub(crate) trait ThriftCompactInputProtocol<'a> {
264    /// Read a single byte from the input.
265    fn read_byte(&mut self) -> ThriftProtocolResult<u8>;
266
267    /// Read a Thrift encoded [binary] from the input.
268    ///
269    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
270    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]>;
271
272    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>>;
273
274    /// Skip the next `n` bytes of input.
275    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()>;
276
277    /// Read a ULEB128 encoded unsigned varint from the input.
278    fn read_vlq(&mut self) -> ThriftProtocolResult<u64> {
279        let mut in_progress = 0;
280        let mut shift = 0;
281        loop {
282            let byte = self.read_byte()?;
283            in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
284            if byte & 0x80 == 0 {
285                return Ok(in_progress);
286            }
287            shift += 7;
288        }
289    }
290
291    /// Read a zig-zag encoded signed varint from the input.
292    fn read_zig_zag(&mut self) -> ThriftProtocolResult<i64> {
293        let val = self.read_vlq()?;
294        Ok((val >> 1) as i64 ^ -((val & 1) as i64))
295    }
296
297    /// Read the [`ListIdentifier`] for a Thrift encoded list.
298    fn read_list_begin(&mut self) -> ThriftProtocolResult<ListIdentifier> {
299        let header = self.read_byte()?;
300        let element_type = ElementType::try_from(header & 0x0f)?;
301
302        let possible_element_count = (header & 0xF0) >> 4;
303        let element_count = if possible_element_count != 15 {
304            // high bits set high if count and type encoded separately
305            possible_element_count as i32
306        } else {
307            self.read_vlq()? as _
308        };
309
310        Ok(ListIdentifier {
311            element_type,
312            size: element_count,
313        })
314    }
315
316    // Full field ids are uncommon.
317    // Not inlining this method reduces the code size of `read_field_begin`, which then ideally gets
318    // inlined everywhere.
319    #[cold]
320    fn read_full_field_id(&mut self) -> ThriftProtocolResult<i16> {
321        self.read_i16()
322    }
323
324    /// Read the [`FieldIdentifier`] for a field in a Thrift encoded struct.
325    fn read_field_begin(&mut self, last_field_id: i16) -> ThriftProtocolResult<FieldIdentifier> {
326        // we can read at least one byte, which is:
327        // - the type
328        // - the field delta and the type
329        let field_type = self.read_byte()?;
330        let field_delta = (field_type & 0xf0) >> 4;
331        let field_type = FieldType::try_from(field_type & 0xf)?;
332        let mut bool_val: Option<bool> = None;
333
334        match field_type {
335            FieldType::Stop => Ok(FieldIdentifier {
336                field_type: FieldType::Stop,
337                id: 0,
338                bool_val,
339            }),
340            _ => {
341                // special handling for bools
342                if field_type == FieldType::BooleanFalse {
343                    bool_val = Some(false);
344                } else if field_type == FieldType::BooleanTrue {
345                    bool_val = Some(true);
346                }
347                let field_id = if field_delta != 0 {
348                    last_field_id.checked_add(field_delta as i16).ok_or(
349                        ThriftProtocolError::FieldDeltaOverflow {
350                            field_delta,
351                            last_field_id,
352                        },
353                    )?
354                } else {
355                    self.read_full_field_id()?
356                };
357
358                Ok(FieldIdentifier {
359                    field_type,
360                    id: field_id,
361                    bool_val,
362                })
363            }
364        }
365    }
366
367    /// This is a specialized version of [`Self::read_field_begin`], solely for use in parsing
368    /// simple structs. This function assumes that the delta field will always be less than 0xf,
369    /// fields will be in order, and no boolean fields will be read.
370    /// This also skips validation of the field type.
371    ///
372    /// Returns a tuple of `(field_type, field_delta)`.
373    fn read_field_header(&mut self) -> ThriftProtocolResult<(u8, u8)> {
374        let field_type = self.read_byte()?;
375        let field_delta = (field_type & 0xf0) >> 4;
376        let field_type = field_type & 0xf;
377        Ok((field_type, field_delta))
378    }
379
380    /// Read a boolean list element. This should not be used for struct fields. For the latter,
381    /// use the [`FieldIdentifier::bool_val`] field.
382    fn read_bool(&mut self) -> ThriftProtocolResult<bool> {
383        let b = self.read_byte()?;
384        // Previous versions of the thrift specification said to use 0 and 1 inside collections,
385        // but that differed from existing implementations.
386        // The specification was updated in https://github.com/apache/thrift/commit/2c29c5665bc442e703480bb0ee60fe925ffe02e8.
387        // At least the go implementation seems to have followed the previously documented values.
388        match b {
389            0x01 => Ok(true),
390            0x00 | 0x02 => Ok(false),
391            _ => Err(ThriftProtocolError::InvalidBoolean(b)),
392        }
393    }
394
395    /// Read a Thrift [binary] as a UTF-8 encoded string.
396    ///
397    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
398    fn read_string(&mut self) -> ThriftProtocolResult<&'a str> {
399        let slice = self.read_bytes()?;
400        Ok(std::str::from_utf8(slice)?)
401    }
402
403    /// Read an `i8`.
404    fn read_i8(&mut self) -> ThriftProtocolResult<i8> {
405        Ok(self.read_byte()? as _)
406    }
407
408    /// Read an `i16`.
409    fn read_i16(&mut self) -> ThriftProtocolResult<i16> {
410        Ok(self.read_zig_zag()? as _)
411    }
412
413    /// Read an `i32`.
414    fn read_i32(&mut self) -> ThriftProtocolResult<i32> {
415        Ok(self.read_zig_zag()? as _)
416    }
417
418    /// Read an `i64`.
419    fn read_i64(&mut self) -> ThriftProtocolResult<i64> {
420        self.read_zig_zag()
421    }
422
423    /// Read a Thrift `double` as `f64`.
424    fn read_double(&mut self) -> ThriftProtocolResult<f64>;
425
426    /// Skip a ULEB128 encoded varint.
427    fn skip_vlq(&mut self) -> ThriftProtocolResult<()> {
428        loop {
429            let byte = self.read_byte()?;
430            if byte & 0x80 == 0 {
431                return Ok(());
432            }
433        }
434    }
435
436    /// Skip a thrift [binary].
437    ///
438    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
439    fn skip_binary(&mut self) -> ThriftProtocolResult<()> {
440        let len = self.read_vlq()? as usize;
441        self.skip_bytes(len)
442    }
443
444    /// Skip a field with type `field_type` recursively until the default
445    /// maximum skip depth (currently 64) is reached.
446    fn skip(&mut self, field_type: FieldType) -> ThriftProtocolResult<()> {
447        const DEFAULT_SKIP_DEPTH: i8 = 64;
448        self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH)
449    }
450
451    /// Empty structs in unions consist of a single byte of 0 for the field stop record.
452    /// This skips that byte without encuring the cost of processing the [`FieldIdentifier`].
453    /// Will return an error if the struct is not actually empty.
454    fn skip_empty_struct(&mut self) -> Result<()> {
455        let b = self.read_byte()?;
456        if b != 0 {
457            Err(general_err!("Empty struct has fields"))
458        } else {
459            Ok(())
460        }
461    }
462
463    /// Skip a field with type `field_type` recursively up to `depth` levels.
464    fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> ThriftProtocolResult<()> {
465        if depth == 0 {
466            return Err(ThriftProtocolError::SkipDepth(field_type));
467        }
468
469        match field_type {
470            // boolean field has no data
471            FieldType::BooleanFalse | FieldType::BooleanTrue => Ok(()),
472            FieldType::Byte => self.read_i8().map(|_| ()),
473            FieldType::I16 => self.skip_vlq().map(|_| ()),
474            FieldType::I32 => self.skip_vlq().map(|_| ()),
475            FieldType::I64 => self.skip_vlq().map(|_| ()),
476            FieldType::Double => self.skip_bytes(8).map(|_| ()),
477            FieldType::Binary => self.skip_binary().map(|_| ()),
478            FieldType::Struct => {
479                let mut last_field_id = 0i16;
480                loop {
481                    let field_ident = self.read_field_begin(last_field_id)?;
482                    if field_ident.field_type == FieldType::Stop {
483                        break;
484                    }
485                    self.skip_till_depth(field_ident.field_type, depth - 1)?;
486                    last_field_id = field_ident.id;
487                }
488                Ok(())
489            }
490            FieldType::List => {
491                let list_ident = self.read_list_begin()?;
492                for _ in 0..list_ident.size {
493                    let element_type = FieldType::try_from(list_ident.element_type)?;
494                    self.skip_till_depth(element_type, depth - 1)?;
495                }
496                Ok(())
497            }
498            // no list or map types in parquet format
499            _ => Err(ThriftProtocolError::SkipUnsupportedType(field_type)),
500        }
501    }
502}
503
504/// A high performance Thrift reader that reads from a slice of bytes.
505pub(crate) struct ThriftSliceInputProtocol<'a> {
506    buf: &'a [u8],
507}
508
509impl<'a> ThriftSliceInputProtocol<'a> {
510    /// Create a new `ThriftSliceInputProtocol` using the bytes in `buf`.
511    pub fn new(buf: &'a [u8]) -> Self {
512        Self { buf }
513    }
514
515    /// Return the current buffer as a slice.
516    pub fn as_slice(&self) -> &'a [u8] {
517        self.buf
518    }
519}
520
521impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for ThriftSliceInputProtocol<'a> {
522    #[inline]
523    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
524        let ret = *self.buf.first().ok_or(ThriftProtocolError::Eof)?;
525        self.buf = &self.buf[1..];
526        Ok(ret)
527    }
528
529    fn read_bytes(&mut self) -> ThriftProtocolResult<&'b [u8]> {
530        let len = self.read_vlq()? as usize;
531        let ret = self.buf.get(..len).ok_or(ThriftProtocolError::Eof)?;
532        self.buf = &self.buf[len..];
533        Ok(ret)
534    }
535
536    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
537        Ok(self.read_bytes()?.to_vec())
538    }
539
540    #[inline]
541    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
542        self.buf.get(..n).ok_or(ThriftProtocolError::Eof)?;
543        self.buf = &self.buf[n..];
544        Ok(())
545    }
546
547    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
548        let slice = self.buf.get(..8).ok_or(ThriftProtocolError::Eof)?;
549        self.buf = &self.buf[8..];
550        match slice.try_into() {
551            Ok(slice) => Ok(f64::from_le_bytes(slice)),
552            Err(_) => unreachable!(),
553        }
554    }
555}
556
557/// A Thrift input protocol that wraps a [`Read`] object.
558///
559/// Note that this is only intended for use in reading Parquet page headers. This will panic
560/// if Thrift `binary` data is encountered because a slice of that data cannot be returned.
561pub(crate) struct ThriftReadInputProtocol<R: Read> {
562    reader: R,
563}
564
565impl<R: Read> ThriftReadInputProtocol<R> {
566    pub(crate) fn new(reader: R) -> Self {
567        Self { reader }
568    }
569}
570
571impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol<R> {
572    #[inline]
573    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
574        let mut buf = [0_u8; 1];
575        self.reader.read_exact(&mut buf)?;
576        Ok(buf[0])
577    }
578
579    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]> {
580        unimplemented!()
581    }
582
583    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
584        let len = self.read_vlq()? as usize;
585        let mut v = Vec::with_capacity(len);
586        std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?;
587        Ok(v)
588    }
589
590    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
591        std::io::copy(
592            &mut self.reader.by_ref().take(n as u64),
593            &mut std::io::sink(),
594        )?;
595        Ok(())
596    }
597
598    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
599        let mut buf = [0_u8; 8];
600        self.reader.read_exact(&mut buf)?;
601        Ok(f64::from_le_bytes(buf))
602    }
603}
604
605/// Trait implemented for objects that can be deserialized from a Thrift input stream.
606/// Implementations are provided for Thrift primitive types.
607pub(crate) trait ReadThrift<'a, R: ThriftCompactInputProtocol<'a>> {
608    /// Read an object of type `Self` from the input protocol object.
609    fn read_thrift(prot: &mut R) -> Result<Self>
610    where
611        Self: Sized;
612}
613
614impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool {
615    fn read_thrift(prot: &mut R) -> Result<Self> {
616        Ok(prot.read_bool()?)
617    }
618}
619
620impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 {
621    fn read_thrift(prot: &mut R) -> Result<Self> {
622        Ok(prot.read_i8()?)
623    }
624}
625
626impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 {
627    fn read_thrift(prot: &mut R) -> Result<Self> {
628        Ok(prot.read_i16()?)
629    }
630}
631
632impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 {
633    fn read_thrift(prot: &mut R) -> Result<Self> {
634        Ok(prot.read_i32()?)
635    }
636}
637
638impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 {
639    fn read_thrift(prot: &mut R) -> Result<Self> {
640        Ok(prot.read_i64()?)
641    }
642}
643
644impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for OrderedF64 {
645    fn read_thrift(prot: &mut R) -> Result<Self> {
646        Ok(OrderedF64(prot.read_double()?))
647    }
648}
649
650impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str {
651    fn read_thrift(prot: &mut R) -> Result<Self> {
652        Ok(prot.read_string()?)
653    }
654}
655
656impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for String {
657    fn read_thrift(prot: &mut R) -> Result<Self> {
658        Ok(String::from_utf8(prot.read_bytes_owned()?)?)
659    }
660}
661
662impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] {
663    fn read_thrift(prot: &mut R) -> Result<Self> {
664        Ok(prot.read_bytes()?)
665    }
666}
667
668/// Read a Thrift encoded [list] from the input protocol object.
669///
670/// [list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
671pub(crate) fn read_thrift_vec<'a, T, R>(prot: &mut R) -> Result<Vec<T>>
672where
673    R: ThriftCompactInputProtocol<'a>,
674    T: ReadThrift<'a, R>,
675{
676    let list_ident = prot.read_list_begin()?;
677    let mut res = Vec::with_capacity(list_ident.size as usize);
678    for _ in 0..list_ident.size {
679        let val = T::read_thrift(prot)?;
680        res.push(val);
681    }
682    Ok(res)
683}
684
685/////////////////////////
686// thrift compact output
687
688/// Low-level object used to serialize structs to the Thrift [compact output] protocol.
689///
690/// This struct serves as a wrapper around a [`Write`] object, to which thrift encoded data
691/// will written. The implementation provides functions to write Thrift primitive types, as well
692/// as functions used in the encoding of lists and structs. This should rarely be used directly,
693/// but is instead intended for use by implementers of [`WriteThrift`] and [`WriteThriftField`].
694///
695/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
696pub(crate) struct ThriftCompactOutputProtocol<W: Write> {
697    writer: W,
698}
699
700impl<W: Write> ThriftCompactOutputProtocol<W> {
701    /// Create a new `ThriftCompactOutputProtocol` wrapping the byte sink `writer`.
702    pub(crate) fn new(writer: W) -> Self {
703        Self { writer }
704    }
705
706    /// Write a single byte to the output stream.
707    fn write_byte(&mut self, b: u8) -> Result<()> {
708        self.writer.write_all(&[b])?;
709        Ok(())
710    }
711
712    /// Write the given `u64` as a ULEB128 encoded varint.
713    fn write_vlq(&mut self, val: u64) -> Result<()> {
714        let mut v = val;
715        while v > 0x7f {
716            self.write_byte(v as u8 | 0x80)?;
717            v >>= 7;
718        }
719        self.write_byte(v as u8)
720    }
721
722    /// Write the given `i64` as a zig-zag encoded varint.
723    fn write_zig_zag(&mut self, val: i64) -> Result<()> {
724        let s = (val < 0) as i64;
725        self.write_vlq((((val ^ -s) << 1) + s) as u64)
726    }
727
728    /// Used to mark the start of a Thrift struct field of type `field_type`. `last_field_id`
729    /// is used to compute a delta to the given `field_id` per the compact protocol [spec].
730    ///
731    /// [spec]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
732    pub(crate) fn write_field_begin(
733        &mut self,
734        field_type: FieldType,
735        field_id: i16,
736        last_field_id: i16,
737    ) -> Result<()> {
738        let delta = field_id.wrapping_sub(last_field_id);
739        if delta > 0 && delta <= 0xf {
740            self.write_byte((delta as u8) << 4 | field_type as u8)
741        } else {
742            self.write_byte(field_type as u8)?;
743            self.write_i16(field_id)
744        }
745    }
746
747    /// Used to indicate the start of a list of `element_type` elements.
748    pub(crate) fn write_list_begin(&mut self, element_type: ElementType, len: usize) -> Result<()> {
749        if len < 15 {
750            self.write_byte((len as u8) << 4 | element_type as u8)
751        } else {
752            self.write_byte(0xf0u8 | element_type as u8)?;
753            self.write_vlq(len as _)
754        }
755    }
756
757    /// Used to mark the end of a struct. This must be called after all fields of the struct have
758    /// been written.
759    pub(crate) fn write_struct_end(&mut self) -> Result<()> {
760        self.write_byte(0)
761    }
762
763    /// Serialize a slice of `u8`s. This will encode a length, and then write the bytes without
764    /// further encoding.
765    pub(crate) fn write_bytes(&mut self, val: &[u8]) -> Result<()> {
766        self.write_vlq(val.len() as u64)?;
767        self.writer.write_all(val)?;
768        Ok(())
769    }
770
771    /// Short-cut method used to encode structs that have no fields (often used in Thrift unions).
772    /// This simply encodes the field id and then immediately writes the end-of-struct marker.
773    pub(crate) fn write_empty_struct(&mut self, field_id: i16, last_field_id: i16) -> Result<i16> {
774        self.write_field_begin(FieldType::Struct, field_id, last_field_id)?;
775        self.write_struct_end()?;
776        Ok(last_field_id)
777    }
778
779    /// Write a boolean value.
780    pub(crate) fn write_bool(&mut self, val: bool) -> Result<()> {
781        match val {
782            true => self.write_byte(1),
783            false => self.write_byte(2),
784        }
785    }
786
787    /// Write a zig-zag encoded `i8` value.
788    pub(crate) fn write_i8(&mut self, val: i8) -> Result<()> {
789        self.write_byte(val as u8)
790    }
791
792    /// Write a zig-zag encoded `i16` value.
793    pub(crate) fn write_i16(&mut self, val: i16) -> Result<()> {
794        self.write_zig_zag(val as _)
795    }
796
797    /// Write a zig-zag encoded `i32` value.
798    pub(crate) fn write_i32(&mut self, val: i32) -> Result<()> {
799        self.write_zig_zag(val as _)
800    }
801
802    /// Write a zig-zag encoded `i64` value.
803    pub(crate) fn write_i64(&mut self, val: i64) -> Result<()> {
804        self.write_zig_zag(val as _)
805    }
806
807    /// Write a double value.
808    pub(crate) fn write_double(&mut self, val: f64) -> Result<()> {
809        self.writer.write_all(&val.to_le_bytes())?;
810        Ok(())
811    }
812}
813
814/// Trait implemented by objects that are to be serialized to a Thrift [compact output] protocol
815/// stream. Implementations are also provided for primitive Thrift types.
816///
817/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
818pub(crate) trait WriteThrift {
819    /// The [`ElementType`] to use when a list of this object is written.
820    const ELEMENT_TYPE: ElementType;
821
822    /// Serialize this object to the given `writer`.
823    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()>;
824}
825
826/// Implementation for a vector of thrift serializable objects that implement [`WriteThrift`].
827/// This will write the necessary list header and then serialize the elements one-at-a-time.
828impl<T> WriteThrift for Vec<T>
829where
830    T: WriteThrift,
831{
832    const ELEMENT_TYPE: ElementType = ElementType::List;
833
834    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
835        writer.write_list_begin(T::ELEMENT_TYPE, self.len())?;
836        for item in self {
837            item.write_thrift(writer)?;
838        }
839        Ok(())
840    }
841}
842
843impl WriteThrift for bool {
844    const ELEMENT_TYPE: ElementType = ElementType::Bool;
845
846    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
847        writer.write_bool(*self)
848    }
849}
850
851impl WriteThrift for i8 {
852    const ELEMENT_TYPE: ElementType = ElementType::Byte;
853
854    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
855        writer.write_i8(*self)
856    }
857}
858
859impl WriteThrift for i16 {
860    const ELEMENT_TYPE: ElementType = ElementType::I16;
861
862    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
863        writer.write_i16(*self)
864    }
865}
866
867impl WriteThrift for i32 {
868    const ELEMENT_TYPE: ElementType = ElementType::I32;
869
870    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
871        writer.write_i32(*self)
872    }
873}
874
875impl WriteThrift for i64 {
876    const ELEMENT_TYPE: ElementType = ElementType::I64;
877
878    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
879        writer.write_i64(*self)
880    }
881}
882
883impl WriteThrift for OrderedF64 {
884    const ELEMENT_TYPE: ElementType = ElementType::Double;
885
886    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
887        writer.write_double(self.0)
888    }
889}
890
891impl WriteThrift for f64 {
892    const ELEMENT_TYPE: ElementType = ElementType::Double;
893
894    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
895        writer.write_double(*self)
896    }
897}
898
899impl WriteThrift for &[u8] {
900    const ELEMENT_TYPE: ElementType = ElementType::Binary;
901
902    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
903        writer.write_bytes(self)
904    }
905}
906
907impl WriteThrift for &str {
908    const ELEMENT_TYPE: ElementType = ElementType::Binary;
909
910    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
911        writer.write_bytes(self.as_bytes())
912    }
913}
914
915impl WriteThrift for String {
916    const ELEMENT_TYPE: ElementType = ElementType::Binary;
917
918    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
919        writer.write_bytes(self.as_bytes())
920    }
921}
922
923/// Trait implemented by objects that are fields of Thrift structs.
924///
925/// For example, given the Thrift struct definition
926/// ```ignore
927/// struct MyStruct {
928///   1: required i32 field1
929///   2: optional bool field2
930///   3: optional OtherStruct field3
931/// }
932/// ```
933///
934/// which becomes in Rust
935/// ```no_run
936/// # struct OtherStruct {}
937/// struct MyStruct {
938///   field1: i32,
939///   field2: Option<bool>,
940///   field3: Option<OtherStruct>,
941/// }
942/// ```
943/// the impl of `WriteThrift` for `MyStruct` will use the `WriteThriftField` impls for `i32`,
944/// `bool`, and `OtherStruct`.
945///
946/// ```ignore
947/// impl WriteThrift for MyStruct {
948///   fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
949///     let mut last_field_id = 0i16;
950///     last_field_id = self.field1.write_thrift_field(writer, 1, last_field_id)?;
951///     if self.field2.is_some() {
952///       // if field2 is `None` then this assignment won't happen and last_field_id will remain
953///       // `1` when writing `field3`
954///       last_field_id = self.field2.write_thrift_field(writer, 2, last_field_id)?;
955///     }
956///     if self.field3.is_some() {
957///       // no need to assign last_field_id since this is the final field.
958///       self.field3.write_thrift_field(writer, 3, last_field_id)?;
959///     }
960///     writer.write_struct_end()
961///   }
962/// }
963/// ```
964///
965pub(crate) trait WriteThriftField {
966    /// Used to write struct fields (which may be primitive or IDL defined types). This will
967    /// write the field marker for the given `field_id`, using `last_field_id` to compute the
968    /// field delta used by the Thrift [compact protocol]. On success this will return `field_id`
969    /// to be used in chaining.
970    ///
971    /// [compact protocol]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
972    fn write_thrift_field<W: Write>(
973        &self,
974        writer: &mut ThriftCompactOutputProtocol<W>,
975        field_id: i16,
976        last_field_id: i16,
977    ) -> Result<i16>;
978}
979
980// bool struct fields are written differently to bool values
981impl WriteThriftField for bool {
982    fn write_thrift_field<W: Write>(
983        &self,
984        writer: &mut ThriftCompactOutputProtocol<W>,
985        field_id: i16,
986        last_field_id: i16,
987    ) -> Result<i16> {
988        // boolean only writes the field header
989        match *self {
990            true => writer.write_field_begin(FieldType::BooleanTrue, field_id, last_field_id)?,
991            false => writer.write_field_begin(FieldType::BooleanFalse, field_id, last_field_id)?,
992        }
993        Ok(field_id)
994    }
995}
996
997write_thrift_field!(i8, FieldType::Byte);
998write_thrift_field!(i16, FieldType::I16);
999write_thrift_field!(i32, FieldType::I32);
1000write_thrift_field!(i64, FieldType::I64);
1001write_thrift_field!(OrderedF64, FieldType::Double);
1002write_thrift_field!(f64, FieldType::Double);
1003write_thrift_field!(String, FieldType::Binary);
1004
1005impl WriteThriftField for &[u8] {
1006    fn write_thrift_field<W: Write>(
1007        &self,
1008        writer: &mut ThriftCompactOutputProtocol<W>,
1009        field_id: i16,
1010        last_field_id: i16,
1011    ) -> Result<i16> {
1012        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1013        writer.write_bytes(self)?;
1014        Ok(field_id)
1015    }
1016}
1017
1018impl WriteThriftField for &str {
1019    fn write_thrift_field<W: Write>(
1020        &self,
1021        writer: &mut ThriftCompactOutputProtocol<W>,
1022        field_id: i16,
1023        last_field_id: i16,
1024    ) -> Result<i16> {
1025        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1026        writer.write_bytes(self.as_bytes())?;
1027        Ok(field_id)
1028    }
1029}
1030
1031impl<T> WriteThriftField for Vec<T>
1032where
1033    T: WriteThrift,
1034{
1035    fn write_thrift_field<W: Write>(
1036        &self,
1037        writer: &mut ThriftCompactOutputProtocol<W>,
1038        field_id: i16,
1039        last_field_id: i16,
1040    ) -> Result<i16> {
1041        writer.write_field_begin(FieldType::List, field_id, last_field_id)?;
1042        self.write_thrift(writer)?;
1043        Ok(field_id)
1044    }
1045}
1046
1047#[cfg(test)]
1048pub(crate) mod tests {
1049    use crate::basic::{TimeUnit, Type};
1050
1051    use super::*;
1052    use std::fmt::Debug;
1053
1054    pub(crate) fn test_roundtrip<T>(val: T)
1055    where
1056        T: for<'a> ReadThrift<'a, ThriftSliceInputProtocol<'a>> + WriteThrift + PartialEq + Debug,
1057    {
1058        let mut buf = Vec::<u8>::new();
1059        {
1060            let mut writer = ThriftCompactOutputProtocol::new(&mut buf);
1061            val.write_thrift(&mut writer).unwrap();
1062        }
1063
1064        let mut prot = ThriftSliceInputProtocol::new(&buf);
1065        let read_val = T::read_thrift(&mut prot).unwrap();
1066        assert_eq!(val, read_val);
1067    }
1068
1069    #[test]
1070    fn test_enum_roundtrip() {
1071        test_roundtrip(Type::BOOLEAN);
1072        test_roundtrip(Type::INT32);
1073        test_roundtrip(Type::INT64);
1074        test_roundtrip(Type::INT96);
1075        test_roundtrip(Type::FLOAT);
1076        test_roundtrip(Type::DOUBLE);
1077        test_roundtrip(Type::BYTE_ARRAY);
1078        test_roundtrip(Type::FIXED_LEN_BYTE_ARRAY);
1079    }
1080
1081    #[test]
1082    fn test_union_all_empty_roundtrip() {
1083        test_roundtrip(TimeUnit::MILLIS);
1084        test_roundtrip(TimeUnit::MICROS);
1085        test_roundtrip(TimeUnit::NANOS);
1086    }
1087}