Skip to main content

arrow_ipc/
compression.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25/// Additional context that may be needed for compression.
26///
27/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
28/// compression calls to avoid the performance overhead of initialising a new context for every
29/// compression.
30#[derive(Default)]
31pub struct CompressionContext {
32    #[cfg(feature = "zstd")]
33    compressor: Option<zstd::bulk::Compressor<'static>>,
34}
35
36impl CompressionContext {
37    #[cfg(feature = "zstd")]
38    fn zstd_compressor(&mut self) -> &mut zstd::bulk::Compressor<'static> {
39        self.compressor.get_or_insert_with(|| {
40            zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
41                .expect("can use default compression level")
42        })
43    }
44}
45
46impl std::fmt::Debug for CompressionContext {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        let mut ds = f.debug_struct("CompressionContext");
49
50        #[cfg(feature = "zstd")]
51        ds.field(
52            "compressor",
53            &self.compressor.as_ref().map(|_| "zstd::bulk::Compressor"),
54        );
55
56        ds.finish()
57    }
58}
59
60/// Additional context that may be needed for decompression.
61///
62/// In the case of zstd, this will contain the zstd decompression context, which can be reused
63/// between subsequent decompression calls to avoid the performance overhead of initialising a new
64/// context for every decompression.
65pub struct DecompressionContext {
66    #[cfg(feature = "zstd")]
67    decompressor: Option<zstd::bulk::Decompressor<'static>>,
68}
69
70impl DecompressionContext {
71    pub(crate) fn new() -> Self {
72        Default::default()
73    }
74
75    #[cfg(feature = "zstd")]
76    fn zstd_decompressor(&mut self) -> &mut zstd::bulk::Decompressor<'static> {
77        self.decompressor.get_or_insert_with(|| {
78            zstd::bulk::Decompressor::new().expect("can create zstd decompressor")
79        })
80    }
81}
82
83#[allow(clippy::derivable_impls)]
84impl Default for DecompressionContext {
85    fn default() -> Self {
86        DecompressionContext {
87            #[cfg(feature = "zstd")]
88            decompressor: None,
89        }
90    }
91}
92
93impl std::fmt::Debug for DecompressionContext {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        let mut ds = f.debug_struct("DecompressionContext");
96
97        #[cfg(feature = "zstd")]
98        ds.field(
99            "decompressor",
100            &self
101                .decompressor
102                .as_ref()
103                .map(|_| "zstd::bulk::Decompressor"),
104        );
105
106        ds.finish()
107    }
108}
109
110/// Represents compressing a ipc stream using a particular compression algorithm
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum CompressionCodec {
113    Lz4Frame,
114    Zstd,
115}
116
117impl TryFrom<CompressionType> for CompressionCodec {
118    type Error = ArrowError;
119
120    fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
121        match compression_type {
122            CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
123            CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
124            other_type => Err(ArrowError::NotYetImplemented(format!(
125                "compression type {other_type:?} not supported "
126            ))),
127        }
128    }
129}
130
131impl CompressionCodec {
132    /// Compresses the data in `input` to `output` and appends the
133    /// data using the specified compression mechanism.
134    ///
135    /// returns the number of bytes written to the stream
136    ///
137    /// Writes this format to output:
138    /// ```text
139    /// [8 bytes]:         uncompressed length
140    /// [remaining bytes]: compressed data stream
141    /// ```
142    pub(crate) fn compress_to_vec(
143        &self,
144        input: &[u8],
145        output: &mut Vec<u8>,
146        context: &mut CompressionContext,
147    ) -> Result<usize, ArrowError> {
148        let uncompressed_data_len = input.len();
149        let original_output_len = output.len();
150
151        if input.is_empty() {
152            // empty input, nothing to do
153        } else {
154            // write compressed data directly into the output buffer
155            output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
156            self.compress(input, output, context)?;
157
158            let compression_len = output.len() - original_output_len;
159            if compression_len > uncompressed_data_len {
160                // length of compressed data was larger than
161                // uncompressed data, use the uncompressed data with
162                // length -1 to indicate that we don't compress the
163                // data
164                output.truncate(original_output_len);
165                output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
166                output.extend_from_slice(input);
167            }
168        }
169        Ok(output.len() - original_output_len)
170    }
171
172    /// Decompresses the input into a [`Buffer`]
173    ///
174    /// The input should look like:
175    /// ```text
176    /// [8 bytes]:         uncompressed length
177    /// [remaining bytes]: compressed data stream
178    /// ```
179    pub(crate) fn decompress_to_buffer(
180        &self,
181        input: &Buffer,
182        context: &mut DecompressionContext,
183    ) -> Result<Buffer, ArrowError> {
184        // read the first 8 bytes to determine if the data is
185        // compressed
186        let decompressed_length = read_uncompressed_size(input)?;
187        let buffer = if decompressed_length == 0 {
188            // empty
189            Buffer::from([])
190        } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
191            // no compression
192            input.slice(LENGTH_OF_PREFIX_DATA as usize)
193        } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
194            // decompress data using the codec
195            let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
196            let v = self.decompress(input_data, decompressed_length as _, context)?;
197            Buffer::from_vec(v)
198        } else {
199            return Err(ArrowError::IpcError(format!(
200                "Invalid uncompressed length: {decompressed_length}"
201            )));
202        };
203        Ok(buffer)
204    }
205
206    /// Compress the data in input buffer and write to output buffer
207    /// using the specified compression
208    fn compress(
209        &self,
210        input: &[u8],
211        output: &mut Vec<u8>,
212        context: &mut CompressionContext,
213    ) -> Result<(), ArrowError> {
214        match self {
215            CompressionCodec::Lz4Frame => compress_lz4(input, output),
216            CompressionCodec::Zstd => compress_zstd(input, output, context),
217        }
218    }
219
220    /// Decompress the data in input buffer and write to output buffer
221    /// using the specified compression
222    fn decompress(
223        &self,
224        input: &[u8],
225        decompressed_size: usize,
226        context: &mut DecompressionContext,
227    ) -> Result<Vec<u8>, ArrowError> {
228        let ret = match self {
229            CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
230            CompressionCodec::Zstd => decompress_zstd(input, decompressed_size, context)?,
231        };
232        if ret.len() != decompressed_size {
233            return Err(ArrowError::IpcError(format!(
234                "Expected compressed length of {decompressed_size} got {}",
235                ret.len()
236            )));
237        }
238        Ok(ret)
239    }
240}
241
242#[cfg(feature = "lz4")]
243fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
244    use std::io::Write;
245    let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
246    encoder.write_all(input)?;
247    encoder
248        .finish()
249        .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
250    Ok(())
251}
252
253#[cfg(not(feature = "lz4"))]
254#[allow(clippy::ptr_arg)]
255fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
256    Err(ArrowError::InvalidArgumentError(
257        "lz4 IPC compression requires the lz4 feature".to_string(),
258    ))
259}
260
261#[cfg(feature = "lz4")]
262fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
263    use std::io::Read;
264    let mut output = Vec::with_capacity(decompressed_size);
265    lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
266    Ok(output)
267}
268
269#[cfg(not(feature = "lz4"))]
270#[allow(clippy::ptr_arg)]
271fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
272    Err(ArrowError::InvalidArgumentError(
273        "lz4 IPC decompression requires the lz4 feature".to_string(),
274    ))
275}
276
277#[cfg(feature = "zstd")]
278fn compress_zstd(
279    input: &[u8],
280    output: &mut Vec<u8>,
281    context: &mut CompressionContext,
282) -> Result<(), ArrowError> {
283    let result = context.zstd_compressor().compress(input)?;
284    output.extend_from_slice(&result);
285    Ok(())
286}
287
288#[cfg(not(feature = "zstd"))]
289#[allow(clippy::ptr_arg)]
290fn compress_zstd(
291    _input: &[u8],
292    _output: &mut Vec<u8>,
293    _context: &mut CompressionContext,
294) -> Result<(), ArrowError> {
295    Err(ArrowError::InvalidArgumentError(
296        "zstd IPC compression requires the zstd feature".to_string(),
297    ))
298}
299
300#[cfg(feature = "zstd")]
301fn decompress_zstd(
302    input: &[u8],
303    decompressed_size: usize,
304    context: &mut DecompressionContext,
305) -> Result<Vec<u8>, ArrowError> {
306    let output = context
307        .zstd_decompressor()
308        .decompress(input, decompressed_size)?;
309    Ok(output)
310}
311
312#[cfg(not(feature = "zstd"))]
313#[allow(clippy::ptr_arg)]
314fn decompress_zstd(
315    _input: &[u8],
316    _decompressed_size: usize,
317    _context: &mut DecompressionContext,
318) -> Result<Vec<u8>, ArrowError> {
319    Err(ArrowError::InvalidArgumentError(
320        "zstd IPC decompression requires the zstd feature".to_string(),
321    ))
322}
323
324/// Get the uncompressed length
325/// Notes:
326///   LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed
327///    0: indicate that there is no data
328///   positive number: indicate the uncompressed length for the following data
329/// Returns an error if the input buffer is shorter than 8 bytes
330#[inline]
331fn read_uncompressed_size(buffer: &[u8]) -> Result<i64, ArrowError> {
332    let len_buffer = buffer.get(..LENGTH_OF_PREFIX_DATA as usize).ok_or_else(|| {
333        ArrowError::IpcError(format!(
334            "Compressed IPC buffer is too short: expected at least {LENGTH_OF_PREFIX_DATA} bytes, got {}",
335            buffer.len()
336        ))
337    })?;
338    Ok(i64::from_le_bytes(len_buffer.try_into().unwrap()))
339}
340
341#[cfg(test)]
342mod tests {
343    #[test]
344    #[cfg(feature = "lz4")]
345    fn test_lz4_compression() {
346        let input_bytes = b"hello lz4";
347        let codec = super::CompressionCodec::Lz4Frame;
348        let mut output_bytes: Vec<u8> = Vec::new();
349        codec
350            .compress(input_bytes, &mut output_bytes, &mut Default::default())
351            .unwrap();
352        let result = codec
353            .decompress(
354                output_bytes.as_slice(),
355                input_bytes.len(),
356                &mut Default::default(),
357            )
358            .unwrap();
359        assert_eq!(input_bytes, result.as_slice());
360    }
361
362    #[test]
363    #[cfg(feature = "zstd")]
364    fn test_zstd_compression() {
365        let input_bytes = b"hello zstd";
366        let codec = super::CompressionCodec::Zstd;
367        let mut output_bytes: Vec<u8> = Vec::new();
368        codec
369            .compress(input_bytes, &mut output_bytes, &mut Default::default())
370            .unwrap();
371        let result = codec
372            .decompress(
373                output_bytes.as_slice(),
374                input_bytes.len(),
375                &mut Default::default(),
376            )
377            .unwrap();
378        assert_eq!(input_bytes, result.as_slice());
379    }
380
381    #[test]
382    fn test_read_uncompressed_size_rejects_short_prefix() {
383        let err = super::read_uncompressed_size(&[1, 2, 3, 4, 5, 6, 7])
384            .expect_err("short compressed IPC prefix should return an error");
385
386        assert!(
387            err.to_string()
388                .contains("Compressed IPC buffer is too short"),
389            "unexpected error: {err}"
390        );
391    }
392}