1use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum CompressionCodec {
28 Lz4Frame,
29 Zstd,
30}
31
32impl TryFrom<CompressionType> for CompressionCodec {
33 type Error = ArrowError;
34
35 fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
36 match compression_type {
37 CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
38 CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
39 other_type => Err(ArrowError::NotYetImplemented(format!(
40 "compression type {other_type:?} not supported "
41 ))),
42 }
43 }
44}
45
46impl CompressionCodec {
47 pub(crate) fn compress_to_vec(
58 &self,
59 input: &[u8],
60 output: &mut Vec<u8>,
61 ) -> Result<usize, ArrowError> {
62 let uncompressed_data_len = input.len();
63 let original_output_len = output.len();
64
65 if input.is_empty() {
66 } else {
68 output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
70 self.compress(input, output)?;
71
72 let compression_len = output.len() - original_output_len;
73 if compression_len > uncompressed_data_len {
74 output.truncate(original_output_len);
79 output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
80 output.extend_from_slice(input);
81 }
82 }
83 Ok(output.len() - original_output_len)
84 }
85
86 pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result<Buffer, ArrowError> {
94 let decompressed_length = read_uncompressed_size(input);
97 let buffer = if decompressed_length == 0 {
98 Buffer::from([])
100 } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
101 input.slice(LENGTH_OF_PREFIX_DATA as usize)
103 } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
104 let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
106 let v = self.decompress(input_data, decompressed_length as _)?;
107 Buffer::from_vec(v)
108 } else {
109 return Err(ArrowError::IpcError(format!(
110 "Invalid uncompressed length: {decompressed_length}"
111 )));
112 };
113 Ok(buffer)
114 }
115
116 fn compress(&self, input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
119 match self {
120 CompressionCodec::Lz4Frame => compress_lz4(input, output),
121 CompressionCodec::Zstd => compress_zstd(input, output),
122 }
123 }
124
125 fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
128 let ret = match self {
129 CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
130 CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?,
131 };
132 if ret.len() != decompressed_size {
133 return Err(ArrowError::IpcError(format!(
134 "Expected compressed length of {decompressed_size} got {}",
135 ret.len()
136 )));
137 }
138 Ok(ret)
139 }
140}
141
142#[cfg(feature = "lz4")]
143fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
144 use std::io::Write;
145 let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
146 encoder.write_all(input)?;
147 encoder
148 .finish()
149 .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
150 Ok(())
151}
152
153#[cfg(not(feature = "lz4"))]
154#[allow(clippy::ptr_arg)]
155fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
156 Err(ArrowError::InvalidArgumentError(
157 "lz4 IPC compression requires the lz4 feature".to_string(),
158 ))
159}
160
161#[cfg(feature = "lz4")]
162fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
163 use std::io::Read;
164 let mut output = Vec::with_capacity(decompressed_size);
165 lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
166 Ok(output)
167}
168
169#[cfg(not(feature = "lz4"))]
170#[allow(clippy::ptr_arg)]
171fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
172 Err(ArrowError::InvalidArgumentError(
173 "lz4 IPC decompression requires the lz4 feature".to_string(),
174 ))
175}
176
177#[cfg(feature = "zstd")]
178fn compress_zstd(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
179 use std::io::Write;
180 let mut encoder = zstd::Encoder::new(output, 0)?;
181 encoder.write_all(input)?;
182 encoder.finish()?;
183 Ok(())
184}
185
186#[cfg(not(feature = "zstd"))]
187#[allow(clippy::ptr_arg)]
188fn compress_zstd(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
189 Err(ArrowError::InvalidArgumentError(
190 "zstd IPC compression requires the zstd feature".to_string(),
191 ))
192}
193
194#[cfg(feature = "zstd")]
195fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
196 use std::io::Read;
197 let mut output = Vec::with_capacity(decompressed_size);
198 zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?;
199 Ok(output)
200}
201
202#[cfg(not(feature = "zstd"))]
203#[allow(clippy::ptr_arg)]
204fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
205 Err(ArrowError::InvalidArgumentError(
206 "zstd IPC decompression requires the zstd feature".to_string(),
207 ))
208}
209
210#[inline]
216fn read_uncompressed_size(buffer: &[u8]) -> i64 {
217 let len_buffer = &buffer[0..8];
218 i64::from_le_bytes(len_buffer.try_into().unwrap())
220}
221
222#[cfg(test)]
223mod tests {
224 #[test]
225 #[cfg(feature = "lz4")]
226 fn test_lz4_compression() {
227 let input_bytes = b"hello lz4";
228 let codec = super::CompressionCodec::Lz4Frame;
229 let mut output_bytes: Vec<u8> = Vec::new();
230 codec.compress(input_bytes, &mut output_bytes).unwrap();
231 let result = codec
232 .decompress(output_bytes.as_slice(), input_bytes.len())
233 .unwrap();
234 assert_eq!(input_bytes, result.as_slice());
235 }
236
237 #[test]
238 #[cfg(feature = "zstd")]
239 fn test_zstd_compression() {
240 let input_bytes = b"hello zstd";
241 let codec = super::CompressionCodec::Zstd;
242 let mut output_bytes: Vec<u8> = Vec::new();
243 codec.compress(input_bytes, &mut output_bytes).unwrap();
244 let result = codec
245 .decompress(output_bytes.as_slice(), input_bytes.len())
246 .unwrap();
247 assert_eq!(input_bytes, result.as_slice());
248 }
249}