arrow_cast/
base64.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for converting data in [`GenericBinaryArray`] such as [`StringArray`] to/from base64 encoded strings
19//!
20//! [`StringArray`]: arrow_array::StringArray
21
22use arrow_array::{Array, GenericBinaryArray, GenericStringArray, OffsetSizeTrait};
23use arrow_buffer::{Buffer, OffsetBuffer};
24use arrow_schema::ArrowError;
25use base64::encoded_len;
26use base64::engine::Config;
27
28pub use base64::prelude::*;
29
30/// Bas64 encode each element of `array` with the provided [`Engine`]
31pub fn b64_encode<E: Engine, O: OffsetSizeTrait>(
32    engine: &E,
33    array: &GenericBinaryArray<O>,
34) -> GenericStringArray<O> {
35    let lengths = array.offsets().windows(2).map(|w| {
36        let len = w[1].as_usize() - w[0].as_usize();
37        encoded_len(len, engine.config().encode_padding()).unwrap()
38    });
39    let offsets = OffsetBuffer::<O>::from_lengths(lengths);
40    let buffer_len = offsets.last().unwrap().as_usize();
41    let mut buffer = vec![0_u8; buffer_len];
42    let mut offset = 0;
43
44    for i in 0..array.len() {
45        let len = engine
46            .encode_slice(array.value(i), &mut buffer[offset..])
47            .unwrap();
48        offset += len;
49    }
50    assert_eq!(offset, buffer_len);
51
52    // Safety: Base64 is valid UTF-8
53    unsafe {
54        GenericStringArray::new_unchecked(offsets, Buffer::from_vec(buffer), array.nulls().cloned())
55    }
56}
57
58/// Base64 decode each element of `array` with the provided [`Engine`]
59pub fn b64_decode<E: Engine, O: OffsetSizeTrait>(
60    engine: &E,
61    array: &GenericBinaryArray<O>,
62) -> Result<GenericBinaryArray<O>, ArrowError> {
63    let estimated_len = array.values().len(); // This is an overestimate
64    let mut buffer = vec![0; estimated_len];
65
66    let mut offsets = Vec::with_capacity(array.len() + 1);
67    offsets.push(O::usize_as(0));
68    let mut offset = 0;
69
70    for v in array.iter() {
71        if let Some(v) = v {
72            let len = engine.decode_slice(v, &mut buffer[offset..]).unwrap();
73            // This cannot overflow as `len` is less than `v.len()` and `a` is valid
74            offset += len;
75        }
76        offsets.push(O::usize_as(offset));
77    }
78
79    // Safety: offsets monotonically increasing by construction
80    let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) };
81
82    GenericBinaryArray::try_new(offsets, Buffer::from_vec(buffer), array.nulls().cloned())
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use arrow_array::BinaryArray;
89    use rand::{Rng, rng};
90
91    fn test_engine<E: Engine>(e: &E, a: &BinaryArray) {
92        let encoded = b64_encode(e, a);
93        encoded.to_data().validate_full().unwrap();
94
95        let to_decode = encoded.into();
96        let decoded = b64_decode(e, &to_decode).unwrap();
97        decoded.to_data().validate_full().unwrap();
98
99        assert_eq!(&decoded, a);
100    }
101
102    #[test]
103    fn test_b64() {
104        let mut rng = rng();
105        let len = rng.random_range(1024..1050);
106        let data: BinaryArray = (0..len)
107            .map(|_| {
108                let len = rng.random_range(0..16);
109                Some((0..len).map(|_| rng.random()).collect::<Vec<u8>>())
110            })
111            .collect();
112
113        test_engine(&BASE64_STANDARD, &data);
114        test_engine(&BASE64_STANDARD_NO_PAD, &data);
115    }
116}