1use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25pub struct CompressionContext {
31 #[cfg(feature = "zstd")]
32 compressor: zstd::bulk::Compressor<'static>,
33}
34
35#[allow(clippy::derivable_impls)]
38impl Default for CompressionContext {
39 fn default() -> Self {
40 CompressionContext {
41 #[cfg(feature = "zstd")]
43 compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
44 .expect("can use default compression level"),
45 }
46 }
47}
48
49impl std::fmt::Debug for CompressionContext {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 let mut ds = f.debug_struct("CompressionContext");
52
53 #[cfg(feature = "zstd")]
54 ds.field("compressor", &"zstd::bulk::Compressor");
55
56 ds.finish()
57 }
58}
59
60pub struct DecompressionContext {
66 #[cfg(feature = "zstd")]
67 decompressor: zstd::bulk::Decompressor<'static>,
68}
69
70impl DecompressionContext {
71 pub(crate) fn new() -> Self {
72 Default::default()
73 }
74}
75
76#[allow(clippy::derivable_impls)]
77impl Default for DecompressionContext {
78 fn default() -> Self {
79 DecompressionContext {
80 #[cfg(feature = "zstd")]
81 decompressor: zstd::bulk::Decompressor::new().expect("can create zstd decompressor"),
82 }
83 }
84}
85
86impl std::fmt::Debug for DecompressionContext {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 let mut ds = f.debug_struct("DecompressionContext");
89
90 #[cfg(feature = "zstd")]
91 ds.field("decompressor", &"zstd::bulk::Decompressor");
92
93 ds.finish()
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum CompressionCodec {
100 Lz4Frame,
101 Zstd,
102}
103
104impl TryFrom<CompressionType> for CompressionCodec {
105 type Error = ArrowError;
106
107 fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
108 match compression_type {
109 CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
110 CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
111 other_type => Err(ArrowError::NotYetImplemented(format!(
112 "compression type {other_type:?} not supported "
113 ))),
114 }
115 }
116}
117
118impl CompressionCodec {
119 pub(crate) fn compress_to_vec(
130 &self,
131 input: &[u8],
132 output: &mut Vec<u8>,
133 context: &mut CompressionContext,
134 ) -> Result<usize, ArrowError> {
135 let uncompressed_data_len = input.len();
136 let original_output_len = output.len();
137
138 if input.is_empty() {
139 } else {
141 output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
143 self.compress(input, output, context)?;
144
145 let compression_len = output.len() - original_output_len;
146 if compression_len > uncompressed_data_len {
147 output.truncate(original_output_len);
152 output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
153 output.extend_from_slice(input);
154 }
155 }
156 Ok(output.len() - original_output_len)
157 }
158
159 pub(crate) fn decompress_to_buffer(
167 &self,
168 input: &Buffer,
169 context: &mut DecompressionContext,
170 ) -> Result<Buffer, ArrowError> {
171 let decompressed_length = read_uncompressed_size(input);
174 let buffer = if decompressed_length == 0 {
175 Buffer::from([])
177 } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
178 input.slice(LENGTH_OF_PREFIX_DATA as usize)
180 } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
181 let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
183 let v = self.decompress(input_data, decompressed_length as _, context)?;
184 Buffer::from_vec(v)
185 } else {
186 return Err(ArrowError::IpcError(format!(
187 "Invalid uncompressed length: {decompressed_length}"
188 )));
189 };
190 Ok(buffer)
191 }
192
193 fn compress(
196 &self,
197 input: &[u8],
198 output: &mut Vec<u8>,
199 context: &mut CompressionContext,
200 ) -> Result<(), ArrowError> {
201 match self {
202 CompressionCodec::Lz4Frame => compress_lz4(input, output),
203 CompressionCodec::Zstd => compress_zstd(input, output, context),
204 }
205 }
206
207 fn decompress(
210 &self,
211 input: &[u8],
212 decompressed_size: usize,
213 context: &mut DecompressionContext,
214 ) -> Result<Vec<u8>, ArrowError> {
215 let ret = match self {
216 CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
217 CompressionCodec::Zstd => decompress_zstd(input, decompressed_size, context)?,
218 };
219 if ret.len() != decompressed_size {
220 return Err(ArrowError::IpcError(format!(
221 "Expected compressed length of {decompressed_size} got {}",
222 ret.len()
223 )));
224 }
225 Ok(ret)
226 }
227}
228
229#[cfg(feature = "lz4")]
230fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
231 use std::io::Write;
232 let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
233 encoder.write_all(input)?;
234 encoder
235 .finish()
236 .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
237 Ok(())
238}
239
240#[cfg(not(feature = "lz4"))]
241#[allow(clippy::ptr_arg)]
242fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
243 Err(ArrowError::InvalidArgumentError(
244 "lz4 IPC compression requires the lz4 feature".to_string(),
245 ))
246}
247
248#[cfg(feature = "lz4")]
249fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
250 use std::io::Read;
251 let mut output = Vec::with_capacity(decompressed_size);
252 lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
253 Ok(output)
254}
255
256#[cfg(not(feature = "lz4"))]
257#[allow(clippy::ptr_arg)]
258fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
259 Err(ArrowError::InvalidArgumentError(
260 "lz4 IPC decompression requires the lz4 feature".to_string(),
261 ))
262}
263
264#[cfg(feature = "zstd")]
265fn compress_zstd(
266 input: &[u8],
267 output: &mut Vec<u8>,
268 context: &mut CompressionContext,
269) -> Result<(), ArrowError> {
270 let result = context.compressor.compress(input)?;
271 output.extend_from_slice(&result);
272 Ok(())
273}
274
275#[cfg(not(feature = "zstd"))]
276#[allow(clippy::ptr_arg)]
277fn compress_zstd(
278 _input: &[u8],
279 _output: &mut Vec<u8>,
280 _context: &mut CompressionContext,
281) -> Result<(), ArrowError> {
282 Err(ArrowError::InvalidArgumentError(
283 "zstd IPC compression requires the zstd feature".to_string(),
284 ))
285}
286
287#[cfg(feature = "zstd")]
288fn decompress_zstd(
289 input: &[u8],
290 decompressed_size: usize,
291 context: &mut DecompressionContext,
292) -> Result<Vec<u8>, ArrowError> {
293 let output = context.decompressor.decompress(input, decompressed_size)?;
294 Ok(output)
295}
296
297#[cfg(not(feature = "zstd"))]
298#[allow(clippy::ptr_arg)]
299fn decompress_zstd(
300 _input: &[u8],
301 _decompressed_size: usize,
302 _context: &mut DecompressionContext,
303) -> Result<Vec<u8>, ArrowError> {
304 Err(ArrowError::InvalidArgumentError(
305 "zstd IPC decompression requires the zstd feature".to_string(),
306 ))
307}
308
309#[inline]
315fn read_uncompressed_size(buffer: &[u8]) -> i64 {
316 let len_buffer = &buffer[0..8];
317 i64::from_le_bytes(len_buffer.try_into().unwrap())
319}
320
321#[cfg(test)]
322mod tests {
323 #[test]
324 #[cfg(feature = "lz4")]
325 fn test_lz4_compression() {
326 let input_bytes = b"hello lz4";
327 let codec = super::CompressionCodec::Lz4Frame;
328 let mut output_bytes: Vec<u8> = Vec::new();
329 codec
330 .compress(input_bytes, &mut output_bytes, &mut Default::default())
331 .unwrap();
332 let result = codec
333 .decompress(
334 output_bytes.as_slice(),
335 input_bytes.len(),
336 &mut Default::default(),
337 )
338 .unwrap();
339 assert_eq!(input_bytes, result.as_slice());
340 }
341
342 #[test]
343 #[cfg(feature = "zstd")]
344 fn test_zstd_compression() {
345 let input_bytes = b"hello zstd";
346 let codec = super::CompressionCodec::Zstd;
347 let mut output_bytes: Vec<u8> = Vec::new();
348 codec
349 .compress(input_bytes, &mut output_bytes, &mut Default::default())
350 .unwrap();
351 let result = codec
352 .decompress(
353 output_bytes.as_slice(),
354 input_bytes.len(),
355 &mut Default::default(),
356 )
357 .unwrap();
358 assert_eq!(input_bytes, result.as_slice());
359 }
360}