arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27    bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer,
28    ScalarBuffer,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num::{One, Zero};
34
35/// Take elements by index from [Array], creating a new [Array] from those indexes.
36///
37/// ```text
38/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
39/// │        A        │      │    0    │                              │        A        │
40/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
41/// │        D        │      │    2    │                              │        B        │
42/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
43/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
44/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
45/// │        C        │      │    1    │                              │        D        │
46/// ├─────────────────┤      └─────────┘                              └─────────────────┘
47/// │        E        │
48/// └─────────────────┘
49///    values array          indices array                              result
50/// ```
51///
52/// For selecting values by index from multiple arrays see [`crate::interleave`]
53///
54/// Note that this kernel, similar to other kernels in this crate,
55/// will avoid allocating where not necessary. Consequently
56/// the returned array may share buffers with the inputs
57///
58/// # Errors
59/// This function errors whenever:
60/// * An index cannot be casted to `usize` (typically 32 bit architectures)
61/// * An index is out of bounds and `options` is set to check bounds.
62///
63/// # Safety
64///
65/// When `options` is not set to check bounds, taking indexes after `len` will panic.
66///
67/// # See also
68/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
69///   the results into a single array.
70///
71/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
72///
73/// # Examples
74/// ```
75/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
76/// # use arrow_select::take::take;
77/// let values = StringArray::from(vec!["zero", "one", "two"]);
78///
79/// // Take items at index 2, and 1:
80/// let indices = UInt32Array::from(vec![2, 1]);
81/// let taken = take(&values, &indices, None).unwrap();
82/// let taken = taken.as_string::<i32>();
83///
84/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
85/// ```
86pub fn take(
87    values: &dyn Array,
88    indices: &dyn Array,
89    options: Option<TakeOptions>,
90) -> Result<ArrayRef, ArrowError> {
91    let options = options.unwrap_or_default();
92    downcast_integer_array!(
93        indices => {
94            if options.check_bounds {
95                check_bounds(values.len(), indices)?;
96            }
97            let indices = indices.to_indices();
98            take_impl(values, &indices)
99        },
100        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
101    )
102}
103
104/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
105/// [`Vec<ArrayRef>`] from those indices.
106///
107/// ```text
108/// ┌────────┬────────┐
109/// │        │        │           ┌────────┐                                ┌────────┬────────┐
110/// │   A    │   1    │           │        │                                │        │        │
111/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
112/// │        │        │           ├────────┤                                ├────────┼────────┤
113/// │   D    │   4    │           │        │                                │        │        │
114/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
115/// │        │        │           ├────────┤                                ├────────┼────────┤
116/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
117/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
118/// │        │        │           ├────────┤                                ├────────┼────────┤
119/// │   C    │   3    │           │        │                                │        │        │
120/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
121/// │        │        │           └────────┘                                └────────┼────────┘
122/// │   E    │   5    │
123/// └────────┴────────┘
124///    values arrays             indices array                                      result
125/// ```
126///
127/// # Errors
128/// This function errors whenever:
129/// * An index cannot be casted to `usize` (typically 32 bit architectures)
130/// * An index is out of bounds and `options` is set to check bounds.
131///
132/// # Safety
133///
134/// When `options` is not set to check bounds, taking indexes after `len` will panic.
135///
136/// # Examples
137/// ```
138/// # use std::sync::Arc;
139/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
140/// # use arrow_select::take::{take, take_arrays};
141/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
142/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
143///
144/// // Take items at index 2, and 1:
145/// let indices = UInt32Array::from(vec![2, 1]);
146/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
147/// let taken_string = taken_arrays[0].as_string::<i32>();
148/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
149/// let taken_values = taken_arrays[1].as_primitive();
150/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
151/// ```
152pub fn take_arrays(
153    arrays: &[ArrayRef],
154    indices: &dyn Array,
155    options: Option<TakeOptions>,
156) -> Result<Vec<ArrayRef>, ArrowError> {
157    arrays
158        .iter()
159        .map(|array| take(array.as_ref(), indices, options.clone()))
160        .collect()
161}
162
163/// Verifies that the non-null values of `indices` are all `< len`
164fn check_bounds<T: ArrowPrimitiveType>(
165    len: usize,
166    indices: &PrimitiveArray<T>,
167) -> Result<(), ArrowError> {
168    if indices.null_count() > 0 {
169        indices.iter().flatten().try_for_each(|index| {
170            let ix = index
171                .to_usize()
172                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
173            if ix >= len {
174                return Err(ArrowError::ComputeError(format!(
175                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
176                )));
177            }
178            Ok(())
179        })
180    } else {
181        indices.values().iter().try_for_each(|index| {
182            let ix = index
183                .to_usize()
184                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
185            if ix >= len {
186                return Err(ArrowError::ComputeError(format!(
187                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
188                )));
189            }
190            Ok(())
191        })
192    }
193}
194
195#[inline(never)]
196fn take_impl<IndexType: ArrowPrimitiveType>(
197    values: &dyn Array,
198    indices: &PrimitiveArray<IndexType>,
199) -> Result<ArrayRef, ArrowError> {
200    downcast_primitive_array! {
201        values => Ok(Arc::new(take_primitive(values, indices)?)),
202        DataType::Boolean => {
203            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
204            Ok(Arc::new(take_boolean(values, indices)))
205        }
206        DataType::Utf8 => {
207            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
208        }
209        DataType::LargeUtf8 => {
210            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
211        }
212        DataType::Utf8View => {
213            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
214        }
215        DataType::List(_) => {
216            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
217        }
218        DataType::LargeList(_) => {
219            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
220        }
221        DataType::FixedSizeList(_, length) => {
222            let values = values
223                .as_any()
224                .downcast_ref::<FixedSizeListArray>()
225                .unwrap();
226            Ok(Arc::new(take_fixed_size_list(
227                values,
228                indices,
229                *length as u32,
230            )?))
231        }
232        DataType::Map(_, _) => {
233            let list_arr = ListArray::from(values.as_map().clone());
234            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
235            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
236            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
237        }
238        DataType::Struct(fields) => {
239            let array: &StructArray = values.as_struct();
240            let arrays  = array
241                .columns()
242                .iter()
243                .map(|a| take_impl(a.as_ref(), indices))
244                .collect::<Result<Vec<ArrayRef>, _>>()?;
245            let fields: Vec<(FieldRef, ArrayRef)> =
246                fields.iter().cloned().zip(arrays).collect();
247
248            // Create the null bit buffer.
249            let is_valid: Buffer = indices
250                .iter()
251                .map(|index| {
252                    if let Some(index) = index {
253                        array.is_valid(index.to_usize().unwrap())
254                    } else {
255                        false
256                    }
257                })
258                .collect();
259
260            if fields.is_empty() {
261                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
262                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
263            } else {
264                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
265            }
266        }
267        DataType::Dictionary(_, _) => downcast_dictionary_array! {
268            values => Ok(Arc::new(take_dict(values, indices)?)),
269            t => unimplemented!("Take not supported for dictionary type {:?}", t)
270        }
271        DataType::RunEndEncoded(_, _) => downcast_run_array! {
272            values => Ok(Arc::new(take_run(values, indices)?)),
273            t => unimplemented!("Take not supported for run type {:?}", t)
274        }
275        DataType::Binary => {
276            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
277        }
278        DataType::LargeBinary => {
279            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
280        }
281        DataType::BinaryView => {
282            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
283        }
284        DataType::FixedSizeBinary(size) => {
285            let values = values
286                .as_any()
287                .downcast_ref::<FixedSizeBinaryArray>()
288                .unwrap();
289            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
290        }
291        DataType::Null => {
292            // Take applied to a null array produces a null array.
293            if values.len() >= indices.len() {
294                // If the existing null array is as big as the indices, we can use a slice of it
295                // to avoid allocating a new null array.
296                Ok(values.slice(0, indices.len()))
297            } else {
298                // If the existing null array isn't big enough, create a new one.
299                Ok(new_null_array(&DataType::Null, indices.len()))
300            }
301        }
302        DataType::Union(fields, UnionMode::Sparse) => {
303            let mut children = Vec::with_capacity(fields.len());
304            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
305            let type_ids = take_native(values.type_ids(), indices);
306            for (type_id, _field) in fields.iter() {
307                let values = values.child(type_id);
308                let values = take_impl(values, indices)?;
309                children.push(values);
310            }
311            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
312            Ok(Arc::new(array))
313        }
314        DataType::Union(fields, UnionMode::Dense) => {
315            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
316
317            let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
318            let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
319
320            let children = fields.iter()
321                .map(|(field_type_id, _)| {
322                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
323
324                    let indices = crate::filter::filter(&offsets, &mask)?;
325
326                    let values = values.child(field_type_id);
327
328                    take_impl(values, indices.as_primitive::<Int32Type>())
329                })
330                .collect::<Result<_, _>>()?;
331
332            let mut child_offsets = [0; 128];
333
334            let offsets = type_ids.values()
335                .iter()
336                .map(|&i| {
337                    let offset = child_offsets[i as usize];
338
339                    child_offsets[i as usize] += 1;
340
341                    offset
342                })
343                .collect();
344
345            let (_, type_ids, _) = type_ids.into_parts();
346
347            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
348
349            Ok(Arc::new(array))
350        }
351        t => unimplemented!("Take not supported for data type {:?}", t)
352    }
353}
354
355/// Options that define how `take` should behave
356#[derive(Clone, Debug, Default)]
357pub struct TakeOptions {
358    /// Perform bounds check before taking indices from values.
359    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
360    /// If not enabled, and indices exceed bounds, the kernel will panic.
361    pub check_bounds: bool,
362}
363
364#[inline(always)]
365fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
366    index
367        .to_usize()
368        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
369}
370
371/// `take` implementation for all primitive arrays
372///
373/// This checks if an `indices` slot is populated, and gets the value from `values`
374///  as the populated index.
375/// If the `indices` slot is null, a null value is returned.
376/// For example, given:
377///     values:  [1, 2, 3, null, 5]
378///     indices: [0, null, 4, 3]
379/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
380fn take_primitive<T, I>(
381    values: &PrimitiveArray<T>,
382    indices: &PrimitiveArray<I>,
383) -> Result<PrimitiveArray<T>, ArrowError>
384where
385    T: ArrowPrimitiveType,
386    I: ArrowPrimitiveType,
387{
388    let values_buf = take_native(values.values(), indices);
389    let nulls = take_nulls(values.nulls(), indices);
390    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
391}
392
393#[inline(never)]
394fn take_nulls<I: ArrowPrimitiveType>(
395    values: Option<&NullBuffer>,
396    indices: &PrimitiveArray<I>,
397) -> Option<NullBuffer> {
398    match values.filter(|n| n.null_count() > 0) {
399        Some(n) => {
400            let buffer = take_bits(n.inner(), indices);
401            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
402        }
403        None => indices.nulls().cloned(),
404    }
405}
406
407#[inline(never)]
408fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
409    values: &[T],
410    indices: &PrimitiveArray<I>,
411) -> ScalarBuffer<T> {
412    match indices.nulls().filter(|n| n.null_count() > 0) {
413        Some(n) => indices
414            .values()
415            .iter()
416            .enumerate()
417            .map(|(idx, index)| match values.get(index.as_usize()) {
418                Some(v) => *v,
419                None => match n.is_null(idx) {
420                    true => T::default(),
421                    false => panic!("Out-of-bounds index {index:?}"),
422                },
423            })
424            .collect(),
425        None => indices
426            .values()
427            .iter()
428            .map(|index| values[index.as_usize()])
429            .collect(),
430    }
431}
432
433#[inline(never)]
434fn take_bits<I: ArrowPrimitiveType>(
435    values: &BooleanBuffer,
436    indices: &PrimitiveArray<I>,
437) -> BooleanBuffer {
438    let len = indices.len();
439
440    match indices.nulls().filter(|n| n.null_count() > 0) {
441        Some(nulls) => {
442            let mut output_buffer = MutableBuffer::new_null(len);
443            let output_slice = output_buffer.as_slice_mut();
444            nulls.valid_indices().for_each(|idx| {
445                if values.value(indices.value(idx).as_usize()) {
446                    bit_util::set_bit(output_slice, idx);
447                }
448            });
449            BooleanBuffer::new(output_buffer.into(), 0, len)
450        }
451        None => {
452            BooleanBuffer::collect_bool(len, |idx: usize| {
453                // SAFETY: idx<indices.len()
454                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
455            })
456        }
457    }
458}
459
460/// `take` implementation for boolean arrays
461fn take_boolean<IndexType: ArrowPrimitiveType>(
462    values: &BooleanArray,
463    indices: &PrimitiveArray<IndexType>,
464) -> BooleanArray {
465    let val_buf = take_bits(values.values(), indices);
466    let null_buf = take_nulls(values.nulls(), indices);
467    BooleanArray::new(val_buf, null_buf)
468}
469
470/// `take` implementation for string arrays
471fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
472    array: &GenericByteArray<T>,
473    indices: &PrimitiveArray<IndexType>,
474) -> Result<GenericByteArray<T>, ArrowError> {
475    let mut offsets = Vec::with_capacity(indices.len() + 1);
476    offsets.push(T::Offset::default());
477
478    let input_offsets = array.value_offsets();
479    let mut capacity = 0;
480    let nulls = take_nulls(array.nulls(), indices);
481
482    let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
483        offsets.reserve(indices.len());
484        for index in indices.values() {
485            let index = index.as_usize();
486            capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
487            offsets.push(
488                T::Offset::from_usize(capacity)
489                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
490            );
491        }
492        let mut values = Vec::with_capacity(capacity);
493
494        for index in indices.values() {
495            values.extend_from_slice(array.value(index.as_usize()).as_ref());
496        }
497        (offsets, values)
498    } else if indices.null_count() == 0 {
499        offsets.reserve(indices.len());
500        for index in indices.values() {
501            let index = index.as_usize();
502            if array.is_valid(index) {
503                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
504            }
505            offsets.push(
506                T::Offset::from_usize(capacity)
507                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
508            );
509        }
510        let mut values = Vec::with_capacity(capacity);
511
512        for index in indices.values() {
513            let index = index.as_usize();
514            if array.is_valid(index) {
515                values.extend_from_slice(array.value(index).as_ref());
516            }
517        }
518        (offsets, values)
519    } else if array.null_count() == 0 {
520        offsets.reserve(indices.len());
521        for (i, index) in indices.values().iter().enumerate() {
522            let index = index.as_usize();
523            if indices.is_valid(i) {
524                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
525            }
526            offsets.push(
527                T::Offset::from_usize(capacity)
528                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
529            );
530        }
531        let mut values = Vec::with_capacity(capacity);
532
533        for (i, index) in indices.values().iter().enumerate() {
534            if indices.is_valid(i) {
535                values.extend_from_slice(array.value(index.as_usize()).as_ref());
536            }
537        }
538        (offsets, values)
539    } else {
540        let nulls = nulls.as_ref().unwrap();
541        offsets.reserve(indices.len());
542        for (i, index) in indices.values().iter().enumerate() {
543            let index = index.as_usize();
544            if nulls.is_valid(i) {
545                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
546            }
547            offsets.push(
548                T::Offset::from_usize(capacity)
549                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
550            );
551        }
552        let mut values = Vec::with_capacity(capacity);
553
554        for (i, index) in indices.values().iter().enumerate() {
555            // check index is valid before using index. The value in
556            // NULL index slots may not be within bounds of array
557            let index = index.as_usize();
558            if nulls.is_valid(i) {
559                values.extend_from_slice(array.value(index).as_ref());
560            }
561        }
562        (offsets, values)
563    };
564
565    T::Offset::from_usize(values.len())
566        .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
567
568    let array = unsafe {
569        let offsets = OffsetBuffer::new_unchecked(offsets.into());
570        GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
571    };
572
573    Ok(array)
574}
575
576/// `take` implementation for byte view arrays
577fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
578    array: &GenericByteViewArray<T>,
579    indices: &PrimitiveArray<IndexType>,
580) -> Result<GenericByteViewArray<T>, ArrowError> {
581    let new_views = take_native(array.views(), indices);
582    let new_nulls = take_nulls(array.nulls(), indices);
583    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
584    Ok(unsafe {
585        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
586    })
587}
588
589/// `take` implementation for list arrays
590///
591/// Calculates the index and indexed offset for the inner array,
592/// applying `take` on the inner array, then reconstructing a list array
593/// with the indexed offsets
594fn take_list<IndexType, OffsetType>(
595    values: &GenericListArray<OffsetType::Native>,
596    indices: &PrimitiveArray<IndexType>,
597) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
598where
599    IndexType: ArrowPrimitiveType,
600    OffsetType: ArrowPrimitiveType,
601    OffsetType::Native: OffsetSizeTrait,
602    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
603{
604    // TODO: Some optimizations can be done here such as if it is
605    // taking the whole list or a contiguous sublist
606    let (list_indices, offsets, null_buf) =
607        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
608
609    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
610    let value_offsets = Buffer::from_vec(offsets);
611    // create a new list with taken data and computed null information
612    let list_data = ArrayDataBuilder::new(values.data_type().clone())
613        .len(indices.len())
614        .null_bit_buffer(Some(null_buf.into()))
615        .offset(0)
616        .add_child_data(taken.into_data())
617        .add_buffer(value_offsets);
618
619    let list_data = unsafe { list_data.build_unchecked() };
620
621    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
622}
623
624/// `take` implementation for `FixedSizeListArray`
625///
626/// Calculates the index and indexed offset for the inner array,
627/// applying `take` on the inner array, then reconstructing a list array
628/// with the indexed offsets
629fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
630    values: &FixedSizeListArray,
631    indices: &PrimitiveArray<IndexType>,
632    length: <UInt32Type as ArrowPrimitiveType>::Native,
633) -> Result<FixedSizeListArray, ArrowError> {
634    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
635    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
636
637    // determine null count and null buffer, which are a function of `values` and `indices`
638    let num_bytes = bit_util::ceil(indices.len(), 8);
639    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
640    let null_slice = null_buf.as_slice_mut();
641
642    for i in 0..indices.len() {
643        let index = indices
644            .value(i)
645            .to_usize()
646            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
647        if !indices.is_valid(i) || values.is_null(index) {
648            bit_util::unset_bit(null_slice, i);
649        }
650    }
651
652    let list_data = ArrayDataBuilder::new(values.data_type().clone())
653        .len(indices.len())
654        .null_bit_buffer(Some(null_buf.into()))
655        .offset(0)
656        .add_child_data(taken.into_data());
657
658    let list_data = unsafe { list_data.build_unchecked() };
659
660    Ok(FixedSizeListArray::from(list_data))
661}
662
663fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
664    values: &FixedSizeBinaryArray,
665    indices: &PrimitiveArray<IndexType>,
666    size: i32,
667) -> Result<FixedSizeBinaryArray, ArrowError> {
668    let nulls = values.nulls();
669    let array_iter = indices
670        .values()
671        .iter()
672        .map(|idx| {
673            let idx = maybe_usize::<IndexType::Native>(*idx)?;
674            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
675                Ok(Some(values.value(idx)))
676            } else {
677                Ok(None)
678            }
679        })
680        .collect::<Result<Vec<_>, ArrowError>>()?
681        .into_iter();
682
683    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
684}
685
686/// `take` implementation for dictionary arrays
687///
688/// applies `take` to the keys of the dictionary array and returns a new dictionary array
689/// with the same dictionary values and reordered keys
690fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
691    values: &DictionaryArray<T>,
692    indices: &PrimitiveArray<I>,
693) -> Result<DictionaryArray<T>, ArrowError> {
694    let new_keys = take_primitive(values.keys(), indices)?;
695    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
696}
697
698/// `take` implementation for run arrays
699///
700/// Finds physical indices for the given logical indices and builds output run array
701/// by taking values in the input run_array.values at the physical indices.
702/// The output run array will be run encoded on the physical indices and not on output values.
703/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
704/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
705/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
706fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
707    run_array: &RunArray<T>,
708    logical_indices: &PrimitiveArray<I>,
709) -> Result<RunArray<T>, ArrowError> {
710    // get physical indices for the input logical indices
711    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
712
713    // Run encode the physical indices into new_run_ends_builder
714    // Keep track of the physical indices to take in take_value_indices
715    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
716    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
717    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
718    let mut new_physical_len = 1;
719    for ix in 1..physical_indices.len() {
720        if physical_indices[ix] != physical_indices[ix - 1] {
721            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
722            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
723            new_physical_len += 1;
724        }
725    }
726    take_value_indices
727        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
728    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
729    let new_run_ends = unsafe {
730        // Safety:
731        // The function builds a valid run_ends array and hence need not be validated.
732        ArrayDataBuilder::new(T::DATA_TYPE)
733            .len(new_physical_len)
734            .null_count(0)
735            .add_buffer(new_run_ends_builder.finish())
736            .build_unchecked()
737    };
738
739    let take_value_indices: PrimitiveArray<I> = unsafe {
740        // Safety:
741        // The function builds a valid take_value_indices array and hence need not be validated.
742        ArrayDataBuilder::new(I::DATA_TYPE)
743            .len(new_physical_len)
744            .null_count(0)
745            .add_buffer(take_value_indices.finish())
746            .build_unchecked()
747            .into()
748    };
749
750    let new_values = take(run_array.values(), &take_value_indices, None)?;
751
752    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
753        .len(physical_indices.len())
754        .add_child_data(new_run_ends)
755        .add_child_data(new_values.into_data());
756    let array_data = unsafe {
757        // Safety:
758        //  This function builds a valid run array and hence can skip validation.
759        builder.build_unchecked()
760    };
761    Ok(array_data.into())
762}
763
764/// Takes/filters a list array's inner data using the offsets of the list array.
765///
766/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
767/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
768/// elements)
769#[allow(clippy::type_complexity)]
770fn take_value_indices_from_list<IndexType, OffsetType>(
771    list: &GenericListArray<OffsetType::Native>,
772    indices: &PrimitiveArray<IndexType>,
773) -> Result<
774    (
775        PrimitiveArray<OffsetType>,
776        Vec<OffsetType::Native>,
777        MutableBuffer,
778    ),
779    ArrowError,
780>
781where
782    IndexType: ArrowPrimitiveType,
783    OffsetType: ArrowPrimitiveType,
784    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
785    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
786{
787    // TODO: benchmark this function, there might be a faster unsafe alternative
788    let offsets: &[OffsetType::Native] = list.value_offsets();
789
790    let mut new_offsets = Vec::with_capacity(indices.len());
791    let mut values = Vec::new();
792    let mut current_offset = OffsetType::Native::zero();
793    // add first offset
794    new_offsets.push(OffsetType::Native::zero());
795
796    // Initialize null buffer
797    let num_bytes = bit_util::ceil(indices.len(), 8);
798    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
799    let null_slice = null_buf.as_slice_mut();
800
801    // compute the value indices, and set offsets accordingly
802    for i in 0..indices.len() {
803        if indices.is_valid(i) {
804            let ix = indices
805                .value(i)
806                .to_usize()
807                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
808            let start = offsets[ix];
809            let end = offsets[ix + 1];
810            current_offset += end - start;
811            new_offsets.push(current_offset);
812
813            let mut curr = start;
814
815            // if start == end, this slot is empty
816            while curr < end {
817                values.push(curr);
818                curr += One::one();
819            }
820            if !list.is_valid(ix) {
821                bit_util::unset_bit(null_slice, i);
822            }
823        } else {
824            bit_util::unset_bit(null_slice, i);
825            new_offsets.push(current_offset);
826        }
827    }
828
829    Ok((
830        PrimitiveArray::<OffsetType>::from(values),
831        new_offsets,
832        null_buf,
833    ))
834}
835
836/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
837fn take_value_indices_from_fixed_size_list<IndexType>(
838    list: &FixedSizeListArray,
839    indices: &PrimitiveArray<IndexType>,
840    length: <UInt32Type as ArrowPrimitiveType>::Native,
841) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
842where
843    IndexType: ArrowPrimitiveType,
844{
845    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
846
847    for i in 0..indices.len() {
848        if indices.is_valid(i) {
849            let index = indices
850                .value(i)
851                .to_usize()
852                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
853            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
854
855            // Safety: Range always has known length.
856            unsafe {
857                values.append_trusted_len_iter(start..start + length);
858            }
859        } else {
860            values.append_nulls(length as usize);
861        }
862    }
863
864    Ok(values.finish())
865}
866
867/// To avoid generating take implementations for every index type, instead we
868/// only generate for UInt32 and UInt64 and coerce inputs to these types
869trait ToIndices {
870    type T: ArrowPrimitiveType;
871
872    fn to_indices(&self) -> PrimitiveArray<Self::T>;
873}
874
875macro_rules! to_indices_reinterpret {
876    ($t:ty, $o:ty) => {
877        impl ToIndices for PrimitiveArray<$t> {
878            type T = $o;
879
880            fn to_indices(&self) -> PrimitiveArray<$o> {
881                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
882                PrimitiveArray::new(cast, self.nulls().cloned())
883            }
884        }
885    };
886}
887
888macro_rules! to_indices_identity {
889    ($t:ty) => {
890        impl ToIndices for PrimitiveArray<$t> {
891            type T = $t;
892
893            fn to_indices(&self) -> PrimitiveArray<$t> {
894                self.clone()
895            }
896        }
897    };
898}
899
900macro_rules! to_indices_widening {
901    ($t:ty, $o:ty) => {
902        impl ToIndices for PrimitiveArray<$t> {
903            type T = UInt32Type;
904
905            fn to_indices(&self) -> PrimitiveArray<$o> {
906                let cast = self.values().iter().copied().map(|x| x as _).collect();
907                PrimitiveArray::new(cast, self.nulls().cloned())
908            }
909        }
910    };
911}
912
913to_indices_widening!(UInt8Type, UInt32Type);
914to_indices_widening!(Int8Type, UInt32Type);
915
916to_indices_widening!(UInt16Type, UInt32Type);
917to_indices_widening!(Int16Type, UInt32Type);
918
919to_indices_identity!(UInt32Type);
920to_indices_reinterpret!(Int32Type, UInt32Type);
921
922to_indices_identity!(UInt64Type);
923to_indices_reinterpret!(Int64Type, UInt64Type);
924
925/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
926///
927/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
928///
929/// # Example
930/// ```
931/// # use std::sync::Arc;
932/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
933/// # use arrow_schema::{DataType, Field, Schema};
934/// # use arrow_select::take::take_record_batch;
935///
936/// let schema = Arc::new(Schema::new(vec![
937///     Field::new("a", DataType::Int32, true),
938///     Field::new("b", DataType::Utf8, true),
939/// ]));
940/// let batch = RecordBatch::try_new(
941///     schema.clone(),
942///     vec![
943///         Arc::new(Int32Array::from_iter_values(0..20)),
944///         Arc::new(StringArray::from_iter_values(
945///             (0..20).map(|i| format!("str-{}", i)),
946///         )),
947///     ],
948/// )
949/// .unwrap();
950///
951/// let indices = UInt32Array::from(vec![1, 5, 10]);
952/// let taken = take_record_batch(&batch, &indices).unwrap();
953///
954/// let expected = RecordBatch::try_new(
955///     schema,
956///     vec![
957///         Arc::new(Int32Array::from(vec![1, 5, 10])),
958///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
959///     ],
960/// )
961/// .unwrap();
962/// assert_eq!(taken, expected);
963/// ```
964pub fn take_record_batch(
965    record_batch: &RecordBatch,
966    indices: &dyn Array,
967) -> Result<RecordBatch, ArrowError> {
968    let columns = record_batch
969        .columns()
970        .iter()
971        .map(|c| take(c, indices, None))
972        .collect::<Result<Vec<_>, _>>()?;
973    RecordBatch::try_new(record_batch.schema(), columns)
974}
975
976#[cfg(test)]
977mod tests {
978    use super::*;
979    use arrow_array::builder::*;
980    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
981    use arrow_data::ArrayData;
982    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
983
984    fn test_take_decimal_arrays(
985        data: Vec<Option<i128>>,
986        index: &UInt32Array,
987        options: Option<TakeOptions>,
988        expected_data: Vec<Option<i128>>,
989        precision: &u8,
990        scale: &i8,
991    ) -> Result<(), ArrowError> {
992        let output = data
993            .into_iter()
994            .collect::<Decimal128Array>()
995            .with_precision_and_scale(*precision, *scale)
996            .unwrap();
997
998        let expected = expected_data
999            .into_iter()
1000            .collect::<Decimal128Array>()
1001            .with_precision_and_scale(*precision, *scale)
1002            .unwrap();
1003
1004        let expected = Arc::new(expected) as ArrayRef;
1005        let output = take(&output, index, options).unwrap();
1006        assert_eq!(&output, &expected);
1007        Ok(())
1008    }
1009
1010    fn test_take_boolean_arrays(
1011        data: Vec<Option<bool>>,
1012        index: &UInt32Array,
1013        options: Option<TakeOptions>,
1014        expected_data: Vec<Option<bool>>,
1015    ) {
1016        let output = BooleanArray::from(data);
1017        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1018        let output = take(&output, index, options).unwrap();
1019        assert_eq!(&output, &expected)
1020    }
1021
1022    fn test_take_primitive_arrays<T>(
1023        data: Vec<Option<T::Native>>,
1024        index: &UInt32Array,
1025        options: Option<TakeOptions>,
1026        expected_data: Vec<Option<T::Native>>,
1027    ) -> Result<(), ArrowError>
1028    where
1029        T: ArrowPrimitiveType,
1030        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1031    {
1032        let output = PrimitiveArray::<T>::from(data);
1033        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1034        let output = take(&output, index, options)?;
1035        assert_eq!(&output, &expected);
1036        Ok(())
1037    }
1038
1039    fn test_take_primitive_arrays_non_null<T>(
1040        data: Vec<T::Native>,
1041        index: &UInt32Array,
1042        options: Option<TakeOptions>,
1043        expected_data: Vec<Option<T::Native>>,
1044    ) -> Result<(), ArrowError>
1045    where
1046        T: ArrowPrimitiveType,
1047        PrimitiveArray<T>: From<Vec<T::Native>>,
1048        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1049    {
1050        let output = PrimitiveArray::<T>::from(data);
1051        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1052        let output = take(&output, index, options)?;
1053        assert_eq!(&output, &expected);
1054        Ok(())
1055    }
1056
1057    fn test_take_impl_primitive_arrays<T, I>(
1058        data: Vec<Option<T::Native>>,
1059        index: &PrimitiveArray<I>,
1060        options: Option<TakeOptions>,
1061        expected_data: Vec<Option<T::Native>>,
1062    ) where
1063        T: ArrowPrimitiveType,
1064        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1065        I: ArrowPrimitiveType,
1066    {
1067        let output = PrimitiveArray::<T>::from(data);
1068        let expected = PrimitiveArray::<T>::from(expected_data);
1069        let output = take(&output, index, options).unwrap();
1070        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1071        assert_eq!(output, &expected)
1072    }
1073
1074    // create a simple struct for testing purposes
1075    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1076        let mut struct_builder = StructBuilder::new(
1077            Fields::from(vec![
1078                Field::new("a", DataType::Boolean, true),
1079                Field::new("b", DataType::Int32, true),
1080            ]),
1081            vec![
1082                Box::new(BooleanBuilder::with_capacity(values.len())),
1083                Box::new(Int32Builder::with_capacity(values.len())),
1084            ],
1085        );
1086
1087        for value in values {
1088            struct_builder
1089                .field_builder::<BooleanBuilder>(0)
1090                .unwrap()
1091                .append_option(value.and_then(|v| v.0));
1092            struct_builder
1093                .field_builder::<Int32Builder>(1)
1094                .unwrap()
1095                .append_option(value.and_then(|v| v.1));
1096            struct_builder.append(value.is_some());
1097        }
1098        struct_builder.finish()
1099    }
1100
1101    #[test]
1102    fn test_take_decimal128_non_null_indices() {
1103        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1104        let precision: u8 = 10;
1105        let scale: i8 = 5;
1106        test_take_decimal_arrays(
1107            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1108            &index,
1109            None,
1110            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1111            &precision,
1112            &scale,
1113        )
1114        .unwrap();
1115    }
1116
1117    #[test]
1118    fn test_take_decimal128() {
1119        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1120        let precision: u8 = 10;
1121        let scale: i8 = 5;
1122        test_take_decimal_arrays(
1123            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1124            &index,
1125            None,
1126            vec![Some(3), None, Some(1), Some(3), Some(2)],
1127            &precision,
1128            &scale,
1129        )
1130        .unwrap();
1131    }
1132
1133    #[test]
1134    fn test_take_primitive_non_null_indices() {
1135        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1136        test_take_primitive_arrays::<Int8Type>(
1137            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1138            &index,
1139            None,
1140            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1141        )
1142        .unwrap();
1143    }
1144
1145    #[test]
1146    fn test_take_primitive_non_null_values() {
1147        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1148        test_take_primitive_arrays::<Int8Type>(
1149            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1150            &index,
1151            None,
1152            vec![Some(3), None, Some(1), Some(3), Some(2)],
1153        )
1154        .unwrap();
1155    }
1156
1157    #[test]
1158    fn test_take_primitive_non_null() {
1159        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1160        test_take_primitive_arrays::<Int8Type>(
1161            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1162            &index,
1163            None,
1164            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1165        )
1166        .unwrap();
1167    }
1168
1169    #[test]
1170    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1171        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1172        let index = index.slice(2, 4);
1173        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1174
1175        assert_eq!(
1176            index,
1177            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1178        );
1179
1180        test_take_primitive_arrays_non_null::<Int64Type>(
1181            vec![0, 10, 20, 30, 40, 50],
1182            index,
1183            None,
1184            vec![Some(20), Some(30), None, None],
1185        )
1186        .unwrap();
1187    }
1188
1189    #[test]
1190    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1191        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1192        let index = index.slice(2, 4);
1193        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1194
1195        assert_eq!(
1196            index,
1197            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1198        );
1199
1200        test_take_primitive_arrays::<Int64Type>(
1201            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1202            index,
1203            None,
1204            vec![Some(20), Some(30), None, None],
1205        )
1206        .unwrap();
1207    }
1208
1209    #[test]
1210    fn test_take_primitive() {
1211        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1212
1213        // int8
1214        test_take_primitive_arrays::<Int8Type>(
1215            vec![Some(0), None, Some(2), Some(3), None],
1216            &index,
1217            None,
1218            vec![Some(3), None, None, Some(3), Some(2)],
1219        )
1220        .unwrap();
1221
1222        // int16
1223        test_take_primitive_arrays::<Int16Type>(
1224            vec![Some(0), None, Some(2), Some(3), None],
1225            &index,
1226            None,
1227            vec![Some(3), None, None, Some(3), Some(2)],
1228        )
1229        .unwrap();
1230
1231        // int32
1232        test_take_primitive_arrays::<Int32Type>(
1233            vec![Some(0), None, Some(2), Some(3), None],
1234            &index,
1235            None,
1236            vec![Some(3), None, None, Some(3), Some(2)],
1237        )
1238        .unwrap();
1239
1240        // int64
1241        test_take_primitive_arrays::<Int64Type>(
1242            vec![Some(0), None, Some(2), Some(3), None],
1243            &index,
1244            None,
1245            vec![Some(3), None, None, Some(3), Some(2)],
1246        )
1247        .unwrap();
1248
1249        // uint8
1250        test_take_primitive_arrays::<UInt8Type>(
1251            vec![Some(0), None, Some(2), Some(3), None],
1252            &index,
1253            None,
1254            vec![Some(3), None, None, Some(3), Some(2)],
1255        )
1256        .unwrap();
1257
1258        // uint16
1259        test_take_primitive_arrays::<UInt16Type>(
1260            vec![Some(0), None, Some(2), Some(3), None],
1261            &index,
1262            None,
1263            vec![Some(3), None, None, Some(3), Some(2)],
1264        )
1265        .unwrap();
1266
1267        // uint32
1268        test_take_primitive_arrays::<UInt32Type>(
1269            vec![Some(0), None, Some(2), Some(3), None],
1270            &index,
1271            None,
1272            vec![Some(3), None, None, Some(3), Some(2)],
1273        )
1274        .unwrap();
1275
1276        // int64
1277        test_take_primitive_arrays::<Int64Type>(
1278            vec![Some(0), None, Some(2), Some(-15), None],
1279            &index,
1280            None,
1281            vec![Some(-15), None, None, Some(-15), Some(2)],
1282        )
1283        .unwrap();
1284
1285        // interval_year_month
1286        test_take_primitive_arrays::<IntervalYearMonthType>(
1287            vec![Some(0), None, Some(2), Some(-15), None],
1288            &index,
1289            None,
1290            vec![Some(-15), None, None, Some(-15), Some(2)],
1291        )
1292        .unwrap();
1293
1294        // interval_day_time
1295        let v1 = IntervalDayTime::new(0, 0);
1296        let v2 = IntervalDayTime::new(2, 0);
1297        let v3 = IntervalDayTime::new(-15, 0);
1298        test_take_primitive_arrays::<IntervalDayTimeType>(
1299            vec![Some(v1), None, Some(v2), Some(v3), None],
1300            &index,
1301            None,
1302            vec![Some(v3), None, None, Some(v3), Some(v2)],
1303        )
1304        .unwrap();
1305
1306        // interval_month_day_nano
1307        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1308        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1309        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1310        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1311            vec![Some(v1), None, Some(v2), Some(v3), None],
1312            &index,
1313            None,
1314            vec![Some(v3), None, None, Some(v3), Some(v2)],
1315        )
1316        .unwrap();
1317
1318        // duration_second
1319        test_take_primitive_arrays::<DurationSecondType>(
1320            vec![Some(0), None, Some(2), Some(-15), None],
1321            &index,
1322            None,
1323            vec![Some(-15), None, None, Some(-15), Some(2)],
1324        )
1325        .unwrap();
1326
1327        // duration_millisecond
1328        test_take_primitive_arrays::<DurationMillisecondType>(
1329            vec![Some(0), None, Some(2), Some(-15), None],
1330            &index,
1331            None,
1332            vec![Some(-15), None, None, Some(-15), Some(2)],
1333        )
1334        .unwrap();
1335
1336        // duration_microsecond
1337        test_take_primitive_arrays::<DurationMicrosecondType>(
1338            vec![Some(0), None, Some(2), Some(-15), None],
1339            &index,
1340            None,
1341            vec![Some(-15), None, None, Some(-15), Some(2)],
1342        )
1343        .unwrap();
1344
1345        // duration_nanosecond
1346        test_take_primitive_arrays::<DurationNanosecondType>(
1347            vec![Some(0), None, Some(2), Some(-15), None],
1348            &index,
1349            None,
1350            vec![Some(-15), None, None, Some(-15), Some(2)],
1351        )
1352        .unwrap();
1353
1354        // float32
1355        test_take_primitive_arrays::<Float32Type>(
1356            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1357            &index,
1358            None,
1359            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1360        )
1361        .unwrap();
1362
1363        // float64
1364        test_take_primitive_arrays::<Float64Type>(
1365            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1366            &index,
1367            None,
1368            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1369        )
1370        .unwrap();
1371    }
1372
1373    #[test]
1374    fn test_take_preserve_timezone() {
1375        let index = Int64Array::from(vec![Some(0), None]);
1376
1377        let input = TimestampNanosecondArray::from(vec![
1378            1_639_715_368_000_000_000,
1379            1_639_715_368_000_000_000,
1380        ])
1381        .with_timezone("UTC".to_string());
1382        let result = take(&input, &index, None).unwrap();
1383        match result.data_type() {
1384            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1385                assert_eq!(tz.clone(), Some("UTC".into()))
1386            }
1387            _ => panic!(),
1388        }
1389    }
1390
1391    #[test]
1392    fn test_take_impl_primitive_with_int64_indices() {
1393        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1394
1395        // int16
1396        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1397            vec![Some(0), None, Some(2), Some(3), None],
1398            &index,
1399            None,
1400            vec![Some(3), None, None, Some(3), Some(2)],
1401        );
1402
1403        // int64
1404        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1405            vec![Some(0), None, Some(2), Some(-15), None],
1406            &index,
1407            None,
1408            vec![Some(-15), None, None, Some(-15), Some(2)],
1409        );
1410
1411        // uint64
1412        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1413            vec![Some(0), None, Some(2), Some(3), None],
1414            &index,
1415            None,
1416            vec![Some(3), None, None, Some(3), Some(2)],
1417        );
1418
1419        // duration_millisecond
1420        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1421            vec![Some(0), None, Some(2), Some(-15), None],
1422            &index,
1423            None,
1424            vec![Some(-15), None, None, Some(-15), Some(2)],
1425        );
1426
1427        // float32
1428        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1429            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1430            &index,
1431            None,
1432            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1433        );
1434    }
1435
1436    #[test]
1437    fn test_take_impl_primitive_with_uint8_indices() {
1438        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1439
1440        // int16
1441        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1442            vec![Some(0), None, Some(2), Some(3), None],
1443            &index,
1444            None,
1445            vec![Some(3), None, None, Some(3), Some(2)],
1446        );
1447
1448        // duration_millisecond
1449        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1450            vec![Some(0), None, Some(2), Some(-15), None],
1451            &index,
1452            None,
1453            vec![Some(-15), None, None, Some(-15), Some(2)],
1454        );
1455
1456        // float32
1457        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1458            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1459            &index,
1460            None,
1461            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1462        );
1463    }
1464
1465    #[test]
1466    fn test_take_bool() {
1467        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1468        // boolean
1469        test_take_boolean_arrays(
1470            vec![Some(false), None, Some(true), Some(false), None],
1471            &index,
1472            None,
1473            vec![Some(false), None, None, Some(false), Some(true)],
1474        );
1475    }
1476
1477    #[test]
1478    fn test_take_bool_nullable_index() {
1479        // indices where the masked invalid elements would be out of bounds
1480        let index_data = ArrayData::try_new(
1481            DataType::UInt32,
1482            6,
1483            Some(Buffer::from_iter(vec![
1484                false, true, false, true, false, true,
1485            ])),
1486            0,
1487            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1488            vec![],
1489        )
1490        .unwrap();
1491        let index = UInt32Array::from(index_data);
1492        test_take_boolean_arrays(
1493            vec![Some(true), None, Some(false)],
1494            &index,
1495            None,
1496            vec![None, Some(true), None, None, None, Some(false)],
1497        );
1498    }
1499
1500    #[test]
1501    fn test_take_bool_nullable_index_nonnull_values() {
1502        // indices where the masked invalid elements would be out of bounds
1503        let index_data = ArrayData::try_new(
1504            DataType::UInt32,
1505            6,
1506            Some(Buffer::from_iter(vec![
1507                false, true, false, true, false, true,
1508            ])),
1509            0,
1510            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1511            vec![],
1512        )
1513        .unwrap();
1514        let index = UInt32Array::from(index_data);
1515        test_take_boolean_arrays(
1516            vec![Some(true), Some(true), Some(false)],
1517            &index,
1518            None,
1519            vec![None, Some(true), None, Some(true), None, Some(false)],
1520        );
1521    }
1522
1523    #[test]
1524    fn test_take_bool_with_offset() {
1525        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1526        let index = index.slice(2, 4);
1527        let index = index
1528            .as_any()
1529            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1530            .unwrap();
1531
1532        // boolean
1533        test_take_boolean_arrays(
1534            vec![Some(false), None, Some(true), Some(false), None],
1535            index,
1536            None,
1537            vec![None, Some(false), Some(true), None],
1538        );
1539    }
1540
1541    fn _test_take_string<'a, K>()
1542    where
1543        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1544    {
1545        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1546
1547        let array = K::from(vec![
1548            Some("one"),
1549            None,
1550            Some("three"),
1551            Some("four"),
1552            Some("five"),
1553        ]);
1554        let actual = take(&array, &index, None).unwrap();
1555        assert_eq!(actual.len(), index.len());
1556
1557        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1558
1559        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1560
1561        assert_eq!(actual, &expected);
1562    }
1563
1564    #[test]
1565    fn test_take_string() {
1566        _test_take_string::<StringArray>()
1567    }
1568
1569    #[test]
1570    fn test_take_large_string() {
1571        _test_take_string::<LargeStringArray>()
1572    }
1573
1574    #[test]
1575    fn test_take_slice_string() {
1576        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1577        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1578        let indices_slice = indices.slice(1, 4);
1579        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1580        let result = take(&strings, &indices_slice, None).unwrap();
1581        assert_eq!(result.as_ref(), &expected);
1582    }
1583
1584    fn _test_byte_view<T>()
1585    where
1586        T: ByteViewType,
1587        str: AsRef<T::Native>,
1588        T::Native: PartialEq,
1589    {
1590        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1591        let array = {
1592            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1593            let mut builder = GenericByteViewBuilder::<T>::new();
1594            builder.append_value("hello");
1595            builder.append_value("world");
1596            builder.append_null();
1597            builder.append_value("large payload over 12 bytes");
1598            builder.append_value("lulu");
1599            builder.finish()
1600        };
1601
1602        let actual = take(&array, &index, None).unwrap();
1603
1604        assert_eq!(actual.len(), index.len());
1605
1606        let expected = {
1607            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1608            let mut builder = GenericByteViewBuilder::<T>::new();
1609            builder.append_value("large payload over 12 bytes");
1610            builder.append_null();
1611            builder.append_value("world");
1612            builder.append_value("large payload over 12 bytes");
1613            builder.append_value("lulu");
1614            builder.append_null();
1615            builder.finish()
1616        };
1617
1618        assert_eq!(actual.as_ref(), &expected);
1619    }
1620
1621    #[test]
1622    fn test_take_string_view() {
1623        _test_byte_view::<StringViewType>()
1624    }
1625
1626    #[test]
1627    fn test_take_binary_view() {
1628        _test_byte_view::<BinaryViewType>()
1629    }
1630
1631    macro_rules! test_take_list {
1632        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1633            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1634            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1635            // Construct offsets
1636            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1637            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1638            // Construct a list array from the above two
1639            let list_data_type =
1640                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1641            let list_data = ArrayData::builder(list_data_type.clone())
1642                .len(4)
1643                .add_buffer(value_offsets)
1644                .add_child_data(value_data)
1645                .build()
1646                .unwrap();
1647            let list_array = $list_array_type::from(list_data);
1648
1649            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1650            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1651
1652            let a = take(&list_array, &index, None).unwrap();
1653            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1654
1655            // construct a value array with expected results:
1656            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1657            let expected_data = Int32Array::from(vec![
1658                Some(2),
1659                Some(3),
1660                Some(-1),
1661                Some(-2),
1662                Some(-1),
1663                Some(0),
1664                Some(0),
1665                Some(0),
1666            ])
1667            .into_data();
1668            // construct offsets
1669            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1670            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1671            // construct list array from the two
1672            let expected_list_data = ArrayData::builder(list_data_type)
1673                .len(5)
1674                // null buffer remains the same as only the indices have nulls
1675                .nulls(index.nulls().cloned())
1676                .add_buffer(expected_offsets)
1677                .add_child_data(expected_data)
1678                .build()
1679                .unwrap();
1680            let expected_list_array = $list_array_type::from(expected_list_data);
1681
1682            assert_eq!(a, &expected_list_array);
1683        }};
1684    }
1685
1686    macro_rules! test_take_list_with_value_nulls {
1687        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1688            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1689            let value_data = Int32Array::from(vec![
1690                Some(0),
1691                None,
1692                Some(0),
1693                Some(-1),
1694                Some(-2),
1695                Some(3),
1696                None,
1697                Some(5),
1698                None,
1699            ])
1700            .into_data();
1701            // Construct offsets
1702            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1703            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1704            // Construct a list array from the above two
1705            let list_data_type =
1706                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1707            let list_data = ArrayData::builder(list_data_type.clone())
1708                .len(4)
1709                .add_buffer(value_offsets)
1710                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1711                .add_child_data(value_data)
1712                .build()
1713                .unwrap();
1714            let list_array = $list_array_type::from(list_data);
1715
1716            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1717            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1718
1719            let a = take(&list_array, &index, None).unwrap();
1720            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1721
1722            // construct a value array with expected results:
1723            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1724            let expected_data = Int32Array::from(vec![
1725                None,
1726                Some(-1),
1727                Some(-2),
1728                Some(3),
1729                Some(5),
1730                None,
1731                Some(0),
1732                None,
1733                Some(0),
1734            ])
1735            .into_data();
1736            // construct offsets
1737            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1738            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1739            // construct list array from the two
1740            let expected_list_data = ArrayData::builder(list_data_type)
1741                .len(5)
1742                // null buffer remains the same as only the indices have nulls
1743                .nulls(index.nulls().cloned())
1744                .add_buffer(expected_offsets)
1745                .add_child_data(expected_data)
1746                .build()
1747                .unwrap();
1748            let expected_list_array = $list_array_type::from(expected_list_data);
1749
1750            assert_eq!(a, &expected_list_array);
1751        }};
1752    }
1753
1754    macro_rules! test_take_list_with_nulls {
1755        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1756            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1757            let value_data = Int32Array::from(vec![
1758                Some(0),
1759                None,
1760                Some(0),
1761                Some(-1),
1762                Some(-2),
1763                Some(3),
1764                Some(5),
1765                None,
1766            ])
1767            .into_data();
1768            // Construct offsets
1769            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1770            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1771            // Construct a list array from the above two
1772            let list_data_type =
1773                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1774            let list_data = ArrayData::builder(list_data_type.clone())
1775                .len(4)
1776                .add_buffer(value_offsets)
1777                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1778                .add_child_data(value_data)
1779                .build()
1780                .unwrap();
1781            let list_array = $list_array_type::from(list_data);
1782
1783            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1784            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1785
1786            let a = take(&list_array, &index, None).unwrap();
1787            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1788
1789            // construct a value array with expected results:
1790            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1791            let expected_data = Int32Array::from(vec![
1792                Some(-1),
1793                Some(-2),
1794                Some(3),
1795                Some(5),
1796                None,
1797                Some(0),
1798                None,
1799                Some(0),
1800            ])
1801            .into_data();
1802            // construct offsets
1803            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1804            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1805            // construct list array from the two
1806            let mut null_bits: [u8; 1] = [0; 1];
1807            bit_util::set_bit(&mut null_bits, 2);
1808            bit_util::set_bit(&mut null_bits, 3);
1809            bit_util::set_bit(&mut null_bits, 4);
1810            let expected_list_data = ArrayData::builder(list_data_type)
1811                .len(5)
1812                // null buffer must be recalculated as both values and indices have nulls
1813                .null_bit_buffer(Some(Buffer::from(null_bits)))
1814                .add_buffer(expected_offsets)
1815                .add_child_data(expected_data)
1816                .build()
1817                .unwrap();
1818            let expected_list_array = $list_array_type::from(expected_list_data);
1819
1820            assert_eq!(a, &expected_list_array);
1821        }};
1822    }
1823
1824    fn do_take_fixed_size_list_test<T>(
1825        length: <Int32Type as ArrowPrimitiveType>::Native,
1826        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1827        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1828        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1829    ) where
1830        T: ArrowPrimitiveType,
1831        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1832    {
1833        let indices = UInt32Array::from(indices);
1834
1835        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1836
1837        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1838
1839        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1840
1841        assert_eq!(&output, &expected)
1842    }
1843
1844    #[test]
1845    fn test_take_list() {
1846        test_take_list!(i32, List, ListArray);
1847    }
1848
1849    #[test]
1850    fn test_take_large_list() {
1851        test_take_list!(i64, LargeList, LargeListArray);
1852    }
1853
1854    #[test]
1855    fn test_take_list_with_value_nulls() {
1856        test_take_list_with_value_nulls!(i32, List, ListArray);
1857    }
1858
1859    #[test]
1860    fn test_take_large_list_with_value_nulls() {
1861        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1862    }
1863
1864    #[test]
1865    fn test_test_take_list_with_nulls() {
1866        test_take_list_with_nulls!(i32, List, ListArray);
1867    }
1868
1869    #[test]
1870    fn test_test_take_large_list_with_nulls() {
1871        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1872    }
1873
1874    #[test]
1875    fn test_take_fixed_size_list() {
1876        do_take_fixed_size_list_test::<Int32Type>(
1877            3,
1878            vec![
1879                Some(vec![None, Some(1), Some(2)]),
1880                Some(vec![Some(3), Some(4), None]),
1881                Some(vec![Some(6), Some(7), Some(8)]),
1882            ],
1883            vec![2, 1, 0],
1884            vec![
1885                Some(vec![Some(6), Some(7), Some(8)]),
1886                Some(vec![Some(3), Some(4), None]),
1887                Some(vec![None, Some(1), Some(2)]),
1888            ],
1889        );
1890
1891        do_take_fixed_size_list_test::<UInt8Type>(
1892            1,
1893            vec![
1894                Some(vec![Some(1)]),
1895                Some(vec![Some(2)]),
1896                Some(vec![Some(3)]),
1897                Some(vec![Some(4)]),
1898                Some(vec![Some(5)]),
1899                Some(vec![Some(6)]),
1900                Some(vec![Some(7)]),
1901                Some(vec![Some(8)]),
1902            ],
1903            vec![2, 7, 0],
1904            vec![
1905                Some(vec![Some(3)]),
1906                Some(vec![Some(8)]),
1907                Some(vec![Some(1)]),
1908            ],
1909        );
1910
1911        do_take_fixed_size_list_test::<UInt64Type>(
1912            3,
1913            vec![
1914                Some(vec![Some(10), Some(11), Some(12)]),
1915                Some(vec![Some(13), Some(14), Some(15)]),
1916                None,
1917                Some(vec![Some(16), Some(17), Some(18)]),
1918            ],
1919            vec![3, 2, 1, 2, 0],
1920            vec![
1921                Some(vec![Some(16), Some(17), Some(18)]),
1922                None,
1923                Some(vec![Some(13), Some(14), Some(15)]),
1924                None,
1925                Some(vec![Some(10), Some(11), Some(12)]),
1926            ],
1927        );
1928    }
1929
1930    #[test]
1931    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1932    fn test_take_list_out_of_bounds() {
1933        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
1934        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1935        // Construct offsets
1936        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1937        // Construct a list array from the above two
1938        let list_data_type =
1939            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1940        let list_data = ArrayData::builder(list_data_type)
1941            .len(3)
1942            .add_buffer(value_offsets)
1943            .add_child_data(value_data)
1944            .build()
1945            .unwrap();
1946        let list_array = ListArray::from(list_data);
1947
1948        let index = UInt32Array::from(vec![1000]);
1949
1950        // A panic is expected here since we have not supplied the check_bounds
1951        // option.
1952        take(&list_array, &index, None).unwrap();
1953    }
1954
1955    #[test]
1956    fn test_take_map() {
1957        let values = Int32Array::from(vec![1, 2, 3, 4]);
1958        let array =
1959            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1960                .unwrap();
1961
1962        let index = UInt32Array::from(vec![0]);
1963
1964        let result = take(&array, &index, None).unwrap();
1965        let expected: ArrayRef = Arc::new(
1966            MapArray::new_from_strings(
1967                vec!["a", "b", "c"].into_iter(),
1968                &values.slice(0, 3),
1969                &[0, 3],
1970            )
1971            .unwrap(),
1972        );
1973        assert_eq!(&expected, &result);
1974    }
1975
1976    #[test]
1977    fn test_take_struct() {
1978        let array = create_test_struct(vec![
1979            Some((Some(true), Some(42))),
1980            Some((Some(false), Some(28))),
1981            Some((Some(false), Some(19))),
1982            Some((Some(true), Some(31))),
1983            None,
1984        ]);
1985
1986        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1987        let actual = take(&array, &index, None).unwrap();
1988        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1989        assert_eq!(index.len(), actual.len());
1990        assert_eq!(1, actual.null_count());
1991
1992        let expected = create_test_struct(vec![
1993            Some((Some(true), Some(42))),
1994            Some((Some(true), Some(31))),
1995            Some((Some(false), Some(28))),
1996            Some((Some(true), Some(42))),
1997            Some((Some(false), Some(19))),
1998            None,
1999        ]);
2000
2001        assert_eq!(&expected, actual);
2002
2003        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2004        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2005        let index = UInt32Array::from(vec![0, 2, 1, 4]);
2006        let actual = take(&empty_struct_arr, &index, None).unwrap();
2007
2008        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2009        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2010        assert_eq!(&expected_struct_arr, actual.as_struct());
2011    }
2012
2013    #[test]
2014    fn test_take_struct_with_null_indices() {
2015        let array = create_test_struct(vec![
2016            Some((Some(true), Some(42))),
2017            Some((Some(false), Some(28))),
2018            Some((Some(false), Some(19))),
2019            Some((Some(true), Some(31))),
2020            None,
2021        ]);
2022
2023        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2024        let actual = take(&array, &index, None).unwrap();
2025        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2026        assert_eq!(index.len(), actual.len());
2027        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
2028
2029        let expected = create_test_struct(vec![
2030            None,
2031            Some((Some(true), Some(31))),
2032            Some((Some(false), Some(28))),
2033            None,
2034            Some((Some(true), Some(42))),
2035            None,
2036        ]);
2037
2038        assert_eq!(&expected, actual);
2039    }
2040
2041    #[test]
2042    fn test_take_out_of_bounds() {
2043        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2044        let take_opt = TakeOptions { check_bounds: true };
2045
2046        // int64
2047        let result = test_take_primitive_arrays::<Int64Type>(
2048            vec![Some(0), None, Some(2), Some(3), None],
2049            &index,
2050            Some(take_opt),
2051            vec![None],
2052        );
2053        assert!(result.is_err());
2054    }
2055
2056    #[test]
2057    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2058    fn test_take_out_of_bounds_panic() {
2059        let index = UInt32Array::from(vec![Some(1000)]);
2060
2061        test_take_primitive_arrays::<Int64Type>(
2062            vec![Some(0), Some(1), Some(2), Some(3)],
2063            &index,
2064            None,
2065            vec![None],
2066        )
2067        .unwrap();
2068    }
2069
2070    #[test]
2071    fn test_null_array_smaller_than_indices() {
2072        let values = NullArray::new(2);
2073        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2074
2075        let result = take(&values, &indices, None).unwrap();
2076        let expected: ArrayRef = Arc::new(NullArray::new(3));
2077        assert_eq!(&result, &expected);
2078    }
2079
2080    #[test]
2081    fn test_null_array_larger_than_indices() {
2082        let values = NullArray::new(5);
2083        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2084
2085        let result = take(&values, &indices, None).unwrap();
2086        let expected: ArrayRef = Arc::new(NullArray::new(3));
2087        assert_eq!(&result, &expected);
2088    }
2089
2090    #[test]
2091    fn test_null_array_indices_out_of_bounds() {
2092        let values = NullArray::new(5);
2093        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2094
2095        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2096        assert_eq!(
2097            result.unwrap_err().to_string(),
2098            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2099        );
2100    }
2101
2102    #[test]
2103    fn test_take_dict() {
2104        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2105
2106        dict_builder.append("foo").unwrap();
2107        dict_builder.append("bar").unwrap();
2108        dict_builder.append("").unwrap();
2109        dict_builder.append_null();
2110        dict_builder.append("foo").unwrap();
2111        dict_builder.append("bar").unwrap();
2112        dict_builder.append("bar").unwrap();
2113        dict_builder.append("foo").unwrap();
2114
2115        let array = dict_builder.finish();
2116        let dict_values = array.values().clone();
2117        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2118
2119        let indices = UInt32Array::from(vec![
2120            Some(0), // first "foo"
2121            Some(7), // last "foo"
2122            None,    // null index should return null
2123            Some(5), // second "bar"
2124            Some(6), // another "bar"
2125            Some(2), // empty string
2126            Some(3), // input is null at this index
2127        ]);
2128
2129        let result = take(&array, &indices, None).unwrap();
2130        let result = result
2131            .as_any()
2132            .downcast_ref::<DictionaryArray<Int16Type>>()
2133            .unwrap();
2134
2135        let result_values: StringArray = result.values().to_data().into();
2136
2137        // dictionary values should stay the same
2138        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2139        assert_eq!(&expected_values, dict_values);
2140        assert_eq!(&expected_values, &result_values);
2141
2142        let expected_keys = Int16Array::from(vec![
2143            Some(0),
2144            Some(0),
2145            None,
2146            Some(1),
2147            Some(1),
2148            Some(2),
2149            None,
2150        ]);
2151        assert_eq!(result.keys(), &expected_keys);
2152    }
2153
2154    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2155    where
2156        S: OffsetSizeTrait + 'static,
2157        T: ArrowPrimitiveType,
2158        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2159    {
2160        GenericListArray::from_iter_primitive::<T, _, _>(
2161            data.iter()
2162                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2163        )
2164    }
2165
2166    #[test]
2167    fn test_take_value_index_from_list() {
2168        let list = build_generic_list::<i32, Int32Type>(vec![
2169            Some(vec![0, 1]),
2170            Some(vec![2, 3, 4]),
2171            Some(vec![5, 6, 7, 8, 9]),
2172        ]);
2173        let indices = UInt32Array::from(vec![2, 0]);
2174
2175        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2176
2177        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2178        assert_eq!(offsets, vec![0, 5, 7]);
2179        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2180    }
2181
2182    #[test]
2183    fn test_take_value_index_from_large_list() {
2184        let list = build_generic_list::<i64, Int32Type>(vec![
2185            Some(vec![0, 1]),
2186            Some(vec![2, 3, 4]),
2187            Some(vec![5, 6, 7, 8, 9]),
2188        ]);
2189        let indices = UInt32Array::from(vec![2, 0]);
2190
2191        let (indexed, offsets, null_buf) =
2192            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2193
2194        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2195        assert_eq!(offsets, vec![0, 5, 7]);
2196        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2197    }
2198
2199    #[test]
2200    fn test_take_runs() {
2201        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2202
2203        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2204        builder.extend(logical_array.into_iter().map(Some));
2205        let run_array = builder.finish();
2206
2207        let take_indices: PrimitiveArray<Int32Type> =
2208            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2209
2210        let take_out = take_run(&run_array, &take_indices).unwrap();
2211
2212        assert_eq!(take_out.len(), 7);
2213        assert_eq!(take_out.run_ends().len(), 7);
2214        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2215
2216        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2217        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2218    }
2219
2220    #[test]
2221    fn test_take_value_index_from_fixed_list() {
2222        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2223            vec![
2224                Some(vec![Some(1), Some(2), None]),
2225                Some(vec![Some(4), None, Some(6)]),
2226                None,
2227                Some(vec![None, Some(8), Some(9)]),
2228            ],
2229            3,
2230        );
2231
2232        let indices = UInt32Array::from(vec![2, 1, 0]);
2233        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2234
2235        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2236
2237        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2238        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2239
2240        assert_eq!(
2241            indexed,
2242            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2243        );
2244    }
2245
2246    #[test]
2247    fn test_take_null_indices() {
2248        // Build indices with values that are out of bounds, but masked by null mask
2249        let indices = Int32Array::new(
2250            vec![1, 2, 400, 400].into(),
2251            Some(NullBuffer::from(vec![true, true, false, false])),
2252        );
2253        let values = Int32Array::from(vec![1, 23, 4, 5]);
2254        let r = take(&values, &indices, None).unwrap();
2255        let values = r
2256            .as_primitive::<Int32Type>()
2257            .into_iter()
2258            .collect::<Vec<_>>();
2259        assert_eq!(&values, &[Some(23), Some(4), None, None])
2260    }
2261
2262    #[test]
2263    fn test_take_fixed_size_list_null_indices() {
2264        let indices = Int32Array::from_iter([Some(0), None]);
2265        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2266        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2267        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2268
2269        let r = take(&values, &indices, None).unwrap();
2270        let values = r
2271            .as_fixed_size_list()
2272            .values()
2273            .as_primitive::<Int32Type>()
2274            .into_iter()
2275            .collect::<Vec<_>>();
2276        assert_eq!(values, &[Some(0), Some(1), None, None])
2277    }
2278
2279    #[test]
2280    fn test_take_bytes_null_indices() {
2281        let indices = Int32Array::new(
2282            vec![0, 1, 400, 400].into(),
2283            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2284        );
2285        let values = StringArray::from(vec![Some("foo"), None]);
2286        let r = take(&values, &indices, None).unwrap();
2287        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2288        assert_eq!(&values, &[Some("foo"), None, None, None])
2289    }
2290
2291    #[test]
2292    fn test_take_union_sparse() {
2293        let structs = create_test_struct(vec![
2294            Some((Some(true), Some(42))),
2295            Some((Some(false), Some(28))),
2296            Some((Some(false), Some(19))),
2297            Some((Some(true), Some(31))),
2298            None,
2299        ]);
2300        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2301        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2302
2303        let union_fields = [
2304            (
2305                0,
2306                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2307            ),
2308            (
2309                1,
2310                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2311            ),
2312        ]
2313        .into_iter()
2314        .collect();
2315        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2316        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2317
2318        let indices = vec![0, 3, 1, 0, 2, 4];
2319        let index = UInt32Array::from(indices.clone());
2320        let actual = take(&array, &index, None).unwrap();
2321        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2322        let strings = actual.child(1);
2323        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2324
2325        let actual = strings.iter().collect::<Vec<_>>();
2326        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2327        assert_eq!(expected, actual);
2328    }
2329
2330    #[test]
2331    fn test_take_union_dense() {
2332        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2333        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2334        let ints = vec![10, 20, 30, 40];
2335        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2336
2337        let indices = vec![0, 3, 1, 0, 2, 4];
2338
2339        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2340        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2341        let taken_ints = vec![10, 20, 10, 30];
2342        let taken_strings = vec![Some("a"), None];
2343
2344        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2345        let offsets = <ScalarBuffer<i32>>::from(offsets);
2346        let ints = UInt32Array::from(ints);
2347        let strings = StringArray::from(strings);
2348
2349        let union_fields = [
2350            (
2351                0,
2352                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2353            ),
2354            (
2355                1,
2356                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2357            ),
2358        ]
2359        .into_iter()
2360        .collect();
2361
2362        let array = UnionArray::try_new(
2363            union_fields,
2364            type_ids,
2365            Some(offsets),
2366            vec![Arc::new(ints), Arc::new(strings)],
2367        )
2368        .unwrap();
2369
2370        let index = UInt32Array::from(indices);
2371
2372        let actual = take(&array, &index, None).unwrap();
2373        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2374
2375        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2376        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2377        assert_eq!(
2378            UInt32Array::from(actual.child(0).to_data()),
2379            UInt32Array::from(taken_ints)
2380        );
2381        assert_eq!(
2382            StringArray::from(actual.child(1).to_data()),
2383            StringArray::from(taken_strings)
2384        );
2385    }
2386
2387    #[test]
2388    fn test_take_union_dense_using_builder() {
2389        let mut builder = UnionBuilder::new_dense();
2390
2391        builder.append::<Int32Type>("a", 1).unwrap();
2392        builder.append::<Float64Type>("b", 3.0).unwrap();
2393        builder.append::<Int32Type>("a", 4).unwrap();
2394        builder.append::<Int32Type>("a", 5).unwrap();
2395        builder.append::<Float64Type>("b", 2.0).unwrap();
2396
2397        let union = builder.build().unwrap();
2398
2399        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2400
2401        let mut builder = UnionBuilder::new_dense();
2402
2403        builder.append::<Int32Type>("a", 4).unwrap();
2404        builder.append::<Int32Type>("a", 1).unwrap();
2405        builder.append::<Float64Type>("b", 3.0).unwrap();
2406        builder.append::<Int32Type>("a", 4).unwrap();
2407
2408        let taken = builder.build().unwrap();
2409
2410        assert_eq!(
2411            taken.to_data(),
2412            take(&union, &indices, None).unwrap().to_data()
2413        );
2414    }
2415
2416    #[test]
2417    fn test_take_union_dense_all_match_issue_6206() {
2418        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2419        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2420
2421        let array = UnionArray::try_new(
2422            fields,
2423            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2424            Some(ScalarBuffer::from_iter(0_i32..5)),
2425            vec![ints],
2426        )
2427        .unwrap();
2428
2429        let indicies = Int64Array::from(vec![0, 2, 4]);
2430        let array = take(&array, &indicies, None).unwrap();
2431        assert_eq!(array.len(), 3);
2432    }
2433
2434    #[test]
2435    fn test_take_bytes_offset_overflow() {
2436        let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2437        let text = ('a'..='z').collect::<String>();
2438        let values = StringArray::from(vec![Some(text.clone())]);
2439        assert!(matches!(
2440            take(&values, &indices, None),
2441            Err(ArrowError::OffsetOverflowError(_))
2442        ));
2443    }
2444}