1use std::{
30 cmp::Ordering,
31 io::{Read, Write},
32};
33
34use crate::{
35 errors::{ParquetError, Result},
36 write_thrift_field,
37};
38use std::io::Error;
39use std::num::TryFromIntError;
40use std::str::Utf8Error;
41
42#[derive(Debug)]
43pub(crate) enum ThriftProtocolError {
44 Eof,
45 IO(Error),
46 InvalidFieldType(u8),
47 InvalidElementType(u8),
48 FieldDeltaOverflow { field_delta: u8, last_field_id: i16 },
49 InvalidBoolean(u8),
50 IntegerOverflow,
51 Utf8Error,
52 SkipDepth(FieldType),
53 SkipUnsupportedType(FieldType),
54}
55
56impl From<ThriftProtocolError> for ParquetError {
57 #[inline(never)]
58 fn from(e: ThriftProtocolError) -> Self {
59 match e {
60 ThriftProtocolError::Eof => eof_err!("Unexpected EOF"),
61 ThriftProtocolError::IO(e) => e.into(),
62 ThriftProtocolError::InvalidFieldType(value) => match FieldType::try_from(value) {
63 Ok(fld_type) => general_err!("Unexpected struct field type {:?}", fld_type),
64 Err(_) => general_err!("Unexpected struct field type {}", value),
65 },
66 ThriftProtocolError::InvalidElementType(value) => {
67 general_err!("Unexpected list/set element type {}", value)
68 }
69 ThriftProtocolError::FieldDeltaOverflow {
70 field_delta,
71 last_field_id,
72 } => general_err!("cannot add {} to {}", field_delta, last_field_id),
73 ThriftProtocolError::InvalidBoolean(value) => {
74 general_err!("cannot convert {} into bool", value)
75 }
76 ThriftProtocolError::IntegerOverflow => {
77 general_err!("integer overflow decoding thrift value")
78 }
79 ThriftProtocolError::Utf8Error => general_err!("invalid utf8"),
80 ThriftProtocolError::SkipDepth(field_type) => {
81 general_err!("cannot parse past {:?}", field_type)
82 }
83 ThriftProtocolError::SkipUnsupportedType(field_type) => {
84 general_err!("cannot skip field type {:?}", field_type)
85 }
86 }
87 }
88}
89
90impl From<Utf8Error> for ThriftProtocolError {
91 fn from(_: Utf8Error) -> Self {
92 Self::Utf8Error
94 }
95}
96
97impl From<Error> for ThriftProtocolError {
98 fn from(e: Error) -> Self {
99 Self::IO(e)
100 }
101}
102
103impl From<TryFromIntError> for ThriftProtocolError {
104 fn from(_: TryFromIntError) -> Self {
105 Self::IntegerOverflow
107 }
108}
109
110pub type ThriftProtocolResult<T> = Result<T, ThriftProtocolError>;
111
112#[derive(Debug, Clone, Copy, PartialEq)]
116pub struct OrderedF64(f64);
117
118impl From<f64> for OrderedF64 {
119 fn from(value: f64) -> Self {
120 Self(value)
121 }
122}
123
124impl From<OrderedF64> for f64 {
125 fn from(value: OrderedF64) -> Self {
126 value.0
127 }
128}
129
130impl Eq for OrderedF64 {} impl Ord for OrderedF64 {
133 fn cmp(&self, other: &Self) -> Ordering {
134 self.0.total_cmp(&other.0)
135 }
136}
137
138impl PartialOrd for OrderedF64 {
139 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140 Some(self.cmp(other))
141 }
142}
143
144#[derive(Clone, Copy, Debug, Eq, PartialEq)]
146pub(crate) enum FieldType {
147 Stop = 0,
148 BooleanTrue = 1,
149 BooleanFalse = 2,
150 Byte = 3,
151 I16 = 4,
152 I32 = 5,
153 I64 = 6,
154 Double = 7,
155 Binary = 8,
156 List = 9,
157 Set = 10,
158 Map = 11,
159 Struct = 12,
160 Uuid = 13,
161}
162
163impl TryFrom<u8> for FieldType {
164 type Error = ThriftProtocolError;
165 fn try_from(value: u8) -> ThriftProtocolResult<Self> {
166 match value {
167 0 => Ok(Self::Stop),
168 1 => Ok(Self::BooleanTrue),
169 2 => Ok(Self::BooleanFalse),
170 3 => Ok(Self::Byte),
171 4 => Ok(Self::I16),
172 5 => Ok(Self::I32),
173 6 => Ok(Self::I64),
174 7 => Ok(Self::Double),
175 8 => Ok(Self::Binary),
176 9 => Ok(Self::List),
177 10 => Ok(Self::Set),
178 11 => Ok(Self::Map),
179 12 => Ok(Self::Struct),
180 13 => Ok(Self::Uuid),
181 _ => Err(ThriftProtocolError::InvalidFieldType(value)),
182 }
183 }
184}
185
186impl From<ElementType> for FieldType {
187 fn from(value: ElementType) -> Self {
188 match value {
189 ElementType::Bool => Self::BooleanTrue,
190 ElementType::Byte => Self::Byte,
191 ElementType::I16 => Self::I16,
192 ElementType::I32 => Self::I32,
193 ElementType::I64 => Self::I64,
194 ElementType::Double => Self::Double,
195 ElementType::Binary => Self::Binary,
196 ElementType::List => Self::List,
197 ElementType::Set => Self::Set,
198 ElementType::Map => Self::Map,
199 ElementType::Struct => Self::Struct,
200 ElementType::Uuid => Self::Uuid,
201 }
202 }
203}
204
205#[derive(Clone, Copy, Debug, Eq, PartialEq)]
207pub(crate) enum ElementType {
208 Bool = 2,
209 Byte = 3,
210 I16 = 4,
211 I32 = 5,
212 I64 = 6,
213 Double = 7,
214 Binary = 8,
215 List = 9,
216 Set = 10,
217 Map = 11,
218 Struct = 12,
219 Uuid = 13,
220}
221
222impl TryFrom<u8> for ElementType {
223 type Error = ThriftProtocolError;
224 fn try_from(value: u8) -> ThriftProtocolResult<Self> {
225 match value {
226 1 | 2 => Ok(Self::Bool),
232 3 => Ok(Self::Byte),
233 4 => Ok(Self::I16),
234 5 => Ok(Self::I32),
235 6 => Ok(Self::I64),
236 7 => Ok(Self::Double),
237 8 => Ok(Self::Binary),
238 9 => Ok(Self::List),
239 10 => Ok(Self::Set),
240 11 => Ok(Self::Map),
241 12 => Ok(Self::Struct),
242 13 => Ok(Self::Uuid),
243 _ => Err(ThriftProtocolError::InvalidElementType(value)),
244 }
245 }
246}
247
248pub(crate) struct FieldIdentifier {
252 pub(crate) field_type: FieldType,
254 pub(crate) id: i16,
256}
257
258impl FieldIdentifier {
259 pub(crate) fn bool_val(&self) -> ThriftProtocolResult<bool> {
260 match self.field_type {
261 FieldType::BooleanTrue => Ok(true),
262 FieldType::BooleanFalse => Ok(false),
263 _ => Err(ThriftProtocolError::InvalidFieldType(self.field_type as u8)),
264 }
265 }
266}
267
268#[derive(Clone, Debug, Eq, PartialEq)]
272pub(crate) struct ListIdentifier {
273 pub(crate) element_type: ElementType,
275 pub(crate) size: i32,
277}
278
279pub(crate) trait ThriftCompactInputProtocol<'a> {
287 fn read_byte(&mut self) -> ThriftProtocolResult<u8>;
289
290 fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]>;
294
295 fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>>;
296
297 fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()>;
299
300 fn read_vlq(&mut self) -> ThriftProtocolResult<u64> {
302 let byte = self.read_byte()?;
304 if byte & 0x80 == 0 {
305 return Ok(byte as u64);
306 }
307 let mut in_progress = (byte & 0x7f) as u64;
308 let mut shift = 7;
309 loop {
310 let byte = self.read_byte()?;
311 in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
312 if byte & 0x80 == 0 {
313 return Ok(in_progress);
314 }
315 shift += 7;
316 }
317 }
318
319 fn read_zig_zag(&mut self) -> ThriftProtocolResult<i64> {
321 let val = self.read_vlq()?;
322 Ok((val >> 1) as i64 ^ -((val & 1) as i64))
323 }
324
325 fn read_list_begin(&mut self) -> ThriftProtocolResult<ListIdentifier> {
327 let header = self.read_byte()?;
328 if header == 0 {
331 return Ok(ListIdentifier {
332 element_type: ElementType::Byte,
333 size: 0,
334 });
335 }
336 let element_type = ElementType::try_from(header & 0x0f)?;
337
338 let possible_element_count = (header & 0xF0) >> 4;
339 let element_count = if possible_element_count != 15 {
340 possible_element_count as i32
342 } else {
343 i32::try_from(self.read_vlq()?)?
350 };
351
352 Ok(ListIdentifier {
353 element_type,
354 size: element_count,
355 })
356 }
357
358 #[cold]
362 fn read_full_field_id(&mut self) -> ThriftProtocolResult<i16> {
363 self.read_i16()
364 }
365
366 fn read_field_begin(&mut self, last_field_id: i16) -> ThriftProtocolResult<FieldIdentifier> {
368 let field_type = self.read_byte()?;
372 if field_type & 0xf == 0 {
373 return Ok(FieldIdentifier {
374 field_type: FieldType::Stop,
375 id: 0,
376 });
377 }
378
379 let field_delta = (field_type & 0xf0) >> 4;
380 let field_type = FieldType::try_from(field_type & 0xf)?;
381
382 let id = if field_delta != 0 {
383 last_field_id.checked_add(field_delta as i16).ok_or(
384 ThriftProtocolError::FieldDeltaOverflow {
385 field_delta,
386 last_field_id,
387 },
388 )?
389 } else {
390 self.read_full_field_id()?
391 };
392
393 Ok(FieldIdentifier { field_type, id })
394 }
395
396 fn read_field_header(&mut self) -> ThriftProtocolResult<(u8, u8)> {
403 let field_type = self.read_byte()?;
404 let field_delta = (field_type & 0xf0) >> 4;
405 let field_type = field_type & 0xf;
406 Ok((field_type, field_delta))
407 }
408
409 fn read_bool(&mut self) -> ThriftProtocolResult<bool> {
412 let b = self.read_byte()?;
413 match b {
418 0x01 => Ok(true),
419 0x00 | 0x02 => Ok(false),
420 _ => Err(ThriftProtocolError::InvalidBoolean(b)),
421 }
422 }
423
424 fn read_string(&mut self) -> ThriftProtocolResult<&'a str> {
428 let slice = self.read_bytes()?;
429 Ok(std::str::from_utf8(slice)?)
430 }
431
432 fn read_i8(&mut self) -> ThriftProtocolResult<i8> {
434 Ok(self.read_byte()? as _)
435 }
436
437 fn read_i16(&mut self) -> ThriftProtocolResult<i16> {
439 Ok(self.read_zig_zag()? as _)
440 }
441
442 fn read_i32(&mut self) -> ThriftProtocolResult<i32> {
444 Ok(self.read_zig_zag()? as _)
445 }
446
447 fn read_i64(&mut self) -> ThriftProtocolResult<i64> {
449 self.read_zig_zag()
450 }
451
452 fn read_double(&mut self) -> ThriftProtocolResult<f64>;
454
455 fn skip_vlq(&mut self) -> ThriftProtocolResult<()> {
457 loop {
458 let byte = self.read_byte()?;
459 if byte & 0x80 == 0 {
460 return Ok(());
461 }
462 }
463 }
464
465 fn skip_binary(&mut self) -> ThriftProtocolResult<()> {
469 let len = self.read_vlq()? as usize;
470 self.skip_bytes(len)
471 }
472
473 fn skip(&mut self, field_type: FieldType) -> ThriftProtocolResult<()> {
476 const DEFAULT_SKIP_DEPTH: i8 = 64;
477 self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH)
478 }
479
480 fn skip_empty_struct(&mut self) -> Result<()> {
484 let b = self.read_byte()?;
485 if b != 0 {
486 Err(general_err!("Empty struct has fields"))
487 } else {
488 Ok(())
489 }
490 }
491
492 fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> ThriftProtocolResult<()> {
494 if depth == 0 {
495 return Err(ThriftProtocolError::SkipDepth(field_type));
496 }
497
498 match field_type {
499 FieldType::BooleanFalse | FieldType::BooleanTrue => Ok(()),
501 FieldType::Byte => self.read_i8().map(|_| ()),
502 FieldType::I16 => self.skip_vlq().map(|_| ()),
503 FieldType::I32 => self.skip_vlq().map(|_| ()),
504 FieldType::I64 => self.skip_vlq().map(|_| ()),
505 FieldType::Double => self.skip_bytes(8).map(|_| ()),
506 FieldType::Binary => self.skip_binary().map(|_| ()),
507 FieldType::Struct => {
509 loop {
510 let field_ident = self.read_field_begin(0)?;
512 if field_ident.field_type == FieldType::Stop {
513 break;
514 }
515 self.skip_till_depth(field_ident.field_type, depth - 1)?;
516 }
517 Ok(())
518 }
519 FieldType::List | FieldType::Set => {
522 let list_ident = self.read_list_begin()?;
523 let element_type = FieldType::from(list_ident.element_type);
524 for _ in 0..list_ident.size {
525 self.skip_till_depth(element_type, depth - 1)?;
526 }
527 Ok(())
528 }
529 FieldType::Map => {
531 let size = i32::try_from(self.read_vlq()?)?;
532 if size > 0 {
533 let kv = self.read_byte()?;
534 let key_type = FieldType::from(ElementType::try_from(kv >> 4)?);
535 let val_type = FieldType::from(ElementType::try_from(kv & 0xf)?);
536 for _ in 0..size {
537 self.skip_till_depth(key_type, depth - 1)?;
538 self.skip_till_depth(val_type, depth - 1)?;
539 }
540 }
541 Ok(())
542 }
543 FieldType::Uuid => self.skip_bytes(16).map(|_| ()),
545 _ => Err(ThriftProtocolError::SkipUnsupportedType(field_type)),
546 }
547 }
548}
549
550pub(crate) struct ThriftSliceInputProtocol<'a> {
552 buf: &'a [u8],
553}
554
555impl<'a> ThriftSliceInputProtocol<'a> {
556 pub fn new(buf: &'a [u8]) -> Self {
558 Self { buf }
559 }
560
561 pub fn as_slice(&self) -> &'a [u8] {
563 self.buf
564 }
565}
566
567impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for ThriftSliceInputProtocol<'a> {
568 #[inline]
569 fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
570 let ret = *self.buf.first().ok_or(ThriftProtocolError::Eof)?;
571 self.buf = &self.buf[1..];
572 Ok(ret)
573 }
574
575 fn read_bytes(&mut self) -> ThriftProtocolResult<&'b [u8]> {
576 let len = self.read_vlq()? as usize;
577 let ret = self.buf.get(..len).ok_or(ThriftProtocolError::Eof)?;
578 self.buf = &self.buf[len..];
579 Ok(ret)
580 }
581
582 fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
583 Ok(self.read_bytes()?.to_vec())
584 }
585
586 #[inline]
587 fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
588 self.buf.get(..n).ok_or(ThriftProtocolError::Eof)?;
589 self.buf = &self.buf[n..];
590 Ok(())
591 }
592
593 fn read_double(&mut self) -> ThriftProtocolResult<f64> {
594 let slice = self.buf.get(..8).ok_or(ThriftProtocolError::Eof)?;
595 self.buf = &self.buf[8..];
596 match slice.try_into() {
597 Ok(slice) => Ok(f64::from_le_bytes(slice)),
598 Err(_) => unreachable!(),
599 }
600 }
601}
602
603pub(crate) struct ThriftReadInputProtocol<R: Read> {
608 reader: R,
609}
610
611impl<R: Read> ThriftReadInputProtocol<R> {
612 pub(crate) fn new(reader: R) -> Self {
613 Self { reader }
614 }
615}
616
617impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol<R> {
618 #[inline]
619 fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
620 let mut buf = [0_u8; 1];
621 self.reader.read_exact(&mut buf)?;
622 Ok(buf[0])
623 }
624
625 fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]> {
626 unimplemented!()
627 }
628
629 fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
630 let len = self.read_vlq()? as usize;
631 let mut v = Vec::with_capacity(len);
632 std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?;
633 Ok(v)
634 }
635
636 fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
637 std::io::copy(
638 &mut self.reader.by_ref().take(n as u64),
639 &mut std::io::sink(),
640 )?;
641 Ok(())
642 }
643
644 fn read_double(&mut self) -> ThriftProtocolResult<f64> {
645 let mut buf = [0_u8; 8];
646 self.reader.read_exact(&mut buf)?;
647 Ok(f64::from_le_bytes(buf))
648 }
649}
650
651pub(crate) trait ReadThrift<'a, R: ThriftCompactInputProtocol<'a>> {
654 fn read_thrift(prot: &mut R) -> Result<Self>
656 where
657 Self: Sized;
658}
659
660impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool {
661 fn read_thrift(prot: &mut R) -> Result<Self> {
662 Ok(prot.read_bool()?)
663 }
664}
665
666impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 {
667 fn read_thrift(prot: &mut R) -> Result<Self> {
668 Ok(prot.read_i8()?)
669 }
670}
671
672impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 {
673 fn read_thrift(prot: &mut R) -> Result<Self> {
674 Ok(prot.read_i16()?)
675 }
676}
677
678impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 {
679 fn read_thrift(prot: &mut R) -> Result<Self> {
680 Ok(prot.read_i32()?)
681 }
682}
683
684impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 {
685 fn read_thrift(prot: &mut R) -> Result<Self> {
686 Ok(prot.read_i64()?)
687 }
688}
689
690impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for OrderedF64 {
691 fn read_thrift(prot: &mut R) -> Result<Self> {
692 Ok(OrderedF64(prot.read_double()?))
693 }
694}
695
696impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str {
697 fn read_thrift(prot: &mut R) -> Result<Self> {
698 Ok(prot.read_string()?)
699 }
700}
701
702impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for String {
703 fn read_thrift(prot: &mut R) -> Result<Self> {
704 Ok(String::from_utf8(prot.read_bytes_owned()?)?)
705 }
706}
707
708impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] {
709 fn read_thrift(prot: &mut R) -> Result<Self> {
710 Ok(prot.read_bytes()?)
711 }
712}
713
714pub(crate) fn read_thrift_vec<'a, T, R>(prot: &mut R) -> Result<Vec<T>>
718where
719 R: ThriftCompactInputProtocol<'a>,
720 T: ReadThrift<'a, R> + WriteThrift,
721{
722 let list_ident = prot.read_list_begin()?;
723 validate_list_type(T::ELEMENT_TYPE, &list_ident)?;
724 let mut res = Vec::with_capacity(list_ident.size as usize);
725 for _ in 0..list_ident.size {
726 let val = T::read_thrift(prot)?;
727 res.push(val);
728 }
729 Ok(res)
730}
731
732pub(crate) fn validate_list_type(expected: ElementType, got: &ListIdentifier) -> Result<()> {
733 if got.element_type != expected {
734 return Err(general_err!(
735 "Expected list element type of {:?} but got {:?}",
736 expected,
737 got.element_type
738 ));
739 }
740 Ok(())
741}
742
743pub(crate) struct ThriftCompactOutputProtocol<W: Write> {
755 writer: W,
756 write_path_in_schema: bool,
757}
758
759impl<W: Write> ThriftCompactOutputProtocol<W> {
760 pub(crate) fn new(writer: W) -> Self {
762 Self {
763 writer,
764 write_path_in_schema: true,
765 }
766 }
767
768 pub(crate) fn set_write_path_in_schema(&mut self, val: bool) {
773 self.write_path_in_schema = val;
774 }
775
776 pub(crate) fn write_path_in_schema(&self) -> bool {
778 self.write_path_in_schema
779 }
780
781 fn write_byte(&mut self, b: u8) -> Result<()> {
783 self.writer.write_all(&[b])?;
784 Ok(())
785 }
786
787 fn write_vlq(&mut self, val: u64) -> Result<()> {
789 let mut v = val;
790 while v > 0x7f {
791 self.write_byte(v as u8 | 0x80)?;
792 v >>= 7;
793 }
794 self.write_byte(v as u8)
795 }
796
797 fn write_zig_zag(&mut self, val: i64) -> Result<()> {
799 let s = (val < 0) as i64;
800 self.write_vlq((((val ^ -s) << 1) + s) as u64)
801 }
802
803 pub(crate) fn write_field_begin(
808 &mut self,
809 field_type: FieldType,
810 field_id: i16,
811 last_field_id: i16,
812 ) -> Result<()> {
813 let delta = field_id.wrapping_sub(last_field_id);
814 if delta > 0 && delta <= 0xf {
815 self.write_byte((delta as u8) << 4 | field_type as u8)
816 } else {
817 self.write_byte(field_type as u8)?;
818 self.write_i16(field_id)
819 }
820 }
821
822 pub(crate) fn write_list_begin(&mut self, element_type: ElementType, len: usize) -> Result<()> {
824 if len < 15 {
825 self.write_byte((len as u8) << 4 | element_type as u8)
826 } else {
827 self.write_byte(0xf0u8 | element_type as u8)?;
828 self.write_vlq(len as _)
829 }
830 }
831
832 pub(crate) fn write_struct_end(&mut self) -> Result<()> {
835 self.write_byte(0)
836 }
837
838 pub(crate) fn write_bytes(&mut self, val: &[u8]) -> Result<()> {
841 self.write_vlq(val.len() as u64)?;
842 self.writer.write_all(val)?;
843 Ok(())
844 }
845
846 pub(crate) fn write_empty_struct(&mut self, field_id: i16, last_field_id: i16) -> Result<i16> {
849 self.write_field_begin(FieldType::Struct, field_id, last_field_id)?;
850 self.write_struct_end()?;
851 Ok(last_field_id)
852 }
853
854 pub(crate) fn write_bool(&mut self, val: bool) -> Result<()> {
856 match val {
857 true => self.write_byte(1),
858 false => self.write_byte(2),
859 }
860 }
861
862 pub(crate) fn write_i8(&mut self, val: i8) -> Result<()> {
864 self.write_byte(val as u8)
865 }
866
867 pub(crate) fn write_i16(&mut self, val: i16) -> Result<()> {
869 self.write_zig_zag(val as _)
870 }
871
872 pub(crate) fn write_i32(&mut self, val: i32) -> Result<()> {
874 self.write_zig_zag(val as _)
875 }
876
877 pub(crate) fn write_i64(&mut self, val: i64) -> Result<()> {
879 self.write_zig_zag(val as _)
880 }
881
882 pub(crate) fn write_double(&mut self, val: f64) -> Result<()> {
884 self.writer.write_all(&val.to_le_bytes())?;
885 Ok(())
886 }
887}
888
889pub(crate) trait WriteThrift {
894 const ELEMENT_TYPE: ElementType;
896
897 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()>;
899}
900
901impl<T> WriteThrift for Vec<T>
904where
905 T: WriteThrift,
906{
907 const ELEMENT_TYPE: ElementType = ElementType::List;
908
909 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
910 writer.write_list_begin(T::ELEMENT_TYPE, self.len())?;
911 for item in self {
912 item.write_thrift(writer)?;
913 }
914 Ok(())
915 }
916}
917
918impl WriteThrift for bool {
919 const ELEMENT_TYPE: ElementType = ElementType::Bool;
920
921 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
922 writer.write_bool(*self)
923 }
924}
925
926impl WriteThrift for i8 {
927 const ELEMENT_TYPE: ElementType = ElementType::Byte;
928
929 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
930 writer.write_i8(*self)
931 }
932}
933
934impl WriteThrift for i16 {
935 const ELEMENT_TYPE: ElementType = ElementType::I16;
936
937 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
938 writer.write_i16(*self)
939 }
940}
941
942impl WriteThrift for i32 {
943 const ELEMENT_TYPE: ElementType = ElementType::I32;
944
945 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
946 writer.write_i32(*self)
947 }
948}
949
950impl WriteThrift for i64 {
951 const ELEMENT_TYPE: ElementType = ElementType::I64;
952
953 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
954 writer.write_i64(*self)
955 }
956}
957
958impl WriteThrift for OrderedF64 {
959 const ELEMENT_TYPE: ElementType = ElementType::Double;
960
961 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
962 writer.write_double(self.0)
963 }
964}
965
966impl WriteThrift for f64 {
967 const ELEMENT_TYPE: ElementType = ElementType::Double;
968
969 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
970 writer.write_double(*self)
971 }
972}
973
974impl WriteThrift for &[u8] {
975 const ELEMENT_TYPE: ElementType = ElementType::Binary;
976
977 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
978 writer.write_bytes(self)
979 }
980}
981
982impl WriteThrift for &str {
983 const ELEMENT_TYPE: ElementType = ElementType::Binary;
984
985 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
986 writer.write_bytes(self.as_bytes())
987 }
988}
989
990impl WriteThrift for String {
991 const ELEMENT_TYPE: ElementType = ElementType::Binary;
992
993 fn write_thrift<W: Write>(&self, writer: &mut ThriftCompactOutputProtocol<W>) -> Result<()> {
994 writer.write_bytes(self.as_bytes())
995 }
996}
997
998pub(crate) trait WriteThriftField {
1041 fn write_thrift_field<W: Write>(
1048 &self,
1049 writer: &mut ThriftCompactOutputProtocol<W>,
1050 field_id: i16,
1051 last_field_id: i16,
1052 ) -> Result<i16>;
1053}
1054
1055impl WriteThriftField for bool {
1057 fn write_thrift_field<W: Write>(
1058 &self,
1059 writer: &mut ThriftCompactOutputProtocol<W>,
1060 field_id: i16,
1061 last_field_id: i16,
1062 ) -> Result<i16> {
1063 match *self {
1065 true => writer.write_field_begin(FieldType::BooleanTrue, field_id, last_field_id)?,
1066 false => writer.write_field_begin(FieldType::BooleanFalse, field_id, last_field_id)?,
1067 }
1068 Ok(field_id)
1069 }
1070}
1071
1072write_thrift_field!(i8, FieldType::Byte);
1073write_thrift_field!(i16, FieldType::I16);
1074write_thrift_field!(i32, FieldType::I32);
1075write_thrift_field!(i64, FieldType::I64);
1076write_thrift_field!(OrderedF64, FieldType::Double);
1077write_thrift_field!(f64, FieldType::Double);
1078write_thrift_field!(String, FieldType::Binary);
1079
1080impl WriteThriftField for &[u8] {
1081 fn write_thrift_field<W: Write>(
1082 &self,
1083 writer: &mut ThriftCompactOutputProtocol<W>,
1084 field_id: i16,
1085 last_field_id: i16,
1086 ) -> Result<i16> {
1087 writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1088 writer.write_bytes(self)?;
1089 Ok(field_id)
1090 }
1091}
1092
1093impl WriteThriftField for &str {
1094 fn write_thrift_field<W: Write>(
1095 &self,
1096 writer: &mut ThriftCompactOutputProtocol<W>,
1097 field_id: i16,
1098 last_field_id: i16,
1099 ) -> Result<i16> {
1100 writer.write_field_begin(FieldType::Binary, field_id, last_field_id)?;
1101 writer.write_bytes(self.as_bytes())?;
1102 Ok(field_id)
1103 }
1104}
1105
1106impl<T> WriteThriftField for Vec<T>
1107where
1108 T: WriteThrift,
1109{
1110 fn write_thrift_field<W: Write>(
1111 &self,
1112 writer: &mut ThriftCompactOutputProtocol<W>,
1113 field_id: i16,
1114 last_field_id: i16,
1115 ) -> Result<i16> {
1116 writer.write_field_begin(FieldType::List, field_id, last_field_id)?;
1117 self.write_thrift(writer)?;
1118 Ok(field_id)
1119 }
1120}
1121
1122#[cfg(test)]
1123pub(crate) mod tests {
1124 use crate::basic::{TimeUnit, Type};
1125
1126 use super::*;
1127 use std::fmt::Debug;
1128
1129 pub(crate) fn test_roundtrip<T>(val: T)
1130 where
1131 T: for<'a> ReadThrift<'a, ThriftSliceInputProtocol<'a>> + WriteThrift + PartialEq + Debug,
1132 {
1133 let mut buf = Vec::<u8>::new();
1134 {
1135 let mut writer = ThriftCompactOutputProtocol::new(&mut buf);
1136 val.write_thrift(&mut writer).unwrap();
1137 }
1138
1139 let mut prot = ThriftSliceInputProtocol::new(&buf);
1140 let read_val = T::read_thrift(&mut prot).unwrap();
1141 assert_eq!(val, read_val);
1142 }
1143
1144 #[test]
1145 fn test_enum_roundtrip() {
1146 test_roundtrip(Type::BOOLEAN);
1147 test_roundtrip(Type::INT32);
1148 test_roundtrip(Type::INT64);
1149 test_roundtrip(Type::INT96);
1150 test_roundtrip(Type::FLOAT);
1151 test_roundtrip(Type::DOUBLE);
1152 test_roundtrip(Type::BYTE_ARRAY);
1153 test_roundtrip(Type::FIXED_LEN_BYTE_ARRAY);
1154 }
1155
1156 #[test]
1157 fn test_union_all_empty_roundtrip() {
1158 test_roundtrip(TimeUnit::MILLIS);
1159 test_roundtrip(TimeUnit::MICROS);
1160 test_roundtrip(TimeUnit::NANOS);
1161 }
1162
1163 #[test]
1164 fn test_decode_empty_list() {
1165 let data = vec![0u8; 1];
1166 let mut prot = ThriftSliceInputProtocol::new(&data);
1167 let header = prot.read_list_begin().expect("error reading list header");
1168 assert_eq!(header.size, 0);
1169 assert_eq!(header.element_type, ElementType::Byte);
1170 }
1171
1172 #[test]
1176 fn test_read_list_begin_size_above_i32_max_returns_err() {
1177 let mut data: Vec<u8> = vec![0xF8];
1180 data.extend_from_slice(&[0x80, 0x80, 0x80, 0x80, 0x08]);
1181 let mut prot = ThriftSliceInputProtocol::new(&data);
1182 let result = prot.read_list_begin();
1183 assert!(result.is_err(), "expected error, got {result:?}");
1184 }
1185
1186 #[test]
1187 fn test_read_list_wrong_type() {
1188 let data = [0x42, 0x01];
1190 let mut prot = ThriftSliceInputProtocol::new(&data);
1191 let result = read_thrift_vec::<i32, ThriftSliceInputProtocol>(&mut prot);
1193 println!("{result:?}");
1194 assert!(
1195 result
1196 .unwrap_err()
1197 .to_string()
1198 .contains("Expected list element type of I32 but got Bool")
1199 );
1200 }
1201}