arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::fmt::Display;
21use std::sync::Arc;
22
23use arrow_array::builder::{BufferBuilder, UInt32Builder};
24use arrow_array::cast::AsArray;
25use arrow_array::types::*;
26use arrow_array::*;
27use arrow_buffer::{
28    ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
29    bit_util,
30};
31use arrow_data::ArrayDataBuilder;
32use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
33
34use num_traits::{One, Zero};
35
36/// Take elements by index from [Array], creating a new [Array] from those indexes.
37///
38/// ```text
39/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
40/// │        A        │      │    0    │                              │        A        │
41/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
42/// │        D        │      │    2    │                              │        B        │
43/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
44/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
45/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
46/// │        C        │      │    1    │                              │        D        │
47/// ├─────────────────┤      └─────────┘                              └─────────────────┘
48/// │        E        │
49/// └─────────────────┘
50///    values array          indices array                              result
51/// ```
52///
53/// For selecting values by index from multiple arrays see [`crate::interleave`]
54///
55/// Note that this kernel, similar to other kernels in this crate,
56/// will avoid allocating where not necessary. Consequently
57/// the returned array may share buffers with the inputs
58///
59/// # Errors
60/// This function errors whenever:
61/// * An index cannot be casted to `usize` (typically 32 bit architectures)
62/// * An index is out of bounds and `options` is set to check bounds.
63///
64/// # Safety
65///
66/// When `options` is not set to check bounds, taking indexes after `len` will panic.
67///
68/// # See also
69/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
70///   the results into a single array.
71///
72/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
73///
74/// # Examples
75/// ```
76/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
77/// # use arrow_select::take::take;
78/// let values = StringArray::from(vec!["zero", "one", "two"]);
79///
80/// // Take items at index 2, and 1:
81/// let indices = UInt32Array::from(vec![2, 1]);
82/// let taken = take(&values, &indices, None).unwrap();
83/// let taken = taken.as_string::<i32>();
84///
85/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
86/// ```
87pub fn take(
88    values: &dyn Array,
89    indices: &dyn Array,
90    options: Option<TakeOptions>,
91) -> Result<ArrayRef, ArrowError> {
92    let options = options.unwrap_or_default();
93    downcast_integer_array!(
94        indices => {
95            if options.check_bounds {
96                check_bounds(values.len(), indices)?;
97            }
98            let indices = indices.to_indices();
99            take_impl(values, &indices)
100        },
101        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
102    )
103}
104
105/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
106/// [`Vec<ArrayRef>`] from those indices.
107///
108/// ```text
109/// ┌────────┬────────┐
110/// │        │        │           ┌────────┐                                ┌────────┬────────┐
111/// │   A    │   1    │           │        │                                │        │        │
112/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
113/// │        │        │           ├────────┤                                ├────────┼────────┤
114/// │   D    │   4    │           │        │                                │        │        │
115/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
116/// │        │        │           ├────────┤                                ├────────┼────────┤
117/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
118/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
119/// │        │        │           ├────────┤                                ├────────┼────────┤
120/// │   C    │   3    │           │        │                                │        │        │
121/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
122/// │        │        │           └────────┘                                └────────┼────────┘
123/// │   E    │   5    │
124/// └────────┴────────┘
125///    values arrays             indices array                                      result
126/// ```
127///
128/// # Errors
129/// This function errors whenever:
130/// * An index cannot be casted to `usize` (typically 32 bit architectures)
131/// * An index is out of bounds and `options` is set to check bounds.
132///
133/// # Safety
134///
135/// When `options` is not set to check bounds, taking indexes after `len` will panic.
136///
137/// # Examples
138/// ```
139/// # use std::sync::Arc;
140/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
141/// # use arrow_select::take::{take, take_arrays};
142/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
143/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
144///
145/// // Take items at index 2, and 1:
146/// let indices = UInt32Array::from(vec![2, 1]);
147/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
148/// let taken_string = taken_arrays[0].as_string::<i32>();
149/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
150/// let taken_values = taken_arrays[1].as_primitive();
151/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
152/// ```
153pub fn take_arrays(
154    arrays: &[ArrayRef],
155    indices: &dyn Array,
156    options: Option<TakeOptions>,
157) -> Result<Vec<ArrayRef>, ArrowError> {
158    arrays
159        .iter()
160        .map(|array| take(array.as_ref(), indices, options.clone()))
161        .collect()
162}
163
164/// Verifies that the non-null values of `indices` are all `< len`
165fn check_bounds<T: ArrowPrimitiveType>(
166    len: usize,
167    indices: &PrimitiveArray<T>,
168) -> Result<(), ArrowError>
169where
170    T::Native: Display,
171{
172    let len = match T::Native::from_usize(len) {
173        Some(len) => len,
174        None => {
175            if T::DATA_TYPE.is_integer() {
176                // the biggest representable value for T::Native is lower than len, e.g: u8::MAX < 512, no need to check bounds
177                return Ok(());
178            } else {
179                return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
180            }
181        }
182    };
183
184    if indices.null_count() > 0 {
185        indices.iter().flatten().try_for_each(|index| {
186            if index >= len {
187                return Err(ArrowError::ComputeError(format!(
188                    "Array index out of bounds, cannot get item at index {index} from {len} entries"
189                )));
190            }
191            Ok(())
192        })
193    } else {
194        let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
195            in_bounds & (i >= T::Native::ZERO) & (i < len)
196        });
197
198        if !in_bounds {
199            for &index in indices.values() {
200                if index < T::Native::ZERO || index >= len {
201                    return Err(ArrowError::ComputeError(format!(
202                        "Array index out of bounds, cannot get item at index {index} from {len} entries"
203                    )));
204                }
205            }
206        }
207
208        Ok(())
209    }
210}
211
212#[inline(never)]
213fn take_impl<IndexType: ArrowPrimitiveType>(
214    values: &dyn Array,
215    indices: &PrimitiveArray<IndexType>,
216) -> Result<ArrayRef, ArrowError> {
217    downcast_primitive_array! {
218        values => Ok(Arc::new(take_primitive(values, indices)?)),
219        DataType::Boolean => {
220            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
221            Ok(Arc::new(take_boolean(values, indices)))
222        }
223        DataType::Utf8 => {
224            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
225        }
226        DataType::LargeUtf8 => {
227            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
228        }
229        DataType::Utf8View => {
230            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
231        }
232        DataType::List(_) => {
233            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
234        }
235        DataType::LargeList(_) => {
236            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
237        }
238        DataType::ListView(_) => {
239            Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
240        }
241        DataType::LargeListView(_) => {
242            Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
243        }
244        DataType::FixedSizeList(_, length) => {
245            let values = values
246                .as_any()
247                .downcast_ref::<FixedSizeListArray>()
248                .unwrap();
249            Ok(Arc::new(take_fixed_size_list(
250                values,
251                indices,
252                *length as u32,
253            )?))
254        }
255        DataType::Map(_, _) => {
256            let list_arr = ListArray::from(values.as_map().clone());
257            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
258            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
259            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
260        }
261        DataType::Struct(fields) => {
262            let array: &StructArray = values.as_struct();
263            let arrays  = array
264                .columns()
265                .iter()
266                .map(|a| take_impl(a.as_ref(), indices))
267                .collect::<Result<Vec<ArrayRef>, _>>()?;
268            let fields: Vec<(FieldRef, ArrayRef)> =
269                fields.iter().cloned().zip(arrays).collect();
270
271            // Create the null bit buffer.
272            let is_valid: Buffer = indices
273                .iter()
274                .map(|index| {
275                    if let Some(index) = index {
276                        array.is_valid(index.to_usize().unwrap())
277                    } else {
278                        false
279                    }
280                })
281                .collect();
282
283            if fields.is_empty() {
284                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
285                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
286            } else {
287                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
288            }
289        }
290        DataType::Dictionary(_, _) => downcast_dictionary_array! {
291            values => Ok(Arc::new(take_dict(values, indices)?)),
292            t => unimplemented!("Take not supported for dictionary type {:?}", t)
293        }
294        DataType::RunEndEncoded(_, _) => downcast_run_array! {
295            values => Ok(Arc::new(take_run(values, indices)?)),
296            t => unimplemented!("Take not supported for run type {:?}", t)
297        }
298        DataType::Binary => {
299            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
300        }
301        DataType::LargeBinary => {
302            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
303        }
304        DataType::BinaryView => {
305            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
306        }
307        DataType::FixedSizeBinary(size) => {
308            let values = values
309                .as_any()
310                .downcast_ref::<FixedSizeBinaryArray>()
311                .unwrap();
312            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
313        }
314        DataType::Null => {
315            // Take applied to a null array produces a null array.
316            if values.len() >= indices.len() {
317                // If the existing null array is as big as the indices, we can use a slice of it
318                // to avoid allocating a new null array.
319                Ok(values.slice(0, indices.len()))
320            } else {
321                // If the existing null array isn't big enough, create a new one.
322                Ok(new_null_array(&DataType::Null, indices.len()))
323            }
324        }
325        DataType::Union(fields, UnionMode::Sparse) => {
326            let mut children = Vec::with_capacity(fields.len());
327            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
328            let type_ids = take_native(values.type_ids(), indices);
329            for (type_id, _field) in fields.iter() {
330                let values = values.child(type_id);
331                let values = take_impl(values, indices)?;
332                children.push(values);
333            }
334            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
335            Ok(Arc::new(array))
336        }
337        DataType::Union(fields, UnionMode::Dense) => {
338            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
339
340            let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
341            let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
342
343            let children = fields.iter()
344                .map(|(field_type_id, _)| {
345                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
346
347                    let indices = crate::filter::filter(&offsets, &mask)?;
348
349                    let values = values.child(field_type_id);
350
351                    take_impl(values, indices.as_primitive::<Int32Type>())
352                })
353                .collect::<Result<_, _>>()?;
354
355            let mut child_offsets = [0; 128];
356
357            let offsets = type_ids.values()
358                .iter()
359                .map(|&i| {
360                    let offset = child_offsets[i as usize];
361
362                    child_offsets[i as usize] += 1;
363
364                    offset
365                })
366                .collect();
367
368            let (_, type_ids, _) = type_ids.into_parts();
369
370            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
371
372            Ok(Arc::new(array))
373        }
374        t => unimplemented!("Take not supported for data type {:?}", t)
375    }
376}
377
378/// Options that define how `take` should behave
379#[derive(Clone, Debug, Default)]
380pub struct TakeOptions {
381    /// Perform bounds check before taking indices from values.
382    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
383    /// If not enabled, and indices exceed bounds, the kernel will panic.
384    pub check_bounds: bool,
385}
386
387/// `take` implementation for all primitive arrays
388///
389/// This checks if an `indices` slot is populated, and gets the value from `values`
390///  as the populated index.
391/// If the `indices` slot is null, a null value is returned.
392/// For example, given:
393///     values:  [1, 2, 3, null, 5]
394///     indices: [0, null, 4, 3]
395/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
396fn take_primitive<T, I>(
397    values: &PrimitiveArray<T>,
398    indices: &PrimitiveArray<I>,
399) -> Result<PrimitiveArray<T>, ArrowError>
400where
401    T: ArrowPrimitiveType,
402    I: ArrowPrimitiveType,
403{
404    let values_buf = take_native(values.values(), indices);
405    let nulls = take_nulls(values.nulls(), indices);
406    Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
407}
408
409#[inline(never)]
410fn take_nulls<I: ArrowPrimitiveType>(
411    values: Option<&NullBuffer>,
412    indices: &PrimitiveArray<I>,
413) -> Option<NullBuffer> {
414    match values.filter(|n| n.null_count() > 0) {
415        Some(n) => {
416            let buffer = take_bits(n.inner(), indices);
417            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
418        }
419        None => indices.nulls().cloned(),
420    }
421}
422
423#[inline(never)]
424fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
425    values: &[T],
426    indices: &PrimitiveArray<I>,
427) -> ScalarBuffer<T> {
428    match indices.nulls().filter(|n| n.null_count() > 0) {
429        Some(n) => indices
430            .values()
431            .iter()
432            .enumerate()
433            .map(|(idx, index)| match values.get(index.as_usize()) {
434                Some(v) => *v,
435                // SAFETY: idx<indices.len()
436                None => match unsafe { n.inner().value_unchecked(idx) } {
437                    false => T::default(),
438                    true => panic!("Out-of-bounds index {index:?}"),
439                },
440            })
441            .collect(),
442        None => indices
443            .values()
444            .iter()
445            .map(|index| values[index.as_usize()])
446            .collect(),
447    }
448}
449
450#[inline(never)]
451fn take_bits<I: ArrowPrimitiveType>(
452    values: &BooleanBuffer,
453    indices: &PrimitiveArray<I>,
454) -> BooleanBuffer {
455    let len = indices.len();
456
457    match indices.nulls().filter(|n| n.null_count() > 0) {
458        Some(nulls) => {
459            let mut output_buffer = MutableBuffer::new_null(len);
460            let output_slice = output_buffer.as_slice_mut();
461            nulls.valid_indices().for_each(|idx| {
462                // SAFETY: idx is a valid index in indices.nulls() --> idx<indices.len()
463                if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
464                    // SAFETY: MutableBuffer was created with space for indices.len() bit, and idx < indices.len()
465                    unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
466                }
467            });
468            BooleanBuffer::new(output_buffer.into(), 0, len)
469        }
470        None => {
471            BooleanBuffer::collect_bool(len, |idx: usize| {
472                // SAFETY: idx<indices.len()
473                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
474            })
475        }
476    }
477}
478
479/// `take` implementation for boolean arrays
480fn take_boolean<IndexType: ArrowPrimitiveType>(
481    values: &BooleanArray,
482    indices: &PrimitiveArray<IndexType>,
483) -> BooleanArray {
484    let val_buf = take_bits(values.values(), indices);
485    let null_buf = take_nulls(values.nulls(), indices);
486    BooleanArray::new(val_buf, null_buf)
487}
488
489/// `take` implementation for string arrays
490fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
491    array: &GenericByteArray<T>,
492    indices: &PrimitiveArray<IndexType>,
493) -> Result<GenericByteArray<T>, ArrowError> {
494    let mut offsets = Vec::with_capacity(indices.len() + 1);
495    offsets.push(T::Offset::default());
496
497    let input_offsets = array.value_offsets();
498    let mut capacity = 0;
499    let nulls = take_nulls(array.nulls(), indices);
500
501    let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
502        offsets.reserve(indices.len());
503        for index in indices.values() {
504            let index = index.as_usize();
505            capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
506            offsets.push(
507                T::Offset::from_usize(capacity)
508                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
509            );
510        }
511        let mut values = Vec::with_capacity(capacity);
512
513        for index in indices.values() {
514            values.extend_from_slice(array.value(index.as_usize()).as_ref());
515        }
516        (offsets, values)
517    } else if indices.null_count() == 0 {
518        offsets.reserve(indices.len());
519        for index in indices.values() {
520            let index = index.as_usize();
521            if array.is_valid(index) {
522                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
523            }
524            offsets.push(
525                T::Offset::from_usize(capacity)
526                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
527            );
528        }
529        let mut values = Vec::with_capacity(capacity);
530
531        for index in indices.values() {
532            let index = index.as_usize();
533            if array.is_valid(index) {
534                values.extend_from_slice(array.value(index).as_ref());
535            }
536        }
537        (offsets, values)
538    } else if array.null_count() == 0 {
539        offsets.reserve(indices.len());
540        for (i, index) in indices.values().iter().enumerate() {
541            let index = index.as_usize();
542            if indices.is_valid(i) {
543                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
544            }
545            offsets.push(
546                T::Offset::from_usize(capacity)
547                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
548            );
549        }
550        let mut values = Vec::with_capacity(capacity);
551
552        for (i, index) in indices.values().iter().enumerate() {
553            if indices.is_valid(i) {
554                values.extend_from_slice(array.value(index.as_usize()).as_ref());
555            }
556        }
557        (offsets, values)
558    } else {
559        let nulls = nulls.as_ref().unwrap();
560        offsets.reserve(indices.len());
561        for (i, index) in indices.values().iter().enumerate() {
562            let index = index.as_usize();
563            if nulls.is_valid(i) {
564                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
565            }
566            offsets.push(
567                T::Offset::from_usize(capacity)
568                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
569            );
570        }
571        let mut values = Vec::with_capacity(capacity);
572
573        for (i, index) in indices.values().iter().enumerate() {
574            // check index is valid before using index. The value in
575            // NULL index slots may not be within bounds of array
576            let index = index.as_usize();
577            if nulls.is_valid(i) {
578                values.extend_from_slice(array.value(index).as_ref());
579            }
580        }
581        (offsets, values)
582    };
583
584    T::Offset::from_usize(values.len())
585        .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
586
587    let array = unsafe {
588        let offsets = OffsetBuffer::new_unchecked(offsets.into());
589        GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
590    };
591
592    Ok(array)
593}
594
595/// `take` implementation for byte view arrays
596fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
597    array: &GenericByteViewArray<T>,
598    indices: &PrimitiveArray<IndexType>,
599) -> Result<GenericByteViewArray<T>, ArrowError> {
600    let new_views = take_native(array.views(), indices);
601    let new_nulls = take_nulls(array.nulls(), indices);
602    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
603    Ok(unsafe {
604        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
605    })
606}
607
608/// `take` implementation for list arrays
609///
610/// Calculates the index and indexed offset for the inner array,
611/// applying `take` on the inner array, then reconstructing a list array
612/// with the indexed offsets
613fn take_list<IndexType, OffsetType>(
614    values: &GenericListArray<OffsetType::Native>,
615    indices: &PrimitiveArray<IndexType>,
616) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
617where
618    IndexType: ArrowPrimitiveType,
619    OffsetType: ArrowPrimitiveType,
620    OffsetType::Native: OffsetSizeTrait,
621    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
622{
623    // TODO: Some optimizations can be done here such as if it is
624    // taking the whole list or a contiguous sublist
625    let (list_indices, offsets, null_buf) =
626        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
627
628    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
629    let value_offsets = Buffer::from_vec(offsets);
630    // create a new list with taken data and computed null information
631    let list_data = ArrayDataBuilder::new(values.data_type().clone())
632        .len(indices.len())
633        .null_bit_buffer(Some(null_buf.into()))
634        .offset(0)
635        .add_child_data(taken.into_data())
636        .add_buffer(value_offsets);
637
638    let list_data = unsafe { list_data.build_unchecked() };
639
640    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
641}
642
643fn take_list_view<IndexType, OffsetType>(
644    values: &GenericListViewArray<OffsetType::Native>,
645    indices: &PrimitiveArray<IndexType>,
646) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
647where
648    IndexType: ArrowPrimitiveType,
649    OffsetType: ArrowPrimitiveType,
650    OffsetType::Native: OffsetSizeTrait,
651{
652    let taken_offsets = take_native(values.offsets(), indices);
653    let taken_sizes = take_native(values.sizes(), indices);
654    let nulls = take_nulls(values.nulls(), indices);
655
656    let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
657        .len(indices.len())
658        .nulls(nulls)
659        .buffers(vec![taken_offsets.into(), taken_sizes.into()])
660        .child_data(vec![values.values().to_data()]);
661
662    // SAFETY: all buffers and child nodes for ListView added in constructor
663    let list_view_data = unsafe { list_view_data.build_unchecked() };
664
665    Ok(GenericListViewArray::<OffsetType::Native>::from(
666        list_view_data,
667    ))
668}
669
670/// `take` implementation for `FixedSizeListArray`
671///
672/// Calculates the index and indexed offset for the inner array,
673/// applying `take` on the inner array, then reconstructing a list array
674/// with the indexed offsets
675fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
676    values: &FixedSizeListArray,
677    indices: &PrimitiveArray<IndexType>,
678    length: <UInt32Type as ArrowPrimitiveType>::Native,
679) -> Result<FixedSizeListArray, ArrowError> {
680    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
681    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
682
683    // determine null count and null buffer, which are a function of `values` and `indices`
684    let num_bytes = bit_util::ceil(indices.len(), 8);
685    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
686    let null_slice = null_buf.as_slice_mut();
687
688    for i in 0..indices.len() {
689        let index = indices
690            .value(i)
691            .to_usize()
692            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
693        if !indices.is_valid(i) || values.is_null(index) {
694            bit_util::unset_bit(null_slice, i);
695        }
696    }
697
698    let list_data = ArrayDataBuilder::new(values.data_type().clone())
699        .len(indices.len())
700        .null_bit_buffer(Some(null_buf.into()))
701        .offset(0)
702        .add_child_data(taken.into_data());
703
704    let list_data = unsafe { list_data.build_unchecked() };
705
706    Ok(FixedSizeListArray::from(list_data))
707}
708
709/// The take kernel implementation for `FixedSizeBinaryArray`.
710///
711/// The computation is done in two steps:
712/// - Compute the values buffer
713/// - Compute the null buffer
714fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
715    values: &FixedSizeBinaryArray,
716    indices: &PrimitiveArray<IndexType>,
717    size: i32,
718) -> Result<FixedSizeBinaryArray, ArrowError> {
719    let size_usize = usize::try_from(size).map_err(|_| {
720        ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
721    })?;
722
723    let values_buffer = values.values().as_slice();
724    let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
725
726    if indices.null_count() == 0 {
727        let array_iter = indices.values().iter().map(|idx| {
728            let offset = idx.as_usize() * size_usize;
729            &values_buffer[offset..offset + size_usize]
730        });
731        for slice in array_iter {
732            values_buffer_builder.append_slice(slice);
733        }
734    } else {
735        // The indices nullability cannot be ignored here because the values buffer may contain
736        // nulls which should not cause a panic.
737        let array_iter = indices.iter().map(|idx| {
738            idx.map(|idx| {
739                let offset = idx.as_usize() * size_usize;
740                &values_buffer[offset..offset + size_usize]
741            })
742        });
743        for slice in array_iter {
744            match slice {
745                None => values_buffer_builder.append_n(size_usize, 0),
746                Some(slice) => values_buffer_builder.append_slice(slice),
747            }
748        }
749    }
750
751    let values_buffer = values_buffer_builder.finish();
752    let value_nulls = take_nulls(values.nulls(), indices);
753    let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
754
755    let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
756        .len(indices.len())
757        .nulls(final_nulls)
758        .offset(0)
759        .add_buffer(values_buffer)
760        .build()?;
761
762    Ok(FixedSizeBinaryArray::from(array_data))
763}
764
765/// `take` implementation for dictionary arrays
766///
767/// applies `take` to the keys of the dictionary array and returns a new dictionary array
768/// with the same dictionary values and reordered keys
769fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
770    values: &DictionaryArray<T>,
771    indices: &PrimitiveArray<I>,
772) -> Result<DictionaryArray<T>, ArrowError> {
773    let new_keys = take_primitive(values.keys(), indices)?;
774    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
775}
776
777/// `take` implementation for run arrays
778///
779/// Finds physical indices for the given logical indices and builds output run array
780/// by taking values in the input run_array.values at the physical indices.
781/// The output run array will be run encoded on the physical indices and not on output values.
782/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
783/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
784/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
785fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
786    run_array: &RunArray<T>,
787    logical_indices: &PrimitiveArray<I>,
788) -> Result<RunArray<T>, ArrowError> {
789    // get physical indices for the input logical indices
790    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
791
792    // Run encode the physical indices into new_run_ends_builder
793    // Keep track of the physical indices to take in take_value_indices
794    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
795    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
796    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
797    let mut new_physical_len = 1;
798    for ix in 1..physical_indices.len() {
799        if physical_indices[ix] != physical_indices[ix - 1] {
800            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
801            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
802            new_physical_len += 1;
803        }
804    }
805    take_value_indices
806        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
807    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
808    let new_run_ends = unsafe {
809        // Safety:
810        // The function builds a valid run_ends array and hence need not be validated.
811        ArrayDataBuilder::new(T::DATA_TYPE)
812            .len(new_physical_len)
813            .null_count(0)
814            .add_buffer(new_run_ends_builder.finish())
815            .build_unchecked()
816    };
817
818    let take_value_indices: PrimitiveArray<I> = unsafe {
819        // Safety:
820        // The function builds a valid take_value_indices array and hence need not be validated.
821        ArrayDataBuilder::new(I::DATA_TYPE)
822            .len(new_physical_len)
823            .null_count(0)
824            .add_buffer(take_value_indices.finish())
825            .build_unchecked()
826            .into()
827    };
828
829    let new_values = take(run_array.values(), &take_value_indices, None)?;
830
831    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
832        .len(physical_indices.len())
833        .add_child_data(new_run_ends)
834        .add_child_data(new_values.into_data());
835    let array_data = unsafe {
836        // Safety:
837        //  This function builds a valid run array and hence can skip validation.
838        builder.build_unchecked()
839    };
840    Ok(array_data.into())
841}
842
843/// Takes/filters a list array's inner data using the offsets of the list array.
844///
845/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
846/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
847/// elements)
848#[allow(clippy::type_complexity)]
849fn take_value_indices_from_list<IndexType, OffsetType>(
850    list: &GenericListArray<OffsetType::Native>,
851    indices: &PrimitiveArray<IndexType>,
852) -> Result<
853    (
854        PrimitiveArray<OffsetType>,
855        Vec<OffsetType::Native>,
856        MutableBuffer,
857    ),
858    ArrowError,
859>
860where
861    IndexType: ArrowPrimitiveType,
862    OffsetType: ArrowPrimitiveType,
863    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
864    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
865{
866    // TODO: benchmark this function, there might be a faster unsafe alternative
867    let offsets: &[OffsetType::Native] = list.value_offsets();
868
869    let mut new_offsets = Vec::with_capacity(indices.len());
870    let mut values = Vec::new();
871    let mut current_offset = OffsetType::Native::zero();
872    // add first offset
873    new_offsets.push(OffsetType::Native::zero());
874
875    // Initialize null buffer
876    let num_bytes = bit_util::ceil(indices.len(), 8);
877    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
878    let null_slice = null_buf.as_slice_mut();
879
880    // compute the value indices, and set offsets accordingly
881    for i in 0..indices.len() {
882        if indices.is_valid(i) {
883            let ix = indices
884                .value(i)
885                .to_usize()
886                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
887            let start = offsets[ix];
888            let end = offsets[ix + 1];
889            current_offset += end - start;
890            new_offsets.push(current_offset);
891
892            let mut curr = start;
893
894            // if start == end, this slot is empty
895            while curr < end {
896                values.push(curr);
897                curr += One::one();
898            }
899            if !list.is_valid(ix) {
900                bit_util::unset_bit(null_slice, i);
901            }
902        } else {
903            bit_util::unset_bit(null_slice, i);
904            new_offsets.push(current_offset);
905        }
906    }
907
908    Ok((
909        PrimitiveArray::<OffsetType>::from(values),
910        new_offsets,
911        null_buf,
912    ))
913}
914
915/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
916fn take_value_indices_from_fixed_size_list<IndexType>(
917    list: &FixedSizeListArray,
918    indices: &PrimitiveArray<IndexType>,
919    length: <UInt32Type as ArrowPrimitiveType>::Native,
920) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
921where
922    IndexType: ArrowPrimitiveType,
923{
924    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
925
926    for i in 0..indices.len() {
927        if indices.is_valid(i) {
928            let index = indices
929                .value(i)
930                .to_usize()
931                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
932            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
933
934            // Safety: Range always has known length.
935            unsafe {
936                values.append_trusted_len_iter(start..start + length);
937            }
938        } else {
939            values.append_nulls(length as usize);
940        }
941    }
942
943    Ok(values.finish())
944}
945
946/// To avoid generating take implementations for every index type, instead we
947/// only generate for UInt32 and UInt64 and coerce inputs to these types
948trait ToIndices {
949    type T: ArrowPrimitiveType;
950
951    fn to_indices(&self) -> PrimitiveArray<Self::T>;
952}
953
954macro_rules! to_indices_reinterpret {
955    ($t:ty, $o:ty) => {
956        impl ToIndices for PrimitiveArray<$t> {
957            type T = $o;
958
959            fn to_indices(&self) -> PrimitiveArray<$o> {
960                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
961                PrimitiveArray::new(cast, self.nulls().cloned())
962            }
963        }
964    };
965}
966
967macro_rules! to_indices_identity {
968    ($t:ty) => {
969        impl ToIndices for PrimitiveArray<$t> {
970            type T = $t;
971
972            fn to_indices(&self) -> PrimitiveArray<$t> {
973                self.clone()
974            }
975        }
976    };
977}
978
979macro_rules! to_indices_widening {
980    ($t:ty, $o:ty) => {
981        impl ToIndices for PrimitiveArray<$t> {
982            type T = UInt32Type;
983
984            fn to_indices(&self) -> PrimitiveArray<$o> {
985                let cast = self.values().iter().copied().map(|x| x as _).collect();
986                PrimitiveArray::new(cast, self.nulls().cloned())
987            }
988        }
989    };
990}
991
992to_indices_widening!(UInt8Type, UInt32Type);
993to_indices_widening!(Int8Type, UInt32Type);
994
995to_indices_widening!(UInt16Type, UInt32Type);
996to_indices_widening!(Int16Type, UInt32Type);
997
998to_indices_identity!(UInt32Type);
999to_indices_reinterpret!(Int32Type, UInt32Type);
1000
1001to_indices_identity!(UInt64Type);
1002to_indices_reinterpret!(Int64Type, UInt64Type);
1003
1004/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
1005///
1006/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
1007///
1008/// # Example
1009/// ```
1010/// # use std::sync::Arc;
1011/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
1012/// # use arrow_schema::{DataType, Field, Schema};
1013/// # use arrow_select::take::take_record_batch;
1014/// let schema = Arc::new(Schema::new(vec![
1015///     Field::new("a", DataType::Int32, true),
1016///     Field::new("b", DataType::Utf8, true),
1017/// ]));
1018/// let batch = RecordBatch::try_new(
1019///     schema.clone(),
1020///     vec![
1021///         Arc::new(Int32Array::from_iter_values(0..20)),
1022///         Arc::new(StringArray::from_iter_values(
1023///             (0..20).map(|i| format!("str-{}", i)),
1024///         )),
1025///     ],
1026/// )
1027/// .unwrap();
1028///
1029/// let indices = UInt32Array::from(vec![1, 5, 10]);
1030/// let taken = take_record_batch(&batch, &indices).unwrap();
1031///
1032/// let expected = RecordBatch::try_new(
1033///     schema,
1034///     vec![
1035///         Arc::new(Int32Array::from(vec![1, 5, 10])),
1036///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
1037///     ],
1038/// )
1039/// .unwrap();
1040/// assert_eq!(taken, expected);
1041/// ```
1042pub fn take_record_batch(
1043    record_batch: &RecordBatch,
1044    indices: &dyn Array,
1045) -> Result<RecordBatch, ArrowError> {
1046    let columns = record_batch
1047        .columns()
1048        .iter()
1049        .map(|c| take(c, indices, None))
1050        .collect::<Result<Vec<_>, _>>()?;
1051    RecordBatch::try_new(record_batch.schema(), columns)
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use super::*;
1057    use arrow_array::builder::*;
1058    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1059    use arrow_data::ArrayData;
1060    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1061    use num_traits::ToPrimitive;
1062
1063    fn test_take_decimal_arrays(
1064        data: Vec<Option<i128>>,
1065        index: &UInt32Array,
1066        options: Option<TakeOptions>,
1067        expected_data: Vec<Option<i128>>,
1068        precision: &u8,
1069        scale: &i8,
1070    ) -> Result<(), ArrowError> {
1071        let output = data
1072            .into_iter()
1073            .collect::<Decimal128Array>()
1074            .with_precision_and_scale(*precision, *scale)
1075            .unwrap();
1076
1077        let expected = expected_data
1078            .into_iter()
1079            .collect::<Decimal128Array>()
1080            .with_precision_and_scale(*precision, *scale)
1081            .unwrap();
1082
1083        let expected = Arc::new(expected) as ArrayRef;
1084        let output = take(&output, index, options).unwrap();
1085        assert_eq!(&output, &expected);
1086        Ok(())
1087    }
1088
1089    fn test_take_boolean_arrays(
1090        data: Vec<Option<bool>>,
1091        index: &UInt32Array,
1092        options: Option<TakeOptions>,
1093        expected_data: Vec<Option<bool>>,
1094    ) {
1095        let output = BooleanArray::from(data);
1096        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1097        let output = take(&output, index, options).unwrap();
1098        assert_eq!(&output, &expected)
1099    }
1100
1101    fn test_take_primitive_arrays<T>(
1102        data: Vec<Option<T::Native>>,
1103        index: &UInt32Array,
1104        options: Option<TakeOptions>,
1105        expected_data: Vec<Option<T::Native>>,
1106    ) -> Result<(), ArrowError>
1107    where
1108        T: ArrowPrimitiveType,
1109        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1110    {
1111        let output = PrimitiveArray::<T>::from(data);
1112        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1113        let output = take(&output, index, options)?;
1114        assert_eq!(&output, &expected);
1115        Ok(())
1116    }
1117
1118    fn test_take_primitive_arrays_non_null<T>(
1119        data: Vec<T::Native>,
1120        index: &UInt32Array,
1121        options: Option<TakeOptions>,
1122        expected_data: Vec<Option<T::Native>>,
1123    ) -> Result<(), ArrowError>
1124    where
1125        T: ArrowPrimitiveType,
1126        PrimitiveArray<T>: From<Vec<T::Native>>,
1127        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1128    {
1129        let output = PrimitiveArray::<T>::from(data);
1130        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1131        let output = take(&output, index, options)?;
1132        assert_eq!(&output, &expected);
1133        Ok(())
1134    }
1135
1136    fn test_take_impl_primitive_arrays<T, I>(
1137        data: Vec<Option<T::Native>>,
1138        index: &PrimitiveArray<I>,
1139        options: Option<TakeOptions>,
1140        expected_data: Vec<Option<T::Native>>,
1141    ) where
1142        T: ArrowPrimitiveType,
1143        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1144        I: ArrowPrimitiveType,
1145    {
1146        let output = PrimitiveArray::<T>::from(data);
1147        let expected = PrimitiveArray::<T>::from(expected_data);
1148        let output = take(&output, index, options).unwrap();
1149        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1150        assert_eq!(output, &expected)
1151    }
1152
1153    // create a simple struct for testing purposes
1154    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1155        let mut struct_builder = StructBuilder::new(
1156            Fields::from(vec![
1157                Field::new("a", DataType::Boolean, true),
1158                Field::new("b", DataType::Int32, true),
1159            ]),
1160            vec![
1161                Box::new(BooleanBuilder::with_capacity(values.len())),
1162                Box::new(Int32Builder::with_capacity(values.len())),
1163            ],
1164        );
1165
1166        for value in values {
1167            struct_builder
1168                .field_builder::<BooleanBuilder>(0)
1169                .unwrap()
1170                .append_option(value.and_then(|v| v.0));
1171            struct_builder
1172                .field_builder::<Int32Builder>(1)
1173                .unwrap()
1174                .append_option(value.and_then(|v| v.1));
1175            struct_builder.append(value.is_some());
1176        }
1177        struct_builder.finish()
1178    }
1179
1180    #[test]
1181    fn test_take_decimal128_non_null_indices() {
1182        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1183        let precision: u8 = 10;
1184        let scale: i8 = 5;
1185        test_take_decimal_arrays(
1186            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1187            &index,
1188            None,
1189            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1190            &precision,
1191            &scale,
1192        )
1193        .unwrap();
1194    }
1195
1196    #[test]
1197    fn test_take_decimal128() {
1198        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1199        let precision: u8 = 10;
1200        let scale: i8 = 5;
1201        test_take_decimal_arrays(
1202            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1203            &index,
1204            None,
1205            vec![Some(3), None, Some(1), Some(3), Some(2)],
1206            &precision,
1207            &scale,
1208        )
1209        .unwrap();
1210    }
1211
1212    #[test]
1213    fn test_take_primitive_non_null_indices() {
1214        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1215        test_take_primitive_arrays::<Int8Type>(
1216            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1217            &index,
1218            None,
1219            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1220        )
1221        .unwrap();
1222    }
1223
1224    #[test]
1225    fn test_take_primitive_non_null_values() {
1226        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1227        test_take_primitive_arrays::<Int8Type>(
1228            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1229            &index,
1230            None,
1231            vec![Some(3), None, Some(1), Some(3), Some(2)],
1232        )
1233        .unwrap();
1234    }
1235
1236    #[test]
1237    fn test_take_primitive_non_null() {
1238        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1239        test_take_primitive_arrays::<Int8Type>(
1240            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1241            &index,
1242            None,
1243            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1244        )
1245        .unwrap();
1246    }
1247
1248    #[test]
1249    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1250        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1251        let index = index.slice(2, 4);
1252        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1253
1254        assert_eq!(
1255            index,
1256            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1257        );
1258
1259        test_take_primitive_arrays_non_null::<Int64Type>(
1260            vec![0, 10, 20, 30, 40, 50],
1261            index,
1262            None,
1263            vec![Some(20), Some(30), None, None],
1264        )
1265        .unwrap();
1266    }
1267
1268    #[test]
1269    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1270        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1271        let index = index.slice(2, 4);
1272        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1273
1274        assert_eq!(
1275            index,
1276            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1277        );
1278
1279        test_take_primitive_arrays::<Int64Type>(
1280            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1281            index,
1282            None,
1283            vec![Some(20), Some(30), None, None],
1284        )
1285        .unwrap();
1286    }
1287
1288    #[test]
1289    fn test_take_primitive() {
1290        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1291
1292        // int8
1293        test_take_primitive_arrays::<Int8Type>(
1294            vec![Some(0), None, Some(2), Some(3), None],
1295            &index,
1296            None,
1297            vec![Some(3), None, None, Some(3), Some(2)],
1298        )
1299        .unwrap();
1300
1301        // int16
1302        test_take_primitive_arrays::<Int16Type>(
1303            vec![Some(0), None, Some(2), Some(3), None],
1304            &index,
1305            None,
1306            vec![Some(3), None, None, Some(3), Some(2)],
1307        )
1308        .unwrap();
1309
1310        // int32
1311        test_take_primitive_arrays::<Int32Type>(
1312            vec![Some(0), None, Some(2), Some(3), None],
1313            &index,
1314            None,
1315            vec![Some(3), None, None, Some(3), Some(2)],
1316        )
1317        .unwrap();
1318
1319        // int64
1320        test_take_primitive_arrays::<Int64Type>(
1321            vec![Some(0), None, Some(2), Some(3), None],
1322            &index,
1323            None,
1324            vec![Some(3), None, None, Some(3), Some(2)],
1325        )
1326        .unwrap();
1327
1328        // uint8
1329        test_take_primitive_arrays::<UInt8Type>(
1330            vec![Some(0), None, Some(2), Some(3), None],
1331            &index,
1332            None,
1333            vec![Some(3), None, None, Some(3), Some(2)],
1334        )
1335        .unwrap();
1336
1337        // uint16
1338        test_take_primitive_arrays::<UInt16Type>(
1339            vec![Some(0), None, Some(2), Some(3), None],
1340            &index,
1341            None,
1342            vec![Some(3), None, None, Some(3), Some(2)],
1343        )
1344        .unwrap();
1345
1346        // uint32
1347        test_take_primitive_arrays::<UInt32Type>(
1348            vec![Some(0), None, Some(2), Some(3), None],
1349            &index,
1350            None,
1351            vec![Some(3), None, None, Some(3), Some(2)],
1352        )
1353        .unwrap();
1354
1355        // int64
1356        test_take_primitive_arrays::<Int64Type>(
1357            vec![Some(0), None, Some(2), Some(-15), None],
1358            &index,
1359            None,
1360            vec![Some(-15), None, None, Some(-15), Some(2)],
1361        )
1362        .unwrap();
1363
1364        // interval_year_month
1365        test_take_primitive_arrays::<IntervalYearMonthType>(
1366            vec![Some(0), None, Some(2), Some(-15), None],
1367            &index,
1368            None,
1369            vec![Some(-15), None, None, Some(-15), Some(2)],
1370        )
1371        .unwrap();
1372
1373        // interval_day_time
1374        let v1 = IntervalDayTime::new(0, 0);
1375        let v2 = IntervalDayTime::new(2, 0);
1376        let v3 = IntervalDayTime::new(-15, 0);
1377        test_take_primitive_arrays::<IntervalDayTimeType>(
1378            vec![Some(v1), None, Some(v2), Some(v3), None],
1379            &index,
1380            None,
1381            vec![Some(v3), None, None, Some(v3), Some(v2)],
1382        )
1383        .unwrap();
1384
1385        // interval_month_day_nano
1386        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1387        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1388        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1389        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1390            vec![Some(v1), None, Some(v2), Some(v3), None],
1391            &index,
1392            None,
1393            vec![Some(v3), None, None, Some(v3), Some(v2)],
1394        )
1395        .unwrap();
1396
1397        // duration_second
1398        test_take_primitive_arrays::<DurationSecondType>(
1399            vec![Some(0), None, Some(2), Some(-15), None],
1400            &index,
1401            None,
1402            vec![Some(-15), None, None, Some(-15), Some(2)],
1403        )
1404        .unwrap();
1405
1406        // duration_millisecond
1407        test_take_primitive_arrays::<DurationMillisecondType>(
1408            vec![Some(0), None, Some(2), Some(-15), None],
1409            &index,
1410            None,
1411            vec![Some(-15), None, None, Some(-15), Some(2)],
1412        )
1413        .unwrap();
1414
1415        // duration_microsecond
1416        test_take_primitive_arrays::<DurationMicrosecondType>(
1417            vec![Some(0), None, Some(2), Some(-15), None],
1418            &index,
1419            None,
1420            vec![Some(-15), None, None, Some(-15), Some(2)],
1421        )
1422        .unwrap();
1423
1424        // duration_nanosecond
1425        test_take_primitive_arrays::<DurationNanosecondType>(
1426            vec![Some(0), None, Some(2), Some(-15), None],
1427            &index,
1428            None,
1429            vec![Some(-15), None, None, Some(-15), Some(2)],
1430        )
1431        .unwrap();
1432
1433        // float32
1434        test_take_primitive_arrays::<Float32Type>(
1435            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1436            &index,
1437            None,
1438            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1439        )
1440        .unwrap();
1441
1442        // float64
1443        test_take_primitive_arrays::<Float64Type>(
1444            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1445            &index,
1446            None,
1447            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1448        )
1449        .unwrap();
1450    }
1451
1452    #[test]
1453    fn test_take_preserve_timezone() {
1454        let index = Int64Array::from(vec![Some(0), None]);
1455
1456        let input = TimestampNanosecondArray::from(vec![
1457            1_639_715_368_000_000_000,
1458            1_639_715_368_000_000_000,
1459        ])
1460        .with_timezone("UTC".to_string());
1461        let result = take(&input, &index, None).unwrap();
1462        match result.data_type() {
1463            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1464                assert_eq!(tz.clone(), Some("UTC".into()))
1465            }
1466            _ => panic!(),
1467        }
1468    }
1469
1470    #[test]
1471    fn test_take_impl_primitive_with_int64_indices() {
1472        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1473
1474        // int16
1475        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1476            vec![Some(0), None, Some(2), Some(3), None],
1477            &index,
1478            None,
1479            vec![Some(3), None, None, Some(3), Some(2)],
1480        );
1481
1482        // int64
1483        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1484            vec![Some(0), None, Some(2), Some(-15), None],
1485            &index,
1486            None,
1487            vec![Some(-15), None, None, Some(-15), Some(2)],
1488        );
1489
1490        // uint64
1491        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1492            vec![Some(0), None, Some(2), Some(3), None],
1493            &index,
1494            None,
1495            vec![Some(3), None, None, Some(3), Some(2)],
1496        );
1497
1498        // duration_millisecond
1499        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1500            vec![Some(0), None, Some(2), Some(-15), None],
1501            &index,
1502            None,
1503            vec![Some(-15), None, None, Some(-15), Some(2)],
1504        );
1505
1506        // float32
1507        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1508            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1509            &index,
1510            None,
1511            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1512        );
1513    }
1514
1515    #[test]
1516    fn test_take_impl_primitive_with_uint8_indices() {
1517        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1518
1519        // int16
1520        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1521            vec![Some(0), None, Some(2), Some(3), None],
1522            &index,
1523            None,
1524            vec![Some(3), None, None, Some(3), Some(2)],
1525        );
1526
1527        // duration_millisecond
1528        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1529            vec![Some(0), None, Some(2), Some(-15), None],
1530            &index,
1531            None,
1532            vec![Some(-15), None, None, Some(-15), Some(2)],
1533        );
1534
1535        // float32
1536        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1537            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1538            &index,
1539            None,
1540            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1541        );
1542    }
1543
1544    #[test]
1545    fn test_take_bool() {
1546        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1547        // boolean
1548        test_take_boolean_arrays(
1549            vec![Some(false), None, Some(true), Some(false), None],
1550            &index,
1551            None,
1552            vec![Some(false), None, None, Some(false), Some(true)],
1553        );
1554    }
1555
1556    #[test]
1557    fn test_take_bool_nullable_index() {
1558        // indices where the masked invalid elements would be out of bounds
1559        let index_data = ArrayData::try_new(
1560            DataType::UInt32,
1561            6,
1562            Some(Buffer::from_iter(vec![
1563                false, true, false, true, false, true,
1564            ])),
1565            0,
1566            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1567            vec![],
1568        )
1569        .unwrap();
1570        let index = UInt32Array::from(index_data);
1571        test_take_boolean_arrays(
1572            vec![Some(true), None, Some(false)],
1573            &index,
1574            None,
1575            vec![None, Some(true), None, None, None, Some(false)],
1576        );
1577    }
1578
1579    #[test]
1580    fn test_take_bool_nullable_index_nonnull_values() {
1581        // indices where the masked invalid elements would be out of bounds
1582        let index_data = ArrayData::try_new(
1583            DataType::UInt32,
1584            6,
1585            Some(Buffer::from_iter(vec![
1586                false, true, false, true, false, true,
1587            ])),
1588            0,
1589            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1590            vec![],
1591        )
1592        .unwrap();
1593        let index = UInt32Array::from(index_data);
1594        test_take_boolean_arrays(
1595            vec![Some(true), Some(true), Some(false)],
1596            &index,
1597            None,
1598            vec![None, Some(true), None, Some(true), None, Some(false)],
1599        );
1600    }
1601
1602    #[test]
1603    fn test_take_bool_with_offset() {
1604        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1605        let index = index.slice(2, 4);
1606        let index = index
1607            .as_any()
1608            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1609            .unwrap();
1610
1611        // boolean
1612        test_take_boolean_arrays(
1613            vec![Some(false), None, Some(true), Some(false), None],
1614            index,
1615            None,
1616            vec![None, Some(false), Some(true), None],
1617        );
1618    }
1619
1620    fn _test_take_string<'a, K>()
1621    where
1622        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1623    {
1624        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1625
1626        let array = K::from(vec![
1627            Some("one"),
1628            None,
1629            Some("three"),
1630            Some("four"),
1631            Some("five"),
1632        ]);
1633        let actual = take(&array, &index, None).unwrap();
1634        assert_eq!(actual.len(), index.len());
1635
1636        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1637
1638        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1639
1640        assert_eq!(actual, &expected);
1641    }
1642
1643    #[test]
1644    fn test_take_string() {
1645        _test_take_string::<StringArray>()
1646    }
1647
1648    #[test]
1649    fn test_take_large_string() {
1650        _test_take_string::<LargeStringArray>()
1651    }
1652
1653    #[test]
1654    fn test_take_slice_string() {
1655        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1656        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1657        let indices_slice = indices.slice(1, 4);
1658        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1659        let result = take(&strings, &indices_slice, None).unwrap();
1660        assert_eq!(result.as_ref(), &expected);
1661    }
1662
1663    fn _test_byte_view<T>()
1664    where
1665        T: ByteViewType,
1666        str: AsRef<T::Native>,
1667        T::Native: PartialEq,
1668    {
1669        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1670        let array = {
1671            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1672            let mut builder = GenericByteViewBuilder::<T>::new();
1673            builder.append_value("hello");
1674            builder.append_value("world");
1675            builder.append_null();
1676            builder.append_value("large payload over 12 bytes");
1677            builder.append_value("lulu");
1678            builder.finish()
1679        };
1680
1681        let actual = take(&array, &index, None).unwrap();
1682
1683        assert_eq!(actual.len(), index.len());
1684
1685        let expected = {
1686            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1687            let mut builder = GenericByteViewBuilder::<T>::new();
1688            builder.append_value("large payload over 12 bytes");
1689            builder.append_null();
1690            builder.append_value("world");
1691            builder.append_value("large payload over 12 bytes");
1692            builder.append_value("lulu");
1693            builder.append_null();
1694            builder.finish()
1695        };
1696
1697        assert_eq!(actual.as_ref(), &expected);
1698    }
1699
1700    #[test]
1701    fn test_take_string_view() {
1702        _test_byte_view::<StringViewType>()
1703    }
1704
1705    #[test]
1706    fn test_take_binary_view() {
1707        _test_byte_view::<BinaryViewType>()
1708    }
1709
1710    macro_rules! test_take_list {
1711        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1712            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1713            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1714            // Construct offsets
1715            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1716            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1717            // Construct a list array from the above two
1718            let list_data_type =
1719                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1720            let list_data = ArrayData::builder(list_data_type.clone())
1721                .len(4)
1722                .add_buffer(value_offsets)
1723                .add_child_data(value_data)
1724                .build()
1725                .unwrap();
1726            let list_array = $list_array_type::from(list_data);
1727
1728            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1729            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1730
1731            let a = take(&list_array, &index, None).unwrap();
1732            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1733
1734            // construct a value array with expected results:
1735            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1736            let expected_data = Int32Array::from(vec![
1737                Some(2),
1738                Some(3),
1739                Some(-1),
1740                Some(-2),
1741                Some(-1),
1742                Some(0),
1743                Some(0),
1744                Some(0),
1745            ])
1746            .into_data();
1747            // construct offsets
1748            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1749            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1750            // construct list array from the two
1751            let expected_list_data = ArrayData::builder(list_data_type)
1752                .len(5)
1753                // null buffer remains the same as only the indices have nulls
1754                .nulls(index.nulls().cloned())
1755                .add_buffer(expected_offsets)
1756                .add_child_data(expected_data)
1757                .build()
1758                .unwrap();
1759            let expected_list_array = $list_array_type::from(expected_list_data);
1760
1761            assert_eq!(a, &expected_list_array);
1762        }};
1763    }
1764
1765    macro_rules! test_take_list_with_value_nulls {
1766        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1767            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1768            let value_data = Int32Array::from(vec![
1769                Some(0),
1770                None,
1771                Some(0),
1772                Some(-1),
1773                Some(-2),
1774                Some(3),
1775                None,
1776                Some(5),
1777                None,
1778            ])
1779            .into_data();
1780            // Construct offsets
1781            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1782            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1783            // Construct a list array from the above two
1784            let list_data_type =
1785                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1786            let list_data = ArrayData::builder(list_data_type.clone())
1787                .len(4)
1788                .add_buffer(value_offsets)
1789                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1790                .add_child_data(value_data)
1791                .build()
1792                .unwrap();
1793            let list_array = $list_array_type::from(list_data);
1794
1795            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1796            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1797
1798            let a = take(&list_array, &index, None).unwrap();
1799            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1800
1801            // construct a value array with expected results:
1802            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1803            let expected_data = Int32Array::from(vec![
1804                None,
1805                Some(-1),
1806                Some(-2),
1807                Some(3),
1808                Some(5),
1809                None,
1810                Some(0),
1811                None,
1812                Some(0),
1813            ])
1814            .into_data();
1815            // construct offsets
1816            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1817            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1818            // construct list array from the two
1819            let expected_list_data = ArrayData::builder(list_data_type)
1820                .len(5)
1821                // null buffer remains the same as only the indices have nulls
1822                .nulls(index.nulls().cloned())
1823                .add_buffer(expected_offsets)
1824                .add_child_data(expected_data)
1825                .build()
1826                .unwrap();
1827            let expected_list_array = $list_array_type::from(expected_list_data);
1828
1829            assert_eq!(a, &expected_list_array);
1830        }};
1831    }
1832
1833    macro_rules! test_take_list_with_nulls {
1834        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1835            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1836            let value_data = Int32Array::from(vec![
1837                Some(0),
1838                None,
1839                Some(0),
1840                Some(-1),
1841                Some(-2),
1842                Some(3),
1843                Some(5),
1844                None,
1845            ])
1846            .into_data();
1847            // Construct offsets
1848            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1849            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1850            // Construct a list array from the above two
1851            let list_data_type =
1852                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1853            let list_data = ArrayData::builder(list_data_type.clone())
1854                .len(4)
1855                .add_buffer(value_offsets)
1856                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1857                .add_child_data(value_data)
1858                .build()
1859                .unwrap();
1860            let list_array = $list_array_type::from(list_data);
1861
1862            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1863            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1864
1865            let a = take(&list_array, &index, None).unwrap();
1866            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1867
1868            // construct a value array with expected results:
1869            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1870            let expected_data = Int32Array::from(vec![
1871                Some(-1),
1872                Some(-2),
1873                Some(3),
1874                Some(5),
1875                None,
1876                Some(0),
1877                None,
1878                Some(0),
1879            ])
1880            .into_data();
1881            // construct offsets
1882            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1883            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1884            // construct list array from the two
1885            let mut null_bits: [u8; 1] = [0; 1];
1886            bit_util::set_bit(&mut null_bits, 2);
1887            bit_util::set_bit(&mut null_bits, 3);
1888            bit_util::set_bit(&mut null_bits, 4);
1889            let expected_list_data = ArrayData::builder(list_data_type)
1890                .len(5)
1891                // null buffer must be recalculated as both values and indices have nulls
1892                .null_bit_buffer(Some(Buffer::from(null_bits)))
1893                .add_buffer(expected_offsets)
1894                .add_child_data(expected_data)
1895                .build()
1896                .unwrap();
1897            let expected_list_array = $list_array_type::from(expected_list_data);
1898
1899            assert_eq!(a, &expected_list_array);
1900        }};
1901    }
1902
1903    fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1904        values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1905        take_indices: Vec<Option<usize>>,
1906        expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1907        mapper: F,
1908    ) where
1909        F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1910    {
1911        let mut list_view_array =
1912            GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1913
1914        for value in values {
1915            list_view_array.append_option(value);
1916        }
1917        let list_view_array = list_view_array.finish();
1918        let list_view_array = mapper(list_view_array);
1919
1920        let mut indices = UInt64Builder::new();
1921        for idx in take_indices {
1922            indices.append_option(idx.map(|i| i.to_u64().unwrap()));
1923        }
1924        let indices = indices.finish();
1925
1926        let taken = take(&list_view_array, &indices, None)
1927            .unwrap()
1928            .as_list_view()
1929            .clone();
1930
1931        let mut expected_array =
1932            GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1933        for value in expected {
1934            expected_array.append_option(value);
1935        }
1936        let expected_array = expected_array.finish();
1937
1938        assert_eq!(taken, expected_array);
1939    }
1940
1941    macro_rules! list_view_test_case {
1942        (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
1943            test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
1944            test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
1945        }};
1946        (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
1947            test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
1948            test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
1949        }};
1950    }
1951
1952    fn do_take_fixed_size_list_test<T>(
1953        length: <Int32Type as ArrowPrimitiveType>::Native,
1954        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1955        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1956        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1957    ) where
1958        T: ArrowPrimitiveType,
1959        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1960    {
1961        let indices = UInt32Array::from(indices);
1962
1963        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1964
1965        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1966
1967        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1968
1969        assert_eq!(&output, &expected)
1970    }
1971
1972    #[test]
1973    fn test_take_list() {
1974        test_take_list!(i32, List, ListArray);
1975    }
1976
1977    #[test]
1978    fn test_take_large_list() {
1979        test_take_list!(i64, LargeList, LargeListArray);
1980    }
1981
1982    #[test]
1983    fn test_take_list_with_value_nulls() {
1984        test_take_list_with_value_nulls!(i32, List, ListArray);
1985    }
1986
1987    #[test]
1988    fn test_take_large_list_with_value_nulls() {
1989        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1990    }
1991
1992    #[test]
1993    fn test_test_take_list_with_nulls() {
1994        test_take_list_with_nulls!(i32, List, ListArray);
1995    }
1996
1997    #[test]
1998    fn test_test_take_large_list_with_nulls() {
1999        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2000    }
2001
2002    #[test]
2003    fn test_test_take_list_view_reversed() {
2004        // Take reversed indices
2005        list_view_test_case! {
2006            values: vec![
2007                Some(vec![Some(1), None, Some(3)]),
2008                None,
2009                Some(vec![Some(7), Some(8), None]),
2010            ],
2011            indices: vec![Some(2), Some(1), Some(0)],
2012            expected: vec![
2013                Some(vec![Some(7), Some(8), None]),
2014                None,
2015                Some(vec![Some(1), None, Some(3)]),
2016            ]
2017        }
2018    }
2019
2020    #[test]
2021    fn test_take_list_view_null_indices() {
2022        // Take with null indices
2023        list_view_test_case! {
2024            values: vec![
2025                Some(vec![Some(1), None, Some(3)]),
2026                None,
2027                Some(vec![Some(7), Some(8), None]),
2028            ],
2029            indices: vec![None, Some(0), None],
2030            expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2031        }
2032    }
2033
2034    #[test]
2035    fn test_take_list_view_null_values() {
2036        // Take at null values
2037        list_view_test_case! {
2038            values: vec![
2039                Some(vec![Some(1), None, Some(3)]),
2040                None,
2041                Some(vec![Some(7), Some(8), None]),
2042            ],
2043            indices: vec![Some(1), Some(1), Some(1), None, None],
2044            expected: vec![None; 5]
2045        }
2046    }
2047
2048    #[test]
2049    fn test_take_list_view_sliced() {
2050        // Take null indices/values, with slicing.
2051        list_view_test_case! {
2052            values: vec![
2053                Some(vec![Some(1)]),
2054                None,
2055                None,
2056                Some(vec![Some(2), Some(3)]),
2057                Some(vec![Some(4), Some(5)]),
2058                None,
2059            ],
2060            transform: |l| l.slice(2, 4),
2061            indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2062            expected: vec![
2063                None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2064            ]
2065        }
2066    }
2067
2068    #[test]
2069    fn test_take_fixed_size_list() {
2070        do_take_fixed_size_list_test::<Int32Type>(
2071            3,
2072            vec![
2073                Some(vec![None, Some(1), Some(2)]),
2074                Some(vec![Some(3), Some(4), None]),
2075                Some(vec![Some(6), Some(7), Some(8)]),
2076            ],
2077            vec![2, 1, 0],
2078            vec![
2079                Some(vec![Some(6), Some(7), Some(8)]),
2080                Some(vec![Some(3), Some(4), None]),
2081                Some(vec![None, Some(1), Some(2)]),
2082            ],
2083        );
2084
2085        do_take_fixed_size_list_test::<UInt8Type>(
2086            1,
2087            vec![
2088                Some(vec![Some(1)]),
2089                Some(vec![Some(2)]),
2090                Some(vec![Some(3)]),
2091                Some(vec![Some(4)]),
2092                Some(vec![Some(5)]),
2093                Some(vec![Some(6)]),
2094                Some(vec![Some(7)]),
2095                Some(vec![Some(8)]),
2096            ],
2097            vec![2, 7, 0],
2098            vec![
2099                Some(vec![Some(3)]),
2100                Some(vec![Some(8)]),
2101                Some(vec![Some(1)]),
2102            ],
2103        );
2104
2105        do_take_fixed_size_list_test::<UInt64Type>(
2106            3,
2107            vec![
2108                Some(vec![Some(10), Some(11), Some(12)]),
2109                Some(vec![Some(13), Some(14), Some(15)]),
2110                None,
2111                Some(vec![Some(16), Some(17), Some(18)]),
2112            ],
2113            vec![3, 2, 1, 2, 0],
2114            vec![
2115                Some(vec![Some(16), Some(17), Some(18)]),
2116                None,
2117                Some(vec![Some(13), Some(14), Some(15)]),
2118                None,
2119                Some(vec![Some(10), Some(11), Some(12)]),
2120            ],
2121        );
2122    }
2123
2124    #[test]
2125    fn test_take_fixed_size_binary_with_nulls_indices() {
2126        let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2127            [
2128                Some(vec![0x01, 0x01, 0x01, 0x01]),
2129                Some(vec![0x02, 0x02, 0x02, 0x02]),
2130                Some(vec![0x03, 0x03, 0x03, 0x03]),
2131                Some(vec![0x04, 0x04, 0x04, 0x04]),
2132            ]
2133            .into_iter(),
2134            4,
2135        )
2136        .unwrap();
2137
2138        // The two middle indices are null -> Should be null in the output.
2139        let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2140
2141        let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2142        assert_eq!(result.len(), 4);
2143        assert_eq!(result.null_count(), 2);
2144        assert_eq!(
2145            result.nulls().unwrap().iter().collect::<Vec<_>>(),
2146            vec![true, false, false, true]
2147        );
2148    }
2149
2150    #[test]
2151    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2152    fn test_take_list_out_of_bounds() {
2153        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
2154        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2155        // Construct offsets
2156        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2157        // Construct a list array from the above two
2158        let list_data_type =
2159            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2160        let list_data = ArrayData::builder(list_data_type)
2161            .len(3)
2162            .add_buffer(value_offsets)
2163            .add_child_data(value_data)
2164            .build()
2165            .unwrap();
2166        let list_array = ListArray::from(list_data);
2167
2168        let index = UInt32Array::from(vec![1000]);
2169
2170        // A panic is expected here since we have not supplied the check_bounds
2171        // option.
2172        take(&list_array, &index, None).unwrap();
2173    }
2174
2175    #[test]
2176    fn test_take_map() {
2177        let values = Int32Array::from(vec![1, 2, 3, 4]);
2178        let array =
2179            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2180                .unwrap();
2181
2182        let index = UInt32Array::from(vec![0]);
2183
2184        let result = take(&array, &index, None).unwrap();
2185        let expected: ArrayRef = Arc::new(
2186            MapArray::new_from_strings(
2187                vec!["a", "b", "c"].into_iter(),
2188                &values.slice(0, 3),
2189                &[0, 3],
2190            )
2191            .unwrap(),
2192        );
2193        assert_eq!(&expected, &result);
2194    }
2195
2196    #[test]
2197    fn test_take_struct() {
2198        let array = create_test_struct(vec![
2199            Some((Some(true), Some(42))),
2200            Some((Some(false), Some(28))),
2201            Some((Some(false), Some(19))),
2202            Some((Some(true), Some(31))),
2203            None,
2204        ]);
2205
2206        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2207        let actual = take(&array, &index, None).unwrap();
2208        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2209        assert_eq!(index.len(), actual.len());
2210        assert_eq!(1, actual.null_count());
2211
2212        let expected = create_test_struct(vec![
2213            Some((Some(true), Some(42))),
2214            Some((Some(true), Some(31))),
2215            Some((Some(false), Some(28))),
2216            Some((Some(true), Some(42))),
2217            Some((Some(false), Some(19))),
2218            None,
2219        ]);
2220
2221        assert_eq!(&expected, actual);
2222
2223        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2224        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2225        let index = UInt32Array::from(vec![0, 2, 1, 4]);
2226        let actual = take(&empty_struct_arr, &index, None).unwrap();
2227
2228        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2229        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2230        assert_eq!(&expected_struct_arr, actual.as_struct());
2231    }
2232
2233    #[test]
2234    fn test_take_struct_with_null_indices() {
2235        let array = create_test_struct(vec![
2236            Some((Some(true), Some(42))),
2237            Some((Some(false), Some(28))),
2238            Some((Some(false), Some(19))),
2239            Some((Some(true), Some(31))),
2240            None,
2241        ]);
2242
2243        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2244        let actual = take(&array, &index, None).unwrap();
2245        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2246        assert_eq!(index.len(), actual.len());
2247        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
2248
2249        let expected = create_test_struct(vec![
2250            None,
2251            Some((Some(true), Some(31))),
2252            Some((Some(false), Some(28))),
2253            None,
2254            Some((Some(true), Some(42))),
2255            None,
2256        ]);
2257
2258        assert_eq!(&expected, actual);
2259    }
2260
2261    #[test]
2262    fn test_take_out_of_bounds() {
2263        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2264        let take_opt = TakeOptions { check_bounds: true };
2265
2266        // int64
2267        let result = test_take_primitive_arrays::<Int64Type>(
2268            vec![Some(0), None, Some(2), Some(3), None],
2269            &index,
2270            Some(take_opt),
2271            vec![None],
2272        );
2273        assert!(result.is_err());
2274    }
2275
2276    #[test]
2277    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2278    fn test_take_out_of_bounds_panic() {
2279        let index = UInt32Array::from(vec![Some(1000)]);
2280
2281        test_take_primitive_arrays::<Int64Type>(
2282            vec![Some(0), Some(1), Some(2), Some(3)],
2283            &index,
2284            None,
2285            vec![None],
2286        )
2287        .unwrap();
2288    }
2289
2290    #[test]
2291    fn test_null_array_smaller_than_indices() {
2292        let values = NullArray::new(2);
2293        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2294
2295        let result = take(&values, &indices, None).unwrap();
2296        let expected: ArrayRef = Arc::new(NullArray::new(3));
2297        assert_eq!(&result, &expected);
2298    }
2299
2300    #[test]
2301    fn test_null_array_larger_than_indices() {
2302        let values = NullArray::new(5);
2303        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2304
2305        let result = take(&values, &indices, None).unwrap();
2306        let expected: ArrayRef = Arc::new(NullArray::new(3));
2307        assert_eq!(&result, &expected);
2308    }
2309
2310    #[test]
2311    fn test_null_array_indices_out_of_bounds() {
2312        let values = NullArray::new(5);
2313        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2314
2315        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2316        assert_eq!(
2317            result.unwrap_err().to_string(),
2318            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2319        );
2320    }
2321
2322    #[test]
2323    fn test_take_dict() {
2324        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2325
2326        dict_builder.append("foo").unwrap();
2327        dict_builder.append("bar").unwrap();
2328        dict_builder.append("").unwrap();
2329        dict_builder.append_null();
2330        dict_builder.append("foo").unwrap();
2331        dict_builder.append("bar").unwrap();
2332        dict_builder.append("bar").unwrap();
2333        dict_builder.append("foo").unwrap();
2334
2335        let array = dict_builder.finish();
2336        let dict_values = array.values().clone();
2337        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2338
2339        let indices = UInt32Array::from(vec![
2340            Some(0), // first "foo"
2341            Some(7), // last "foo"
2342            None,    // null index should return null
2343            Some(5), // second "bar"
2344            Some(6), // another "bar"
2345            Some(2), // empty string
2346            Some(3), // input is null at this index
2347        ]);
2348
2349        let result = take(&array, &indices, None).unwrap();
2350        let result = result
2351            .as_any()
2352            .downcast_ref::<DictionaryArray<Int16Type>>()
2353            .unwrap();
2354
2355        let result_values: StringArray = result.values().to_data().into();
2356
2357        // dictionary values should stay the same
2358        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2359        assert_eq!(&expected_values, dict_values);
2360        assert_eq!(&expected_values, &result_values);
2361
2362        let expected_keys = Int16Array::from(vec![
2363            Some(0),
2364            Some(0),
2365            None,
2366            Some(1),
2367            Some(1),
2368            Some(2),
2369            None,
2370        ]);
2371        assert_eq!(result.keys(), &expected_keys);
2372    }
2373
2374    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2375    where
2376        S: OffsetSizeTrait + 'static,
2377        T: ArrowPrimitiveType,
2378        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2379    {
2380        GenericListArray::from_iter_primitive::<T, _, _>(
2381            data.iter()
2382                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2383        )
2384    }
2385
2386    #[test]
2387    fn test_take_value_index_from_list() {
2388        let list = build_generic_list::<i32, Int32Type>(vec![
2389            Some(vec![0, 1]),
2390            Some(vec![2, 3, 4]),
2391            Some(vec![5, 6, 7, 8, 9]),
2392        ]);
2393        let indices = UInt32Array::from(vec![2, 0]);
2394
2395        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2396
2397        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2398        assert_eq!(offsets, vec![0, 5, 7]);
2399        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2400    }
2401
2402    #[test]
2403    fn test_take_value_index_from_large_list() {
2404        let list = build_generic_list::<i64, Int32Type>(vec![
2405            Some(vec![0, 1]),
2406            Some(vec![2, 3, 4]),
2407            Some(vec![5, 6, 7, 8, 9]),
2408        ]);
2409        let indices = UInt32Array::from(vec![2, 0]);
2410
2411        let (indexed, offsets, null_buf) =
2412            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2413
2414        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2415        assert_eq!(offsets, vec![0, 5, 7]);
2416        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2417    }
2418
2419    #[test]
2420    fn test_take_runs() {
2421        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2422
2423        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2424        builder.extend(logical_array.into_iter().map(Some));
2425        let run_array = builder.finish();
2426
2427        let take_indices: PrimitiveArray<Int32Type> =
2428            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2429
2430        let take_out = take_run(&run_array, &take_indices).unwrap();
2431
2432        assert_eq!(take_out.len(), 7);
2433        assert_eq!(take_out.run_ends().len(), 7);
2434        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2435
2436        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2437        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2438    }
2439
2440    #[test]
2441    fn test_take_runs_sliced() {
2442        let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2443
2444        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2445        builder.extend(logical_array.into_iter().map(Some));
2446        let run_array = builder.finish();
2447
2448        let run_array = run_array.slice(4, 6); // [3, 3, 3, 4, 4, 5]
2449
2450        let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2451
2452        let result = take_run(&run_array, &take_indices).unwrap();
2453        let result = result.downcast::<Int32Array>().unwrap();
2454
2455        let expected = vec![3, 5, 5, 3, 4];
2456        let actual = result.into_iter().flatten().collect::<Vec<_>>();
2457
2458        assert_eq!(expected, actual);
2459    }
2460
2461    #[test]
2462    fn test_take_value_index_from_fixed_list() {
2463        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2464            vec![
2465                Some(vec![Some(1), Some(2), None]),
2466                Some(vec![Some(4), None, Some(6)]),
2467                None,
2468                Some(vec![None, Some(8), Some(9)]),
2469            ],
2470            3,
2471        );
2472
2473        let indices = UInt32Array::from(vec![2, 1, 0]);
2474        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2475
2476        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2477
2478        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2479        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2480
2481        assert_eq!(
2482            indexed,
2483            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2484        );
2485    }
2486
2487    #[test]
2488    fn test_take_null_indices() {
2489        // Build indices with values that are out of bounds, but masked by null mask
2490        let indices = Int32Array::new(
2491            vec![1, 2, 400, 400].into(),
2492            Some(NullBuffer::from(vec![true, true, false, false])),
2493        );
2494        let values = Int32Array::from(vec![1, 23, 4, 5]);
2495        let r = take(&values, &indices, None).unwrap();
2496        let values = r
2497            .as_primitive::<Int32Type>()
2498            .into_iter()
2499            .collect::<Vec<_>>();
2500        assert_eq!(&values, &[Some(23), Some(4), None, None])
2501    }
2502
2503    #[test]
2504    fn test_take_fixed_size_list_null_indices() {
2505        let indices = Int32Array::from_iter([Some(0), None]);
2506        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2507        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2508        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2509
2510        let r = take(&values, &indices, None).unwrap();
2511        let values = r
2512            .as_fixed_size_list()
2513            .values()
2514            .as_primitive::<Int32Type>()
2515            .into_iter()
2516            .collect::<Vec<_>>();
2517        assert_eq!(values, &[Some(0), Some(1), None, None])
2518    }
2519
2520    #[test]
2521    fn test_take_bytes_null_indices() {
2522        let indices = Int32Array::new(
2523            vec![0, 1, 400, 400].into(),
2524            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2525        );
2526        let values = StringArray::from(vec![Some("foo"), None]);
2527        let r = take(&values, &indices, None).unwrap();
2528        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2529        assert_eq!(&values, &[Some("foo"), None, None, None])
2530    }
2531
2532    #[test]
2533    fn test_take_union_sparse() {
2534        let structs = create_test_struct(vec![
2535            Some((Some(true), Some(42))),
2536            Some((Some(false), Some(28))),
2537            Some((Some(false), Some(19))),
2538            Some((Some(true), Some(31))),
2539            None,
2540        ]);
2541        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2542        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2543
2544        let union_fields = [
2545            (
2546                0,
2547                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2548            ),
2549            (
2550                1,
2551                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2552            ),
2553        ]
2554        .into_iter()
2555        .collect();
2556        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2557        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2558
2559        let indices = vec![0, 3, 1, 0, 2, 4];
2560        let index = UInt32Array::from(indices.clone());
2561        let actual = take(&array, &index, None).unwrap();
2562        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2563        let strings = actual.child(1);
2564        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2565
2566        let actual = strings.iter().collect::<Vec<_>>();
2567        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2568        assert_eq!(expected, actual);
2569    }
2570
2571    #[test]
2572    fn test_take_union_dense() {
2573        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2574        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2575        let ints = vec![10, 20, 30, 40];
2576        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2577
2578        let indices = vec![0, 3, 1, 0, 2, 4];
2579
2580        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2581        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2582        let taken_ints = vec![10, 20, 10, 30];
2583        let taken_strings = vec![Some("a"), None];
2584
2585        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2586        let offsets = <ScalarBuffer<i32>>::from(offsets);
2587        let ints = UInt32Array::from(ints);
2588        let strings = StringArray::from(strings);
2589
2590        let union_fields = [
2591            (
2592                0,
2593                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2594            ),
2595            (
2596                1,
2597                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2598            ),
2599        ]
2600        .into_iter()
2601        .collect();
2602
2603        let array = UnionArray::try_new(
2604            union_fields,
2605            type_ids,
2606            Some(offsets),
2607            vec![Arc::new(ints), Arc::new(strings)],
2608        )
2609        .unwrap();
2610
2611        let index = UInt32Array::from(indices);
2612
2613        let actual = take(&array, &index, None).unwrap();
2614        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2615
2616        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2617        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2618        assert_eq!(
2619            UInt32Array::from(actual.child(0).to_data()),
2620            UInt32Array::from(taken_ints)
2621        );
2622        assert_eq!(
2623            StringArray::from(actual.child(1).to_data()),
2624            StringArray::from(taken_strings)
2625        );
2626    }
2627
2628    #[test]
2629    fn test_take_union_dense_using_builder() {
2630        let mut builder = UnionBuilder::new_dense();
2631
2632        builder.append::<Int32Type>("a", 1).unwrap();
2633        builder.append::<Float64Type>("b", 3.0).unwrap();
2634        builder.append::<Int32Type>("a", 4).unwrap();
2635        builder.append::<Int32Type>("a", 5).unwrap();
2636        builder.append::<Float64Type>("b", 2.0).unwrap();
2637
2638        let union = builder.build().unwrap();
2639
2640        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2641
2642        let mut builder = UnionBuilder::new_dense();
2643
2644        builder.append::<Int32Type>("a", 4).unwrap();
2645        builder.append::<Int32Type>("a", 1).unwrap();
2646        builder.append::<Float64Type>("b", 3.0).unwrap();
2647        builder.append::<Int32Type>("a", 4).unwrap();
2648
2649        let taken = builder.build().unwrap();
2650
2651        assert_eq!(
2652            taken.to_data(),
2653            take(&union, &indices, None).unwrap().to_data()
2654        );
2655    }
2656
2657    #[test]
2658    fn test_take_union_dense_all_match_issue_6206() {
2659        let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2660        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2661
2662        let array = UnionArray::try_new(
2663            fields,
2664            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2665            Some(ScalarBuffer::from_iter(0_i32..5)),
2666            vec![ints],
2667        )
2668        .unwrap();
2669
2670        let indicies = Int64Array::from(vec![0, 2, 4]);
2671        let array = take(&array, &indicies, None).unwrap();
2672        assert_eq!(array.len(), 3);
2673    }
2674
2675    #[test]
2676    fn test_take_bytes_offset_overflow() {
2677        let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2678        let text = ('a'..='z').collect::<String>();
2679        let values = StringArray::from(vec![Some(text.clone())]);
2680        assert!(matches!(
2681            take(&values, &indices, None),
2682            Err(ArrowError::OffsetOverflowError(_))
2683        ));
2684    }
2685}