arrow_ipc/
compression.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::CompressionType;
19use arrow_buffer::Buffer;
20use arrow_schema::ArrowError;
21
22const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
23const LENGTH_OF_PREFIX_DATA: i64 = 8;
24
25/// Represents compressing a ipc stream using a particular compression algorithm
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum CompressionCodec {
28    Lz4Frame,
29    Zstd,
30}
31
32impl TryFrom<CompressionType> for CompressionCodec {
33    type Error = ArrowError;
34
35    fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
36        match compression_type {
37            CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
38            CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
39            other_type => Err(ArrowError::NotYetImplemented(format!(
40                "compression type {other_type:?} not supported "
41            ))),
42        }
43    }
44}
45
46impl CompressionCodec {
47    /// Compresses the data in `input` to `output` and appends the
48    /// data using the specified compression mechanism.
49    ///
50    /// returns the number of bytes written to the stream
51    ///
52    /// Writes this format to output:
53    /// ```text
54    /// [8 bytes]:         uncompressed length
55    /// [remaining bytes]: compressed data stream
56    /// ```
57    pub(crate) fn compress_to_vec(
58        &self,
59        input: &[u8],
60        output: &mut Vec<u8>,
61    ) -> Result<usize, ArrowError> {
62        let uncompressed_data_len = input.len();
63        let original_output_len = output.len();
64
65        if input.is_empty() {
66            // empty input, nothing to do
67        } else {
68            // write compressed data directly into the output buffer
69            output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
70            self.compress(input, output)?;
71
72            let compression_len = output.len() - original_output_len;
73            if compression_len > uncompressed_data_len {
74                // length of compressed data was larger than
75                // uncompressed data, use the uncompressed data with
76                // length -1 to indicate that we don't compress the
77                // data
78                output.truncate(original_output_len);
79                output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
80                output.extend_from_slice(input);
81            }
82        }
83        Ok(output.len() - original_output_len)
84    }
85
86    /// Decompresses the input into a [`Buffer`]
87    ///
88    /// The input should look like:
89    /// ```text
90    /// [8 bytes]:         uncompressed length
91    /// [remaining bytes]: compressed data stream
92    /// ```
93    pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result<Buffer, ArrowError> {
94        // read the first 8 bytes to determine if the data is
95        // compressed
96        let decompressed_length = read_uncompressed_size(input);
97        let buffer = if decompressed_length == 0 {
98            // empty
99            Buffer::from([])
100        } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
101            // no compression
102            input.slice(LENGTH_OF_PREFIX_DATA as usize)
103        } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
104            // decompress data using the codec
105            let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
106            let v = self.decompress(input_data, decompressed_length as _)?;
107            Buffer::from_vec(v)
108        } else {
109            return Err(ArrowError::IpcError(format!(
110                "Invalid uncompressed length: {decompressed_length}"
111            )));
112        };
113        Ok(buffer)
114    }
115
116    /// Compress the data in input buffer and write to output buffer
117    /// using the specified compression
118    fn compress(&self, input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
119        match self {
120            CompressionCodec::Lz4Frame => compress_lz4(input, output),
121            CompressionCodec::Zstd => compress_zstd(input, output),
122        }
123    }
124
125    /// Decompress the data in input buffer and write to output buffer
126    /// using the specified compression
127    fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
128        let ret = match self {
129            CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
130            CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?,
131        };
132        if ret.len() != decompressed_size {
133            return Err(ArrowError::IpcError(format!(
134                "Expected compressed length of {decompressed_size} got {}",
135                ret.len()
136            )));
137        }
138        Ok(ret)
139    }
140}
141
142#[cfg(feature = "lz4")]
143fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
144    use std::io::Write;
145    let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
146    encoder.write_all(input)?;
147    encoder
148        .finish()
149        .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
150    Ok(())
151}
152
153#[cfg(not(feature = "lz4"))]
154#[allow(clippy::ptr_arg)]
155fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
156    Err(ArrowError::InvalidArgumentError(
157        "lz4 IPC compression requires the lz4 feature".to_string(),
158    ))
159}
160
161#[cfg(feature = "lz4")]
162fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
163    use std::io::Read;
164    let mut output = Vec::with_capacity(decompressed_size);
165    lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
166    Ok(output)
167}
168
169#[cfg(not(feature = "lz4"))]
170#[allow(clippy::ptr_arg)]
171fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
172    Err(ArrowError::InvalidArgumentError(
173        "lz4 IPC decompression requires the lz4 feature".to_string(),
174    ))
175}
176
177#[cfg(feature = "zstd")]
178fn compress_zstd(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
179    use std::io::Write;
180    let mut encoder = zstd::Encoder::new(output, 0)?;
181    encoder.write_all(input)?;
182    encoder.finish()?;
183    Ok(())
184}
185
186#[cfg(not(feature = "zstd"))]
187#[allow(clippy::ptr_arg)]
188fn compress_zstd(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
189    Err(ArrowError::InvalidArgumentError(
190        "zstd IPC compression requires the zstd feature".to_string(),
191    ))
192}
193
194#[cfg(feature = "zstd")]
195fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
196    use std::io::Read;
197    let mut output = Vec::with_capacity(decompressed_size);
198    zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?;
199    Ok(output)
200}
201
202#[cfg(not(feature = "zstd"))]
203#[allow(clippy::ptr_arg)]
204fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
205    Err(ArrowError::InvalidArgumentError(
206        "zstd IPC decompression requires the zstd feature".to_string(),
207    ))
208}
209
210/// Get the uncompressed length
211/// Notes:
212///   LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed
213///    0: indicate that there is no data
214///   positive number: indicate the uncompressed length for the following data
215#[inline]
216fn read_uncompressed_size(buffer: &[u8]) -> i64 {
217    let len_buffer = &buffer[0..8];
218    // 64-bit little-endian signed integer
219    i64::from_le_bytes(len_buffer.try_into().unwrap())
220}
221
222#[cfg(test)]
223mod tests {
224    #[test]
225    #[cfg(feature = "lz4")]
226    fn test_lz4_compression() {
227        let input_bytes = b"hello lz4";
228        let codec = super::CompressionCodec::Lz4Frame;
229        let mut output_bytes: Vec<u8> = Vec::new();
230        codec.compress(input_bytes, &mut output_bytes).unwrap();
231        let result = codec
232            .decompress(output_bytes.as_slice(), input_bytes.len())
233            .unwrap();
234        assert_eq!(input_bytes, result.as_slice());
235    }
236
237    #[test]
238    #[cfg(feature = "zstd")]
239    fn test_zstd_compression() {
240        let input_bytes = b"hello zstd";
241        let codec = super::CompressionCodec::Zstd;
242        let mut output_bytes: Vec<u8> = Vec::new();
243        codec.compress(input_bytes, &mut output_bytes).unwrap();
244        let result = codec
245            .decompress(output_bytes.as_slice(), input_bytes.len())
246            .unwrap();
247        assert_eq!(input_bytes, result.as_slice());
248    }
249}