arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27    bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer,
28    ScalarBuffer,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num::{One, Zero};
34
35/// Take elements by index from [Array], creating a new [Array] from those indexes.
36///
37/// ```text
38/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
39/// │        A        │      │    0    │                              │        A        │
40/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
41/// │        D        │      │    2    │                              │        B        │
42/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
43/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
44/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
45/// │        C        │      │    1    │                              │        D        │
46/// ├─────────────────┤      └─────────┘                              └─────────────────┘
47/// │        E        │
48/// └─────────────────┘
49///    values array          indices array                              result
50/// ```
51///
52/// For selecting values by index from multiple arrays see [`crate::interleave`]
53///
54/// Note that this kernel, similar to other kernels in this crate,
55/// will avoid allocating where not necessary. Consequently
56/// the returned array may share buffers with the inputs
57///
58/// # Errors
59/// This function errors whenever:
60/// * An index cannot be casted to `usize` (typically 32 bit architectures)
61/// * An index is out of bounds and `options` is set to check bounds.
62///
63/// # Safety
64///
65/// When `options` is not set to check bounds, taking indexes after `len` will panic.
66///
67/// # Examples
68/// ```
69/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
70/// # use arrow_select::take::take;
71/// let values = StringArray::from(vec!["zero", "one", "two"]);
72///
73/// // Take items at index 2, and 1:
74/// let indices = UInt32Array::from(vec![2, 1]);
75/// let taken = take(&values, &indices, None).unwrap();
76/// let taken = taken.as_string::<i32>();
77///
78/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
79/// ```
80pub fn take(
81    values: &dyn Array,
82    indices: &dyn Array,
83    options: Option<TakeOptions>,
84) -> Result<ArrayRef, ArrowError> {
85    let options = options.unwrap_or_default();
86    downcast_integer_array!(
87        indices => {
88            if options.check_bounds {
89                check_bounds(values.len(), indices)?;
90            }
91            let indices = indices.to_indices();
92            take_impl(values, &indices)
93        },
94        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
95    )
96}
97
98/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
99/// [`Vec<ArrayRef>`] from those indices.
100///
101/// ```text
102/// ┌────────┬────────┐
103/// │        │        │           ┌────────┐                                ┌────────┬────────┐
104/// │   A    │   1    │           │        │                                │        │        │
105/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
106/// │        │        │           ├────────┤                                ├────────┼────────┤
107/// │   D    │   4    │           │        │                                │        │        │
108/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
109/// │        │        │           ├────────┤                                ├────────┼────────┤
110/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
111/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
112/// │        │        │           ├────────┤                                ├────────┼────────┤
113/// │   C    │   3    │           │        │                                │        │        │
114/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
115/// │        │        │           └────────┘                                └────────┼────────┘
116/// │   E    │   5    │
117/// └────────┴────────┘
118///    values arrays             indices array                                      result
119/// ```
120///
121/// # Errors
122/// This function errors whenever:
123/// * An index cannot be casted to `usize` (typically 32 bit architectures)
124/// * An index is out of bounds and `options` is set to check bounds.
125///
126/// # Safety
127///
128/// When `options` is not set to check bounds, taking indexes after `len` will panic.
129///
130/// # Examples
131/// ```
132/// # use std::sync::Arc;
133/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
134/// # use arrow_select::take::{take, take_arrays};
135/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
136/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
137///
138/// // Take items at index 2, and 1:
139/// let indices = UInt32Array::from(vec![2, 1]);
140/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
141/// let taken_string = taken_arrays[0].as_string::<i32>();
142/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
143/// let taken_values = taken_arrays[1].as_primitive();
144/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
145/// ```
146pub fn take_arrays(
147    arrays: &[ArrayRef],
148    indices: &dyn Array,
149    options: Option<TakeOptions>,
150) -> Result<Vec<ArrayRef>, ArrowError> {
151    arrays
152        .iter()
153        .map(|array| take(array.as_ref(), indices, options.clone()))
154        .collect()
155}
156
157/// Verifies that the non-null values of `indices` are all `< len`
158fn check_bounds<T: ArrowPrimitiveType>(
159    len: usize,
160    indices: &PrimitiveArray<T>,
161) -> Result<(), ArrowError> {
162    if indices.null_count() > 0 {
163        indices.iter().flatten().try_for_each(|index| {
164            let ix = index
165                .to_usize()
166                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
167            if ix >= len {
168                return Err(ArrowError::ComputeError(format!(
169                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
170                )));
171            }
172            Ok(())
173        })
174    } else {
175        indices.values().iter().try_for_each(|index| {
176            let ix = index
177                .to_usize()
178                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
179            if ix >= len {
180                return Err(ArrowError::ComputeError(format!(
181                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
182                )));
183            }
184            Ok(())
185        })
186    }
187}
188
189#[inline(never)]
190fn take_impl<IndexType: ArrowPrimitiveType>(
191    values: &dyn Array,
192    indices: &PrimitiveArray<IndexType>,
193) -> Result<ArrayRef, ArrowError> {
194    downcast_primitive_array! {
195        values => Ok(Arc::new(take_primitive(values, indices)?)),
196        DataType::Boolean => {
197            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
198            Ok(Arc::new(take_boolean(values, indices)))
199        }
200        DataType::Utf8 => {
201            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
202        }
203        DataType::LargeUtf8 => {
204            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
205        }
206        DataType::Utf8View => {
207            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
208        }
209        DataType::List(_) => {
210            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
211        }
212        DataType::LargeList(_) => {
213            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
214        }
215        DataType::FixedSizeList(_, length) => {
216            let values = values
217                .as_any()
218                .downcast_ref::<FixedSizeListArray>()
219                .unwrap();
220            Ok(Arc::new(take_fixed_size_list(
221                values,
222                indices,
223                *length as u32,
224            )?))
225        }
226        DataType::Map(_, _) => {
227            let list_arr = ListArray::from(values.as_map().clone());
228            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
229            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
230            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
231        }
232        DataType::Struct(fields) => {
233            let array: &StructArray = values.as_struct();
234            let arrays  = array
235                .columns()
236                .iter()
237                .map(|a| take_impl(a.as_ref(), indices))
238                .collect::<Result<Vec<ArrayRef>, _>>()?;
239            let fields: Vec<(FieldRef, ArrayRef)> =
240                fields.iter().cloned().zip(arrays).collect();
241
242            // Create the null bit buffer.
243            let is_valid: Buffer = indices
244                .iter()
245                .map(|index| {
246                    if let Some(index) = index {
247                        array.is_valid(index.to_usize().unwrap())
248                    } else {
249                        false
250                    }
251                })
252                .collect();
253
254            if fields.is_empty() {
255                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
256                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
257            } else {
258                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
259            }
260        }
261        DataType::Dictionary(_, _) => downcast_dictionary_array! {
262            values => Ok(Arc::new(take_dict(values, indices)?)),
263            t => unimplemented!("Take not supported for dictionary type {:?}", t)
264        }
265        DataType::RunEndEncoded(_, _) => downcast_run_array! {
266            values => Ok(Arc::new(take_run(values, indices)?)),
267            t => unimplemented!("Take not supported for run type {:?}", t)
268        }
269        DataType::Binary => {
270            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
271        }
272        DataType::LargeBinary => {
273            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
274        }
275        DataType::BinaryView => {
276            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
277        }
278        DataType::FixedSizeBinary(size) => {
279            let values = values
280                .as_any()
281                .downcast_ref::<FixedSizeBinaryArray>()
282                .unwrap();
283            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
284        }
285        DataType::Null => {
286            // Take applied to a null array produces a null array.
287            if values.len() >= indices.len() {
288                // If the existing null array is as big as the indices, we can use a slice of it
289                // to avoid allocating a new null array.
290                Ok(values.slice(0, indices.len()))
291            } else {
292                // If the existing null array isn't big enough, create a new one.
293                Ok(new_null_array(&DataType::Null, indices.len()))
294            }
295        }
296        DataType::Union(fields, UnionMode::Sparse) => {
297            let mut children = Vec::with_capacity(fields.len());
298            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
299            let type_ids = take_native(values.type_ids(), indices);
300            for (type_id, _field) in fields.iter() {
301                let values = values.child(type_id);
302                let values = take_impl(values, indices)?;
303                children.push(values);
304            }
305            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
306            Ok(Arc::new(array))
307        }
308        DataType::Union(fields, UnionMode::Dense) => {
309            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
310
311            let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
312            let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
313
314            let children = fields.iter()
315                .map(|(field_type_id, _)| {
316                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
317
318                    let indices = crate::filter::filter(&offsets, &mask)?;
319
320                    let values = values.child(field_type_id);
321
322                    take_impl(values, indices.as_primitive::<Int32Type>())
323                })
324                .collect::<Result<_, _>>()?;
325
326            let mut child_offsets = [0; 128];
327
328            let offsets = type_ids.values()
329                .iter()
330                .map(|&i| {
331                    let offset = child_offsets[i as usize];
332
333                    child_offsets[i as usize] += 1;
334
335                    offset
336                })
337                .collect();
338
339            let (_, type_ids, _) = type_ids.into_parts();
340
341            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
342
343            Ok(Arc::new(array))
344        }
345        t => unimplemented!("Take not supported for data type {:?}", t)
346    }
347}
348
349/// Options that define how `take` should behave
350#[derive(Clone, Debug, Default)]
351pub struct TakeOptions {
352    /// Perform bounds check before taking indices from values.
353    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
354    /// If not enabled, and indices exceed bounds, the kernel will panic.
355    pub check_bounds: bool,
356}
357
358#[inline(always)]
359fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
360    index
361        .to_usize()
362        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
363}
364
365/// `take` implementation for all primitive arrays
366///
367/// This checks if an `indices` slot is populated, and gets the value from `values`
368///  as the populated index.
369/// If the `indices` slot is null, a null value is returned.
370/// For example, given:
371///     values:  [1, 2, 3, null, 5]
372///     indices: [0, null, 4, 3]
373/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
374fn take_primitive<T, I>(
375    values: &PrimitiveArray<T>,
376    indices: &PrimitiveArray<I>,
377) -> Result<PrimitiveArray<T>, ArrowError>
378where
379    T: ArrowPrimitiveType,
380    I: ArrowPrimitiveType,
381{
382    let values_buf = take_native(values.values(), indices);
383    let nulls = take_nulls(values.nulls(), indices);
384    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
385}
386
387#[inline(never)]
388fn take_nulls<I: ArrowPrimitiveType>(
389    values: Option<&NullBuffer>,
390    indices: &PrimitiveArray<I>,
391) -> Option<NullBuffer> {
392    match values.filter(|n| n.null_count() > 0) {
393        Some(n) => {
394            let buffer = take_bits(n.inner(), indices);
395            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
396        }
397        None => indices.nulls().cloned(),
398    }
399}
400
401#[inline(never)]
402fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
403    values: &[T],
404    indices: &PrimitiveArray<I>,
405) -> ScalarBuffer<T> {
406    match indices.nulls().filter(|n| n.null_count() > 0) {
407        Some(n) => indices
408            .values()
409            .iter()
410            .enumerate()
411            .map(|(idx, index)| match values.get(index.as_usize()) {
412                Some(v) => *v,
413                None => match n.is_null(idx) {
414                    true => T::default(),
415                    false => panic!("Out-of-bounds index {index:?}"),
416                },
417            })
418            .collect(),
419        None => indices
420            .values()
421            .iter()
422            .map(|index| values[index.as_usize()])
423            .collect(),
424    }
425}
426
427#[inline(never)]
428fn take_bits<I: ArrowPrimitiveType>(
429    values: &BooleanBuffer,
430    indices: &PrimitiveArray<I>,
431) -> BooleanBuffer {
432    let len = indices.len();
433
434    match indices.nulls().filter(|n| n.null_count() > 0) {
435        Some(nulls) => {
436            let mut output_buffer = MutableBuffer::new_null(len);
437            let output_slice = output_buffer.as_slice_mut();
438            nulls.valid_indices().for_each(|idx| {
439                if values.value(indices.value(idx).as_usize()) {
440                    bit_util::set_bit(output_slice, idx);
441                }
442            });
443            BooleanBuffer::new(output_buffer.into(), 0, len)
444        }
445        None => {
446            BooleanBuffer::collect_bool(len, |idx: usize| {
447                // SAFETY: idx<indices.len()
448                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
449            })
450        }
451    }
452}
453
454/// `take` implementation for boolean arrays
455fn take_boolean<IndexType: ArrowPrimitiveType>(
456    values: &BooleanArray,
457    indices: &PrimitiveArray<IndexType>,
458) -> BooleanArray {
459    let val_buf = take_bits(values.values(), indices);
460    let null_buf = take_nulls(values.nulls(), indices);
461    BooleanArray::new(val_buf, null_buf)
462}
463
464/// `take` implementation for string arrays
465fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
466    array: &GenericByteArray<T>,
467    indices: &PrimitiveArray<IndexType>,
468) -> Result<GenericByteArray<T>, ArrowError> {
469    let mut offsets = Vec::with_capacity(indices.len() + 1);
470    offsets.push(T::Offset::default());
471
472    let input_offsets = array.value_offsets();
473    let mut capacity = 0;
474    let nulls = take_nulls(array.nulls(), indices);
475
476    let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
477        offsets.extend(indices.values().iter().map(|index| {
478            let index = index.as_usize();
479            capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
480            T::Offset::from_usize(capacity).expect("overflow")
481        }));
482        let mut values = Vec::with_capacity(capacity);
483
484        for index in indices.values() {
485            values.extend_from_slice(array.value(index.as_usize()).as_ref());
486        }
487        (offsets, values)
488    } else if indices.null_count() == 0 {
489        offsets.extend(indices.values().iter().map(|index| {
490            let index = index.as_usize();
491            if array.is_valid(index) {
492                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
493            }
494            T::Offset::from_usize(capacity).expect("overflow")
495        }));
496        let mut values = Vec::with_capacity(capacity);
497
498        for index in indices.values() {
499            let index = index.as_usize();
500            if array.is_valid(index) {
501                values.extend_from_slice(array.value(index).as_ref());
502            }
503        }
504        (offsets, values)
505    } else if array.null_count() == 0 {
506        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
507            let index = index.as_usize();
508            if indices.is_valid(i) {
509                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
510            }
511            T::Offset::from_usize(capacity).expect("overflow")
512        }));
513        let mut values = Vec::with_capacity(capacity);
514
515        for (i, index) in indices.values().iter().enumerate() {
516            if indices.is_valid(i) {
517                values.extend_from_slice(array.value(index.as_usize()).as_ref());
518            }
519        }
520        (offsets, values)
521    } else {
522        let nulls = nulls.as_ref().unwrap();
523        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
524            let index = index.as_usize();
525            if nulls.is_valid(i) {
526                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
527            }
528            T::Offset::from_usize(capacity).expect("overflow")
529        }));
530        let mut values = Vec::with_capacity(capacity);
531
532        for (i, index) in indices.values().iter().enumerate() {
533            // check index is valid before using index. The value in
534            // NULL index slots may not be within bounds of array
535            let index = index.as_usize();
536            if nulls.is_valid(i) {
537                values.extend_from_slice(array.value(index).as_ref());
538            }
539        }
540        (offsets, values)
541    };
542
543    T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
544        "Offset overflow for {}BinaryArray: {}",
545        T::Offset::PREFIX,
546        values.len()
547    )))?;
548
549    let array = unsafe {
550        let offsets = OffsetBuffer::new_unchecked(offsets.into());
551        GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
552    };
553
554    Ok(array)
555}
556
557/// `take` implementation for byte view arrays
558fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
559    array: &GenericByteViewArray<T>,
560    indices: &PrimitiveArray<IndexType>,
561) -> Result<GenericByteViewArray<T>, ArrowError> {
562    let new_views = take_native(array.views(), indices);
563    let new_nulls = take_nulls(array.nulls(), indices);
564    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
565    Ok(unsafe {
566        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
567    })
568}
569
570/// `take` implementation for list arrays
571///
572/// Calculates the index and indexed offset for the inner array,
573/// applying `take` on the inner array, then reconstructing a list array
574/// with the indexed offsets
575fn take_list<IndexType, OffsetType>(
576    values: &GenericListArray<OffsetType::Native>,
577    indices: &PrimitiveArray<IndexType>,
578) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
579where
580    IndexType: ArrowPrimitiveType,
581    OffsetType: ArrowPrimitiveType,
582    OffsetType::Native: OffsetSizeTrait,
583    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
584{
585    // TODO: Some optimizations can be done here such as if it is
586    // taking the whole list or a contiguous sublist
587    let (list_indices, offsets, null_buf) =
588        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
589
590    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
591    let value_offsets = Buffer::from_vec(offsets);
592    // create a new list with taken data and computed null information
593    let list_data = ArrayDataBuilder::new(values.data_type().clone())
594        .len(indices.len())
595        .null_bit_buffer(Some(null_buf.into()))
596        .offset(0)
597        .add_child_data(taken.into_data())
598        .add_buffer(value_offsets);
599
600    let list_data = unsafe { list_data.build_unchecked() };
601
602    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
603}
604
605/// `take` implementation for `FixedSizeListArray`
606///
607/// Calculates the index and indexed offset for the inner array,
608/// applying `take` on the inner array, then reconstructing a list array
609/// with the indexed offsets
610fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
611    values: &FixedSizeListArray,
612    indices: &PrimitiveArray<IndexType>,
613    length: <UInt32Type as ArrowPrimitiveType>::Native,
614) -> Result<FixedSizeListArray, ArrowError> {
615    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
616    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
617
618    // determine null count and null buffer, which are a function of `values` and `indices`
619    let num_bytes = bit_util::ceil(indices.len(), 8);
620    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
621    let null_slice = null_buf.as_slice_mut();
622
623    for i in 0..indices.len() {
624        let index = indices
625            .value(i)
626            .to_usize()
627            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
628        if !indices.is_valid(i) || values.is_null(index) {
629            bit_util::unset_bit(null_slice, i);
630        }
631    }
632
633    let list_data = ArrayDataBuilder::new(values.data_type().clone())
634        .len(indices.len())
635        .null_bit_buffer(Some(null_buf.into()))
636        .offset(0)
637        .add_child_data(taken.into_data());
638
639    let list_data = unsafe { list_data.build_unchecked() };
640
641    Ok(FixedSizeListArray::from(list_data))
642}
643
644fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
645    values: &FixedSizeBinaryArray,
646    indices: &PrimitiveArray<IndexType>,
647    size: i32,
648) -> Result<FixedSizeBinaryArray, ArrowError> {
649    let nulls = values.nulls();
650    let array_iter = indices
651        .values()
652        .iter()
653        .map(|idx| {
654            let idx = maybe_usize::<IndexType::Native>(*idx)?;
655            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
656                Ok(Some(values.value(idx)))
657            } else {
658                Ok(None)
659            }
660        })
661        .collect::<Result<Vec<_>, ArrowError>>()?
662        .into_iter();
663
664    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
665}
666
667/// `take` implementation for dictionary arrays
668///
669/// applies `take` to the keys of the dictionary array and returns a new dictionary array
670/// with the same dictionary values and reordered keys
671fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
672    values: &DictionaryArray<T>,
673    indices: &PrimitiveArray<I>,
674) -> Result<DictionaryArray<T>, ArrowError> {
675    let new_keys = take_primitive(values.keys(), indices)?;
676    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
677}
678
679/// `take` implementation for run arrays
680///
681/// Finds physical indices for the given logical indices and builds output run array
682/// by taking values in the input run_array.values at the physical indices.
683/// The output run array will be run encoded on the physical indices and not on output values.
684/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
685/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
686/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
687fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
688    run_array: &RunArray<T>,
689    logical_indices: &PrimitiveArray<I>,
690) -> Result<RunArray<T>, ArrowError> {
691    // get physical indices for the input logical indices
692    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
693
694    // Run encode the physical indices into new_run_ends_builder
695    // Keep track of the physical indices to take in take_value_indices
696    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
697    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
698    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
699    let mut new_physical_len = 1;
700    for ix in 1..physical_indices.len() {
701        if physical_indices[ix] != physical_indices[ix - 1] {
702            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
703            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
704            new_physical_len += 1;
705        }
706    }
707    take_value_indices
708        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
709    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
710    let new_run_ends = unsafe {
711        // Safety:
712        // The function builds a valid run_ends array and hence need not be validated.
713        ArrayDataBuilder::new(T::DATA_TYPE)
714            .len(new_physical_len)
715            .null_count(0)
716            .add_buffer(new_run_ends_builder.finish())
717            .build_unchecked()
718    };
719
720    let take_value_indices: PrimitiveArray<I> = unsafe {
721        // Safety:
722        // The function builds a valid take_value_indices array and hence need not be validated.
723        ArrayDataBuilder::new(I::DATA_TYPE)
724            .len(new_physical_len)
725            .null_count(0)
726            .add_buffer(take_value_indices.finish())
727            .build_unchecked()
728            .into()
729    };
730
731    let new_values = take(run_array.values(), &take_value_indices, None)?;
732
733    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
734        .len(physical_indices.len())
735        .add_child_data(new_run_ends)
736        .add_child_data(new_values.into_data());
737    let array_data = unsafe {
738        // Safety:
739        //  This function builds a valid run array and hence can skip validation.
740        builder.build_unchecked()
741    };
742    Ok(array_data.into())
743}
744
745/// Takes/filters a list array's inner data using the offsets of the list array.
746///
747/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
748/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
749/// elements)
750#[allow(clippy::type_complexity)]
751fn take_value_indices_from_list<IndexType, OffsetType>(
752    list: &GenericListArray<OffsetType::Native>,
753    indices: &PrimitiveArray<IndexType>,
754) -> Result<
755    (
756        PrimitiveArray<OffsetType>,
757        Vec<OffsetType::Native>,
758        MutableBuffer,
759    ),
760    ArrowError,
761>
762where
763    IndexType: ArrowPrimitiveType,
764    OffsetType: ArrowPrimitiveType,
765    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
766    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
767{
768    // TODO: benchmark this function, there might be a faster unsafe alternative
769    let offsets: &[OffsetType::Native] = list.value_offsets();
770
771    let mut new_offsets = Vec::with_capacity(indices.len());
772    let mut values = Vec::new();
773    let mut current_offset = OffsetType::Native::zero();
774    // add first offset
775    new_offsets.push(OffsetType::Native::zero());
776
777    // Initialize null buffer
778    let num_bytes = bit_util::ceil(indices.len(), 8);
779    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
780    let null_slice = null_buf.as_slice_mut();
781
782    // compute the value indices, and set offsets accordingly
783    for i in 0..indices.len() {
784        if indices.is_valid(i) {
785            let ix = indices
786                .value(i)
787                .to_usize()
788                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
789            let start = offsets[ix];
790            let end = offsets[ix + 1];
791            current_offset += end - start;
792            new_offsets.push(current_offset);
793
794            let mut curr = start;
795
796            // if start == end, this slot is empty
797            while curr < end {
798                values.push(curr);
799                curr += One::one();
800            }
801            if !list.is_valid(ix) {
802                bit_util::unset_bit(null_slice, i);
803            }
804        } else {
805            bit_util::unset_bit(null_slice, i);
806            new_offsets.push(current_offset);
807        }
808    }
809
810    Ok((
811        PrimitiveArray::<OffsetType>::from(values),
812        new_offsets,
813        null_buf,
814    ))
815}
816
817/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
818fn take_value_indices_from_fixed_size_list<IndexType>(
819    list: &FixedSizeListArray,
820    indices: &PrimitiveArray<IndexType>,
821    length: <UInt32Type as ArrowPrimitiveType>::Native,
822) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
823where
824    IndexType: ArrowPrimitiveType,
825{
826    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
827
828    for i in 0..indices.len() {
829        if indices.is_valid(i) {
830            let index = indices
831                .value(i)
832                .to_usize()
833                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
834            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
835
836            // Safety: Range always has known length.
837            unsafe {
838                values.append_trusted_len_iter(start..start + length);
839            }
840        } else {
841            values.append_nulls(length as usize);
842        }
843    }
844
845    Ok(values.finish())
846}
847
848/// To avoid generating take implementations for every index type, instead we
849/// only generate for UInt32 and UInt64 and coerce inputs to these types
850trait ToIndices {
851    type T: ArrowPrimitiveType;
852
853    fn to_indices(&self) -> PrimitiveArray<Self::T>;
854}
855
856macro_rules! to_indices_reinterpret {
857    ($t:ty, $o:ty) => {
858        impl ToIndices for PrimitiveArray<$t> {
859            type T = $o;
860
861            fn to_indices(&self) -> PrimitiveArray<$o> {
862                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
863                PrimitiveArray::new(cast, self.nulls().cloned())
864            }
865        }
866    };
867}
868
869macro_rules! to_indices_identity {
870    ($t:ty) => {
871        impl ToIndices for PrimitiveArray<$t> {
872            type T = $t;
873
874            fn to_indices(&self) -> PrimitiveArray<$t> {
875                self.clone()
876            }
877        }
878    };
879}
880
881macro_rules! to_indices_widening {
882    ($t:ty, $o:ty) => {
883        impl ToIndices for PrimitiveArray<$t> {
884            type T = UInt32Type;
885
886            fn to_indices(&self) -> PrimitiveArray<$o> {
887                let cast = self.values().iter().copied().map(|x| x as _).collect();
888                PrimitiveArray::new(cast, self.nulls().cloned())
889            }
890        }
891    };
892}
893
894to_indices_widening!(UInt8Type, UInt32Type);
895to_indices_widening!(Int8Type, UInt32Type);
896
897to_indices_widening!(UInt16Type, UInt32Type);
898to_indices_widening!(Int16Type, UInt32Type);
899
900to_indices_identity!(UInt32Type);
901to_indices_reinterpret!(Int32Type, UInt32Type);
902
903to_indices_identity!(UInt64Type);
904to_indices_reinterpret!(Int64Type, UInt64Type);
905
906/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
907///
908/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
909///
910/// # Example
911/// ```
912/// # use std::sync::Arc;
913/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
914/// # use arrow_schema::{DataType, Field, Schema};
915/// # use arrow_select::take::take_record_batch;
916///
917/// let schema = Arc::new(Schema::new(vec![
918///     Field::new("a", DataType::Int32, true),
919///     Field::new("b", DataType::Utf8, true),
920/// ]));
921/// let batch = RecordBatch::try_new(
922///     schema.clone(),
923///     vec![
924///         Arc::new(Int32Array::from_iter_values(0..20)),
925///         Arc::new(StringArray::from_iter_values(
926///             (0..20).map(|i| format!("str-{}", i)),
927///         )),
928///     ],
929/// )
930/// .unwrap();
931///
932/// let indices = UInt32Array::from(vec![1, 5, 10]);
933/// let taken = take_record_batch(&batch, &indices).unwrap();
934///
935/// let expected = RecordBatch::try_new(
936///     schema,
937///     vec![
938///         Arc::new(Int32Array::from(vec![1, 5, 10])),
939///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
940///     ],
941/// )
942/// .unwrap();
943/// assert_eq!(taken, expected);
944/// ```
945pub fn take_record_batch(
946    record_batch: &RecordBatch,
947    indices: &dyn Array,
948) -> Result<RecordBatch, ArrowError> {
949    let columns = record_batch
950        .columns()
951        .iter()
952        .map(|c| take(c, indices, None))
953        .collect::<Result<Vec<_>, _>>()?;
954    RecordBatch::try_new(record_batch.schema(), columns)
955}
956
957#[cfg(test)]
958mod tests {
959    use super::*;
960    use arrow_array::builder::*;
961    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
962    use arrow_data::ArrayData;
963    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
964
965    fn test_take_decimal_arrays(
966        data: Vec<Option<i128>>,
967        index: &UInt32Array,
968        options: Option<TakeOptions>,
969        expected_data: Vec<Option<i128>>,
970        precision: &u8,
971        scale: &i8,
972    ) -> Result<(), ArrowError> {
973        let output = data
974            .into_iter()
975            .collect::<Decimal128Array>()
976            .with_precision_and_scale(*precision, *scale)
977            .unwrap();
978
979        let expected = expected_data
980            .into_iter()
981            .collect::<Decimal128Array>()
982            .with_precision_and_scale(*precision, *scale)
983            .unwrap();
984
985        let expected = Arc::new(expected) as ArrayRef;
986        let output = take(&output, index, options).unwrap();
987        assert_eq!(&output, &expected);
988        Ok(())
989    }
990
991    fn test_take_boolean_arrays(
992        data: Vec<Option<bool>>,
993        index: &UInt32Array,
994        options: Option<TakeOptions>,
995        expected_data: Vec<Option<bool>>,
996    ) {
997        let output = BooleanArray::from(data);
998        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
999        let output = take(&output, index, options).unwrap();
1000        assert_eq!(&output, &expected)
1001    }
1002
1003    fn test_take_primitive_arrays<T>(
1004        data: Vec<Option<T::Native>>,
1005        index: &UInt32Array,
1006        options: Option<TakeOptions>,
1007        expected_data: Vec<Option<T::Native>>,
1008    ) -> Result<(), ArrowError>
1009    where
1010        T: ArrowPrimitiveType,
1011        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1012    {
1013        let output = PrimitiveArray::<T>::from(data);
1014        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1015        let output = take(&output, index, options)?;
1016        assert_eq!(&output, &expected);
1017        Ok(())
1018    }
1019
1020    fn test_take_primitive_arrays_non_null<T>(
1021        data: Vec<T::Native>,
1022        index: &UInt32Array,
1023        options: Option<TakeOptions>,
1024        expected_data: Vec<Option<T::Native>>,
1025    ) -> Result<(), ArrowError>
1026    where
1027        T: ArrowPrimitiveType,
1028        PrimitiveArray<T>: From<Vec<T::Native>>,
1029        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1030    {
1031        let output = PrimitiveArray::<T>::from(data);
1032        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1033        let output = take(&output, index, options)?;
1034        assert_eq!(&output, &expected);
1035        Ok(())
1036    }
1037
1038    fn test_take_impl_primitive_arrays<T, I>(
1039        data: Vec<Option<T::Native>>,
1040        index: &PrimitiveArray<I>,
1041        options: Option<TakeOptions>,
1042        expected_data: Vec<Option<T::Native>>,
1043    ) where
1044        T: ArrowPrimitiveType,
1045        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1046        I: ArrowPrimitiveType,
1047    {
1048        let output = PrimitiveArray::<T>::from(data);
1049        let expected = PrimitiveArray::<T>::from(expected_data);
1050        let output = take(&output, index, options).unwrap();
1051        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1052        assert_eq!(output, &expected)
1053    }
1054
1055    // create a simple struct for testing purposes
1056    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1057        let mut struct_builder = StructBuilder::new(
1058            Fields::from(vec![
1059                Field::new("a", DataType::Boolean, true),
1060                Field::new("b", DataType::Int32, true),
1061            ]),
1062            vec![
1063                Box::new(BooleanBuilder::with_capacity(values.len())),
1064                Box::new(Int32Builder::with_capacity(values.len())),
1065            ],
1066        );
1067
1068        for value in values {
1069            struct_builder
1070                .field_builder::<BooleanBuilder>(0)
1071                .unwrap()
1072                .append_option(value.and_then(|v| v.0));
1073            struct_builder
1074                .field_builder::<Int32Builder>(1)
1075                .unwrap()
1076                .append_option(value.and_then(|v| v.1));
1077            struct_builder.append(value.is_some());
1078        }
1079        struct_builder.finish()
1080    }
1081
1082    #[test]
1083    fn test_take_decimal128_non_null_indices() {
1084        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1085        let precision: u8 = 10;
1086        let scale: i8 = 5;
1087        test_take_decimal_arrays(
1088            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1089            &index,
1090            None,
1091            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1092            &precision,
1093            &scale,
1094        )
1095        .unwrap();
1096    }
1097
1098    #[test]
1099    fn test_take_decimal128() {
1100        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1101        let precision: u8 = 10;
1102        let scale: i8 = 5;
1103        test_take_decimal_arrays(
1104            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1105            &index,
1106            None,
1107            vec![Some(3), None, Some(1), Some(3), Some(2)],
1108            &precision,
1109            &scale,
1110        )
1111        .unwrap();
1112    }
1113
1114    #[test]
1115    fn test_take_primitive_non_null_indices() {
1116        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1117        test_take_primitive_arrays::<Int8Type>(
1118            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1119            &index,
1120            None,
1121            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1122        )
1123        .unwrap();
1124    }
1125
1126    #[test]
1127    fn test_take_primitive_non_null_values() {
1128        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1129        test_take_primitive_arrays::<Int8Type>(
1130            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1131            &index,
1132            None,
1133            vec![Some(3), None, Some(1), Some(3), Some(2)],
1134        )
1135        .unwrap();
1136    }
1137
1138    #[test]
1139    fn test_take_primitive_non_null() {
1140        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1141        test_take_primitive_arrays::<Int8Type>(
1142            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1143            &index,
1144            None,
1145            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1146        )
1147        .unwrap();
1148    }
1149
1150    #[test]
1151    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1152        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1153        let index = index.slice(2, 4);
1154        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1155
1156        assert_eq!(
1157            index,
1158            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1159        );
1160
1161        test_take_primitive_arrays_non_null::<Int64Type>(
1162            vec![0, 10, 20, 30, 40, 50],
1163            index,
1164            None,
1165            vec![Some(20), Some(30), None, None],
1166        )
1167        .unwrap();
1168    }
1169
1170    #[test]
1171    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1172        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1173        let index = index.slice(2, 4);
1174        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1175
1176        assert_eq!(
1177            index,
1178            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1179        );
1180
1181        test_take_primitive_arrays::<Int64Type>(
1182            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1183            index,
1184            None,
1185            vec![Some(20), Some(30), None, None],
1186        )
1187        .unwrap();
1188    }
1189
1190    #[test]
1191    fn test_take_primitive() {
1192        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1193
1194        // int8
1195        test_take_primitive_arrays::<Int8Type>(
1196            vec![Some(0), None, Some(2), Some(3), None],
1197            &index,
1198            None,
1199            vec![Some(3), None, None, Some(3), Some(2)],
1200        )
1201        .unwrap();
1202
1203        // int16
1204        test_take_primitive_arrays::<Int16Type>(
1205            vec![Some(0), None, Some(2), Some(3), None],
1206            &index,
1207            None,
1208            vec![Some(3), None, None, Some(3), Some(2)],
1209        )
1210        .unwrap();
1211
1212        // int32
1213        test_take_primitive_arrays::<Int32Type>(
1214            vec![Some(0), None, Some(2), Some(3), None],
1215            &index,
1216            None,
1217            vec![Some(3), None, None, Some(3), Some(2)],
1218        )
1219        .unwrap();
1220
1221        // int64
1222        test_take_primitive_arrays::<Int64Type>(
1223            vec![Some(0), None, Some(2), Some(3), None],
1224            &index,
1225            None,
1226            vec![Some(3), None, None, Some(3), Some(2)],
1227        )
1228        .unwrap();
1229
1230        // uint8
1231        test_take_primitive_arrays::<UInt8Type>(
1232            vec![Some(0), None, Some(2), Some(3), None],
1233            &index,
1234            None,
1235            vec![Some(3), None, None, Some(3), Some(2)],
1236        )
1237        .unwrap();
1238
1239        // uint16
1240        test_take_primitive_arrays::<UInt16Type>(
1241            vec![Some(0), None, Some(2), Some(3), None],
1242            &index,
1243            None,
1244            vec![Some(3), None, None, Some(3), Some(2)],
1245        )
1246        .unwrap();
1247
1248        // uint32
1249        test_take_primitive_arrays::<UInt32Type>(
1250            vec![Some(0), None, Some(2), Some(3), None],
1251            &index,
1252            None,
1253            vec![Some(3), None, None, Some(3), Some(2)],
1254        )
1255        .unwrap();
1256
1257        // int64
1258        test_take_primitive_arrays::<Int64Type>(
1259            vec![Some(0), None, Some(2), Some(-15), None],
1260            &index,
1261            None,
1262            vec![Some(-15), None, None, Some(-15), Some(2)],
1263        )
1264        .unwrap();
1265
1266        // interval_year_month
1267        test_take_primitive_arrays::<IntervalYearMonthType>(
1268            vec![Some(0), None, Some(2), Some(-15), None],
1269            &index,
1270            None,
1271            vec![Some(-15), None, None, Some(-15), Some(2)],
1272        )
1273        .unwrap();
1274
1275        // interval_day_time
1276        let v1 = IntervalDayTime::new(0, 0);
1277        let v2 = IntervalDayTime::new(2, 0);
1278        let v3 = IntervalDayTime::new(-15, 0);
1279        test_take_primitive_arrays::<IntervalDayTimeType>(
1280            vec![Some(v1), None, Some(v2), Some(v3), None],
1281            &index,
1282            None,
1283            vec![Some(v3), None, None, Some(v3), Some(v2)],
1284        )
1285        .unwrap();
1286
1287        // interval_month_day_nano
1288        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1289        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1290        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1291        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1292            vec![Some(v1), None, Some(v2), Some(v3), None],
1293            &index,
1294            None,
1295            vec![Some(v3), None, None, Some(v3), Some(v2)],
1296        )
1297        .unwrap();
1298
1299        // duration_second
1300        test_take_primitive_arrays::<DurationSecondType>(
1301            vec![Some(0), None, Some(2), Some(-15), None],
1302            &index,
1303            None,
1304            vec![Some(-15), None, None, Some(-15), Some(2)],
1305        )
1306        .unwrap();
1307
1308        // duration_millisecond
1309        test_take_primitive_arrays::<DurationMillisecondType>(
1310            vec![Some(0), None, Some(2), Some(-15), None],
1311            &index,
1312            None,
1313            vec![Some(-15), None, None, Some(-15), Some(2)],
1314        )
1315        .unwrap();
1316
1317        // duration_microsecond
1318        test_take_primitive_arrays::<DurationMicrosecondType>(
1319            vec![Some(0), None, Some(2), Some(-15), None],
1320            &index,
1321            None,
1322            vec![Some(-15), None, None, Some(-15), Some(2)],
1323        )
1324        .unwrap();
1325
1326        // duration_nanosecond
1327        test_take_primitive_arrays::<DurationNanosecondType>(
1328            vec![Some(0), None, Some(2), Some(-15), None],
1329            &index,
1330            None,
1331            vec![Some(-15), None, None, Some(-15), Some(2)],
1332        )
1333        .unwrap();
1334
1335        // float32
1336        test_take_primitive_arrays::<Float32Type>(
1337            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1338            &index,
1339            None,
1340            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1341        )
1342        .unwrap();
1343
1344        // float64
1345        test_take_primitive_arrays::<Float64Type>(
1346            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1347            &index,
1348            None,
1349            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1350        )
1351        .unwrap();
1352    }
1353
1354    #[test]
1355    fn test_take_preserve_timezone() {
1356        let index = Int64Array::from(vec![Some(0), None]);
1357
1358        let input = TimestampNanosecondArray::from(vec![
1359            1_639_715_368_000_000_000,
1360            1_639_715_368_000_000_000,
1361        ])
1362        .with_timezone("UTC".to_string());
1363        let result = take(&input, &index, None).unwrap();
1364        match result.data_type() {
1365            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1366                assert_eq!(tz.clone(), Some("UTC".into()))
1367            }
1368            _ => panic!(),
1369        }
1370    }
1371
1372    #[test]
1373    fn test_take_impl_primitive_with_int64_indices() {
1374        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1375
1376        // int16
1377        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1378            vec![Some(0), None, Some(2), Some(3), None],
1379            &index,
1380            None,
1381            vec![Some(3), None, None, Some(3), Some(2)],
1382        );
1383
1384        // int64
1385        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1386            vec![Some(0), None, Some(2), Some(-15), None],
1387            &index,
1388            None,
1389            vec![Some(-15), None, None, Some(-15), Some(2)],
1390        );
1391
1392        // uint64
1393        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1394            vec![Some(0), None, Some(2), Some(3), None],
1395            &index,
1396            None,
1397            vec![Some(3), None, None, Some(3), Some(2)],
1398        );
1399
1400        // duration_millisecond
1401        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1402            vec![Some(0), None, Some(2), Some(-15), None],
1403            &index,
1404            None,
1405            vec![Some(-15), None, None, Some(-15), Some(2)],
1406        );
1407
1408        // float32
1409        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1410            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1411            &index,
1412            None,
1413            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1414        );
1415    }
1416
1417    #[test]
1418    fn test_take_impl_primitive_with_uint8_indices() {
1419        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1420
1421        // int16
1422        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1423            vec![Some(0), None, Some(2), Some(3), None],
1424            &index,
1425            None,
1426            vec![Some(3), None, None, Some(3), Some(2)],
1427        );
1428
1429        // duration_millisecond
1430        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1431            vec![Some(0), None, Some(2), Some(-15), None],
1432            &index,
1433            None,
1434            vec![Some(-15), None, None, Some(-15), Some(2)],
1435        );
1436
1437        // float32
1438        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1439            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1440            &index,
1441            None,
1442            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1443        );
1444    }
1445
1446    #[test]
1447    fn test_take_bool() {
1448        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1449        // boolean
1450        test_take_boolean_arrays(
1451            vec![Some(false), None, Some(true), Some(false), None],
1452            &index,
1453            None,
1454            vec![Some(false), None, None, Some(false), Some(true)],
1455        );
1456    }
1457
1458    #[test]
1459    fn test_take_bool_nullable_index() {
1460        // indices where the masked invalid elements would be out of bounds
1461        let index_data = ArrayData::try_new(
1462            DataType::UInt32,
1463            6,
1464            Some(Buffer::from_iter(vec![
1465                false, true, false, true, false, true,
1466            ])),
1467            0,
1468            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1469            vec![],
1470        )
1471        .unwrap();
1472        let index = UInt32Array::from(index_data);
1473        test_take_boolean_arrays(
1474            vec![Some(true), None, Some(false)],
1475            &index,
1476            None,
1477            vec![None, Some(true), None, None, None, Some(false)],
1478        );
1479    }
1480
1481    #[test]
1482    fn test_take_bool_nullable_index_nonnull_values() {
1483        // indices where the masked invalid elements would be out of bounds
1484        let index_data = ArrayData::try_new(
1485            DataType::UInt32,
1486            6,
1487            Some(Buffer::from_iter(vec![
1488                false, true, false, true, false, true,
1489            ])),
1490            0,
1491            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1492            vec![],
1493        )
1494        .unwrap();
1495        let index = UInt32Array::from(index_data);
1496        test_take_boolean_arrays(
1497            vec![Some(true), Some(true), Some(false)],
1498            &index,
1499            None,
1500            vec![None, Some(true), None, Some(true), None, Some(false)],
1501        );
1502    }
1503
1504    #[test]
1505    fn test_take_bool_with_offset() {
1506        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1507        let index = index.slice(2, 4);
1508        let index = index
1509            .as_any()
1510            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1511            .unwrap();
1512
1513        // boolean
1514        test_take_boolean_arrays(
1515            vec![Some(false), None, Some(true), Some(false), None],
1516            index,
1517            None,
1518            vec![None, Some(false), Some(true), None],
1519        );
1520    }
1521
1522    fn _test_take_string<'a, K>()
1523    where
1524        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1525    {
1526        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1527
1528        let array = K::from(vec![
1529            Some("one"),
1530            None,
1531            Some("three"),
1532            Some("four"),
1533            Some("five"),
1534        ]);
1535        let actual = take(&array, &index, None).unwrap();
1536        assert_eq!(actual.len(), index.len());
1537
1538        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1539
1540        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1541
1542        assert_eq!(actual, &expected);
1543    }
1544
1545    #[test]
1546    fn test_take_string() {
1547        _test_take_string::<StringArray>()
1548    }
1549
1550    #[test]
1551    fn test_take_large_string() {
1552        _test_take_string::<LargeStringArray>()
1553    }
1554
1555    #[test]
1556    fn test_take_slice_string() {
1557        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1558        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1559        let indices_slice = indices.slice(1, 4);
1560        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1561        let result = take(&strings, &indices_slice, None).unwrap();
1562        assert_eq!(result.as_ref(), &expected);
1563    }
1564
1565    fn _test_byte_view<T>()
1566    where
1567        T: ByteViewType,
1568        str: AsRef<T::Native>,
1569        T::Native: PartialEq,
1570    {
1571        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1572        let array = {
1573            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1574            let mut builder = GenericByteViewBuilder::<T>::new();
1575            builder.append_value("hello");
1576            builder.append_value("world");
1577            builder.append_null();
1578            builder.append_value("large payload over 12 bytes");
1579            builder.append_value("lulu");
1580            builder.finish()
1581        };
1582
1583        let actual = take(&array, &index, None).unwrap();
1584
1585        assert_eq!(actual.len(), index.len());
1586
1587        let expected = {
1588            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1589            let mut builder = GenericByteViewBuilder::<T>::new();
1590            builder.append_value("large payload over 12 bytes");
1591            builder.append_null();
1592            builder.append_value("world");
1593            builder.append_value("large payload over 12 bytes");
1594            builder.append_value("lulu");
1595            builder.append_null();
1596            builder.finish()
1597        };
1598
1599        assert_eq!(actual.as_ref(), &expected);
1600    }
1601
1602    #[test]
1603    fn test_take_string_view() {
1604        _test_byte_view::<StringViewType>()
1605    }
1606
1607    #[test]
1608    fn test_take_binary_view() {
1609        _test_byte_view::<BinaryViewType>()
1610    }
1611
1612    macro_rules! test_take_list {
1613        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1614            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1615            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1616            // Construct offsets
1617            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1618            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1619            // Construct a list array from the above two
1620            let list_data_type =
1621                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1622            let list_data = ArrayData::builder(list_data_type.clone())
1623                .len(4)
1624                .add_buffer(value_offsets)
1625                .add_child_data(value_data)
1626                .build()
1627                .unwrap();
1628            let list_array = $list_array_type::from(list_data);
1629
1630            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1631            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1632
1633            let a = take(&list_array, &index, None).unwrap();
1634            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1635
1636            // construct a value array with expected results:
1637            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1638            let expected_data = Int32Array::from(vec![
1639                Some(2),
1640                Some(3),
1641                Some(-1),
1642                Some(-2),
1643                Some(-1),
1644                Some(0),
1645                Some(0),
1646                Some(0),
1647            ])
1648            .into_data();
1649            // construct offsets
1650            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1651            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1652            // construct list array from the two
1653            let expected_list_data = ArrayData::builder(list_data_type)
1654                .len(5)
1655                // null buffer remains the same as only the indices have nulls
1656                .nulls(index.nulls().cloned())
1657                .add_buffer(expected_offsets)
1658                .add_child_data(expected_data)
1659                .build()
1660                .unwrap();
1661            let expected_list_array = $list_array_type::from(expected_list_data);
1662
1663            assert_eq!(a, &expected_list_array);
1664        }};
1665    }
1666
1667    macro_rules! test_take_list_with_value_nulls {
1668        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1669            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1670            let value_data = Int32Array::from(vec![
1671                Some(0),
1672                None,
1673                Some(0),
1674                Some(-1),
1675                Some(-2),
1676                Some(3),
1677                None,
1678                Some(5),
1679                None,
1680            ])
1681            .into_data();
1682            // Construct offsets
1683            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1684            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1685            // Construct a list array from the above two
1686            let list_data_type =
1687                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1688            let list_data = ArrayData::builder(list_data_type.clone())
1689                .len(4)
1690                .add_buffer(value_offsets)
1691                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1692                .add_child_data(value_data)
1693                .build()
1694                .unwrap();
1695            let list_array = $list_array_type::from(list_data);
1696
1697            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1698            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1699
1700            let a = take(&list_array, &index, None).unwrap();
1701            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1702
1703            // construct a value array with expected results:
1704            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1705            let expected_data = Int32Array::from(vec![
1706                None,
1707                Some(-1),
1708                Some(-2),
1709                Some(3),
1710                Some(5),
1711                None,
1712                Some(0),
1713                None,
1714                Some(0),
1715            ])
1716            .into_data();
1717            // construct offsets
1718            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1719            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1720            // construct list array from the two
1721            let expected_list_data = ArrayData::builder(list_data_type)
1722                .len(5)
1723                // null buffer remains the same as only the indices have nulls
1724                .nulls(index.nulls().cloned())
1725                .add_buffer(expected_offsets)
1726                .add_child_data(expected_data)
1727                .build()
1728                .unwrap();
1729            let expected_list_array = $list_array_type::from(expected_list_data);
1730
1731            assert_eq!(a, &expected_list_array);
1732        }};
1733    }
1734
1735    macro_rules! test_take_list_with_nulls {
1736        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1737            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1738            let value_data = Int32Array::from(vec![
1739                Some(0),
1740                None,
1741                Some(0),
1742                Some(-1),
1743                Some(-2),
1744                Some(3),
1745                Some(5),
1746                None,
1747            ])
1748            .into_data();
1749            // Construct offsets
1750            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1751            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1752            // Construct a list array from the above two
1753            let list_data_type =
1754                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1755            let list_data = ArrayData::builder(list_data_type.clone())
1756                .len(4)
1757                .add_buffer(value_offsets)
1758                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1759                .add_child_data(value_data)
1760                .build()
1761                .unwrap();
1762            let list_array = $list_array_type::from(list_data);
1763
1764            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1765            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1766
1767            let a = take(&list_array, &index, None).unwrap();
1768            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1769
1770            // construct a value array with expected results:
1771            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1772            let expected_data = Int32Array::from(vec![
1773                Some(-1),
1774                Some(-2),
1775                Some(3),
1776                Some(5),
1777                None,
1778                Some(0),
1779                None,
1780                Some(0),
1781            ])
1782            .into_data();
1783            // construct offsets
1784            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1785            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1786            // construct list array from the two
1787            let mut null_bits: [u8; 1] = [0; 1];
1788            bit_util::set_bit(&mut null_bits, 2);
1789            bit_util::set_bit(&mut null_bits, 3);
1790            bit_util::set_bit(&mut null_bits, 4);
1791            let expected_list_data = ArrayData::builder(list_data_type)
1792                .len(5)
1793                // null buffer must be recalculated as both values and indices have nulls
1794                .null_bit_buffer(Some(Buffer::from(null_bits)))
1795                .add_buffer(expected_offsets)
1796                .add_child_data(expected_data)
1797                .build()
1798                .unwrap();
1799            let expected_list_array = $list_array_type::from(expected_list_data);
1800
1801            assert_eq!(a, &expected_list_array);
1802        }};
1803    }
1804
1805    fn do_take_fixed_size_list_test<T>(
1806        length: <Int32Type as ArrowPrimitiveType>::Native,
1807        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1808        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1809        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1810    ) where
1811        T: ArrowPrimitiveType,
1812        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1813    {
1814        let indices = UInt32Array::from(indices);
1815
1816        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1817
1818        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1819
1820        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1821
1822        assert_eq!(&output, &expected)
1823    }
1824
1825    #[test]
1826    fn test_take_list() {
1827        test_take_list!(i32, List, ListArray);
1828    }
1829
1830    #[test]
1831    fn test_take_large_list() {
1832        test_take_list!(i64, LargeList, LargeListArray);
1833    }
1834
1835    #[test]
1836    fn test_take_list_with_value_nulls() {
1837        test_take_list_with_value_nulls!(i32, List, ListArray);
1838    }
1839
1840    #[test]
1841    fn test_take_large_list_with_value_nulls() {
1842        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1843    }
1844
1845    #[test]
1846    fn test_test_take_list_with_nulls() {
1847        test_take_list_with_nulls!(i32, List, ListArray);
1848    }
1849
1850    #[test]
1851    fn test_test_take_large_list_with_nulls() {
1852        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1853    }
1854
1855    #[test]
1856    fn test_take_fixed_size_list() {
1857        do_take_fixed_size_list_test::<Int32Type>(
1858            3,
1859            vec![
1860                Some(vec![None, Some(1), Some(2)]),
1861                Some(vec![Some(3), Some(4), None]),
1862                Some(vec![Some(6), Some(7), Some(8)]),
1863            ],
1864            vec![2, 1, 0],
1865            vec![
1866                Some(vec![Some(6), Some(7), Some(8)]),
1867                Some(vec![Some(3), Some(4), None]),
1868                Some(vec![None, Some(1), Some(2)]),
1869            ],
1870        );
1871
1872        do_take_fixed_size_list_test::<UInt8Type>(
1873            1,
1874            vec![
1875                Some(vec![Some(1)]),
1876                Some(vec![Some(2)]),
1877                Some(vec![Some(3)]),
1878                Some(vec![Some(4)]),
1879                Some(vec![Some(5)]),
1880                Some(vec![Some(6)]),
1881                Some(vec![Some(7)]),
1882                Some(vec![Some(8)]),
1883            ],
1884            vec![2, 7, 0],
1885            vec![
1886                Some(vec![Some(3)]),
1887                Some(vec![Some(8)]),
1888                Some(vec![Some(1)]),
1889            ],
1890        );
1891
1892        do_take_fixed_size_list_test::<UInt64Type>(
1893            3,
1894            vec![
1895                Some(vec![Some(10), Some(11), Some(12)]),
1896                Some(vec![Some(13), Some(14), Some(15)]),
1897                None,
1898                Some(vec![Some(16), Some(17), Some(18)]),
1899            ],
1900            vec![3, 2, 1, 2, 0],
1901            vec![
1902                Some(vec![Some(16), Some(17), Some(18)]),
1903                None,
1904                Some(vec![Some(13), Some(14), Some(15)]),
1905                None,
1906                Some(vec![Some(10), Some(11), Some(12)]),
1907            ],
1908        );
1909    }
1910
1911    #[test]
1912    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1913    fn test_take_list_out_of_bounds() {
1914        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
1915        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1916        // Construct offsets
1917        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1918        // Construct a list array from the above two
1919        let list_data_type =
1920            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1921        let list_data = ArrayData::builder(list_data_type)
1922            .len(3)
1923            .add_buffer(value_offsets)
1924            .add_child_data(value_data)
1925            .build()
1926            .unwrap();
1927        let list_array = ListArray::from(list_data);
1928
1929        let index = UInt32Array::from(vec![1000]);
1930
1931        // A panic is expected here since we have not supplied the check_bounds
1932        // option.
1933        take(&list_array, &index, None).unwrap();
1934    }
1935
1936    #[test]
1937    fn test_take_map() {
1938        let values = Int32Array::from(vec![1, 2, 3, 4]);
1939        let array =
1940            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1941                .unwrap();
1942
1943        let index = UInt32Array::from(vec![0]);
1944
1945        let result = take(&array, &index, None).unwrap();
1946        let expected: ArrayRef = Arc::new(
1947            MapArray::new_from_strings(
1948                vec!["a", "b", "c"].into_iter(),
1949                &values.slice(0, 3),
1950                &[0, 3],
1951            )
1952            .unwrap(),
1953        );
1954        assert_eq!(&expected, &result);
1955    }
1956
1957    #[test]
1958    fn test_take_struct() {
1959        let array = create_test_struct(vec![
1960            Some((Some(true), Some(42))),
1961            Some((Some(false), Some(28))),
1962            Some((Some(false), Some(19))),
1963            Some((Some(true), Some(31))),
1964            None,
1965        ]);
1966
1967        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1968        let actual = take(&array, &index, None).unwrap();
1969        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1970        assert_eq!(index.len(), actual.len());
1971        assert_eq!(1, actual.null_count());
1972
1973        let expected = create_test_struct(vec![
1974            Some((Some(true), Some(42))),
1975            Some((Some(true), Some(31))),
1976            Some((Some(false), Some(28))),
1977            Some((Some(true), Some(42))),
1978            Some((Some(false), Some(19))),
1979            None,
1980        ]);
1981
1982        assert_eq!(&expected, actual);
1983
1984        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
1985        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
1986        let index = UInt32Array::from(vec![0, 2, 1, 4]);
1987        let actual = take(&empty_struct_arr, &index, None).unwrap();
1988
1989        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
1990        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
1991        assert_eq!(&expected_struct_arr, actual.as_struct());
1992    }
1993
1994    #[test]
1995    fn test_take_struct_with_null_indices() {
1996        let array = create_test_struct(vec![
1997            Some((Some(true), Some(42))),
1998            Some((Some(false), Some(28))),
1999            Some((Some(false), Some(19))),
2000            Some((Some(true), Some(31))),
2001            None,
2002        ]);
2003
2004        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2005        let actual = take(&array, &index, None).unwrap();
2006        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2007        assert_eq!(index.len(), actual.len());
2008        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
2009
2010        let expected = create_test_struct(vec![
2011            None,
2012            Some((Some(true), Some(31))),
2013            Some((Some(false), Some(28))),
2014            None,
2015            Some((Some(true), Some(42))),
2016            None,
2017        ]);
2018
2019        assert_eq!(&expected, actual);
2020    }
2021
2022    #[test]
2023    fn test_take_out_of_bounds() {
2024        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2025        let take_opt = TakeOptions { check_bounds: true };
2026
2027        // int64
2028        let result = test_take_primitive_arrays::<Int64Type>(
2029            vec![Some(0), None, Some(2), Some(3), None],
2030            &index,
2031            Some(take_opt),
2032            vec![None],
2033        );
2034        assert!(result.is_err());
2035    }
2036
2037    #[test]
2038    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2039    fn test_take_out_of_bounds_panic() {
2040        let index = UInt32Array::from(vec![Some(1000)]);
2041
2042        test_take_primitive_arrays::<Int64Type>(
2043            vec![Some(0), Some(1), Some(2), Some(3)],
2044            &index,
2045            None,
2046            vec![None],
2047        )
2048        .unwrap();
2049    }
2050
2051    #[test]
2052    fn test_null_array_smaller_than_indices() {
2053        let values = NullArray::new(2);
2054        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2055
2056        let result = take(&values, &indices, None).unwrap();
2057        let expected: ArrayRef = Arc::new(NullArray::new(3));
2058        assert_eq!(&result, &expected);
2059    }
2060
2061    #[test]
2062    fn test_null_array_larger_than_indices() {
2063        let values = NullArray::new(5);
2064        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2065
2066        let result = take(&values, &indices, None).unwrap();
2067        let expected: ArrayRef = Arc::new(NullArray::new(3));
2068        assert_eq!(&result, &expected);
2069    }
2070
2071    #[test]
2072    fn test_null_array_indices_out_of_bounds() {
2073        let values = NullArray::new(5);
2074        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2075
2076        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2077        assert_eq!(
2078            result.unwrap_err().to_string(),
2079            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2080        );
2081    }
2082
2083    #[test]
2084    fn test_take_dict() {
2085        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2086
2087        dict_builder.append("foo").unwrap();
2088        dict_builder.append("bar").unwrap();
2089        dict_builder.append("").unwrap();
2090        dict_builder.append_null();
2091        dict_builder.append("foo").unwrap();
2092        dict_builder.append("bar").unwrap();
2093        dict_builder.append("bar").unwrap();
2094        dict_builder.append("foo").unwrap();
2095
2096        let array = dict_builder.finish();
2097        let dict_values = array.values().clone();
2098        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2099
2100        let indices = UInt32Array::from(vec![
2101            Some(0), // first "foo"
2102            Some(7), // last "foo"
2103            None,    // null index should return null
2104            Some(5), // second "bar"
2105            Some(6), // another "bar"
2106            Some(2), // empty string
2107            Some(3), // input is null at this index
2108        ]);
2109
2110        let result = take(&array, &indices, None).unwrap();
2111        let result = result
2112            .as_any()
2113            .downcast_ref::<DictionaryArray<Int16Type>>()
2114            .unwrap();
2115
2116        let result_values: StringArray = result.values().to_data().into();
2117
2118        // dictionary values should stay the same
2119        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2120        assert_eq!(&expected_values, dict_values);
2121        assert_eq!(&expected_values, &result_values);
2122
2123        let expected_keys = Int16Array::from(vec![
2124            Some(0),
2125            Some(0),
2126            None,
2127            Some(1),
2128            Some(1),
2129            Some(2),
2130            None,
2131        ]);
2132        assert_eq!(result.keys(), &expected_keys);
2133    }
2134
2135    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2136    where
2137        S: OffsetSizeTrait + 'static,
2138        T: ArrowPrimitiveType,
2139        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2140    {
2141        GenericListArray::from_iter_primitive::<T, _, _>(
2142            data.iter()
2143                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2144        )
2145    }
2146
2147    #[test]
2148    fn test_take_value_index_from_list() {
2149        let list = build_generic_list::<i32, Int32Type>(vec![
2150            Some(vec![0, 1]),
2151            Some(vec![2, 3, 4]),
2152            Some(vec![5, 6, 7, 8, 9]),
2153        ]);
2154        let indices = UInt32Array::from(vec![2, 0]);
2155
2156        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2157
2158        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2159        assert_eq!(offsets, vec![0, 5, 7]);
2160        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2161    }
2162
2163    #[test]
2164    fn test_take_value_index_from_large_list() {
2165        let list = build_generic_list::<i64, Int32Type>(vec![
2166            Some(vec![0, 1]),
2167            Some(vec![2, 3, 4]),
2168            Some(vec![5, 6, 7, 8, 9]),
2169        ]);
2170        let indices = UInt32Array::from(vec![2, 0]);
2171
2172        let (indexed, offsets, null_buf) =
2173            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2174
2175        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2176        assert_eq!(offsets, vec![0, 5, 7]);
2177        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2178    }
2179
2180    #[test]
2181    fn test_take_runs() {
2182        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2183
2184        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2185        builder.extend(logical_array.into_iter().map(Some));
2186        let run_array = builder.finish();
2187
2188        let take_indices: PrimitiveArray<Int32Type> =
2189            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2190
2191        let take_out = take_run(&run_array, &take_indices).unwrap();
2192
2193        assert_eq!(take_out.len(), 7);
2194        assert_eq!(take_out.run_ends().len(), 7);
2195        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2196
2197        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2198        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2199    }
2200
2201    #[test]
2202    fn test_take_value_index_from_fixed_list() {
2203        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2204            vec![
2205                Some(vec![Some(1), Some(2), None]),
2206                Some(vec![Some(4), None, Some(6)]),
2207                None,
2208                Some(vec![None, Some(8), Some(9)]),
2209            ],
2210            3,
2211        );
2212
2213        let indices = UInt32Array::from(vec![2, 1, 0]);
2214        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2215
2216        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2217
2218        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2219        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2220
2221        assert_eq!(
2222            indexed,
2223            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2224        );
2225    }
2226
2227    #[test]
2228    fn test_take_null_indices() {
2229        // Build indices with values that are out of bounds, but masked by null mask
2230        let indices = Int32Array::new(
2231            vec![1, 2, 400, 400].into(),
2232            Some(NullBuffer::from(vec![true, true, false, false])),
2233        );
2234        let values = Int32Array::from(vec![1, 23, 4, 5]);
2235        let r = take(&values, &indices, None).unwrap();
2236        let values = r
2237            .as_primitive::<Int32Type>()
2238            .into_iter()
2239            .collect::<Vec<_>>();
2240        assert_eq!(&values, &[Some(23), Some(4), None, None])
2241    }
2242
2243    #[test]
2244    fn test_take_fixed_size_list_null_indices() {
2245        let indices = Int32Array::from_iter([Some(0), None]);
2246        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2247        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2248        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2249
2250        let r = take(&values, &indices, None).unwrap();
2251        let values = r
2252            .as_fixed_size_list()
2253            .values()
2254            .as_primitive::<Int32Type>()
2255            .into_iter()
2256            .collect::<Vec<_>>();
2257        assert_eq!(values, &[Some(0), Some(1), None, None])
2258    }
2259
2260    #[test]
2261    fn test_take_bytes_null_indices() {
2262        let indices = Int32Array::new(
2263            vec![0, 1, 400, 400].into(),
2264            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2265        );
2266        let values = StringArray::from(vec![Some("foo"), None]);
2267        let r = take(&values, &indices, None).unwrap();
2268        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2269        assert_eq!(&values, &[Some("foo"), None, None, None])
2270    }
2271
2272    #[test]
2273    fn test_take_union_sparse() {
2274        let structs = create_test_struct(vec![
2275            Some((Some(true), Some(42))),
2276            Some((Some(false), Some(28))),
2277            Some((Some(false), Some(19))),
2278            Some((Some(true), Some(31))),
2279            None,
2280        ]);
2281        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2282        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2283
2284        let union_fields = [
2285            (
2286                0,
2287                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2288            ),
2289            (
2290                1,
2291                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2292            ),
2293        ]
2294        .into_iter()
2295        .collect();
2296        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2297        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2298
2299        let indices = vec![0, 3, 1, 0, 2, 4];
2300        let index = UInt32Array::from(indices.clone());
2301        let actual = take(&array, &index, None).unwrap();
2302        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2303        let strings = actual.child(1);
2304        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2305
2306        let actual = strings.iter().collect::<Vec<_>>();
2307        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2308        assert_eq!(expected, actual);
2309    }
2310
2311    #[test]
2312    fn test_take_union_dense() {
2313        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2314        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2315        let ints = vec![10, 20, 30, 40];
2316        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2317
2318        let indices = vec![0, 3, 1, 0, 2, 4];
2319
2320        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2321        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2322        let taken_ints = vec![10, 20, 10, 30];
2323        let taken_strings = vec![Some("a"), None];
2324
2325        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2326        let offsets = <ScalarBuffer<i32>>::from(offsets);
2327        let ints = UInt32Array::from(ints);
2328        let strings = StringArray::from(strings);
2329
2330        let union_fields = [
2331            (
2332                0,
2333                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2334            ),
2335            (
2336                1,
2337                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2338            ),
2339        ]
2340        .into_iter()
2341        .collect();
2342
2343        let array = UnionArray::try_new(
2344            union_fields,
2345            type_ids,
2346            Some(offsets),
2347            vec![Arc::new(ints), Arc::new(strings)],
2348        )
2349        .unwrap();
2350
2351        let index = UInt32Array::from(indices);
2352
2353        let actual = take(&array, &index, None).unwrap();
2354        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2355
2356        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2357        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2358        assert_eq!(
2359            UInt32Array::from(actual.child(0).to_data()),
2360            UInt32Array::from(taken_ints)
2361        );
2362        assert_eq!(
2363            StringArray::from(actual.child(1).to_data()),
2364            StringArray::from(taken_strings)
2365        );
2366    }
2367
2368    #[test]
2369    fn test_take_union_dense_using_builder() {
2370        let mut builder = UnionBuilder::new_dense();
2371
2372        builder.append::<Int32Type>("a", 1).unwrap();
2373        builder.append::<Float64Type>("b", 3.0).unwrap();
2374        builder.append::<Int32Type>("a", 4).unwrap();
2375        builder.append::<Int32Type>("a", 5).unwrap();
2376        builder.append::<Float64Type>("b", 2.0).unwrap();
2377
2378        let union = builder.build().unwrap();
2379
2380        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2381
2382        let mut builder = UnionBuilder::new_dense();
2383
2384        builder.append::<Int32Type>("a", 4).unwrap();
2385        builder.append::<Int32Type>("a", 1).unwrap();
2386        builder.append::<Float64Type>("b", 3.0).unwrap();
2387        builder.append::<Int32Type>("a", 4).unwrap();
2388
2389        let taken = builder.build().unwrap();
2390
2391        assert_eq!(
2392            taken.to_data(),
2393            take(&union, &indices, None).unwrap().to_data()
2394        );
2395    }
2396
2397    #[test]
2398    fn test_take_union_dense_all_match_issue_6206() {
2399        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2400        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2401
2402        let array = UnionArray::try_new(
2403            fields,
2404            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2405            Some(ScalarBuffer::from_iter(0_i32..5)),
2406            vec![ints],
2407        )
2408        .unwrap();
2409
2410        let indicies = Int64Array::from(vec![0, 2, 4]);
2411        let array = take(&array, &indicies, None).unwrap();
2412        assert_eq!(array.len(), 3);
2413    }
2414}