1use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25#[derive(Default)]
31pub struct CompressionContext {
32 #[cfg(feature = "zstd")]
33 compressor: Option<zstd::bulk::Compressor<'static>>,
34}
35
36impl CompressionContext {
37 #[cfg(feature = "zstd")]
38 fn zstd_compressor(&mut self) -> &mut zstd::bulk::Compressor<'static> {
39 self.compressor.get_or_insert_with(|| {
40 zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
41 .expect("can use default compression level")
42 })
43 }
44}
45
46impl std::fmt::Debug for CompressionContext {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 let mut ds = f.debug_struct("CompressionContext");
49
50 #[cfg(feature = "zstd")]
51 ds.field(
52 "compressor",
53 &self.compressor.as_ref().map(|_| "zstd::bulk::Compressor"),
54 );
55
56 ds.finish()
57 }
58}
59
60pub struct DecompressionContext {
66 #[cfg(feature = "zstd")]
67 decompressor: Option<zstd::bulk::Decompressor<'static>>,
68}
69
70impl DecompressionContext {
71 pub(crate) fn new() -> Self {
72 Default::default()
73 }
74
75 #[cfg(feature = "zstd")]
76 fn zstd_decompressor(&mut self) -> &mut zstd::bulk::Decompressor<'static> {
77 self.decompressor.get_or_insert_with(|| {
78 zstd::bulk::Decompressor::new().expect("can create zstd decompressor")
79 })
80 }
81}
82
83#[allow(clippy::derivable_impls)]
84impl Default for DecompressionContext {
85 fn default() -> Self {
86 DecompressionContext {
87 #[cfg(feature = "zstd")]
88 decompressor: None,
89 }
90 }
91}
92
93impl std::fmt::Debug for DecompressionContext {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 let mut ds = f.debug_struct("DecompressionContext");
96
97 #[cfg(feature = "zstd")]
98 ds.field(
99 "decompressor",
100 &self
101 .decompressor
102 .as_ref()
103 .map(|_| "zstd::bulk::Decompressor"),
104 );
105
106 ds.finish()
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum CompressionCodec {
113 Lz4Frame,
114 Zstd,
115}
116
117impl TryFrom<CompressionType> for CompressionCodec {
118 type Error = ArrowError;
119
120 fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
121 match compression_type {
122 CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
123 CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
124 other_type => Err(ArrowError::NotYetImplemented(format!(
125 "compression type {other_type:?} not supported "
126 ))),
127 }
128 }
129}
130
131impl CompressionCodec {
132 pub(crate) fn compress_to_vec(
143 &self,
144 input: &[u8],
145 output: &mut Vec<u8>,
146 context: &mut CompressionContext,
147 ) -> Result<usize, ArrowError> {
148 let uncompressed_data_len = input.len();
149 let original_output_len = output.len();
150
151 if input.is_empty() {
152 } else {
154 output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
156 self.compress(input, output, context)?;
157
158 let compression_len = output.len() - original_output_len;
159 if compression_len > uncompressed_data_len {
160 output.truncate(original_output_len);
165 output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
166 output.extend_from_slice(input);
167 }
168 }
169 Ok(output.len() - original_output_len)
170 }
171
172 pub(crate) fn decompress_to_buffer(
180 &self,
181 input: &Buffer,
182 context: &mut DecompressionContext,
183 ) -> Result<Buffer, ArrowError> {
184 let decompressed_length = read_uncompressed_size(input)?;
187 let buffer = if decompressed_length == 0 {
188 Buffer::from([])
190 } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
191 input.slice(LENGTH_OF_PREFIX_DATA as usize)
193 } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
194 let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
196 let v = self.decompress(input_data, decompressed_length as _, context)?;
197 Buffer::from_vec(v)
198 } else {
199 return Err(ArrowError::IpcError(format!(
200 "Invalid uncompressed length: {decompressed_length}"
201 )));
202 };
203 Ok(buffer)
204 }
205
206 fn compress(
209 &self,
210 input: &[u8],
211 output: &mut Vec<u8>,
212 context: &mut CompressionContext,
213 ) -> Result<(), ArrowError> {
214 match self {
215 CompressionCodec::Lz4Frame => compress_lz4(input, output),
216 CompressionCodec::Zstd => compress_zstd(input, output, context),
217 }
218 }
219
220 fn decompress(
223 &self,
224 input: &[u8],
225 decompressed_size: usize,
226 context: &mut DecompressionContext,
227 ) -> Result<Vec<u8>, ArrowError> {
228 let ret = match self {
229 CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
230 CompressionCodec::Zstd => decompress_zstd(input, decompressed_size, context)?,
231 };
232 if ret.len() != decompressed_size {
233 return Err(ArrowError::IpcError(format!(
234 "Expected compressed length of {decompressed_size} got {}",
235 ret.len()
236 )));
237 }
238 Ok(ret)
239 }
240}
241
242#[cfg(feature = "lz4")]
243fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
244 use std::io::Write;
245 let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
246 encoder.write_all(input)?;
247 encoder
248 .finish()
249 .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
250 Ok(())
251}
252
253#[cfg(not(feature = "lz4"))]
254#[allow(clippy::ptr_arg)]
255fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
256 Err(ArrowError::InvalidArgumentError(
257 "lz4 IPC compression requires the lz4 feature".to_string(),
258 ))
259}
260
261#[cfg(feature = "lz4")]
262fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
263 use std::io::Read;
264 let mut output = Vec::with_capacity(decompressed_size);
265 lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
266 Ok(output)
267}
268
269#[cfg(not(feature = "lz4"))]
270#[allow(clippy::ptr_arg)]
271fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
272 Err(ArrowError::InvalidArgumentError(
273 "lz4 IPC decompression requires the lz4 feature".to_string(),
274 ))
275}
276
277#[cfg(feature = "zstd")]
278fn compress_zstd(
279 input: &[u8],
280 output: &mut Vec<u8>,
281 context: &mut CompressionContext,
282) -> Result<(), ArrowError> {
283 let result = context.zstd_compressor().compress(input)?;
284 output.extend_from_slice(&result);
285 Ok(())
286}
287
288#[cfg(not(feature = "zstd"))]
289#[allow(clippy::ptr_arg)]
290fn compress_zstd(
291 _input: &[u8],
292 _output: &mut Vec<u8>,
293 _context: &mut CompressionContext,
294) -> Result<(), ArrowError> {
295 Err(ArrowError::InvalidArgumentError(
296 "zstd IPC compression requires the zstd feature".to_string(),
297 ))
298}
299
300#[cfg(feature = "zstd")]
301fn decompress_zstd(
302 input: &[u8],
303 decompressed_size: usize,
304 context: &mut DecompressionContext,
305) -> Result<Vec<u8>, ArrowError> {
306 let output = context
307 .zstd_decompressor()
308 .decompress(input, decompressed_size)?;
309 Ok(output)
310}
311
312#[cfg(not(feature = "zstd"))]
313#[allow(clippy::ptr_arg)]
314fn decompress_zstd(
315 _input: &[u8],
316 _decompressed_size: usize,
317 _context: &mut DecompressionContext,
318) -> Result<Vec<u8>, ArrowError> {
319 Err(ArrowError::InvalidArgumentError(
320 "zstd IPC decompression requires the zstd feature".to_string(),
321 ))
322}
323
324#[inline]
331fn read_uncompressed_size(buffer: &[u8]) -> Result<i64, ArrowError> {
332 let len_buffer = buffer.get(..LENGTH_OF_PREFIX_DATA as usize).ok_or_else(|| {
333 ArrowError::IpcError(format!(
334 "Compressed IPC buffer is too short: expected at least {LENGTH_OF_PREFIX_DATA} bytes, got {}",
335 buffer.len()
336 ))
337 })?;
338 Ok(i64::from_le_bytes(len_buffer.try_into().unwrap()))
339}
340
341#[cfg(test)]
342mod tests {
343 #[test]
344 #[cfg(feature = "lz4")]
345 fn test_lz4_compression() {
346 let input_bytes = b"hello lz4";
347 let codec = super::CompressionCodec::Lz4Frame;
348 let mut output_bytes: Vec<u8> = Vec::new();
349 codec
350 .compress(input_bytes, &mut output_bytes, &mut Default::default())
351 .unwrap();
352 let result = codec
353 .decompress(
354 output_bytes.as_slice(),
355 input_bytes.len(),
356 &mut Default::default(),
357 )
358 .unwrap();
359 assert_eq!(input_bytes, result.as_slice());
360 }
361
362 #[test]
363 #[cfg(feature = "zstd")]
364 fn test_zstd_compression() {
365 let input_bytes = b"hello zstd";
366 let codec = super::CompressionCodec::Zstd;
367 let mut output_bytes: Vec<u8> = Vec::new();
368 codec
369 .compress(input_bytes, &mut output_bytes, &mut Default::default())
370 .unwrap();
371 let result = codec
372 .decompress(
373 output_bytes.as_slice(),
374 input_bytes.len(),
375 &mut Default::default(),
376 )
377 .unwrap();
378 assert_eq!(input_bytes, result.as_slice());
379 }
380
381 #[test]
382 fn test_read_uncompressed_size_rejects_short_prefix() {
383 let err = super::read_uncompressed_size(&[1, 2, 3, 4, 5, 6, 7])
384 .expect_err("short compressed IPC prefix should return an error");
385
386 assert!(
387 err.to_string()
388 .contains("Compressed IPC buffer is too short"),
389 "unexpected error: {err}"
390 );
391 }
392}