arrow_ipc/
compression.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25/// Additional context that may be needed for compression.
26///
27/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
28/// compression calls to avoid the performance overhead of initialising a new context for every
29/// compression.
30pub struct CompressionContext {
31    #[cfg(feature = "zstd")]
32    compressor: zstd::bulk::Compressor<'static>,
33}
34
35// the reason we allow derivable_impls here is because when zstd feature is not enabled, this
36// becomes derivable. however with zstd feature want to be explicit about the compression level.
37#[allow(clippy::derivable_impls)]
38impl Default for CompressionContext {
39    fn default() -> Self {
40        CompressionContext {
41            // safety: `new` here will only return error here if using an invalid compression level
42            #[cfg(feature = "zstd")]
43            compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
44                .expect("can use default compression level"),
45        }
46    }
47}
48
49impl std::fmt::Debug for CompressionContext {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        let mut ds = f.debug_struct("CompressionContext");
52
53        #[cfg(feature = "zstd")]
54        ds.field("compressor", &"zstd::bulk::Compressor");
55
56        ds.finish()
57    }
58}
59
60/// Represents compressing a ipc stream using a particular compression algorithm
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum CompressionCodec {
63    Lz4Frame,
64    Zstd,
65}
66
67impl TryFrom<CompressionType> for CompressionCodec {
68    type Error = ArrowError;
69
70    fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
71        match compression_type {
72            CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
73            CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
74            other_type => Err(ArrowError::NotYetImplemented(format!(
75                "compression type {other_type:?} not supported "
76            ))),
77        }
78    }
79}
80
81impl CompressionCodec {
82    /// Compresses the data in `input` to `output` and appends the
83    /// data using the specified compression mechanism.
84    ///
85    /// returns the number of bytes written to the stream
86    ///
87    /// Writes this format to output:
88    /// ```text
89    /// [8 bytes]:         uncompressed length
90    /// [remaining bytes]: compressed data stream
91    /// ```
92    pub(crate) fn compress_to_vec(
93        &self,
94        input: &[u8],
95        output: &mut Vec<u8>,
96        context: &mut CompressionContext,
97    ) -> Result<usize, ArrowError> {
98        let uncompressed_data_len = input.len();
99        let original_output_len = output.len();
100
101        if input.is_empty() {
102            // empty input, nothing to do
103        } else {
104            // write compressed data directly into the output buffer
105            output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
106            self.compress(input, output, context)?;
107
108            let compression_len = output.len() - original_output_len;
109            if compression_len > uncompressed_data_len {
110                // length of compressed data was larger than
111                // uncompressed data, use the uncompressed data with
112                // length -1 to indicate that we don't compress the
113                // data
114                output.truncate(original_output_len);
115                output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
116                output.extend_from_slice(input);
117            }
118        }
119        Ok(output.len() - original_output_len)
120    }
121
122    /// Decompresses the input into a [`Buffer`]
123    ///
124    /// The input should look like:
125    /// ```text
126    /// [8 bytes]:         uncompressed length
127    /// [remaining bytes]: compressed data stream
128    /// ```
129    pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result<Buffer, ArrowError> {
130        // read the first 8 bytes to determine if the data is
131        // compressed
132        let decompressed_length = read_uncompressed_size(input);
133        let buffer = if decompressed_length == 0 {
134            // empty
135            Buffer::from([])
136        } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
137            // no compression
138            input.slice(LENGTH_OF_PREFIX_DATA as usize)
139        } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
140            // decompress data using the codec
141            let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
142            let v = self.decompress(input_data, decompressed_length as _)?;
143            Buffer::from_vec(v)
144        } else {
145            return Err(ArrowError::IpcError(format!(
146                "Invalid uncompressed length: {decompressed_length}"
147            )));
148        };
149        Ok(buffer)
150    }
151
152    /// Compress the data in input buffer and write to output buffer
153    /// using the specified compression
154    fn compress(
155        &self,
156        input: &[u8],
157        output: &mut Vec<u8>,
158        context: &mut CompressionContext,
159    ) -> Result<(), ArrowError> {
160        match self {
161            CompressionCodec::Lz4Frame => compress_lz4(input, output),
162            CompressionCodec::Zstd => compress_zstd(input, output, context),
163        }
164    }
165
166    /// Decompress the data in input buffer and write to output buffer
167    /// using the specified compression
168    fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
169        let ret = match self {
170            CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
171            CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?,
172        };
173        if ret.len() != decompressed_size {
174            return Err(ArrowError::IpcError(format!(
175                "Expected compressed length of {decompressed_size} got {}",
176                ret.len()
177            )));
178        }
179        Ok(ret)
180    }
181}
182
183#[cfg(feature = "lz4")]
184fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
185    use std::io::Write;
186    let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
187    encoder.write_all(input)?;
188    encoder
189        .finish()
190        .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
191    Ok(())
192}
193
194#[cfg(not(feature = "lz4"))]
195#[allow(clippy::ptr_arg)]
196fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
197    Err(ArrowError::InvalidArgumentError(
198        "lz4 IPC compression requires the lz4 feature".to_string(),
199    ))
200}
201
202#[cfg(feature = "lz4")]
203fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
204    use std::io::Read;
205    let mut output = Vec::with_capacity(decompressed_size);
206    lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
207    Ok(output)
208}
209
210#[cfg(not(feature = "lz4"))]
211#[allow(clippy::ptr_arg)]
212fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
213    Err(ArrowError::InvalidArgumentError(
214        "lz4 IPC decompression requires the lz4 feature".to_string(),
215    ))
216}
217
218#[cfg(feature = "zstd")]
219fn compress_zstd(
220    input: &[u8],
221    output: &mut Vec<u8>,
222    context: &mut CompressionContext,
223) -> Result<(), ArrowError> {
224    let result = context.compressor.compress(input)?;
225    output.extend_from_slice(&result);
226    Ok(())
227}
228
229#[cfg(not(feature = "zstd"))]
230#[allow(clippy::ptr_arg)]
231fn compress_zstd(
232    _input: &[u8],
233    _output: &mut Vec<u8>,
234    _context: &mut CompressionContext,
235) -> Result<(), ArrowError> {
236    Err(ArrowError::InvalidArgumentError(
237        "zstd IPC compression requires the zstd feature".to_string(),
238    ))
239}
240
241#[cfg(feature = "zstd")]
242fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
243    use std::io::Read;
244    let mut output = Vec::with_capacity(decompressed_size);
245    zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?;
246    Ok(output)
247}
248
249#[cfg(not(feature = "zstd"))]
250#[allow(clippy::ptr_arg)]
251fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
252    Err(ArrowError::InvalidArgumentError(
253        "zstd IPC decompression requires the zstd feature".to_string(),
254    ))
255}
256
257/// Get the uncompressed length
258/// Notes:
259///   LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed
260///    0: indicate that there is no data
261///   positive number: indicate the uncompressed length for the following data
262#[inline]
263fn read_uncompressed_size(buffer: &[u8]) -> i64 {
264    let len_buffer = &buffer[0..8];
265    // 64-bit little-endian signed integer
266    i64::from_le_bytes(len_buffer.try_into().unwrap())
267}
268
269#[cfg(test)]
270mod tests {
271    #[test]
272    #[cfg(feature = "lz4")]
273    fn test_lz4_compression() {
274        let input_bytes = b"hello lz4";
275        let codec = super::CompressionCodec::Lz4Frame;
276        let mut output_bytes: Vec<u8> = Vec::new();
277        codec
278            .compress(input_bytes, &mut output_bytes, &mut Default::default())
279            .unwrap();
280        let result = codec
281            .decompress(output_bytes.as_slice(), input_bytes.len())
282            .unwrap();
283        assert_eq!(input_bytes, result.as_slice());
284    }
285
286    #[test]
287    #[cfg(feature = "zstd")]
288    fn test_zstd_compression() {
289        let input_bytes = b"hello zstd";
290        let codec = super::CompressionCodec::Zstd;
291        let mut output_bytes: Vec<u8> = Vec::new();
292        codec
293            .compress(input_bytes, &mut output_bytes, &mut Default::default())
294            .unwrap();
295        let result = codec
296            .decompress(output_bytes.as_slice(), input_bytes.len())
297            .unwrap();
298        assert_eq!(input_bytes, result.as_slice());
299    }
300}