parquet/
thrift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Custom thrift definitions
19
20pub use thrift::protocol::TCompactOutputProtocol;
21use thrift::protocol::{
22    TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
23    TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
24};
25
26/// Reads and writes the struct to Thrift protocols.
27///
28/// Unlike [`thrift::protocol::TSerializable`] this uses generics instead of trait objects
29pub trait TSerializable: Sized {
30    /// Reads the struct from the input Thrift protocol
31    fn read_from_in_protocol<T: TInputProtocol>(i_prot: &mut T) -> thrift::Result<Self>;
32    /// Writes the struct to the output Thrift protocol
33    fn write_to_out_protocol<T: TOutputProtocol>(&self, o_prot: &mut T) -> thrift::Result<()>;
34}
35
36/// Public function to aid benchmarking.
37pub fn bench_file_metadata(bytes: &bytes::Bytes) {
38    let mut input = TCompactSliceInputProtocol::new(bytes);
39    crate::format::FileMetaData::read_from_in_protocol(&mut input).unwrap();
40}
41
42/// A more performant implementation of [`TCompactInputProtocol`] that reads a slice
43///
44/// [`TCompactInputProtocol`]: thrift::protocol::TCompactInputProtocol
45pub(crate) struct TCompactSliceInputProtocol<'a> {
46    buf: &'a [u8],
47    // Identifier of the last field deserialized for a struct.
48    last_read_field_id: i16,
49    // Stack of the last read field ids (a new entry is added each time a nested struct is read).
50    read_field_id_stack: Vec<i16>,
51    // Boolean value for a field.
52    // Saved because boolean fields and their value are encoded in a single byte,
53    // and reading the field only occurs after the field id is read.
54    pending_read_bool_value: Option<bool>,
55}
56
57impl<'a> TCompactSliceInputProtocol<'a> {
58    pub fn new(buf: &'a [u8]) -> Self {
59        Self {
60            buf,
61            last_read_field_id: 0,
62            read_field_id_stack: Vec::with_capacity(16),
63            pending_read_bool_value: None,
64        }
65    }
66
67    pub fn as_slice(&self) -> &'a [u8] {
68        self.buf
69    }
70
71    fn read_vlq(&mut self) -> thrift::Result<u64> {
72        let mut in_progress = 0;
73        let mut shift = 0;
74        loop {
75            let byte = self.read_byte()?;
76            in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
77            shift += 7;
78            if byte & 0x80 == 0 {
79                return Ok(in_progress);
80            }
81        }
82    }
83
84    fn read_zig_zag(&mut self) -> thrift::Result<i64> {
85        let val = self.read_vlq()?;
86        Ok((val >> 1) as i64 ^ -((val & 1) as i64))
87    }
88
89    fn read_list_set_begin(&mut self) -> thrift::Result<(TType, i32)> {
90        let header = self.read_byte()?;
91        let element_type = collection_u8_to_type(header & 0x0F)?;
92
93        let possible_element_count = (header & 0xF0) >> 4;
94        let element_count = if possible_element_count != 15 {
95            // high bits set high if count and type encoded separately
96            possible_element_count as i32
97        } else {
98            self.read_vlq()? as _
99        };
100
101        Ok((element_type, element_count))
102    }
103}
104
105macro_rules! thrift_unimplemented {
106    () => {
107        Err(thrift::Error::Protocol(thrift::ProtocolError {
108            kind: thrift::ProtocolErrorKind::NotImplemented,
109            message: "not implemented".to_string(),
110        }))
111    };
112}
113
114impl TInputProtocol for TCompactSliceInputProtocol<'_> {
115    fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
116        unimplemented!()
117    }
118
119    fn read_message_end(&mut self) -> thrift::Result<()> {
120        thrift_unimplemented!()
121    }
122
123    fn read_struct_begin(&mut self) -> thrift::Result<Option<TStructIdentifier>> {
124        self.read_field_id_stack.push(self.last_read_field_id);
125        self.last_read_field_id = 0;
126        Ok(None)
127    }
128
129    fn read_struct_end(&mut self) -> thrift::Result<()> {
130        self.last_read_field_id = self
131            .read_field_id_stack
132            .pop()
133            .expect("should have previous field ids");
134        Ok(())
135    }
136
137    fn read_field_begin(&mut self) -> thrift::Result<TFieldIdentifier> {
138        // we can read at least one byte, which is:
139        // - the type
140        // - the field delta and the type
141        let field_type = self.read_byte()?;
142        let field_delta = (field_type & 0xF0) >> 4;
143        let field_type = match field_type & 0x0F {
144            0x01 => {
145                self.pending_read_bool_value = Some(true);
146                Ok(TType::Bool)
147            }
148            0x02 => {
149                self.pending_read_bool_value = Some(false);
150                Ok(TType::Bool)
151            }
152            ttu8 => u8_to_type(ttu8),
153        }?;
154
155        match field_type {
156            TType::Stop => Ok(
157                TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
158                    None,
159                    TType::Stop,
160                    None,
161                ),
162            ),
163            _ => {
164                if field_delta != 0 {
165                    self.last_read_field_id = self
166                        .last_read_field_id
167                        .checked_add(field_delta as i16)
168                        .map_or_else(
169                            || {
170                                Err(thrift::Error::Protocol(thrift::ProtocolError {
171                                    kind: thrift::ProtocolErrorKind::InvalidData,
172                                    message: format!(
173                                        "cannot add {} to {}",
174                                        field_delta, self.last_read_field_id
175                                    ),
176                                }))
177                            },
178                            Ok,
179                        )?;
180                } else {
181                    self.last_read_field_id = self.read_i16()?;
182                };
183
184                Ok(TFieldIdentifier {
185                    name: None,
186                    field_type,
187                    id: Some(self.last_read_field_id),
188                })
189            }
190        }
191    }
192
193    fn read_field_end(&mut self) -> thrift::Result<()> {
194        Ok(())
195    }
196
197    fn read_bool(&mut self) -> thrift::Result<bool> {
198        match self.pending_read_bool_value.take() {
199            Some(b) => Ok(b),
200            None => {
201                let b = self.read_byte()?;
202                // Previous versions of the thrift specification said to use 0 and 1 inside collections,
203                // but that differed from existing implementations.
204                // The specification was updated in https://github.com/apache/thrift/commit/2c29c5665bc442e703480bb0ee60fe925ffe02e8.
205                // At least the go implementation seems to have followed the previously documented values.
206                match b {
207                    0x01 => Ok(true),
208                    0x00 | 0x02 => Ok(false),
209                    unkn => Err(thrift::Error::Protocol(thrift::ProtocolError {
210                        kind: thrift::ProtocolErrorKind::InvalidData,
211                        message: format!("cannot convert {unkn} into bool"),
212                    })),
213                }
214            }
215        }
216    }
217
218    fn read_bytes(&mut self) -> thrift::Result<Vec<u8>> {
219        let len = self.read_vlq()? as usize;
220        let ret = self.buf.get(..len).ok_or_else(eof_error)?.to_vec();
221        self.buf = &self.buf[len..];
222        Ok(ret)
223    }
224
225    fn read_i8(&mut self) -> thrift::Result<i8> {
226        Ok(self.read_byte()? as _)
227    }
228
229    fn read_i16(&mut self) -> thrift::Result<i16> {
230        Ok(self.read_zig_zag()? as _)
231    }
232
233    fn read_i32(&mut self) -> thrift::Result<i32> {
234        Ok(self.read_zig_zag()? as _)
235    }
236
237    fn read_i64(&mut self) -> thrift::Result<i64> {
238        self.read_zig_zag()
239    }
240
241    fn read_double(&mut self) -> thrift::Result<f64> {
242        let slice = (self.buf[..8]).try_into().unwrap();
243        self.buf = &self.buf[8..];
244        Ok(f64::from_le_bytes(slice))
245    }
246
247    fn read_string(&mut self) -> thrift::Result<String> {
248        let bytes = self.read_bytes()?;
249        String::from_utf8(bytes).map_err(From::from)
250    }
251
252    fn read_list_begin(&mut self) -> thrift::Result<TListIdentifier> {
253        let (element_type, element_count) = self.read_list_set_begin()?;
254        Ok(TListIdentifier::new(element_type, element_count))
255    }
256
257    fn read_list_end(&mut self) -> thrift::Result<()> {
258        Ok(())
259    }
260
261    fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
262        thrift_unimplemented!()
263    }
264
265    fn read_set_end(&mut self) -> thrift::Result<()> {
266        thrift_unimplemented!()
267    }
268
269    fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
270        thrift_unimplemented!()
271    }
272
273    fn read_map_end(&mut self) -> thrift::Result<()> {
274        Ok(())
275    }
276
277    #[inline]
278    fn read_byte(&mut self) -> thrift::Result<u8> {
279        let ret = *self.buf.first().ok_or_else(eof_error)?;
280        self.buf = &self.buf[1..];
281        Ok(ret)
282    }
283}
284
285fn collection_u8_to_type(b: u8) -> thrift::Result<TType> {
286    match b {
287        // For historical and compatibility reasons, a reader should be capable to deal with both cases.
288        // The only valid value in the original spec was 2, but due to an widespread implementation bug
289        // the defacto standard across large parts of the library became 1 instead.
290        // As a result, both values are now allowed.
291        // https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#list-and-set
292        0x01 | 0x02 => Ok(TType::Bool),
293        o => u8_to_type(o),
294    }
295}
296
297fn u8_to_type(b: u8) -> thrift::Result<TType> {
298    match b {
299        0x00 => Ok(TType::Stop),
300        0x03 => Ok(TType::I08), // equivalent to TType::Byte
301        0x04 => Ok(TType::I16),
302        0x05 => Ok(TType::I32),
303        0x06 => Ok(TType::I64),
304        0x07 => Ok(TType::Double),
305        0x08 => Ok(TType::String),
306        0x09 => Ok(TType::List),
307        0x0A => Ok(TType::Set),
308        0x0B => Ok(TType::Map),
309        0x0C => Ok(TType::Struct),
310        unkn => Err(thrift::Error::Protocol(thrift::ProtocolError {
311            kind: thrift::ProtocolErrorKind::InvalidData,
312            message: format!("cannot convert {unkn} into TType"),
313        })),
314    }
315}
316
317fn eof_error() -> thrift::Error {
318    thrift::Error::Transport(thrift::TransportError {
319        kind: thrift::TransportErrorKind::EndOfFile,
320        message: "Unexpected EOF".to_string(),
321    })
322}
323
324#[cfg(test)]
325mod tests {
326    use crate::format::{BoundaryOrder, ColumnIndex};
327    use crate::thrift::{TCompactSliceInputProtocol, TSerializable};
328
329    #[test]
330    pub fn read_boolean_list_field_type() {
331        // Boolean collection type encoded as 0x01, as used by this crate when writing.
332        // Values encoded as 1 (true) or 2 (false) as in the current version of the thrift
333        // documentation.
334        let bytes = vec![0x19, 0x21, 2, 1, 0x19, 8, 0x19, 8, 0x15, 0, 0];
335
336        let mut protocol = TCompactSliceInputProtocol::new(bytes.as_slice());
337        let index = ColumnIndex::read_from_in_protocol(&mut protocol).unwrap();
338        let expected = ColumnIndex {
339            null_pages: vec![false, true],
340            min_values: vec![],
341            max_values: vec![],
342            boundary_order: BoundaryOrder::UNORDERED,
343            null_counts: None,
344            repetition_level_histograms: None,
345            definition_level_histograms: None,
346        };
347
348        assert_eq!(&index, &expected);
349    }
350
351    #[test]
352    pub fn read_boolean_list_alternative_encoding() {
353        // Boolean collection type encoded as 0x02, as allowed by the spec.
354        // Values encoded as 1 (true) or 0 (false) as before the thrift documentation change on 2024-12-13.
355        let bytes = vec![0x19, 0x22, 0, 1, 0x19, 8, 0x19, 8, 0x15, 0, 0];
356
357        let mut protocol = TCompactSliceInputProtocol::new(bytes.as_slice());
358        let index = ColumnIndex::read_from_in_protocol(&mut protocol).unwrap();
359        let expected = ColumnIndex {
360            null_pages: vec![false, true],
361            min_values: vec![],
362            max_values: vec![],
363            boundary_order: BoundaryOrder::UNORDERED,
364            null_counts: None,
365            repetition_level_histograms: None,
366            definition_level_histograms: None,
367        };
368
369        assert_eq!(&index, &expected);
370    }
371}