1use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25pub struct CompressionContext {
31 #[cfg(feature = "zstd")]
32 compressor: zstd::bulk::Compressor<'static>,
33}
34
35#[allow(clippy::derivable_impls)]
38impl Default for CompressionContext {
39 fn default() -> Self {
40 CompressionContext {
41 #[cfg(feature = "zstd")]
43 compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
44 .expect("can use default compression level"),
45 }
46 }
47}
48
49impl std::fmt::Debug for CompressionContext {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 let mut ds = f.debug_struct("CompressionContext");
52
53 #[cfg(feature = "zstd")]
54 ds.field("compressor", &"zstd::bulk::Compressor");
55
56 ds.finish()
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum CompressionCodec {
63 Lz4Frame,
64 Zstd,
65}
66
67impl TryFrom<CompressionType> for CompressionCodec {
68 type Error = ArrowError;
69
70 fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
71 match compression_type {
72 CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
73 CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
74 other_type => Err(ArrowError::NotYetImplemented(format!(
75 "compression type {other_type:?} not supported "
76 ))),
77 }
78 }
79}
80
81impl CompressionCodec {
82 pub(crate) fn compress_to_vec(
93 &self,
94 input: &[u8],
95 output: &mut Vec<u8>,
96 context: &mut CompressionContext,
97 ) -> Result<usize, ArrowError> {
98 let uncompressed_data_len = input.len();
99 let original_output_len = output.len();
100
101 if input.is_empty() {
102 } else {
104 output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
106 self.compress(input, output, context)?;
107
108 let compression_len = output.len() - original_output_len;
109 if compression_len > uncompressed_data_len {
110 output.truncate(original_output_len);
115 output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
116 output.extend_from_slice(input);
117 }
118 }
119 Ok(output.len() - original_output_len)
120 }
121
122 pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result<Buffer, ArrowError> {
130 let decompressed_length = read_uncompressed_size(input);
133 let buffer = if decompressed_length == 0 {
134 Buffer::from([])
136 } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
137 input.slice(LENGTH_OF_PREFIX_DATA as usize)
139 } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
140 let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
142 let v = self.decompress(input_data, decompressed_length as _)?;
143 Buffer::from_vec(v)
144 } else {
145 return Err(ArrowError::IpcError(format!(
146 "Invalid uncompressed length: {decompressed_length}"
147 )));
148 };
149 Ok(buffer)
150 }
151
152 fn compress(
155 &self,
156 input: &[u8],
157 output: &mut Vec<u8>,
158 context: &mut CompressionContext,
159 ) -> Result<(), ArrowError> {
160 match self {
161 CompressionCodec::Lz4Frame => compress_lz4(input, output),
162 CompressionCodec::Zstd => compress_zstd(input, output, context),
163 }
164 }
165
166 fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
169 let ret = match self {
170 CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
171 CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?,
172 };
173 if ret.len() != decompressed_size {
174 return Err(ArrowError::IpcError(format!(
175 "Expected compressed length of {decompressed_size} got {}",
176 ret.len()
177 )));
178 }
179 Ok(ret)
180 }
181}
182
183#[cfg(feature = "lz4")]
184fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
185 use std::io::Write;
186 let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
187 encoder.write_all(input)?;
188 encoder
189 .finish()
190 .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
191 Ok(())
192}
193
194#[cfg(not(feature = "lz4"))]
195#[allow(clippy::ptr_arg)]
196fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
197 Err(ArrowError::InvalidArgumentError(
198 "lz4 IPC compression requires the lz4 feature".to_string(),
199 ))
200}
201
202#[cfg(feature = "lz4")]
203fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
204 use std::io::Read;
205 let mut output = Vec::with_capacity(decompressed_size);
206 lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
207 Ok(output)
208}
209
210#[cfg(not(feature = "lz4"))]
211#[allow(clippy::ptr_arg)]
212fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
213 Err(ArrowError::InvalidArgumentError(
214 "lz4 IPC decompression requires the lz4 feature".to_string(),
215 ))
216}
217
218#[cfg(feature = "zstd")]
219fn compress_zstd(
220 input: &[u8],
221 output: &mut Vec<u8>,
222 context: &mut CompressionContext,
223) -> Result<(), ArrowError> {
224 let result = context.compressor.compress(input)?;
225 output.extend_from_slice(&result);
226 Ok(())
227}
228
229#[cfg(not(feature = "zstd"))]
230#[allow(clippy::ptr_arg)]
231fn compress_zstd(
232 _input: &[u8],
233 _output: &mut Vec<u8>,
234 _context: &mut CompressionContext,
235) -> Result<(), ArrowError> {
236 Err(ArrowError::InvalidArgumentError(
237 "zstd IPC compression requires the zstd feature".to_string(),
238 ))
239}
240
241#[cfg(feature = "zstd")]
242fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
243 use std::io::Read;
244 let mut output = Vec::with_capacity(decompressed_size);
245 zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?;
246 Ok(output)
247}
248
249#[cfg(not(feature = "zstd"))]
250#[allow(clippy::ptr_arg)]
251fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
252 Err(ArrowError::InvalidArgumentError(
253 "zstd IPC decompression requires the zstd feature".to_string(),
254 ))
255}
256
257#[inline]
263fn read_uncompressed_size(buffer: &[u8]) -> i64 {
264 let len_buffer = &buffer[0..8];
265 i64::from_le_bytes(len_buffer.try_into().unwrap())
267}
268
269#[cfg(test)]
270mod tests {
271 #[test]
272 #[cfg(feature = "lz4")]
273 fn test_lz4_compression() {
274 let input_bytes = b"hello lz4";
275 let codec = super::CompressionCodec::Lz4Frame;
276 let mut output_bytes: Vec<u8> = Vec::new();
277 codec
278 .compress(input_bytes, &mut output_bytes, &mut Default::default())
279 .unwrap();
280 let result = codec
281 .decompress(output_bytes.as_slice(), input_bytes.len())
282 .unwrap();
283 assert_eq!(input_bytes, result.as_slice());
284 }
285
286 #[test]
287 #[cfg(feature = "zstd")]
288 fn test_zstd_compression() {
289 let input_bytes = b"hello zstd";
290 let codec = super::CompressionCodec::Zstd;
291 let mut output_bytes: Vec<u8> = Vec::new();
292 codec
293 .compress(input_bytes, &mut output_bytes, &mut Default::default())
294 .unwrap();
295 let result = codec
296 .decompress(output_bytes.as_slice(), input_bytes.len())
297 .unwrap();
298 assert_eq!(input_bytes, result.as_slice());
299 }
300}