parquet/
parquet_thrift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Structs used for encoding and decoding Parquet Thrift objects.
19//!
20//! These include:
21//! * [`ThriftCompactInputProtocol`]: Trait implemented by Thrift decoders.
22//!     * [`ThriftSliceInputProtocol`]: Thrift decoder that takes a slice of bytes as input.
23//!     * [`ThriftReadInputProtocol`]: Thrift decoder that takes a [`Read`] as input.
24//! * [`ReadThrift`]: Trait implemented by serializable objects.
25//! * [`ThriftCompactOutputProtocol`]: Thrift encoder.
26//! * [`WriteThrift`]: Trait implemented by serializable objects.
27//! * [`WriteThriftField`]: Trait implemented by serializable objects that are fields in Thrift structs.
28
29use std::{
30    cmp::Ordering,
31    io::{Read, Write},
32};
33
34use crate::errors::{ParquetError, Result};
35
36/// Wrapper for thrift `double` fields. This is used to provide
37/// an implementation of `Eq` for floats. This implementation
38/// uses IEEE 754 total order.
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub struct OrderedF64(f64);
41
42impl From<f64> for OrderedF64 {
43    fn from(value: f64) -> Self {
44        Self(value)
45    }
46}
47
48impl From<OrderedF64> for f64 {
49    fn from(value: OrderedF64) -> Self {
50        value.0
51    }
52}
53
54impl Eq for OrderedF64 {} // Marker trait, requires PartialEq
55
56impl Ord for OrderedF64 {
57    fn cmp(&self, other: &Self) -> Ordering {
58        self.0.total_cmp(&other.0)
59    }
60}
61
62impl PartialOrd for OrderedF64 {
63    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64        Some(self.cmp(other))
65    }
66}
67
68// Thrift compact protocol types for struct fields.
69#[derive(Clone, Copy, Debug, Eq, PartialEq)]
70pub(crate) enum FieldType {
71    Stop = 0,
72    BooleanTrue = 1,
73    BooleanFalse = 2,
74    Byte = 3,
75    I16 = 4,
76    I32 = 5,
77    I64 = 6,
78    Double = 7,
79    Binary = 8,
80    List = 9,
81    Set = 10,
82    Map = 11,
83    Struct = 12,
84}
85
86impl TryFrom<u8> for FieldType {
87    type Error = ParquetError;
88    fn try_from(value: u8) -> Result<Self> {
89        match value {
90            0 => Ok(Self::Stop),
91            1 => Ok(Self::BooleanTrue),
92            2 => Ok(Self::BooleanFalse),
93            3 => Ok(Self::Byte),
94            4 => Ok(Self::I16),
95            5 => Ok(Self::I32),
96            6 => Ok(Self::I64),
97            7 => Ok(Self::Double),
98            8 => Ok(Self::Binary),
99            9 => Ok(Self::List),
100            10 => Ok(Self::Set),
101            11 => Ok(Self::Map),
102            12 => Ok(Self::Struct),
103            _ => Err(general_err!("Unexpected struct field type{}", value)),
104        }
105    }
106}
107
108impl TryFrom<ElementType> for FieldType {
109    type Error = ParquetError;
110    fn try_from(value: ElementType) -> std::result::Result<Self, Self::Error> {
111        match value {
112            ElementType::Bool => Ok(Self::BooleanTrue),
113            ElementType::Byte => Ok(Self::Byte),
114            ElementType::I16 => Ok(Self::I16),
115            ElementType::I32 => Ok(Self::I32),
116            ElementType::I64 => Ok(Self::I64),
117            ElementType::Double => Ok(Self::Double),
118            ElementType::Binary => Ok(Self::Binary),
119            ElementType::List => Ok(Self::List),
120            ElementType::Struct => Ok(Self::Struct),
121            _ => Err(general_err!("Unexpected list element type{:?}", value)),
122        }
123    }
124}
125
126// Thrift compact protocol types for list elements
127#[derive(Clone, Copy, Debug, Eq, PartialEq)]
128pub(crate) enum ElementType {
129    Bool = 2,
130    Byte = 3,
131    I16 = 4,
132    I32 = 5,
133    I64 = 6,
134    Double = 7,
135    Binary = 8,
136    List = 9,
137    Set = 10,
138    Map = 11,
139    Struct = 12,
140}
141
142impl TryFrom<u8> for ElementType {
143    type Error = ParquetError;
144    fn try_from(value: u8) -> Result<Self> {
145        match value {
146            // For historical and compatibility reasons, a reader should be capable to deal with both cases.
147            // The only valid value in the original spec was 2, but due to an widespread implementation bug
148            // the defacto standard across large parts of the library became 1 instead.
149            // As a result, both values are now allowed.
150            // https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
151            1 | 2 => Ok(Self::Bool),
152            3 => Ok(Self::Byte),
153            4 => Ok(Self::I16),
154            5 => Ok(Self::I32),
155            6 => Ok(Self::I64),
156            7 => Ok(Self::Double),
157            8 => Ok(Self::Binary),
158            9 => Ok(Self::List),
159            10 => Ok(Self::Set),
160            11 => Ok(Self::Map),
161            12 => Ok(Self::Struct),
162            _ => Err(general_err!("Unexpected list/set element type{}", value)),
163        }
164    }
165}
166
167/// Struct used to describe a [thrift struct] field during decoding.
168///
169/// [thrift struct]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
170pub(crate) struct FieldIdentifier {
171    /// The type for the field.
172    pub(crate) field_type: FieldType,
173    /// The field's `id`. May be computed from delta or directly decoded.
174    pub(crate) id: i16,
175    /// Stores the value for booleans.
176    ///
177    /// Boolean fields store no data, instead the field type is either boolean true, or
178    /// boolean false.
179    pub(crate) bool_val: Option<bool>,
180}
181
182/// Struct used to describe a [thrift list].
183///
184/// [thrift list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
185#[derive(Clone, Debug, Eq, PartialEq)]
186pub(crate) struct ListIdentifier {
187    /// The type for each element in the list.
188    pub(crate) element_type: ElementType,
189    /// Number of elements contained in the list.
190    pub(crate) size: i32,
191}
192
193/// Low-level object used to deserialize structs encoded with the Thrift [compact] protocol.
194///
195/// Implementation of this trait must provide the low-level functions `read_byte`, `read_bytes`,
196/// `skip_bytes`, and `read_double`. These primitives are used by the default functions provided
197/// here to perform deserialization.
198///
199/// [compact]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
200pub(crate) trait ThriftCompactInputProtocol<'a> {
201    /// Read a single byte from the input.
202    fn read_byte(&mut self) -> Result<u8>;
203
204    /// Read a Thrift encoded [binary] from the input.
205    ///
206    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
207    fn read_bytes(&mut self) -> Result<&'a [u8]>;
208
209    fn read_bytes_owned(&mut self) -> Result<Vec<u8>>;
210
211    /// Skip the next `n` bytes of input.
212    fn skip_bytes(&mut self, n: usize) -> Result<()>;
213
214    /// Read a ULEB128 encoded unsigned varint from the input.
215    fn read_vlq(&mut self) -> Result<u64> {
216        let mut in_progress = 0;
217        let mut shift = 0;
218        loop {
219            let byte = self.read_byte()?;
220            in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
221            if byte & 0x80 == 0 {
222                return Ok(in_progress);
223            }
224            shift += 7;
225        }
226    }
227
228    /// Read a zig-zag encoded signed varint from the input.
229    fn read_zig_zag(&mut self) -> Result<i64> {
230        let val = self.read_vlq()?;
231        Ok((val >> 1) as i64 ^ -((val & 1) as i64))
232    }
233
234    /// Read the [`ListIdentifier`] for a Thrift encoded list.
235    fn read_list_begin(&mut self) -> Result<ListIdentifier> {
236        let header = self.read_byte()?;
237        let element_type = ElementType::try_from(header & 0x0f)?;
238
239        let possible_element_count = (header & 0xF0) >> 4;
240        let element_count = if possible_element_count != 15 {
241            // high bits set high if count and type encoded separately
242            possible_element_count as i32
243        } else {
244            self.read_vlq()? as _
245        };
246
247        Ok(ListIdentifier {
248            element_type,
249            size: element_count,
250        })
251    }
252
253    /// Read the [`FieldIdentifier`] for a field in a Thrift encoded struct.
254    fn read_field_begin(&mut self, last_field_id: i16) -> Result<FieldIdentifier> {
255        // we can read at least one byte, which is:
256        // - the type
257        // - the field delta and the type
258        let field_type = self.read_byte()?;
259        let field_delta = (field_type & 0xf0) >> 4;
260        let field_type = FieldType::try_from(field_type & 0xf)?;
261        let mut bool_val: Option<bool> = None;
262
263        match field_type {
264            FieldType::Stop => Ok(FieldIdentifier {
265                field_type: FieldType::Stop,
266                id: 0,
267                bool_val,
268            }),
269            _ => {
270                // special handling for bools
271                if field_type == FieldType::BooleanFalse {
272                    bool_val = Some(false);
273                } else if field_type == FieldType::BooleanTrue {
274                    bool_val = Some(true);
275                }
276                let field_id = if field_delta != 0 {
277                    last_field_id.checked_add(field_delta as i16).map_or_else(
278                        || {
279                            Err(general_err!(format!(
280                                "cannot add {} to {}",
281                                field_delta, last_field_id
282                            )))
283                        },
284                        Ok,
285                    )?
286                } else {
287                    self.read_i16()?
288                };
289
290                Ok(FieldIdentifier {
291                    field_type,
292                    id: field_id,
293                    bool_val,
294                })
295            }
296        }
297    }
298
299    /// This is a specialized version of [`Self::read_field_begin`], solely for use in parsing
300    /// simple structs. This function assumes that the delta field will always be less than 0xf,
301    /// fields will be in order, and no boolean fields will be read.
302    /// This also skips validation of the field type.
303    ///
304    /// Returns a tuple of `(field_type, field_delta)`.
305    fn read_field_header(&mut self) -> Result<(u8, u8)> {
306        let field_type = self.read_byte()?;
307        let field_delta = (field_type & 0xf0) >> 4;
308        let field_type = field_type & 0xf;
309        Ok((field_type, field_delta))
310    }
311
312    /// Read a boolean list element. This should not be used for struct fields. For the latter,
313    /// use the [`FieldIdentifier::bool_val`] field.
314    fn read_bool(&mut self) -> Result<bool> {
315        let b = self.read_byte()?;
316        // Previous versions of the thrift specification said to use 0 and 1 inside collections,
317        // but that differed from existing implementations.
318        // The specification was updated in https://github.com/apache/thrift/commit/2c29c5665bc442e703480bb0ee60fe925ffe02e8.
319        // At least the go implementation seems to have followed the previously documented values.
320        match b {
321            0x01 => Ok(true),
322            0x00 | 0x02 => Ok(false),
323            unkn => Err(general_err!(format!("cannot convert {unkn} into bool"))),
324        }
325    }
326
327    /// Read a Thrift [binary] as a UTF-8 encoded string.
328    ///
329    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
330    fn read_string(&mut self) -> Result<&'a str> {
331        let slice = self.read_bytes()?;
332        Ok(std::str::from_utf8(slice)?)
333    }
334
335    /// Read an `i8`.
336    fn read_i8(&mut self) -> Result<i8> {
337        Ok(self.read_byte()? as _)
338    }
339
340    /// Read an `i16`.
341    fn read_i16(&mut self) -> Result<i16> {
342        Ok(self.read_zig_zag()? as _)
343    }
344
345    /// Read an `i32`.
346    fn read_i32(&mut self) -> Result<i32> {
347        Ok(self.read_zig_zag()? as _)
348    }
349
350    /// Read an `i64`.
351    fn read_i64(&mut self) -> Result<i64> {
352        self.read_zig_zag()
353    }
354
355    /// Read a Thrift `double` as `f64`.
356    fn read_double(&mut self) -> Result<f64>;
357
358    /// Skip a ULEB128 encoded varint.
359    fn skip_vlq(&mut self) -> Result<()> {
360        loop {
361            let byte = self.read_byte()?;
362            if byte & 0x80 == 0 {
363                return Ok(());
364            }
365        }
366    }
367
368    /// Skip a thrift [binary].
369    ///
370    /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
371    fn skip_binary(&mut self) -> Result<()> {
372        let len = self.read_vlq()? as usize;
373        self.skip_bytes(len)
374    }
375
376    /// Skip a field with type `field_type` recursively until the default
377    /// maximum skip depth (currently 64) is reached.
378    fn skip(&mut self, field_type: FieldType) -> Result<()> {
379        const DEFAULT_SKIP_DEPTH: i8 = 64;
380        self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH)
381    }
382
383    /// Empty structs in unions consist of a single byte of 0 for the field stop record.
384    /// This skips that byte without encuring the cost of processing the [`FieldIdentifier`].
385    /// Will return an error if the struct is not actually empty.
386    fn skip_empty_struct(&mut self) -> Result<()> {
387        let b = self.read_byte()?;
388        if b != 0 {
389            Err(general_err!("Empty struct has fields"))
390        } else {
391            Ok(())
392        }
393    }
394
395    /// Skip a field with type `field_type` recursively up to `depth` levels.
396    fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> Result<()> {
397        if depth == 0 {
398            return Err(general_err!(format!("cannot parse past {:?}", field_type)));
399        }
400
401        match field_type {
402            // boolean field has no data
403            FieldType::BooleanFalse | FieldType::BooleanTrue => Ok(()),
404            FieldType::Byte => self.read_i8().map(|_| ()),
405            FieldType::I16 => self.skip_vlq().map(|_| ()),
406            FieldType::I32 => self.skip_vlq().map(|_| ()),
407            FieldType::I64 => self.skip_vlq().map(|_| ()),
408            FieldType::Double => self.skip_bytes(8).map(|_| ()),
409            FieldType::Binary => self.skip_binary().map(|_| ()),
410            FieldType::Struct => {
411                let mut last_field_id = 0i16;
412                loop {
413                    let field_ident = self.read_field_begin(last_field_id)?;
414                    if field_ident.field_type == FieldType::Stop {
415                        break;
416                    }
417                    self.skip_till_depth(field_ident.field_type, depth - 1)?;
418                    last_field_id = field_ident.id;
419                }
420                Ok(())
421            }
422            FieldType::List => {
423                let list_ident = self.read_list_begin()?;
424                for _ in 0..list_ident.size {
425                    let element_type = FieldType::try_from(list_ident.element_type)?;
426                    self.skip_till_depth(element_type, depth - 1)?;
427                }
428                Ok(())
429            }
430            // no list or map types in parquet format
431            u => Err(general_err!(format!("cannot skip field type {:?}", &u))),
432        }
433    }
434}
435
436/// A high performance Thrift reader that reads from a slice of bytes.
437pub(crate) struct ThriftSliceInputProtocol<'a> {
438    buf: &'a [u8],
439}
440
441impl<'a> ThriftSliceInputProtocol<'a> {
442    /// Create a new `ThriftSliceInputProtocol` using the bytes in `buf`.
443    pub fn new(buf: &'a [u8]) -> Self {
444        Self { buf }
445    }
446
447    /// Return the current buffer as a slice.
448    pub fn as_slice(&self) -> &'a [u8] {
449        self.buf
450    }
451}
452
453impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for ThriftSliceInputProtocol<'a> {
454    #[inline]
455    fn read_byte(&mut self) -> Result<u8> {
456        let ret = *self.buf.first().ok_or_else(eof_error)?;
457        self.buf = &self.buf[1..];
458        Ok(ret)
459    }
460
461    fn read_bytes(&mut self) -> Result<&'b [u8]> {
462        let len = self.read_vlq()? as usize;
463        let ret = self.buf.get(..len).ok_or_else(eof_error)?;
464        self.buf = &self.buf[len..];
465        Ok(ret)
466    }
467
468    fn read_bytes_owned(&mut self) -> Result<Vec<u8>> {
469        Ok(self.read_bytes()?.to_vec())
470    }
471
472    #[inline]
473    fn skip_bytes(&mut self, n: usize) -> Result<()> {
474        self.buf.get(..n).ok_or_else(eof_error)?;
475        self.buf = &self.buf[n..];
476        Ok(())
477    }
478
479    fn read_double(&mut self) -> Result<f64> {
480        let slice = self.buf.get(..8).ok_or_else(eof_error)?;
481        self.buf = &self.buf[8..];
482        match slice.try_into() {
483            Ok(slice) => Ok(f64::from_le_bytes(slice)),
484            Err(_) => Err(general_err!("Unexpected error converting slice")),
485        }
486    }
487}
488
489fn eof_error() -> ParquetError {
490    eof_err!("Unexpected EOF")
491}
492
493/// A Thrift input protocol that wraps a [`Read`] object.
494///
495/// Note that this is only intended for use in reading Parquet page headers. This will panic
496/// if Thrift `binary` data is encountered because a slice of that data cannot be returned.
497pub(crate) struct ThriftReadInputProtocol<R: Read> {
498    reader: R,
499}
500
501impl<R: Read> ThriftReadInputProtocol<R> {
502    pub(crate) fn new(reader: R) -> Self {
503        Self { reader }
504    }
505}
506
507impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol<R> {
508    #[inline]
509    fn read_byte(&mut self) -> Result<u8> {
510        let mut buf = [0_u8; 1];
511        self.reader.read_exact(&mut buf)?;
512        Ok(buf[0])
513    }
514
515    fn read_bytes(&mut self) -> Result<&'a [u8]> {
516        unimplemented!()
517    }
518
519    fn read_bytes_owned(&mut self) -> Result<Vec<u8>> {
520        let len = self.read_vlq()? as usize;
521        let mut v = Vec::with_capacity(len);
522        std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?;
523        Ok(v)
524    }
525
526    fn skip_bytes(&mut self, n: usize) -> Result<()> {
527        std::io::copy(
528            &mut self.reader.by_ref().take(n as u64),
529            &mut std::io::sink(),
530        )?;
531        Ok(())
532    }
533
534    fn read_double(&mut self) -> Result<f64> {
535        let mut buf = [0_u8; 8];
536        self.reader.read_exact(&mut buf)?;
537        Ok(f64::from_le_bytes(buf))
538    }
539}
540
541/// Trait implemented for objects that can be deserialized from a Thrift input stream.
542/// Implementations are provided for Thrift primitive types.
543pub(crate) trait ReadThrift<'a, R: ThriftCompactInputProtocol<'a>> {
544    /// Read an object of type `Self` from the input protocol object.
545    fn read_thrift(prot: &mut R) -> Result<Self>
546    where
547        Self: Sized;
548}
549
550impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool {
551    fn read_thrift(prot: &mut R) -> Result<Self> {
552        prot.read_bool()
553    }
554}
555
556impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 {
557    fn read_thrift(prot: &mut R) -> Result<Self> {
558        prot.read_i8()
559    }
560}
561
562impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 {
563    fn read_thrift(prot: &mut R) -> Result<Self> {
564        prot.read_i16()
565    }
566}
567
568impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 {
569    fn read_thrift(prot: &mut R) -> Result<Self> {
570        prot.read_i32()
571    }
572}
573
574impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 {
575    fn read_thrift(prot: &mut R) -> Result<Self> {
576        prot.read_i64()
577    }
578}
579
580impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for OrderedF64 {
581    fn read_thrift(prot: &mut R) -> Result<Self> {
582        Ok(OrderedF64(prot.read_double()?))
583    }
584}
585
586impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str {
587    fn read_thrift(prot: &mut R) -> Result<Self> {
588        prot.read_string()
589    }
590}
591
592impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for String {
593    fn read_thrift(prot: &mut R) -> Result<Self> {
594        Ok(String::from_utf8(prot.read_bytes_owned()?)?)
595    }
596}
597
598impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] {
599    fn read_thrift(prot: &mut R) -> Result<Self> {
600        prot.read_bytes()
601    }
602}
603
604/// Read a Thrift encoded [list] from the input protocol object.
605///
606/// [list]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
607pub(crate) fn read_thrift_vec<'a, T, R>(prot: &mut R) -> Result<Vec<T>>
608where
609    R: ThriftCompactInputProtocol<'a>,
610    T: ReadThrift<'a, R>,
611{
612    let list_ident = prot.read_list_begin()?;
613    let mut res = Vec::with_capacity(list_ident.size as usize);
614    for _ in 0..list_ident.size {
615        let val = T::read_thrift(prot)?;
616        res.push(val);
617    }
618    Ok(res)
619}
620
621/////////////////////////
622// thrift compact output
623
624/// Low-level object used to serialize structs to the Thrift [compact output] protocol.
625///
626/// This struct serves as a wrapper around a [`Write`] object, to which thrift encoded data
627/// will written. The implementation provides functions to write Thrift primitive types, as well
628/// as functions used in the encoding of lists and structs. This should rarely be used directly,
629/// but is instead intended for use by implementers of [`WriteThrift`] and [`WriteThriftField`].
630///
631/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
632pub(crate) struct ThriftCompactOutputProtocol<W: Write> {
633    writer: W,
634}
635
636impl<W: Write> ThriftCompactOutputProtocol<W> {
637    /// Create a new `ThriftCompactOutputProtocol` wrapping the byte sink `writer`.
638    pub(crate) fn new(writer: W) -> Self {
639        Self { writer }
640    }
641
642    /// Write a single byte to the output stream.
643    fn write_byte(&mut self, b: u8) -> Result<()> {
644        self.writer.write_all(&[b])?;
645        Ok(())
646    }
647
648    /// Write the given `u64` as a ULEB128 encoded varint.
649    fn write_vlq(&mut self, val: u64) -> Result<()> {
650        let mut v = val;
651        while v > 0x7f {
652            self.write_byte(v as u8 | 0x80)?;
653            v >>= 7;
654        }
655        self.write_byte(v as u8)
656    }
657
658    /// Write the given `i64` as a zig-zag encoded varint.
659    fn write_zig_zag(&mut self, val: i64) -> Result<()> {
660        let s = (val < 0) as i64;
661        self.write_vlq((((val ^ -s) << 1) + s) as u64)
662    }
663
664    /// Used to mark the start of a Thrift struct field of type `field_type`. `last_field_id`
665    /// is used to compute a delta to the given `field_id` per the compact protocol [spec].
666    ///
667    /// [spec]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
668    pub(crate) fn write_field_begin(
669        &mut self,
670        field_type: FieldType,
671        field_id: i16,
672        last_field_id: i16,
673    ) -> Result<()> {
674        let delta = field_id.wrapping_sub(last_field_id);
675        if delta > 0 && delta <= 0xf {
676            self.write_byte((delta as u8) << 4 | field_type as u8)
677        } else {
678            self.write_byte(field_type as u8)?;
679            self.write_i16(field_id)
680        }
681    }
682
683    /// Used to indicate the start of a list of `element_type` elements.
684    pub(crate) fn write_list_begin(&mut self, element_type: ElementType, len: usize) -> Result<()> {
685        if len < 15 {
686            self.write_byte((len as u8) << 4 | element_type as u8)
687        } else {
688            self.write_byte(0xf0u8 | element_type as u8)?;
689            self.write_vlq(len as _)
690        }
691    }
692
693    /// Used to mark the end of a struct. This must be called after all fields of the struct have
694    /// been written.
695    pub(crate) fn write_struct_end(&mut self) -> Result<()> {
696        self.write_byte(0)
697    }
698
699    /// Serialize a slice of `u8`s. This will encode a length, and then write the bytes without
700    /// further encoding.
701    pub(crate) fn write_bytes(&mut self, val: &[u8]) -> Result<()> {
702        self.write_vlq(val.len() as u64)?;
703        self.writer.write_all(val)?;
704        Ok(())
705    }
706
707    /// Short-cut method used to encode structs that have no fields (often used in Thrift unions).
708    /// This simply encodes the field id and then immediately writes the end-of-struct marker.
709    pub(crate) fn write_empty_struct(&mut self, field_id: i16, last_field_id: i16) -> Result<i16> {
710        self.write_field_begin(FieldType::Struct, field_id, last_field_id)?;
711        self.write_struct_end()?;
712        Ok(last_field_id)
713    }
714
715    /// Write a boolean value.
716    pub(crate) fn write_bool(&mut self, val: bool) -> Result<()> {
717        match val {
718            true => self.write_byte(1),
719            false => self.write_byte(2),
720        }
721    }
722
723    /// Write a zig-zag encoded `i8` value.
724    pub(crate) fn write_i8(&mut self, val: i8) -> Result<()> {
725        self.write_byte(val as u8)
726    }
727
728    /// Write a zig-zag encoded `i16` value.
729    pub(crate) fn write_i16(&mut self, val: i16) -> Result<()> {
730        self.write_zig_zag(val as _)
731    }
732
733    /// Write a zig-zag encoded `i32` value.
734    pub(crate) fn write_i32(&mut self, val: i32) -> Result<()> {
735        self.write_zig_zag(val as _)
736    }
737
738    /// Write a zig-zag encoded `i64` value.
739    pub(crate) fn write_i64(&mut self, val: i64) -> Result<()> {
740        self.write_zig_zag(val as _)
741    }
742
743    /// Write a double value.
744    pub(crate) fn write_double(&mut self, val: f64) -> Result<()> {
745        self.writer.write_all(&val.to_le_bytes())?;
746        Ok(())
747    }
748}
749
750/// Trait implemented by objects that are to be serialized to a Thrift [compact output] protocol
751/// stream. Implementations are also provided for primitive Thrift types.
752///
753/// [compact output]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
754pub(crate) trait WriteThrift {
755    /// The [`ElementType`] to use when a list of this object is written.
756    const ELEMENT_TYPE: ElementType;
757
758    /// Serialize this object to the given `writer`.
759    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()>;
760}
761
762/// Implementation for a vector of thrift serializable objects that implement [`WriteThrift`].
763/// This will write the necessary list header and then serialize the elements one-at-a-time.
764impl<T> WriteThrift for Vec<T>
765where
766    T: WriteThrift,
767{
768    const ELEMENT_TYPE: ElementType = ElementType::List;
769
770    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
771        writer.write_list_begin(T::ELEMENT_TYPE, self.len())?;
772        for item in self {
773            item.write_thrift(writer)?;
774        }
775        Ok(())
776    }
777}
778
779impl WriteThrift for bool {
780    const ELEMENT_TYPE: ElementType = ElementType::Bool;
781
782    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
783        writer.write_bool(*self)
784    }
785}
786
787impl WriteThrift for i8 {
788    const ELEMENT_TYPE: ElementType = ElementType::Byte;
789
790    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
791        writer.write_i8(*self)
792    }
793}
794
795impl WriteThrift for i16 {
796    const ELEMENT_TYPE: ElementType = ElementType::I16;
797
798    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
799        writer.write_i16(*self)
800    }
801}
802
803impl WriteThrift for i32 {
804    const ELEMENT_TYPE: ElementType = ElementType::I32;
805
806    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
807        writer.write_i32(*self)
808    }
809}
810
811impl WriteThrift for i64 {
812    const ELEMENT_TYPE: ElementType = ElementType::I64;
813
814    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
815        writer.write_i64(*self)
816    }
817}
818
819impl WriteThrift for OrderedF64 {
820    const ELEMENT_TYPE: ElementType = ElementType::Double;
821
822    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
823        writer.write_double(self.0)
824    }
825}
826
827impl WriteThrift for f64 {
828    const ELEMENT_TYPE: ElementType = ElementType::Double;
829
830    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
831        writer.write_double(*self)
832    }
833}
834
835impl WriteThrift for &[u8] {
836    const ELEMENT_TYPE: ElementType = ElementType::Binary;
837
838    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
839        writer.write_bytes(self)
840    }
841}
842
843impl WriteThrift for &str {
844    const ELEMENT_TYPE: ElementType = ElementType::Binary;
845
846    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
847        writer.write_bytes(self.as_bytes())
848    }
849}
850
851impl WriteThrift for String {
852    const ELEMENT_TYPE: ElementType = ElementType::Binary;
853
854    fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
855        writer.write_bytes(self.as_bytes())
856    }
857}
858
859/// Trait implemented by objects that are fields of Thrift structs.
860///
861/// For example, given the Thrift struct definition
862/// ```ignore
863/// struct MyStruct {
864///   1: required i32 field1
865///   2: optional bool field2
866///   3: optional OtherStruct field3
867/// }
868/// ```
869///
870/// which becomes in Rust
871/// ```no_run
872/// # struct OtherStruct {}
873/// struct MyStruct {
874///   field1: i32,
875///   field2: Option<bool>,
876///   field3: Option<OtherStruct>,
877/// }
878/// ```
879/// the impl of `WriteThrift` for `MyStruct` will use the `WriteThriftField` impls for `i32`,
880/// `bool`, and `OtherStruct`.
881///
882/// ```ignore
883/// impl WriteThrift for MyStruct {
884///   fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
885///     let mut last_field_id = 0i16;
886///     last_field_id = self.field1.write_thrift_field(writer, 1, last_field_id)?;
887///     if self.field2.is_some() {
888///       // if field2 is `None` then this assignment won't happen and last_field_id will remain
889///       // `1` when writing `field3`
890///       last_field_id = self.field2.write_thrift_field(writer, 2, last_field_id)?;
891///     }
892///     if self.field3.is_some() {
893///       // no need to assign last_field_id since this is the final field.
894///       self.field3.write_thrift_field(writer, 3, last_field_id)?;
895///     }
896///     writer.write_struct_end()
897///   }
898/// }
899/// ```
900///
901pub(crate) trait WriteThriftField {
902    /// Used to write struct fields (which may be primitive or IDL defined types). This will
903    /// write the field marker for the given `field_id`, using `last_field_id` to compute the
904    /// field delta used by the Thrift [compact protocol]. On success this will return `field_id`
905    /// to be used in chaining.
906    ///
907    /// [compact protocol]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#struct-encoding
908    fn write_thrift_field<W: Write>(
909        &self,
910        writer: &mut ThriftCompactOutputProtocol<W>,
911        field_id: i16,
912        last_field_id: i16,
913    ) -> Result<i16>;
914}
915
916impl WriteThriftField for bool {
917    fn write_thrift_field<W: Write>(
918        &self,
919        writer: &mut ThriftCompactOutputProtocol<W>,
920        field_id: i16,
921        last_field_id: i16,
922    ) -> Result<i16> {
923        // boolean only writes the field header
924        match *self {
925            true => writer.write_field_begin(FieldType::BooleanTrue, field_id, last_field_id)?,
926            false => writer.write_field_begin(FieldType::BooleanFalse, field_id, last_field_id)?,
927        }
928        Ok(field_id)
929    }
930}
931
932impl WriteThriftField for i8 {
933    fn write_thrift_field<W: Write>(
934        &self,
935        writer: &mut ThriftCompactOutputProtocol<W>,
936        field_id: i16,
937        last_field_id: i16,
938    ) -> Result<i16> {
939        writer.write_field_begin(FieldType::Byte, field_id, last_field_id)?;
940        writer.write_i8(*self)?;
941        Ok(field_id)
942    }
943}
944
945impl WriteThriftField for i16 {
946    fn write_thrift_field<W: Write>(
947        &self,
948        writer: &mut ThriftCompactOutputProtocol<W>,
949        field_id: i16,
950        last_field_id: i16,
951    ) -> Result<i16> {
952        writer.write_field_begin(FieldType::I16, field_id, last_field_id)?;
953        writer.write_i16(*self)?;
954        Ok(field_id)
955    }
956}
957
958impl WriteThriftField for i32 {
959    fn write_thrift_field<W: Write>(
960        &self,
961        writer: &mut ThriftCompactOutputProtocol<W>,
962        field_id: i16,
963        last_field_id: i16,
964    ) -> Result<i16> {
965        writer.write_field_begin(FieldType::I32, field_id, last_field_id)?;
966        writer.write_i32(*self)?;
967        Ok(field_id)
968    }
969}
970
971impl WriteThriftField for i64 {
972    fn write_thrift_field<W: Write>(
973        &self,
974        writer: &mut ThriftCompactOutputProtocol<W>,
975        field_id: i16,
976        last_field_id: i16,
977    ) -> Result<i16> {
978        writer.write_field_begin(FieldType::I64, field_id, last_field_id)?;
979        writer.write_i64(*self)?;
980        Ok(field_id)
981    }
982}
983
984impl WriteThriftField for OrderedF64 {
985    fn write_thrift_field<W: Write>(
986        &self,
987        writer: &mut ThriftCompactOutputProtocol<W>,
988        field_id: i16,
989        last_field_id: i16,
990    ) -> Result<i16> {
991        writer.write_field_begin(FieldType::Double, field_id, last_field_id)?;
992        writer.write_double(self.0)?;
993        Ok(field_id)
994    }
995}
996
997impl WriteThriftField for f64 {
998    fn write_thrift_field<W: Write>(
999        &self,
1000        writer: &mut ThriftCompactOutputProtocol<W>,
1001        field_id: i16,
1002        last_field_id: i16,
1003    ) -> Result<i16> {
1004        writer.write_field_begin(FieldType::Double, field_id, last_field_id)?;
1005        writer.write_double(*self)?;
1006        Ok(field_id)
1007    }
1008}
1009
1010impl WriteThriftField for &[u8] {
1011    fn write_thrift_field<W: Write>(
1012        &self,
1013        writer: &mut ThriftCompactOutputProtocol<W>,
1014        field_id: i16,
1015        last_field_id: i16,
1016    ) -> Result<i16> {
1017        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1018        writer.write_bytes(self)?;
1019        Ok(field_id)
1020    }
1021}
1022
1023impl WriteThriftField for &str {
1024    fn write_thrift_field<W: Write>(
1025        &self,
1026        writer: &mut ThriftCompactOutputProtocol<W>,
1027        field_id: i16,
1028        last_field_id: i16,
1029    ) -> Result<i16> {
1030        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1031        writer.write_bytes(self.as_bytes())?;
1032        Ok(field_id)
1033    }
1034}
1035
1036impl WriteThriftField for String {
1037    fn write_thrift_field<W: Write>(
1038        &self,
1039        writer: &mut ThriftCompactOutputProtocol<W>,
1040        field_id: i16,
1041        last_field_id: i16,
1042    ) -> Result<i16> {
1043        writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1044        writer.write_bytes(self.as_bytes())?;
1045        Ok(field_id)
1046    }
1047}
1048
1049impl<T> WriteThriftField for Vec<T>
1050where
1051    T: WriteThrift,
1052{
1053    fn write_thrift_field<W: Write>(
1054        &self,
1055        writer: &mut ThriftCompactOutputProtocol<W>,
1056        field_id: i16,
1057        last_field_id: i16,
1058    ) -> Result<i16> {
1059        writer.write_field_begin(FieldType::List, field_id, last_field_id)?;
1060        self.write_thrift(writer)?;
1061        Ok(field_id)
1062    }
1063}
1064
1065#[cfg(test)]
1066pub(crate) mod tests {
1067    use crate::basic::{TimeUnit, Type};
1068
1069    use super::*;
1070    use std::fmt::Debug;
1071
1072    pub(crate) fn test_roundtrip<T>(val: T)
1073    where
1074        T: for<'a> ReadThrift<'a, ThriftSliceInputProtocol<'a>> + WriteThrift + PartialEq + Debug,
1075    {
1076        let mut buf = Vec::<u8>::new();
1077        {
1078            let mut writer = ThriftCompactOutputProtocol::new(&mut buf);
1079            val.write_thrift(&mut writer).unwrap();
1080        }
1081
1082        let mut prot = ThriftSliceInputProtocol::new(&buf);
1083        let read_val = T::read_thrift(&mut prot).unwrap();
1084        assert_eq!(val, read_val);
1085    }
1086
1087    #[test]
1088    fn test_enum_roundtrip() {
1089        test_roundtrip(Type::BOOLEAN);
1090        test_roundtrip(Type::INT32);
1091        test_roundtrip(Type::INT64);
1092        test_roundtrip(Type::INT96);
1093        test_roundtrip(Type::FLOAT);
1094        test_roundtrip(Type::DOUBLE);
1095        test_roundtrip(Type::BYTE_ARRAY);
1096        test_roundtrip(Type::FIXED_LEN_BYTE_ARRAY);
1097    }
1098
1099    #[test]
1100    fn test_union_all_empty_roundtrip() {
1101        test_roundtrip(TimeUnit::MILLIS);
1102        test_roundtrip(TimeUnit::MICROS);
1103        test_roundtrip(TimeUnit::NANOS);
1104    }
1105}