arrow_string/
concat_elements.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Provides utility functions for concatenation of elements in arrays.
19use std::sync::Arc;
20
21use arrow_array::builder::BufferBuilder;
22use arrow_array::types::ByteArrayType;
23use arrow_array::*;
24use arrow_buffer::{ArrowNativeType, NullBuffer};
25use arrow_data::ArrayDataBuilder;
26use arrow_schema::{ArrowError, DataType};
27
28/// Returns the elementwise concatenation of a [`GenericByteArray`].
29pub fn concat_elements_bytes<T: ByteArrayType>(
30    left: &GenericByteArray<T>,
31    right: &GenericByteArray<T>,
32) -> Result<GenericByteArray<T>, ArrowError> {
33    if left.len() != right.len() {
34        return Err(ArrowError::ComputeError(format!(
35            "Arrays must have the same length: {} != {}",
36            left.len(),
37            right.len()
38        )));
39    }
40
41    let nulls = NullBuffer::union(left.nulls(), right.nulls());
42
43    let left_offsets = left.value_offsets();
44    let right_offsets = right.value_offsets();
45
46    let left_values = left.value_data();
47    let right_values = right.value_data();
48
49    let mut output_values = BufferBuilder::<u8>::new(
50        left_values.len() + right_values.len()
51            - left_offsets[0].as_usize()
52            - right_offsets[0].as_usize(),
53    );
54
55    let mut output_offsets = BufferBuilder::<T::Offset>::new(left_offsets.len());
56    output_offsets.append(T::Offset::usize_as(0));
57    for (left_idx, right_idx) in left_offsets.windows(2).zip(right_offsets.windows(2)) {
58        output_values.append_slice(&left_values[left_idx[0].as_usize()..left_idx[1].as_usize()]);
59        output_values.append_slice(&right_values[right_idx[0].as_usize()..right_idx[1].as_usize()]);
60        output_offsets.append(T::Offset::from_usize(output_values.len()).unwrap());
61    }
62
63    let builder = ArrayDataBuilder::new(T::DATA_TYPE)
64        .len(left.len())
65        .add_buffer(output_offsets.finish())
66        .add_buffer(output_values.finish())
67        .nulls(nulls);
68
69    // SAFETY - offsets valid by construction
70    Ok(unsafe { builder.build_unchecked() }.into())
71}
72
73/// Returns the elementwise concatenation of a [`GenericStringArray`].
74///
75/// An index of the resulting [`GenericStringArray`] is null if any of
76/// `StringArray` are null at that location.
77///
78/// ```text
79/// e.g:
80///
81///   ["Hello"] + ["World"] = ["HelloWorld"]
82///
83///   ["a", "b"] + [None, "c"] = [None, "bc"]
84/// ```
85///
86/// An error will be returned if `left` and `right` have different lengths
87pub fn concat_elements_utf8<Offset: OffsetSizeTrait>(
88    left: &GenericStringArray<Offset>,
89    right: &GenericStringArray<Offset>,
90) -> Result<GenericStringArray<Offset>, ArrowError> {
91    concat_elements_bytes(left, right)
92}
93
94/// Returns the elementwise concatenation of a [`GenericBinaryArray`].
95pub fn concat_element_binary<Offset: OffsetSizeTrait>(
96    left: &GenericBinaryArray<Offset>,
97    right: &GenericBinaryArray<Offset>,
98) -> Result<GenericBinaryArray<Offset>, ArrowError> {
99    concat_elements_bytes(left, right)
100}
101
102/// Returns the elementwise concatenation of [`StringArray`].
103/// ```text
104/// e.g:
105///   ["a", "b"] + [None, "c"] + [None, "d"] = [None, "bcd"]
106/// ```
107///
108/// An error will be returned if the [`StringArray`] are of different lengths
109pub fn concat_elements_utf8_many<Offset: OffsetSizeTrait>(
110    arrays: &[&GenericStringArray<Offset>],
111) -> Result<GenericStringArray<Offset>, ArrowError> {
112    if arrays.is_empty() {
113        return Err(ArrowError::ComputeError(
114            "concat requires input of at least one array".to_string(),
115        ));
116    }
117
118    let size = arrays[0].len();
119    if !arrays.iter().all(|array| array.len() == size) {
120        return Err(ArrowError::ComputeError(format!(
121            "Arrays must have the same length of {size}",
122        )));
123    }
124
125    let nulls = arrays
126        .iter()
127        .fold(None, |acc, a| NullBuffer::union(acc.as_ref(), a.nulls()));
128
129    let data_values = arrays
130        .iter()
131        .map(|array| array.value_data())
132        .collect::<Vec<_>>();
133
134    let mut offsets = arrays
135        .iter()
136        .map(|a| a.value_offsets().iter().peekable())
137        .collect::<Vec<_>>();
138
139    let mut output_values = BufferBuilder::<u8>::new(
140        data_values
141            .iter()
142            .zip(offsets.iter_mut())
143            .map(|(data, offset)| data.len() - offset.peek().unwrap().as_usize())
144            .sum(),
145    );
146
147    let mut output_offsets = BufferBuilder::<Offset>::new(size + 1);
148    output_offsets.append(Offset::zero());
149    for _ in 0..size {
150        data_values
151            .iter()
152            .zip(offsets.iter_mut())
153            .for_each(|(values, offset)| {
154                let index_start = offset.next().unwrap().as_usize();
155                let index_end = offset.peek().unwrap().as_usize();
156                output_values.append_slice(&values[index_start..index_end]);
157            });
158        output_offsets.append(Offset::from_usize(output_values.len()).unwrap());
159    }
160
161    let builder = ArrayDataBuilder::new(GenericStringArray::<Offset>::DATA_TYPE)
162        .len(size)
163        .add_buffer(output_offsets.finish())
164        .add_buffer(output_values.finish())
165        .nulls(nulls);
166
167    // SAFETY - offsets valid by construction
168    Ok(unsafe { builder.build_unchecked() }.into())
169}
170
171/// Returns the elementwise concatenation of [`Array`]s.
172///
173/// # Errors
174///
175/// This function errors if the arrays are of different types.
176pub fn concat_elements_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef, ArrowError> {
177    if left.data_type() != right.data_type() {
178        return Err(ArrowError::ComputeError(format!(
179            "Cannot concat arrays of different types: {} != {}",
180            left.data_type(),
181            right.data_type()
182        )));
183    }
184    match (left.data_type(), right.data_type()) {
185        (DataType::Utf8, DataType::Utf8) => {
186            let left = left.as_any().downcast_ref::<StringArray>().unwrap();
187            let right = right.as_any().downcast_ref::<StringArray>().unwrap();
188            Ok(Arc::new(concat_elements_utf8(left, right).unwrap()))
189        }
190        (DataType::LargeUtf8, DataType::LargeUtf8) => {
191            let left = left.as_any().downcast_ref::<LargeStringArray>().unwrap();
192            let right = right.as_any().downcast_ref::<LargeStringArray>().unwrap();
193            Ok(Arc::new(concat_elements_utf8(left, right).unwrap()))
194        }
195        (DataType::Binary, DataType::Binary) => {
196            let left = left.as_any().downcast_ref::<BinaryArray>().unwrap();
197            let right = right.as_any().downcast_ref::<BinaryArray>().unwrap();
198            Ok(Arc::new(concat_element_binary(left, right).unwrap()))
199        }
200        (DataType::LargeBinary, DataType::LargeBinary) => {
201            let left = left.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
202            let right = right.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
203            Ok(Arc::new(concat_element_binary(left, right).unwrap()))
204        }
205        // unimplemented
206        _ => Err(ArrowError::NotYetImplemented(format!(
207            "concat not supported for {}",
208            left.data_type()
209        ))),
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    #[test]
217    fn test_string_concat() {
218        let left = [Some("foo"), Some("bar"), None]
219            .into_iter()
220            .collect::<StringArray>();
221        let right = [None, Some("yyy"), Some("zzz")]
222            .into_iter()
223            .collect::<StringArray>();
224
225        let output = concat_elements_utf8(&left, &right).unwrap();
226
227        let expected = [None, Some("baryyy"), None]
228            .into_iter()
229            .collect::<StringArray>();
230
231        assert_eq!(output, expected);
232    }
233
234    #[test]
235    fn test_string_concat_empty_string() {
236        let left = [Some("foo"), Some(""), Some("bar")]
237            .into_iter()
238            .collect::<StringArray>();
239        let right = [Some("baz"), Some(""), Some("")]
240            .into_iter()
241            .collect::<StringArray>();
242
243        let output = concat_elements_utf8(&left, &right).unwrap();
244
245        let expected = [Some("foobaz"), Some(""), Some("bar")]
246            .into_iter()
247            .collect::<StringArray>();
248
249        assert_eq!(output, expected);
250    }
251
252    #[test]
253    fn test_string_concat_no_null() {
254        let left = StringArray::from(vec!["foo", "bar"]);
255        let right = StringArray::from(vec!["bar", "baz"]);
256
257        let output = concat_elements_utf8(&left, &right).unwrap();
258
259        let expected = StringArray::from(vec!["foobar", "barbaz"]);
260
261        assert_eq!(output, expected);
262    }
263
264    #[test]
265    fn test_string_concat_error() {
266        let left = StringArray::from(vec!["foo", "bar"]);
267        let right = StringArray::from(vec!["baz"]);
268
269        let output = concat_elements_utf8(&left, &right);
270
271        assert_eq!(
272            output.unwrap_err().to_string(),
273            "Compute error: Arrays must have the same length: 2 != 1".to_string()
274        );
275    }
276
277    #[test]
278    fn test_string_concat_slice() {
279        let left = &StringArray::from(vec![None, Some("foo"), Some("bar"), Some("baz")]);
280        let right = &StringArray::from(vec![Some("boo"), None, Some("far"), Some("faz")]);
281
282        let left_slice = left.slice(0, 3);
283        let right_slice = right.slice(1, 3);
284        let output = concat_elements_utf8(
285            left_slice
286                .as_any()
287                .downcast_ref::<GenericStringArray<i32>>()
288                .unwrap(),
289            right_slice
290                .as_any()
291                .downcast_ref::<GenericStringArray<i32>>()
292                .unwrap(),
293        )
294        .unwrap();
295
296        let expected = [None, Some("foofar"), Some("barfaz")]
297            .into_iter()
298            .collect::<StringArray>();
299
300        assert_eq!(output, expected);
301
302        let left_slice = left.slice(2, 2);
303        let right_slice = right.slice(1, 2);
304
305        let output = concat_elements_utf8(
306            left_slice
307                .as_any()
308                .downcast_ref::<GenericStringArray<i32>>()
309                .unwrap(),
310            right_slice
311                .as_any()
312                .downcast_ref::<GenericStringArray<i32>>()
313                .unwrap(),
314        )
315        .unwrap();
316
317        let expected = [None, Some("bazfar")].into_iter().collect::<StringArray>();
318
319        assert_eq!(output, expected);
320    }
321
322    #[test]
323    fn test_string_concat_error_empty() {
324        assert_eq!(
325            concat_elements_utf8_many::<i32>(&[])
326                .unwrap_err()
327                .to_string(),
328            "Compute error: concat requires input of at least one array".to_string()
329        );
330    }
331
332    #[test]
333    fn test_string_concat_one() {
334        let expected = [None, Some("baryyy"), None]
335            .into_iter()
336            .collect::<StringArray>();
337
338        let output = concat_elements_utf8_many(&[&expected]).unwrap();
339
340        assert_eq!(output, expected);
341    }
342
343    #[test]
344    fn test_string_concat_many() {
345        let foo = StringArray::from(vec![Some("f"), Some("o"), Some("o"), None]);
346        let bar = StringArray::from(vec![None, Some("b"), Some("a"), Some("r")]);
347        let baz = StringArray::from(vec![Some("b"), None, Some("a"), Some("z")]);
348
349        let output = concat_elements_utf8_many(&[&foo, &bar, &baz]).unwrap();
350
351        let expected = [None, None, Some("oaa"), None]
352            .into_iter()
353            .collect::<StringArray>();
354
355        assert_eq!(output, expected);
356    }
357
358    #[test]
359    fn test_concat_dyn_same_type() {
360        // test for StringArray
361        let left = StringArray::from(vec![Some("foo"), Some("bar"), None]);
362        let right = StringArray::from(vec![None, Some("yyy"), Some("zzz")]);
363
364        let output: StringArray = concat_elements_dyn(&left, &right)
365            .unwrap()
366            .into_data()
367            .into();
368        let expected = StringArray::from(vec![None, Some("baryyy"), None]);
369        assert_eq!(output, expected);
370
371        // test for LargeStringArray
372        let left = LargeStringArray::from(vec![Some("foo"), Some("bar"), None]);
373        let right = LargeStringArray::from(vec![None, Some("yyy"), Some("zzz")]);
374
375        let output: LargeStringArray = concat_elements_dyn(&left, &right)
376            .unwrap()
377            .into_data()
378            .into();
379        let expected = LargeStringArray::from(vec![None, Some("baryyy"), None]);
380        assert_eq!(output, expected);
381
382        // test for BinaryArray
383        let left = BinaryArray::from_opt_vec(vec![Some(b"foo"), Some(b"bar"), None]);
384        let right = BinaryArray::from_opt_vec(vec![None, Some(b"yyy"), Some(b"zzz")]);
385        let output: BinaryArray = concat_elements_dyn(&left, &right)
386            .unwrap()
387            .into_data()
388            .into();
389        let expected = BinaryArray::from_opt_vec(vec![None, Some(b"baryyy"), None]);
390        assert_eq!(output, expected);
391
392        // test for LargeBinaryArray
393        let left = LargeBinaryArray::from_opt_vec(vec![Some(b"foo"), Some(b"bar"), None]);
394        let right = LargeBinaryArray::from_opt_vec(vec![None, Some(b"yyy"), Some(b"zzz")]);
395        let output: LargeBinaryArray = concat_elements_dyn(&left, &right)
396            .unwrap()
397            .into_data()
398            .into();
399        let expected = LargeBinaryArray::from_opt_vec(vec![None, Some(b"baryyy"), None]);
400        assert_eq!(output, expected);
401    }
402
403    #[test]
404    fn test_concat_dyn_different_type() {
405        let left = StringArray::from(vec![Some("foo"), Some("bar"), None]);
406        let right = LargeStringArray::from(vec![None, Some("1"), Some("2")]);
407
408        let output = concat_elements_dyn(&left, &right);
409        assert_eq!(
410            output.unwrap_err().to_string(),
411            "Compute error: Cannot concat arrays of different types: Utf8 != LargeUtf8".to_string()
412        );
413    }
414}