arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27    bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer,
28    ScalarBuffer,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num::{One, Zero};
34
35/// Take elements by index from [Array], creating a new [Array] from those indexes.
36///
37/// ```text
38/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
39/// │        A        │      │    0    │                              │        A        │
40/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
41/// │        D        │      │    2    │                              │        B        │
42/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
43/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
44/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
45/// │        C        │      │    1    │                              │        D        │
46/// ├─────────────────┤      └─────────┘                              └─────────────────┘
47/// │        E        │
48/// └─────────────────┘
49///    values array          indices array                              result
50/// ```
51///
52/// For selecting values by index from multiple arrays see [`crate::interleave`]
53///
54/// Note that this kernel, similar to other kernels in this crate,
55/// will avoid allocating where not necessary. Consequently
56/// the returned array may share buffers with the inputs
57///
58/// # Errors
59/// This function errors whenever:
60/// * An index cannot be casted to `usize` (typically 32 bit architectures)
61/// * An index is out of bounds and `options` is set to check bounds.
62///
63/// # Safety
64///
65/// When `options` is not set to check bounds, taking indexes after `len` will panic.
66///
67/// # See also
68/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
69///   the results into a single array.
70///
71/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
72///
73/// # Examples
74/// ```
75/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
76/// # use arrow_select::take::take;
77/// let values = StringArray::from(vec!["zero", "one", "two"]);
78///
79/// // Take items at index 2, and 1:
80/// let indices = UInt32Array::from(vec![2, 1]);
81/// let taken = take(&values, &indices, None).unwrap();
82/// let taken = taken.as_string::<i32>();
83///
84/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
85/// ```
86pub fn take(
87    values: &dyn Array,
88    indices: &dyn Array,
89    options: Option<TakeOptions>,
90) -> Result<ArrayRef, ArrowError> {
91    let options = options.unwrap_or_default();
92    downcast_integer_array!(
93        indices => {
94            if options.check_bounds {
95                check_bounds(values.len(), indices)?;
96            }
97            let indices = indices.to_indices();
98            take_impl(values, &indices)
99        },
100        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
101    )
102}
103
104/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
105/// [`Vec<ArrayRef>`] from those indices.
106///
107/// ```text
108/// ┌────────┬────────┐
109/// │        │        │           ┌────────┐                                ┌────────┬────────┐
110/// │   A    │   1    │           │        │                                │        │        │
111/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
112/// │        │        │           ├────────┤                                ├────────┼────────┤
113/// │   D    │   4    │           │        │                                │        │        │
114/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
115/// │        │        │           ├────────┤                                ├────────┼────────┤
116/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
117/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
118/// │        │        │           ├────────┤                                ├────────┼────────┤
119/// │   C    │   3    │           │        │                                │        │        │
120/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
121/// │        │        │           └────────┘                                └────────┼────────┘
122/// │   E    │   5    │
123/// └────────┴────────┘
124///    values arrays             indices array                                      result
125/// ```
126///
127/// # Errors
128/// This function errors whenever:
129/// * An index cannot be casted to `usize` (typically 32 bit architectures)
130/// * An index is out of bounds and `options` is set to check bounds.
131///
132/// # Safety
133///
134/// When `options` is not set to check bounds, taking indexes after `len` will panic.
135///
136/// # Examples
137/// ```
138/// # use std::sync::Arc;
139/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
140/// # use arrow_select::take::{take, take_arrays};
141/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
142/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
143///
144/// // Take items at index 2, and 1:
145/// let indices = UInt32Array::from(vec![2, 1]);
146/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
147/// let taken_string = taken_arrays[0].as_string::<i32>();
148/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
149/// let taken_values = taken_arrays[1].as_primitive();
150/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
151/// ```
152pub fn take_arrays(
153    arrays: &[ArrayRef],
154    indices: &dyn Array,
155    options: Option<TakeOptions>,
156) -> Result<Vec<ArrayRef>, ArrowError> {
157    arrays
158        .iter()
159        .map(|array| take(array.as_ref(), indices, options.clone()))
160        .collect()
161}
162
163/// Verifies that the non-null values of `indices` are all `< len`
164fn check_bounds<T: ArrowPrimitiveType>(
165    len: usize,
166    indices: &PrimitiveArray<T>,
167) -> Result<(), ArrowError> {
168    if indices.null_count() > 0 {
169        indices.iter().flatten().try_for_each(|index| {
170            let ix = index
171                .to_usize()
172                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
173            if ix >= len {
174                return Err(ArrowError::ComputeError(format!(
175                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
176                )));
177            }
178            Ok(())
179        })
180    } else {
181        indices.values().iter().try_for_each(|index| {
182            let ix = index
183                .to_usize()
184                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
185            if ix >= len {
186                return Err(ArrowError::ComputeError(format!(
187                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
188                )));
189            }
190            Ok(())
191        })
192    }
193}
194
195#[inline(never)]
196fn take_impl<IndexType: ArrowPrimitiveType>(
197    values: &dyn Array,
198    indices: &PrimitiveArray<IndexType>,
199) -> Result<ArrayRef, ArrowError> {
200    downcast_primitive_array! {
201        values => Ok(Arc::new(take_primitive(values, indices)?)),
202        DataType::Boolean => {
203            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
204            Ok(Arc::new(take_boolean(values, indices)))
205        }
206        DataType::Utf8 => {
207            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
208        }
209        DataType::LargeUtf8 => {
210            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
211        }
212        DataType::Utf8View => {
213            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
214        }
215        DataType::List(_) => {
216            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
217        }
218        DataType::LargeList(_) => {
219            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
220        }
221        DataType::FixedSizeList(_, length) => {
222            let values = values
223                .as_any()
224                .downcast_ref::<FixedSizeListArray>()
225                .unwrap();
226            Ok(Arc::new(take_fixed_size_list(
227                values,
228                indices,
229                *length as u32,
230            )?))
231        }
232        DataType::Map(_, _) => {
233            let list_arr = ListArray::from(values.as_map().clone());
234            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
235            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
236            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
237        }
238        DataType::Struct(fields) => {
239            let array: &StructArray = values.as_struct();
240            let arrays  = array
241                .columns()
242                .iter()
243                .map(|a| take_impl(a.as_ref(), indices))
244                .collect::<Result<Vec<ArrayRef>, _>>()?;
245            let fields: Vec<(FieldRef, ArrayRef)> =
246                fields.iter().cloned().zip(arrays).collect();
247
248            // Create the null bit buffer.
249            let is_valid: Buffer = indices
250                .iter()
251                .map(|index| {
252                    if let Some(index) = index {
253                        array.is_valid(index.to_usize().unwrap())
254                    } else {
255                        false
256                    }
257                })
258                .collect();
259
260            if fields.is_empty() {
261                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
262                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
263            } else {
264                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
265            }
266        }
267        DataType::Dictionary(_, _) => downcast_dictionary_array! {
268            values => Ok(Arc::new(take_dict(values, indices)?)),
269            t => unimplemented!("Take not supported for dictionary type {:?}", t)
270        }
271        DataType::RunEndEncoded(_, _) => downcast_run_array! {
272            values => Ok(Arc::new(take_run(values, indices)?)),
273            t => unimplemented!("Take not supported for run type {:?}", t)
274        }
275        DataType::Binary => {
276            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
277        }
278        DataType::LargeBinary => {
279            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
280        }
281        DataType::BinaryView => {
282            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
283        }
284        DataType::FixedSizeBinary(size) => {
285            let values = values
286                .as_any()
287                .downcast_ref::<FixedSizeBinaryArray>()
288                .unwrap();
289            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
290        }
291        DataType::Null => {
292            // Take applied to a null array produces a null array.
293            if values.len() >= indices.len() {
294                // If the existing null array is as big as the indices, we can use a slice of it
295                // to avoid allocating a new null array.
296                Ok(values.slice(0, indices.len()))
297            } else {
298                // If the existing null array isn't big enough, create a new one.
299                Ok(new_null_array(&DataType::Null, indices.len()))
300            }
301        }
302        DataType::Union(fields, UnionMode::Sparse) => {
303            let mut children = Vec::with_capacity(fields.len());
304            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
305            let type_ids = take_native(values.type_ids(), indices);
306            for (type_id, _field) in fields.iter() {
307                let values = values.child(type_id);
308                let values = take_impl(values, indices)?;
309                children.push(values);
310            }
311            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
312            Ok(Arc::new(array))
313        }
314        DataType::Union(fields, UnionMode::Dense) => {
315            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
316
317            let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
318            let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
319
320            let children = fields.iter()
321                .map(|(field_type_id, _)| {
322                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
323
324                    let indices = crate::filter::filter(&offsets, &mask)?;
325
326                    let values = values.child(field_type_id);
327
328                    take_impl(values, indices.as_primitive::<Int32Type>())
329                })
330                .collect::<Result<_, _>>()?;
331
332            let mut child_offsets = [0; 128];
333
334            let offsets = type_ids.values()
335                .iter()
336                .map(|&i| {
337                    let offset = child_offsets[i as usize];
338
339                    child_offsets[i as usize] += 1;
340
341                    offset
342                })
343                .collect();
344
345            let (_, type_ids, _) = type_ids.into_parts();
346
347            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
348
349            Ok(Arc::new(array))
350        }
351        t => unimplemented!("Take not supported for data type {:?}", t)
352    }
353}
354
355/// Options that define how `take` should behave
356#[derive(Clone, Debug, Default)]
357pub struct TakeOptions {
358    /// Perform bounds check before taking indices from values.
359    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
360    /// If not enabled, and indices exceed bounds, the kernel will panic.
361    pub check_bounds: bool,
362}
363
364#[inline(always)]
365fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
366    index
367        .to_usize()
368        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
369}
370
371/// `take` implementation for all primitive arrays
372///
373/// This checks if an `indices` slot is populated, and gets the value from `values`
374///  as the populated index.
375/// If the `indices` slot is null, a null value is returned.
376/// For example, given:
377///     values:  [1, 2, 3, null, 5]
378///     indices: [0, null, 4, 3]
379/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
380fn take_primitive<T, I>(
381    values: &PrimitiveArray<T>,
382    indices: &PrimitiveArray<I>,
383) -> Result<PrimitiveArray<T>, ArrowError>
384where
385    T: ArrowPrimitiveType,
386    I: ArrowPrimitiveType,
387{
388    let values_buf = take_native(values.values(), indices);
389    let nulls = take_nulls(values.nulls(), indices);
390    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
391}
392
393#[inline(never)]
394fn take_nulls<I: ArrowPrimitiveType>(
395    values: Option<&NullBuffer>,
396    indices: &PrimitiveArray<I>,
397) -> Option<NullBuffer> {
398    match values.filter(|n| n.null_count() > 0) {
399        Some(n) => {
400            let buffer = take_bits(n.inner(), indices);
401            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
402        }
403        None => indices.nulls().cloned(),
404    }
405}
406
407#[inline(never)]
408fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
409    values: &[T],
410    indices: &PrimitiveArray<I>,
411) -> ScalarBuffer<T> {
412    match indices.nulls().filter(|n| n.null_count() > 0) {
413        Some(n) => indices
414            .values()
415            .iter()
416            .enumerate()
417            .map(|(idx, index)| match values.get(index.as_usize()) {
418                Some(v) => *v,
419                None => match n.is_null(idx) {
420                    true => T::default(),
421                    false => panic!("Out-of-bounds index {index:?}"),
422                },
423            })
424            .collect(),
425        None => indices
426            .values()
427            .iter()
428            .map(|index| values[index.as_usize()])
429            .collect(),
430    }
431}
432
433#[inline(never)]
434fn take_bits<I: ArrowPrimitiveType>(
435    values: &BooleanBuffer,
436    indices: &PrimitiveArray<I>,
437) -> BooleanBuffer {
438    let len = indices.len();
439
440    match indices.nulls().filter(|n| n.null_count() > 0) {
441        Some(nulls) => {
442            let mut output_buffer = MutableBuffer::new_null(len);
443            let output_slice = output_buffer.as_slice_mut();
444            nulls.valid_indices().for_each(|idx| {
445                if values.value(indices.value(idx).as_usize()) {
446                    bit_util::set_bit(output_slice, idx);
447                }
448            });
449            BooleanBuffer::new(output_buffer.into(), 0, len)
450        }
451        None => {
452            BooleanBuffer::collect_bool(len, |idx: usize| {
453                // SAFETY: idx<indices.len()
454                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
455            })
456        }
457    }
458}
459
460/// `take` implementation for boolean arrays
461fn take_boolean<IndexType: ArrowPrimitiveType>(
462    values: &BooleanArray,
463    indices: &PrimitiveArray<IndexType>,
464) -> BooleanArray {
465    let val_buf = take_bits(values.values(), indices);
466    let null_buf = take_nulls(values.nulls(), indices);
467    BooleanArray::new(val_buf, null_buf)
468}
469
470/// `take` implementation for string arrays
471fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
472    array: &GenericByteArray<T>,
473    indices: &PrimitiveArray<IndexType>,
474) -> Result<GenericByteArray<T>, ArrowError> {
475    let mut offsets = Vec::with_capacity(indices.len() + 1);
476    offsets.push(T::Offset::default());
477
478    let input_offsets = array.value_offsets();
479    let mut capacity = 0;
480    let nulls = take_nulls(array.nulls(), indices);
481
482    let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
483        offsets.extend(indices.values().iter().map(|index| {
484            let index = index.as_usize();
485            capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
486            T::Offset::from_usize(capacity).expect("overflow")
487        }));
488        let mut values = Vec::with_capacity(capacity);
489
490        for index in indices.values() {
491            values.extend_from_slice(array.value(index.as_usize()).as_ref());
492        }
493        (offsets, values)
494    } else if indices.null_count() == 0 {
495        offsets.extend(indices.values().iter().map(|index| {
496            let index = index.as_usize();
497            if array.is_valid(index) {
498                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
499            }
500            T::Offset::from_usize(capacity).expect("overflow")
501        }));
502        let mut values = Vec::with_capacity(capacity);
503
504        for index in indices.values() {
505            let index = index.as_usize();
506            if array.is_valid(index) {
507                values.extend_from_slice(array.value(index).as_ref());
508            }
509        }
510        (offsets, values)
511    } else if array.null_count() == 0 {
512        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
513            let index = index.as_usize();
514            if indices.is_valid(i) {
515                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
516            }
517            T::Offset::from_usize(capacity).expect("overflow")
518        }));
519        let mut values = Vec::with_capacity(capacity);
520
521        for (i, index) in indices.values().iter().enumerate() {
522            if indices.is_valid(i) {
523                values.extend_from_slice(array.value(index.as_usize()).as_ref());
524            }
525        }
526        (offsets, values)
527    } else {
528        let nulls = nulls.as_ref().unwrap();
529        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
530            let index = index.as_usize();
531            if nulls.is_valid(i) {
532                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
533            }
534            T::Offset::from_usize(capacity).expect("overflow")
535        }));
536        let mut values = Vec::with_capacity(capacity);
537
538        for (i, index) in indices.values().iter().enumerate() {
539            // check index is valid before using index. The value in
540            // NULL index slots may not be within bounds of array
541            let index = index.as_usize();
542            if nulls.is_valid(i) {
543                values.extend_from_slice(array.value(index).as_ref());
544            }
545        }
546        (offsets, values)
547    };
548
549    T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
550        "Offset overflow for {}BinaryArray: {}",
551        T::Offset::PREFIX,
552        values.len()
553    )))?;
554
555    let array = unsafe {
556        let offsets = OffsetBuffer::new_unchecked(offsets.into());
557        GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
558    };
559
560    Ok(array)
561}
562
563/// `take` implementation for byte view arrays
564fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
565    array: &GenericByteViewArray<T>,
566    indices: &PrimitiveArray<IndexType>,
567) -> Result<GenericByteViewArray<T>, ArrowError> {
568    let new_views = take_native(array.views(), indices);
569    let new_nulls = take_nulls(array.nulls(), indices);
570    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
571    Ok(unsafe {
572        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
573    })
574}
575
576/// `take` implementation for list arrays
577///
578/// Calculates the index and indexed offset for the inner array,
579/// applying `take` on the inner array, then reconstructing a list array
580/// with the indexed offsets
581fn take_list<IndexType, OffsetType>(
582    values: &GenericListArray<OffsetType::Native>,
583    indices: &PrimitiveArray<IndexType>,
584) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
585where
586    IndexType: ArrowPrimitiveType,
587    OffsetType: ArrowPrimitiveType,
588    OffsetType::Native: OffsetSizeTrait,
589    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
590{
591    // TODO: Some optimizations can be done here such as if it is
592    // taking the whole list or a contiguous sublist
593    let (list_indices, offsets, null_buf) =
594        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
595
596    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
597    let value_offsets = Buffer::from_vec(offsets);
598    // create a new list with taken data and computed null information
599    let list_data = ArrayDataBuilder::new(values.data_type().clone())
600        .len(indices.len())
601        .null_bit_buffer(Some(null_buf.into()))
602        .offset(0)
603        .add_child_data(taken.into_data())
604        .add_buffer(value_offsets);
605
606    let list_data = unsafe { list_data.build_unchecked() };
607
608    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
609}
610
611/// `take` implementation for `FixedSizeListArray`
612///
613/// Calculates the index and indexed offset for the inner array,
614/// applying `take` on the inner array, then reconstructing a list array
615/// with the indexed offsets
616fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
617    values: &FixedSizeListArray,
618    indices: &PrimitiveArray<IndexType>,
619    length: <UInt32Type as ArrowPrimitiveType>::Native,
620) -> Result<FixedSizeListArray, ArrowError> {
621    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
622    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
623
624    // determine null count and null buffer, which are a function of `values` and `indices`
625    let num_bytes = bit_util::ceil(indices.len(), 8);
626    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
627    let null_slice = null_buf.as_slice_mut();
628
629    for i in 0..indices.len() {
630        let index = indices
631            .value(i)
632            .to_usize()
633            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
634        if !indices.is_valid(i) || values.is_null(index) {
635            bit_util::unset_bit(null_slice, i);
636        }
637    }
638
639    let list_data = ArrayDataBuilder::new(values.data_type().clone())
640        .len(indices.len())
641        .null_bit_buffer(Some(null_buf.into()))
642        .offset(0)
643        .add_child_data(taken.into_data());
644
645    let list_data = unsafe { list_data.build_unchecked() };
646
647    Ok(FixedSizeListArray::from(list_data))
648}
649
650fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
651    values: &FixedSizeBinaryArray,
652    indices: &PrimitiveArray<IndexType>,
653    size: i32,
654) -> Result<FixedSizeBinaryArray, ArrowError> {
655    let nulls = values.nulls();
656    let array_iter = indices
657        .values()
658        .iter()
659        .map(|idx| {
660            let idx = maybe_usize::<IndexType::Native>(*idx)?;
661            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
662                Ok(Some(values.value(idx)))
663            } else {
664                Ok(None)
665            }
666        })
667        .collect::<Result<Vec<_>, ArrowError>>()?
668        .into_iter();
669
670    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
671}
672
673/// `take` implementation for dictionary arrays
674///
675/// applies `take` to the keys of the dictionary array and returns a new dictionary array
676/// with the same dictionary values and reordered keys
677fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
678    values: &DictionaryArray<T>,
679    indices: &PrimitiveArray<I>,
680) -> Result<DictionaryArray<T>, ArrowError> {
681    let new_keys = take_primitive(values.keys(), indices)?;
682    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
683}
684
685/// `take` implementation for run arrays
686///
687/// Finds physical indices for the given logical indices and builds output run array
688/// by taking values in the input run_array.values at the physical indices.
689/// The output run array will be run encoded on the physical indices and not on output values.
690/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
691/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
692/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
693fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
694    run_array: &RunArray<T>,
695    logical_indices: &PrimitiveArray<I>,
696) -> Result<RunArray<T>, ArrowError> {
697    // get physical indices for the input logical indices
698    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
699
700    // Run encode the physical indices into new_run_ends_builder
701    // Keep track of the physical indices to take in take_value_indices
702    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
703    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
704    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
705    let mut new_physical_len = 1;
706    for ix in 1..physical_indices.len() {
707        if physical_indices[ix] != physical_indices[ix - 1] {
708            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
709            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
710            new_physical_len += 1;
711        }
712    }
713    take_value_indices
714        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
715    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
716    let new_run_ends = unsafe {
717        // Safety:
718        // The function builds a valid run_ends array and hence need not be validated.
719        ArrayDataBuilder::new(T::DATA_TYPE)
720            .len(new_physical_len)
721            .null_count(0)
722            .add_buffer(new_run_ends_builder.finish())
723            .build_unchecked()
724    };
725
726    let take_value_indices: PrimitiveArray<I> = unsafe {
727        // Safety:
728        // The function builds a valid take_value_indices array and hence need not be validated.
729        ArrayDataBuilder::new(I::DATA_TYPE)
730            .len(new_physical_len)
731            .null_count(0)
732            .add_buffer(take_value_indices.finish())
733            .build_unchecked()
734            .into()
735    };
736
737    let new_values = take(run_array.values(), &take_value_indices, None)?;
738
739    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
740        .len(physical_indices.len())
741        .add_child_data(new_run_ends)
742        .add_child_data(new_values.into_data());
743    let array_data = unsafe {
744        // Safety:
745        //  This function builds a valid run array and hence can skip validation.
746        builder.build_unchecked()
747    };
748    Ok(array_data.into())
749}
750
751/// Takes/filters a list array's inner data using the offsets of the list array.
752///
753/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
754/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
755/// elements)
756#[allow(clippy::type_complexity)]
757fn take_value_indices_from_list<IndexType, OffsetType>(
758    list: &GenericListArray<OffsetType::Native>,
759    indices: &PrimitiveArray<IndexType>,
760) -> Result<
761    (
762        PrimitiveArray<OffsetType>,
763        Vec<OffsetType::Native>,
764        MutableBuffer,
765    ),
766    ArrowError,
767>
768where
769    IndexType: ArrowPrimitiveType,
770    OffsetType: ArrowPrimitiveType,
771    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
772    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
773{
774    // TODO: benchmark this function, there might be a faster unsafe alternative
775    let offsets: &[OffsetType::Native] = list.value_offsets();
776
777    let mut new_offsets = Vec::with_capacity(indices.len());
778    let mut values = Vec::new();
779    let mut current_offset = OffsetType::Native::zero();
780    // add first offset
781    new_offsets.push(OffsetType::Native::zero());
782
783    // Initialize null buffer
784    let num_bytes = bit_util::ceil(indices.len(), 8);
785    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
786    let null_slice = null_buf.as_slice_mut();
787
788    // compute the value indices, and set offsets accordingly
789    for i in 0..indices.len() {
790        if indices.is_valid(i) {
791            let ix = indices
792                .value(i)
793                .to_usize()
794                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
795            let start = offsets[ix];
796            let end = offsets[ix + 1];
797            current_offset += end - start;
798            new_offsets.push(current_offset);
799
800            let mut curr = start;
801
802            // if start == end, this slot is empty
803            while curr < end {
804                values.push(curr);
805                curr += One::one();
806            }
807            if !list.is_valid(ix) {
808                bit_util::unset_bit(null_slice, i);
809            }
810        } else {
811            bit_util::unset_bit(null_slice, i);
812            new_offsets.push(current_offset);
813        }
814    }
815
816    Ok((
817        PrimitiveArray::<OffsetType>::from(values),
818        new_offsets,
819        null_buf,
820    ))
821}
822
823/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
824fn take_value_indices_from_fixed_size_list<IndexType>(
825    list: &FixedSizeListArray,
826    indices: &PrimitiveArray<IndexType>,
827    length: <UInt32Type as ArrowPrimitiveType>::Native,
828) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
829where
830    IndexType: ArrowPrimitiveType,
831{
832    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
833
834    for i in 0..indices.len() {
835        if indices.is_valid(i) {
836            let index = indices
837                .value(i)
838                .to_usize()
839                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
840            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
841
842            // Safety: Range always has known length.
843            unsafe {
844                values.append_trusted_len_iter(start..start + length);
845            }
846        } else {
847            values.append_nulls(length as usize);
848        }
849    }
850
851    Ok(values.finish())
852}
853
854/// To avoid generating take implementations for every index type, instead we
855/// only generate for UInt32 and UInt64 and coerce inputs to these types
856trait ToIndices {
857    type T: ArrowPrimitiveType;
858
859    fn to_indices(&self) -> PrimitiveArray<Self::T>;
860}
861
862macro_rules! to_indices_reinterpret {
863    ($t:ty, $o:ty) => {
864        impl ToIndices for PrimitiveArray<$t> {
865            type T = $o;
866
867            fn to_indices(&self) -> PrimitiveArray<$o> {
868                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
869                PrimitiveArray::new(cast, self.nulls().cloned())
870            }
871        }
872    };
873}
874
875macro_rules! to_indices_identity {
876    ($t:ty) => {
877        impl ToIndices for PrimitiveArray<$t> {
878            type T = $t;
879
880            fn to_indices(&self) -> PrimitiveArray<$t> {
881                self.clone()
882            }
883        }
884    };
885}
886
887macro_rules! to_indices_widening {
888    ($t:ty, $o:ty) => {
889        impl ToIndices for PrimitiveArray<$t> {
890            type T = UInt32Type;
891
892            fn to_indices(&self) -> PrimitiveArray<$o> {
893                let cast = self.values().iter().copied().map(|x| x as _).collect();
894                PrimitiveArray::new(cast, self.nulls().cloned())
895            }
896        }
897    };
898}
899
900to_indices_widening!(UInt8Type, UInt32Type);
901to_indices_widening!(Int8Type, UInt32Type);
902
903to_indices_widening!(UInt16Type, UInt32Type);
904to_indices_widening!(Int16Type, UInt32Type);
905
906to_indices_identity!(UInt32Type);
907to_indices_reinterpret!(Int32Type, UInt32Type);
908
909to_indices_identity!(UInt64Type);
910to_indices_reinterpret!(Int64Type, UInt64Type);
911
912/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
913///
914/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
915///
916/// # Example
917/// ```
918/// # use std::sync::Arc;
919/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
920/// # use arrow_schema::{DataType, Field, Schema};
921/// # use arrow_select::take::take_record_batch;
922///
923/// let schema = Arc::new(Schema::new(vec![
924///     Field::new("a", DataType::Int32, true),
925///     Field::new("b", DataType::Utf8, true),
926/// ]));
927/// let batch = RecordBatch::try_new(
928///     schema.clone(),
929///     vec![
930///         Arc::new(Int32Array::from_iter_values(0..20)),
931///         Arc::new(StringArray::from_iter_values(
932///             (0..20).map(|i| format!("str-{}", i)),
933///         )),
934///     ],
935/// )
936/// .unwrap();
937///
938/// let indices = UInt32Array::from(vec![1, 5, 10]);
939/// let taken = take_record_batch(&batch, &indices).unwrap();
940///
941/// let expected = RecordBatch::try_new(
942///     schema,
943///     vec![
944///         Arc::new(Int32Array::from(vec![1, 5, 10])),
945///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
946///     ],
947/// )
948/// .unwrap();
949/// assert_eq!(taken, expected);
950/// ```
951pub fn take_record_batch(
952    record_batch: &RecordBatch,
953    indices: &dyn Array,
954) -> Result<RecordBatch, ArrowError> {
955    let columns = record_batch
956        .columns()
957        .iter()
958        .map(|c| take(c, indices, None))
959        .collect::<Result<Vec<_>, _>>()?;
960    RecordBatch::try_new(record_batch.schema(), columns)
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use arrow_array::builder::*;
967    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
968    use arrow_data::ArrayData;
969    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
970
971    fn test_take_decimal_arrays(
972        data: Vec<Option<i128>>,
973        index: &UInt32Array,
974        options: Option<TakeOptions>,
975        expected_data: Vec<Option<i128>>,
976        precision: &u8,
977        scale: &i8,
978    ) -> Result<(), ArrowError> {
979        let output = data
980            .into_iter()
981            .collect::<Decimal128Array>()
982            .with_precision_and_scale(*precision, *scale)
983            .unwrap();
984
985        let expected = expected_data
986            .into_iter()
987            .collect::<Decimal128Array>()
988            .with_precision_and_scale(*precision, *scale)
989            .unwrap();
990
991        let expected = Arc::new(expected) as ArrayRef;
992        let output = take(&output, index, options).unwrap();
993        assert_eq!(&output, &expected);
994        Ok(())
995    }
996
997    fn test_take_boolean_arrays(
998        data: Vec<Option<bool>>,
999        index: &UInt32Array,
1000        options: Option<TakeOptions>,
1001        expected_data: Vec<Option<bool>>,
1002    ) {
1003        let output = BooleanArray::from(data);
1004        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1005        let output = take(&output, index, options).unwrap();
1006        assert_eq!(&output, &expected)
1007    }
1008
1009    fn test_take_primitive_arrays<T>(
1010        data: Vec<Option<T::Native>>,
1011        index: &UInt32Array,
1012        options: Option<TakeOptions>,
1013        expected_data: Vec<Option<T::Native>>,
1014    ) -> Result<(), ArrowError>
1015    where
1016        T: ArrowPrimitiveType,
1017        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1018    {
1019        let output = PrimitiveArray::<T>::from(data);
1020        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1021        let output = take(&output, index, options)?;
1022        assert_eq!(&output, &expected);
1023        Ok(())
1024    }
1025
1026    fn test_take_primitive_arrays_non_null<T>(
1027        data: Vec<T::Native>,
1028        index: &UInt32Array,
1029        options: Option<TakeOptions>,
1030        expected_data: Vec<Option<T::Native>>,
1031    ) -> Result<(), ArrowError>
1032    where
1033        T: ArrowPrimitiveType,
1034        PrimitiveArray<T>: From<Vec<T::Native>>,
1035        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1036    {
1037        let output = PrimitiveArray::<T>::from(data);
1038        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1039        let output = take(&output, index, options)?;
1040        assert_eq!(&output, &expected);
1041        Ok(())
1042    }
1043
1044    fn test_take_impl_primitive_arrays<T, I>(
1045        data: Vec<Option<T::Native>>,
1046        index: &PrimitiveArray<I>,
1047        options: Option<TakeOptions>,
1048        expected_data: Vec<Option<T::Native>>,
1049    ) where
1050        T: ArrowPrimitiveType,
1051        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1052        I: ArrowPrimitiveType,
1053    {
1054        let output = PrimitiveArray::<T>::from(data);
1055        let expected = PrimitiveArray::<T>::from(expected_data);
1056        let output = take(&output, index, options).unwrap();
1057        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1058        assert_eq!(output, &expected)
1059    }
1060
1061    // create a simple struct for testing purposes
1062    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1063        let mut struct_builder = StructBuilder::new(
1064            Fields::from(vec![
1065                Field::new("a", DataType::Boolean, true),
1066                Field::new("b", DataType::Int32, true),
1067            ]),
1068            vec![
1069                Box::new(BooleanBuilder::with_capacity(values.len())),
1070                Box::new(Int32Builder::with_capacity(values.len())),
1071            ],
1072        );
1073
1074        for value in values {
1075            struct_builder
1076                .field_builder::<BooleanBuilder>(0)
1077                .unwrap()
1078                .append_option(value.and_then(|v| v.0));
1079            struct_builder
1080                .field_builder::<Int32Builder>(1)
1081                .unwrap()
1082                .append_option(value.and_then(|v| v.1));
1083            struct_builder.append(value.is_some());
1084        }
1085        struct_builder.finish()
1086    }
1087
1088    #[test]
1089    fn test_take_decimal128_non_null_indices() {
1090        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1091        let precision: u8 = 10;
1092        let scale: i8 = 5;
1093        test_take_decimal_arrays(
1094            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1095            &index,
1096            None,
1097            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1098            &precision,
1099            &scale,
1100        )
1101        .unwrap();
1102    }
1103
1104    #[test]
1105    fn test_take_decimal128() {
1106        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1107        let precision: u8 = 10;
1108        let scale: i8 = 5;
1109        test_take_decimal_arrays(
1110            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1111            &index,
1112            None,
1113            vec![Some(3), None, Some(1), Some(3), Some(2)],
1114            &precision,
1115            &scale,
1116        )
1117        .unwrap();
1118    }
1119
1120    #[test]
1121    fn test_take_primitive_non_null_indices() {
1122        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1123        test_take_primitive_arrays::<Int8Type>(
1124            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1125            &index,
1126            None,
1127            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1128        )
1129        .unwrap();
1130    }
1131
1132    #[test]
1133    fn test_take_primitive_non_null_values() {
1134        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1135        test_take_primitive_arrays::<Int8Type>(
1136            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1137            &index,
1138            None,
1139            vec![Some(3), None, Some(1), Some(3), Some(2)],
1140        )
1141        .unwrap();
1142    }
1143
1144    #[test]
1145    fn test_take_primitive_non_null() {
1146        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1147        test_take_primitive_arrays::<Int8Type>(
1148            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1149            &index,
1150            None,
1151            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1152        )
1153        .unwrap();
1154    }
1155
1156    #[test]
1157    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1158        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1159        let index = index.slice(2, 4);
1160        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1161
1162        assert_eq!(
1163            index,
1164            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1165        );
1166
1167        test_take_primitive_arrays_non_null::<Int64Type>(
1168            vec![0, 10, 20, 30, 40, 50],
1169            index,
1170            None,
1171            vec![Some(20), Some(30), None, None],
1172        )
1173        .unwrap();
1174    }
1175
1176    #[test]
1177    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1178        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1179        let index = index.slice(2, 4);
1180        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1181
1182        assert_eq!(
1183            index,
1184            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1185        );
1186
1187        test_take_primitive_arrays::<Int64Type>(
1188            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1189            index,
1190            None,
1191            vec![Some(20), Some(30), None, None],
1192        )
1193        .unwrap();
1194    }
1195
1196    #[test]
1197    fn test_take_primitive() {
1198        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1199
1200        // int8
1201        test_take_primitive_arrays::<Int8Type>(
1202            vec![Some(0), None, Some(2), Some(3), None],
1203            &index,
1204            None,
1205            vec![Some(3), None, None, Some(3), Some(2)],
1206        )
1207        .unwrap();
1208
1209        // int16
1210        test_take_primitive_arrays::<Int16Type>(
1211            vec![Some(0), None, Some(2), Some(3), None],
1212            &index,
1213            None,
1214            vec![Some(3), None, None, Some(3), Some(2)],
1215        )
1216        .unwrap();
1217
1218        // int32
1219        test_take_primitive_arrays::<Int32Type>(
1220            vec![Some(0), None, Some(2), Some(3), None],
1221            &index,
1222            None,
1223            vec![Some(3), None, None, Some(3), Some(2)],
1224        )
1225        .unwrap();
1226
1227        // int64
1228        test_take_primitive_arrays::<Int64Type>(
1229            vec![Some(0), None, Some(2), Some(3), None],
1230            &index,
1231            None,
1232            vec![Some(3), None, None, Some(3), Some(2)],
1233        )
1234        .unwrap();
1235
1236        // uint8
1237        test_take_primitive_arrays::<UInt8Type>(
1238            vec![Some(0), None, Some(2), Some(3), None],
1239            &index,
1240            None,
1241            vec![Some(3), None, None, Some(3), Some(2)],
1242        )
1243        .unwrap();
1244
1245        // uint16
1246        test_take_primitive_arrays::<UInt16Type>(
1247            vec![Some(0), None, Some(2), Some(3), None],
1248            &index,
1249            None,
1250            vec![Some(3), None, None, Some(3), Some(2)],
1251        )
1252        .unwrap();
1253
1254        // uint32
1255        test_take_primitive_arrays::<UInt32Type>(
1256            vec![Some(0), None, Some(2), Some(3), None],
1257            &index,
1258            None,
1259            vec![Some(3), None, None, Some(3), Some(2)],
1260        )
1261        .unwrap();
1262
1263        // int64
1264        test_take_primitive_arrays::<Int64Type>(
1265            vec![Some(0), None, Some(2), Some(-15), None],
1266            &index,
1267            None,
1268            vec![Some(-15), None, None, Some(-15), Some(2)],
1269        )
1270        .unwrap();
1271
1272        // interval_year_month
1273        test_take_primitive_arrays::<IntervalYearMonthType>(
1274            vec![Some(0), None, Some(2), Some(-15), None],
1275            &index,
1276            None,
1277            vec![Some(-15), None, None, Some(-15), Some(2)],
1278        )
1279        .unwrap();
1280
1281        // interval_day_time
1282        let v1 = IntervalDayTime::new(0, 0);
1283        let v2 = IntervalDayTime::new(2, 0);
1284        let v3 = IntervalDayTime::new(-15, 0);
1285        test_take_primitive_arrays::<IntervalDayTimeType>(
1286            vec![Some(v1), None, Some(v2), Some(v3), None],
1287            &index,
1288            None,
1289            vec![Some(v3), None, None, Some(v3), Some(v2)],
1290        )
1291        .unwrap();
1292
1293        // interval_month_day_nano
1294        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1295        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1296        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1297        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1298            vec![Some(v1), None, Some(v2), Some(v3), None],
1299            &index,
1300            None,
1301            vec![Some(v3), None, None, Some(v3), Some(v2)],
1302        )
1303        .unwrap();
1304
1305        // duration_second
1306        test_take_primitive_arrays::<DurationSecondType>(
1307            vec![Some(0), None, Some(2), Some(-15), None],
1308            &index,
1309            None,
1310            vec![Some(-15), None, None, Some(-15), Some(2)],
1311        )
1312        .unwrap();
1313
1314        // duration_millisecond
1315        test_take_primitive_arrays::<DurationMillisecondType>(
1316            vec![Some(0), None, Some(2), Some(-15), None],
1317            &index,
1318            None,
1319            vec![Some(-15), None, None, Some(-15), Some(2)],
1320        )
1321        .unwrap();
1322
1323        // duration_microsecond
1324        test_take_primitive_arrays::<DurationMicrosecondType>(
1325            vec![Some(0), None, Some(2), Some(-15), None],
1326            &index,
1327            None,
1328            vec![Some(-15), None, None, Some(-15), Some(2)],
1329        )
1330        .unwrap();
1331
1332        // duration_nanosecond
1333        test_take_primitive_arrays::<DurationNanosecondType>(
1334            vec![Some(0), None, Some(2), Some(-15), None],
1335            &index,
1336            None,
1337            vec![Some(-15), None, None, Some(-15), Some(2)],
1338        )
1339        .unwrap();
1340
1341        // float32
1342        test_take_primitive_arrays::<Float32Type>(
1343            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1344            &index,
1345            None,
1346            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1347        )
1348        .unwrap();
1349
1350        // float64
1351        test_take_primitive_arrays::<Float64Type>(
1352            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1353            &index,
1354            None,
1355            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1356        )
1357        .unwrap();
1358    }
1359
1360    #[test]
1361    fn test_take_preserve_timezone() {
1362        let index = Int64Array::from(vec![Some(0), None]);
1363
1364        let input = TimestampNanosecondArray::from(vec![
1365            1_639_715_368_000_000_000,
1366            1_639_715_368_000_000_000,
1367        ])
1368        .with_timezone("UTC".to_string());
1369        let result = take(&input, &index, None).unwrap();
1370        match result.data_type() {
1371            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1372                assert_eq!(tz.clone(), Some("UTC".into()))
1373            }
1374            _ => panic!(),
1375        }
1376    }
1377
1378    #[test]
1379    fn test_take_impl_primitive_with_int64_indices() {
1380        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1381
1382        // int16
1383        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1384            vec![Some(0), None, Some(2), Some(3), None],
1385            &index,
1386            None,
1387            vec![Some(3), None, None, Some(3), Some(2)],
1388        );
1389
1390        // int64
1391        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1392            vec![Some(0), None, Some(2), Some(-15), None],
1393            &index,
1394            None,
1395            vec![Some(-15), None, None, Some(-15), Some(2)],
1396        );
1397
1398        // uint64
1399        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1400            vec![Some(0), None, Some(2), Some(3), None],
1401            &index,
1402            None,
1403            vec![Some(3), None, None, Some(3), Some(2)],
1404        );
1405
1406        // duration_millisecond
1407        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1408            vec![Some(0), None, Some(2), Some(-15), None],
1409            &index,
1410            None,
1411            vec![Some(-15), None, None, Some(-15), Some(2)],
1412        );
1413
1414        // float32
1415        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1416            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1417            &index,
1418            None,
1419            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1420        );
1421    }
1422
1423    #[test]
1424    fn test_take_impl_primitive_with_uint8_indices() {
1425        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1426
1427        // int16
1428        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1429            vec![Some(0), None, Some(2), Some(3), None],
1430            &index,
1431            None,
1432            vec![Some(3), None, None, Some(3), Some(2)],
1433        );
1434
1435        // duration_millisecond
1436        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1437            vec![Some(0), None, Some(2), Some(-15), None],
1438            &index,
1439            None,
1440            vec![Some(-15), None, None, Some(-15), Some(2)],
1441        );
1442
1443        // float32
1444        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1445            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1446            &index,
1447            None,
1448            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1449        );
1450    }
1451
1452    #[test]
1453    fn test_take_bool() {
1454        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1455        // boolean
1456        test_take_boolean_arrays(
1457            vec![Some(false), None, Some(true), Some(false), None],
1458            &index,
1459            None,
1460            vec![Some(false), None, None, Some(false), Some(true)],
1461        );
1462    }
1463
1464    #[test]
1465    fn test_take_bool_nullable_index() {
1466        // indices where the masked invalid elements would be out of bounds
1467        let index_data = ArrayData::try_new(
1468            DataType::UInt32,
1469            6,
1470            Some(Buffer::from_iter(vec![
1471                false, true, false, true, false, true,
1472            ])),
1473            0,
1474            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1475            vec![],
1476        )
1477        .unwrap();
1478        let index = UInt32Array::from(index_data);
1479        test_take_boolean_arrays(
1480            vec![Some(true), None, Some(false)],
1481            &index,
1482            None,
1483            vec![None, Some(true), None, None, None, Some(false)],
1484        );
1485    }
1486
1487    #[test]
1488    fn test_take_bool_nullable_index_nonnull_values() {
1489        // indices where the masked invalid elements would be out of bounds
1490        let index_data = ArrayData::try_new(
1491            DataType::UInt32,
1492            6,
1493            Some(Buffer::from_iter(vec![
1494                false, true, false, true, false, true,
1495            ])),
1496            0,
1497            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1498            vec![],
1499        )
1500        .unwrap();
1501        let index = UInt32Array::from(index_data);
1502        test_take_boolean_arrays(
1503            vec![Some(true), Some(true), Some(false)],
1504            &index,
1505            None,
1506            vec![None, Some(true), None, Some(true), None, Some(false)],
1507        );
1508    }
1509
1510    #[test]
1511    fn test_take_bool_with_offset() {
1512        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1513        let index = index.slice(2, 4);
1514        let index = index
1515            .as_any()
1516            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1517            .unwrap();
1518
1519        // boolean
1520        test_take_boolean_arrays(
1521            vec![Some(false), None, Some(true), Some(false), None],
1522            index,
1523            None,
1524            vec![None, Some(false), Some(true), None],
1525        );
1526    }
1527
1528    fn _test_take_string<'a, K>()
1529    where
1530        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1531    {
1532        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1533
1534        let array = K::from(vec![
1535            Some("one"),
1536            None,
1537            Some("three"),
1538            Some("four"),
1539            Some("five"),
1540        ]);
1541        let actual = take(&array, &index, None).unwrap();
1542        assert_eq!(actual.len(), index.len());
1543
1544        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1545
1546        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1547
1548        assert_eq!(actual, &expected);
1549    }
1550
1551    #[test]
1552    fn test_take_string() {
1553        _test_take_string::<StringArray>()
1554    }
1555
1556    #[test]
1557    fn test_take_large_string() {
1558        _test_take_string::<LargeStringArray>()
1559    }
1560
1561    #[test]
1562    fn test_take_slice_string() {
1563        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1564        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1565        let indices_slice = indices.slice(1, 4);
1566        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1567        let result = take(&strings, &indices_slice, None).unwrap();
1568        assert_eq!(result.as_ref(), &expected);
1569    }
1570
1571    fn _test_byte_view<T>()
1572    where
1573        T: ByteViewType,
1574        str: AsRef<T::Native>,
1575        T::Native: PartialEq,
1576    {
1577        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1578        let array = {
1579            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1580            let mut builder = GenericByteViewBuilder::<T>::new();
1581            builder.append_value("hello");
1582            builder.append_value("world");
1583            builder.append_null();
1584            builder.append_value("large payload over 12 bytes");
1585            builder.append_value("lulu");
1586            builder.finish()
1587        };
1588
1589        let actual = take(&array, &index, None).unwrap();
1590
1591        assert_eq!(actual.len(), index.len());
1592
1593        let expected = {
1594            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1595            let mut builder = GenericByteViewBuilder::<T>::new();
1596            builder.append_value("large payload over 12 bytes");
1597            builder.append_null();
1598            builder.append_value("world");
1599            builder.append_value("large payload over 12 bytes");
1600            builder.append_value("lulu");
1601            builder.append_null();
1602            builder.finish()
1603        };
1604
1605        assert_eq!(actual.as_ref(), &expected);
1606    }
1607
1608    #[test]
1609    fn test_take_string_view() {
1610        _test_byte_view::<StringViewType>()
1611    }
1612
1613    #[test]
1614    fn test_take_binary_view() {
1615        _test_byte_view::<BinaryViewType>()
1616    }
1617
1618    macro_rules! test_take_list {
1619        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1620            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1621            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1622            // Construct offsets
1623            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1624            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1625            // Construct a list array from the above two
1626            let list_data_type =
1627                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1628            let list_data = ArrayData::builder(list_data_type.clone())
1629                .len(4)
1630                .add_buffer(value_offsets)
1631                .add_child_data(value_data)
1632                .build()
1633                .unwrap();
1634            let list_array = $list_array_type::from(list_data);
1635
1636            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1637            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1638
1639            let a = take(&list_array, &index, None).unwrap();
1640            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1641
1642            // construct a value array with expected results:
1643            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1644            let expected_data = Int32Array::from(vec![
1645                Some(2),
1646                Some(3),
1647                Some(-1),
1648                Some(-2),
1649                Some(-1),
1650                Some(0),
1651                Some(0),
1652                Some(0),
1653            ])
1654            .into_data();
1655            // construct offsets
1656            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1657            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1658            // construct list array from the two
1659            let expected_list_data = ArrayData::builder(list_data_type)
1660                .len(5)
1661                // null buffer remains the same as only the indices have nulls
1662                .nulls(index.nulls().cloned())
1663                .add_buffer(expected_offsets)
1664                .add_child_data(expected_data)
1665                .build()
1666                .unwrap();
1667            let expected_list_array = $list_array_type::from(expected_list_data);
1668
1669            assert_eq!(a, &expected_list_array);
1670        }};
1671    }
1672
1673    macro_rules! test_take_list_with_value_nulls {
1674        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1675            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1676            let value_data = Int32Array::from(vec![
1677                Some(0),
1678                None,
1679                Some(0),
1680                Some(-1),
1681                Some(-2),
1682                Some(3),
1683                None,
1684                Some(5),
1685                None,
1686            ])
1687            .into_data();
1688            // Construct offsets
1689            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1690            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1691            // Construct a list array from the above two
1692            let list_data_type =
1693                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1694            let list_data = ArrayData::builder(list_data_type.clone())
1695                .len(4)
1696                .add_buffer(value_offsets)
1697                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1698                .add_child_data(value_data)
1699                .build()
1700                .unwrap();
1701            let list_array = $list_array_type::from(list_data);
1702
1703            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1704            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1705
1706            let a = take(&list_array, &index, None).unwrap();
1707            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1708
1709            // construct a value array with expected results:
1710            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1711            let expected_data = Int32Array::from(vec![
1712                None,
1713                Some(-1),
1714                Some(-2),
1715                Some(3),
1716                Some(5),
1717                None,
1718                Some(0),
1719                None,
1720                Some(0),
1721            ])
1722            .into_data();
1723            // construct offsets
1724            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1725            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1726            // construct list array from the two
1727            let expected_list_data = ArrayData::builder(list_data_type)
1728                .len(5)
1729                // null buffer remains the same as only the indices have nulls
1730                .nulls(index.nulls().cloned())
1731                .add_buffer(expected_offsets)
1732                .add_child_data(expected_data)
1733                .build()
1734                .unwrap();
1735            let expected_list_array = $list_array_type::from(expected_list_data);
1736
1737            assert_eq!(a, &expected_list_array);
1738        }};
1739    }
1740
1741    macro_rules! test_take_list_with_nulls {
1742        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1743            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1744            let value_data = Int32Array::from(vec![
1745                Some(0),
1746                None,
1747                Some(0),
1748                Some(-1),
1749                Some(-2),
1750                Some(3),
1751                Some(5),
1752                None,
1753            ])
1754            .into_data();
1755            // Construct offsets
1756            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1757            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1758            // Construct a list array from the above two
1759            let list_data_type =
1760                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1761            let list_data = ArrayData::builder(list_data_type.clone())
1762                .len(4)
1763                .add_buffer(value_offsets)
1764                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1765                .add_child_data(value_data)
1766                .build()
1767                .unwrap();
1768            let list_array = $list_array_type::from(list_data);
1769
1770            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1771            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1772
1773            let a = take(&list_array, &index, None).unwrap();
1774            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1775
1776            // construct a value array with expected results:
1777            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1778            let expected_data = Int32Array::from(vec![
1779                Some(-1),
1780                Some(-2),
1781                Some(3),
1782                Some(5),
1783                None,
1784                Some(0),
1785                None,
1786                Some(0),
1787            ])
1788            .into_data();
1789            // construct offsets
1790            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1791            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1792            // construct list array from the two
1793            let mut null_bits: [u8; 1] = [0; 1];
1794            bit_util::set_bit(&mut null_bits, 2);
1795            bit_util::set_bit(&mut null_bits, 3);
1796            bit_util::set_bit(&mut null_bits, 4);
1797            let expected_list_data = ArrayData::builder(list_data_type)
1798                .len(5)
1799                // null buffer must be recalculated as both values and indices have nulls
1800                .null_bit_buffer(Some(Buffer::from(null_bits)))
1801                .add_buffer(expected_offsets)
1802                .add_child_data(expected_data)
1803                .build()
1804                .unwrap();
1805            let expected_list_array = $list_array_type::from(expected_list_data);
1806
1807            assert_eq!(a, &expected_list_array);
1808        }};
1809    }
1810
1811    fn do_take_fixed_size_list_test<T>(
1812        length: <Int32Type as ArrowPrimitiveType>::Native,
1813        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1814        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1815        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1816    ) where
1817        T: ArrowPrimitiveType,
1818        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1819    {
1820        let indices = UInt32Array::from(indices);
1821
1822        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1823
1824        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1825
1826        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1827
1828        assert_eq!(&output, &expected)
1829    }
1830
1831    #[test]
1832    fn test_take_list() {
1833        test_take_list!(i32, List, ListArray);
1834    }
1835
1836    #[test]
1837    fn test_take_large_list() {
1838        test_take_list!(i64, LargeList, LargeListArray);
1839    }
1840
1841    #[test]
1842    fn test_take_list_with_value_nulls() {
1843        test_take_list_with_value_nulls!(i32, List, ListArray);
1844    }
1845
1846    #[test]
1847    fn test_take_large_list_with_value_nulls() {
1848        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1849    }
1850
1851    #[test]
1852    fn test_test_take_list_with_nulls() {
1853        test_take_list_with_nulls!(i32, List, ListArray);
1854    }
1855
1856    #[test]
1857    fn test_test_take_large_list_with_nulls() {
1858        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1859    }
1860
1861    #[test]
1862    fn test_take_fixed_size_list() {
1863        do_take_fixed_size_list_test::<Int32Type>(
1864            3,
1865            vec![
1866                Some(vec![None, Some(1), Some(2)]),
1867                Some(vec![Some(3), Some(4), None]),
1868                Some(vec![Some(6), Some(7), Some(8)]),
1869            ],
1870            vec![2, 1, 0],
1871            vec![
1872                Some(vec![Some(6), Some(7), Some(8)]),
1873                Some(vec![Some(3), Some(4), None]),
1874                Some(vec![None, Some(1), Some(2)]),
1875            ],
1876        );
1877
1878        do_take_fixed_size_list_test::<UInt8Type>(
1879            1,
1880            vec![
1881                Some(vec![Some(1)]),
1882                Some(vec![Some(2)]),
1883                Some(vec![Some(3)]),
1884                Some(vec![Some(4)]),
1885                Some(vec![Some(5)]),
1886                Some(vec![Some(6)]),
1887                Some(vec![Some(7)]),
1888                Some(vec![Some(8)]),
1889            ],
1890            vec![2, 7, 0],
1891            vec![
1892                Some(vec![Some(3)]),
1893                Some(vec![Some(8)]),
1894                Some(vec![Some(1)]),
1895            ],
1896        );
1897
1898        do_take_fixed_size_list_test::<UInt64Type>(
1899            3,
1900            vec![
1901                Some(vec![Some(10), Some(11), Some(12)]),
1902                Some(vec![Some(13), Some(14), Some(15)]),
1903                None,
1904                Some(vec![Some(16), Some(17), Some(18)]),
1905            ],
1906            vec![3, 2, 1, 2, 0],
1907            vec![
1908                Some(vec![Some(16), Some(17), Some(18)]),
1909                None,
1910                Some(vec![Some(13), Some(14), Some(15)]),
1911                None,
1912                Some(vec![Some(10), Some(11), Some(12)]),
1913            ],
1914        );
1915    }
1916
1917    #[test]
1918    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1919    fn test_take_list_out_of_bounds() {
1920        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
1921        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1922        // Construct offsets
1923        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1924        // Construct a list array from the above two
1925        let list_data_type =
1926            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1927        let list_data = ArrayData::builder(list_data_type)
1928            .len(3)
1929            .add_buffer(value_offsets)
1930            .add_child_data(value_data)
1931            .build()
1932            .unwrap();
1933        let list_array = ListArray::from(list_data);
1934
1935        let index = UInt32Array::from(vec![1000]);
1936
1937        // A panic is expected here since we have not supplied the check_bounds
1938        // option.
1939        take(&list_array, &index, None).unwrap();
1940    }
1941
1942    #[test]
1943    fn test_take_map() {
1944        let values = Int32Array::from(vec![1, 2, 3, 4]);
1945        let array =
1946            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1947                .unwrap();
1948
1949        let index = UInt32Array::from(vec![0]);
1950
1951        let result = take(&array, &index, None).unwrap();
1952        let expected: ArrayRef = Arc::new(
1953            MapArray::new_from_strings(
1954                vec!["a", "b", "c"].into_iter(),
1955                &values.slice(0, 3),
1956                &[0, 3],
1957            )
1958            .unwrap(),
1959        );
1960        assert_eq!(&expected, &result);
1961    }
1962
1963    #[test]
1964    fn test_take_struct() {
1965        let array = create_test_struct(vec![
1966            Some((Some(true), Some(42))),
1967            Some((Some(false), Some(28))),
1968            Some((Some(false), Some(19))),
1969            Some((Some(true), Some(31))),
1970            None,
1971        ]);
1972
1973        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1974        let actual = take(&array, &index, None).unwrap();
1975        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1976        assert_eq!(index.len(), actual.len());
1977        assert_eq!(1, actual.null_count());
1978
1979        let expected = create_test_struct(vec![
1980            Some((Some(true), Some(42))),
1981            Some((Some(true), Some(31))),
1982            Some((Some(false), Some(28))),
1983            Some((Some(true), Some(42))),
1984            Some((Some(false), Some(19))),
1985            None,
1986        ]);
1987
1988        assert_eq!(&expected, actual);
1989
1990        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
1991        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
1992        let index = UInt32Array::from(vec![0, 2, 1, 4]);
1993        let actual = take(&empty_struct_arr, &index, None).unwrap();
1994
1995        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
1996        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
1997        assert_eq!(&expected_struct_arr, actual.as_struct());
1998    }
1999
2000    #[test]
2001    fn test_take_struct_with_null_indices() {
2002        let array = create_test_struct(vec![
2003            Some((Some(true), Some(42))),
2004            Some((Some(false), Some(28))),
2005            Some((Some(false), Some(19))),
2006            Some((Some(true), Some(31))),
2007            None,
2008        ]);
2009
2010        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2011        let actual = take(&array, &index, None).unwrap();
2012        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2013        assert_eq!(index.len(), actual.len());
2014        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
2015
2016        let expected = create_test_struct(vec![
2017            None,
2018            Some((Some(true), Some(31))),
2019            Some((Some(false), Some(28))),
2020            None,
2021            Some((Some(true), Some(42))),
2022            None,
2023        ]);
2024
2025        assert_eq!(&expected, actual);
2026    }
2027
2028    #[test]
2029    fn test_take_out_of_bounds() {
2030        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2031        let take_opt = TakeOptions { check_bounds: true };
2032
2033        // int64
2034        let result = test_take_primitive_arrays::<Int64Type>(
2035            vec![Some(0), None, Some(2), Some(3), None],
2036            &index,
2037            Some(take_opt),
2038            vec![None],
2039        );
2040        assert!(result.is_err());
2041    }
2042
2043    #[test]
2044    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2045    fn test_take_out_of_bounds_panic() {
2046        let index = UInt32Array::from(vec![Some(1000)]);
2047
2048        test_take_primitive_arrays::<Int64Type>(
2049            vec![Some(0), Some(1), Some(2), Some(3)],
2050            &index,
2051            None,
2052            vec![None],
2053        )
2054        .unwrap();
2055    }
2056
2057    #[test]
2058    fn test_null_array_smaller_than_indices() {
2059        let values = NullArray::new(2);
2060        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2061
2062        let result = take(&values, &indices, None).unwrap();
2063        let expected: ArrayRef = Arc::new(NullArray::new(3));
2064        assert_eq!(&result, &expected);
2065    }
2066
2067    #[test]
2068    fn test_null_array_larger_than_indices() {
2069        let values = NullArray::new(5);
2070        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2071
2072        let result = take(&values, &indices, None).unwrap();
2073        let expected: ArrayRef = Arc::new(NullArray::new(3));
2074        assert_eq!(&result, &expected);
2075    }
2076
2077    #[test]
2078    fn test_null_array_indices_out_of_bounds() {
2079        let values = NullArray::new(5);
2080        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2081
2082        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2083        assert_eq!(
2084            result.unwrap_err().to_string(),
2085            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2086        );
2087    }
2088
2089    #[test]
2090    fn test_take_dict() {
2091        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2092
2093        dict_builder.append("foo").unwrap();
2094        dict_builder.append("bar").unwrap();
2095        dict_builder.append("").unwrap();
2096        dict_builder.append_null();
2097        dict_builder.append("foo").unwrap();
2098        dict_builder.append("bar").unwrap();
2099        dict_builder.append("bar").unwrap();
2100        dict_builder.append("foo").unwrap();
2101
2102        let array = dict_builder.finish();
2103        let dict_values = array.values().clone();
2104        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2105
2106        let indices = UInt32Array::from(vec![
2107            Some(0), // first "foo"
2108            Some(7), // last "foo"
2109            None,    // null index should return null
2110            Some(5), // second "bar"
2111            Some(6), // another "bar"
2112            Some(2), // empty string
2113            Some(3), // input is null at this index
2114        ]);
2115
2116        let result = take(&array, &indices, None).unwrap();
2117        let result = result
2118            .as_any()
2119            .downcast_ref::<DictionaryArray<Int16Type>>()
2120            .unwrap();
2121
2122        let result_values: StringArray = result.values().to_data().into();
2123
2124        // dictionary values should stay the same
2125        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2126        assert_eq!(&expected_values, dict_values);
2127        assert_eq!(&expected_values, &result_values);
2128
2129        let expected_keys = Int16Array::from(vec![
2130            Some(0),
2131            Some(0),
2132            None,
2133            Some(1),
2134            Some(1),
2135            Some(2),
2136            None,
2137        ]);
2138        assert_eq!(result.keys(), &expected_keys);
2139    }
2140
2141    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2142    where
2143        S: OffsetSizeTrait + 'static,
2144        T: ArrowPrimitiveType,
2145        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2146    {
2147        GenericListArray::from_iter_primitive::<T, _, _>(
2148            data.iter()
2149                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2150        )
2151    }
2152
2153    #[test]
2154    fn test_take_value_index_from_list() {
2155        let list = build_generic_list::<i32, Int32Type>(vec![
2156            Some(vec![0, 1]),
2157            Some(vec![2, 3, 4]),
2158            Some(vec![5, 6, 7, 8, 9]),
2159        ]);
2160        let indices = UInt32Array::from(vec![2, 0]);
2161
2162        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2163
2164        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2165        assert_eq!(offsets, vec![0, 5, 7]);
2166        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2167    }
2168
2169    #[test]
2170    fn test_take_value_index_from_large_list() {
2171        let list = build_generic_list::<i64, Int32Type>(vec![
2172            Some(vec![0, 1]),
2173            Some(vec![2, 3, 4]),
2174            Some(vec![5, 6, 7, 8, 9]),
2175        ]);
2176        let indices = UInt32Array::from(vec![2, 0]);
2177
2178        let (indexed, offsets, null_buf) =
2179            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2180
2181        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2182        assert_eq!(offsets, vec![0, 5, 7]);
2183        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2184    }
2185
2186    #[test]
2187    fn test_take_runs() {
2188        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2189
2190        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2191        builder.extend(logical_array.into_iter().map(Some));
2192        let run_array = builder.finish();
2193
2194        let take_indices: PrimitiveArray<Int32Type> =
2195            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2196
2197        let take_out = take_run(&run_array, &take_indices).unwrap();
2198
2199        assert_eq!(take_out.len(), 7);
2200        assert_eq!(take_out.run_ends().len(), 7);
2201        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2202
2203        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2204        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2205    }
2206
2207    #[test]
2208    fn test_take_value_index_from_fixed_list() {
2209        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2210            vec![
2211                Some(vec![Some(1), Some(2), None]),
2212                Some(vec![Some(4), None, Some(6)]),
2213                None,
2214                Some(vec![None, Some(8), Some(9)]),
2215            ],
2216            3,
2217        );
2218
2219        let indices = UInt32Array::from(vec![2, 1, 0]);
2220        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2221
2222        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2223
2224        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2225        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2226
2227        assert_eq!(
2228            indexed,
2229            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2230        );
2231    }
2232
2233    #[test]
2234    fn test_take_null_indices() {
2235        // Build indices with values that are out of bounds, but masked by null mask
2236        let indices = Int32Array::new(
2237            vec![1, 2, 400, 400].into(),
2238            Some(NullBuffer::from(vec![true, true, false, false])),
2239        );
2240        let values = Int32Array::from(vec![1, 23, 4, 5]);
2241        let r = take(&values, &indices, None).unwrap();
2242        let values = r
2243            .as_primitive::<Int32Type>()
2244            .into_iter()
2245            .collect::<Vec<_>>();
2246        assert_eq!(&values, &[Some(23), Some(4), None, None])
2247    }
2248
2249    #[test]
2250    fn test_take_fixed_size_list_null_indices() {
2251        let indices = Int32Array::from_iter([Some(0), None]);
2252        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2253        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2254        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2255
2256        let r = take(&values, &indices, None).unwrap();
2257        let values = r
2258            .as_fixed_size_list()
2259            .values()
2260            .as_primitive::<Int32Type>()
2261            .into_iter()
2262            .collect::<Vec<_>>();
2263        assert_eq!(values, &[Some(0), Some(1), None, None])
2264    }
2265
2266    #[test]
2267    fn test_take_bytes_null_indices() {
2268        let indices = Int32Array::new(
2269            vec![0, 1, 400, 400].into(),
2270            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2271        );
2272        let values = StringArray::from(vec![Some("foo"), None]);
2273        let r = take(&values, &indices, None).unwrap();
2274        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2275        assert_eq!(&values, &[Some("foo"), None, None, None])
2276    }
2277
2278    #[test]
2279    fn test_take_union_sparse() {
2280        let structs = create_test_struct(vec![
2281            Some((Some(true), Some(42))),
2282            Some((Some(false), Some(28))),
2283            Some((Some(false), Some(19))),
2284            Some((Some(true), Some(31))),
2285            None,
2286        ]);
2287        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2288        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2289
2290        let union_fields = [
2291            (
2292                0,
2293                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2294            ),
2295            (
2296                1,
2297                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2298            ),
2299        ]
2300        .into_iter()
2301        .collect();
2302        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2303        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2304
2305        let indices = vec![0, 3, 1, 0, 2, 4];
2306        let index = UInt32Array::from(indices.clone());
2307        let actual = take(&array, &index, None).unwrap();
2308        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2309        let strings = actual.child(1);
2310        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2311
2312        let actual = strings.iter().collect::<Vec<_>>();
2313        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2314        assert_eq!(expected, actual);
2315    }
2316
2317    #[test]
2318    fn test_take_union_dense() {
2319        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2320        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2321        let ints = vec![10, 20, 30, 40];
2322        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2323
2324        let indices = vec![0, 3, 1, 0, 2, 4];
2325
2326        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2327        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2328        let taken_ints = vec![10, 20, 10, 30];
2329        let taken_strings = vec![Some("a"), None];
2330
2331        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2332        let offsets = <ScalarBuffer<i32>>::from(offsets);
2333        let ints = UInt32Array::from(ints);
2334        let strings = StringArray::from(strings);
2335
2336        let union_fields = [
2337            (
2338                0,
2339                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2340            ),
2341            (
2342                1,
2343                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2344            ),
2345        ]
2346        .into_iter()
2347        .collect();
2348
2349        let array = UnionArray::try_new(
2350            union_fields,
2351            type_ids,
2352            Some(offsets),
2353            vec![Arc::new(ints), Arc::new(strings)],
2354        )
2355        .unwrap();
2356
2357        let index = UInt32Array::from(indices);
2358
2359        let actual = take(&array, &index, None).unwrap();
2360        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2361
2362        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2363        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2364        assert_eq!(
2365            UInt32Array::from(actual.child(0).to_data()),
2366            UInt32Array::from(taken_ints)
2367        );
2368        assert_eq!(
2369            StringArray::from(actual.child(1).to_data()),
2370            StringArray::from(taken_strings)
2371        );
2372    }
2373
2374    #[test]
2375    fn test_take_union_dense_using_builder() {
2376        let mut builder = UnionBuilder::new_dense();
2377
2378        builder.append::<Int32Type>("a", 1).unwrap();
2379        builder.append::<Float64Type>("b", 3.0).unwrap();
2380        builder.append::<Int32Type>("a", 4).unwrap();
2381        builder.append::<Int32Type>("a", 5).unwrap();
2382        builder.append::<Float64Type>("b", 2.0).unwrap();
2383
2384        let union = builder.build().unwrap();
2385
2386        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2387
2388        let mut builder = UnionBuilder::new_dense();
2389
2390        builder.append::<Int32Type>("a", 4).unwrap();
2391        builder.append::<Int32Type>("a", 1).unwrap();
2392        builder.append::<Float64Type>("b", 3.0).unwrap();
2393        builder.append::<Int32Type>("a", 4).unwrap();
2394
2395        let taken = builder.build().unwrap();
2396
2397        assert_eq!(
2398            taken.to_data(),
2399            take(&union, &indices, None).unwrap().to_data()
2400        );
2401    }
2402
2403    #[test]
2404    fn test_take_union_dense_all_match_issue_6206() {
2405        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2406        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2407
2408        let array = UnionArray::try_new(
2409            fields,
2410            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2411            Some(ScalarBuffer::from_iter(0_i32..5)),
2412            vec![ints],
2413        )
2414        .unwrap();
2415
2416        let indicies = Int64Array::from(vec![0, 2, 4]);
2417        let array = take(&array, &indicies, None).unwrap();
2418        assert_eq!(array.len(), 3);
2419    }
2420}