1use std::fmt::Display;
21use std::mem::ManuallyDrop;
22use std::sync::Arc;
23
24use arrow_array::builder::{BufferBuilder, UInt32Builder};
25use arrow_array::cast::AsArray;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::{
29 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, RunEndBuffer,
30 ScalarBuffer, bit_util,
31};
32use arrow_data::transform::MutableArrayData;
33use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
34
35use num_traits::Zero;
36
37pub fn take(
89 values: &dyn Array,
90 indices: &dyn Array,
91 options: Option<TakeOptions>,
92) -> Result<ArrayRef, ArrowError> {
93 let options = options.unwrap_or_default();
94 downcast_integer_array!(
95 indices => {
96 if options.check_bounds {
97 check_bounds(values.len(), indices)?;
98 }
99 let indices = indices.to_indices();
100 take_impl(values, &indices)
101 },
102 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
103 )
104}
105
106pub fn take_arrays(
155 arrays: &[ArrayRef],
156 indices: &dyn Array,
157 options: Option<TakeOptions>,
158) -> Result<Vec<ArrayRef>, ArrowError> {
159 arrays
160 .iter()
161 .map(|array| take(array.as_ref(), indices, options.clone()))
162 .collect()
163}
164
165fn check_bounds<T: ArrowPrimitiveType>(
167 len: usize,
168 indices: &PrimitiveArray<T>,
169) -> Result<(), ArrowError>
170where
171 T::Native: Display,
172{
173 let len = match T::Native::from_usize(len) {
174 Some(len) => len,
175 None => {
176 if T::DATA_TYPE.is_integer() {
177 return Ok(());
179 } else {
180 return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
181 }
182 }
183 };
184
185 if indices.null_count() > 0 {
186 indices.iter().flatten().try_for_each(|index| {
187 if index >= len {
188 return Err(ArrowError::ComputeError(format!(
189 "Array index out of bounds, cannot get item at index {index} from {len} entries"
190 )));
191 }
192 Ok(())
193 })
194 } else {
195 let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
196 in_bounds & (i >= T::Native::ZERO) & (i < len)
197 });
198
199 if !in_bounds {
200 for &index in indices.values() {
201 if index < T::Native::ZERO || index >= len {
202 return Err(ArrowError::ComputeError(format!(
203 "Array index out of bounds, cannot get item at index {index} from {len} entries"
204 )));
205 }
206 }
207 }
208
209 Ok(())
210 }
211}
212
213#[inline(never)]
214fn take_impl<IndexType: ArrowPrimitiveType>(
215 values: &dyn Array,
216 indices: &PrimitiveArray<IndexType>,
217) -> Result<ArrayRef, ArrowError> {
218 if indices.is_empty() {
219 return Ok(new_empty_array(values.data_type()));
220 }
221 downcast_primitive_array! {
222 values => Ok(Arc::new(take_primitive(values, indices)?)),
223 DataType::Boolean => {
224 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
225 Ok(Arc::new(take_boolean(values, indices)))
226 }
227 DataType::Utf8 => {
228 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
229 }
230 DataType::LargeUtf8 => {
231 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
232 }
233 DataType::Utf8View => {
234 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
235 }
236 DataType::List(_) => {
237 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
238 }
239 DataType::LargeList(_) => {
240 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
241 }
242 DataType::ListView(_) => {
243 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
244 }
245 DataType::LargeListView(_) => {
246 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
247 }
248 DataType::FixedSizeList(_, length) => {
249 let values = values
250 .as_any()
251 .downcast_ref::<FixedSizeListArray>()
252 .unwrap();
253 Ok(Arc::new(take_fixed_size_list(
254 values,
255 indices,
256 *length as u32,
257 )?))
258 }
259 DataType::Map(field, ordered) => {
260 let list_arr = ListArray::from(values.as_map().clone());
261 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
262 let (_, offsets, entries, nulls) = list_data.into_parts();
263 let entries = entries.as_struct().clone();
264 Ok(Arc::new(MapArray::try_new(
265 field.clone(),
266 offsets,
267 entries,
268 nulls,
269 *ordered,
270 )?))
271 }
272 DataType::Struct(fields) => {
273 let array: &StructArray = values.as_struct();
274 let arrays = array
275 .columns()
276 .iter()
277 .map(|a| take_impl(a.as_ref(), indices))
278 .collect::<Result<Vec<ArrayRef>, _>>()?;
279 let fields: Vec<(FieldRef, ArrayRef)> =
280 fields.iter().cloned().zip(arrays).collect();
281
282 let is_valid: Buffer = indices
284 .iter()
285 .map(|index| {
286 if let Some(index) = index {
287 array.is_valid(index.to_usize().unwrap())
288 } else {
289 false
290 }
291 })
292 .collect();
293
294 if fields.is_empty() {
295 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
296 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
297 } else {
298 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
299 }
300 }
301 DataType::Dictionary(_, _) => downcast_dictionary_array! {
302 values => Ok(Arc::new(take_dict(values, indices)?)),
303 t => unimplemented!("Take not supported for dictionary type {:?}", t)
304 }
305 DataType::RunEndEncoded(_, _) => downcast_run_array! {
306 values => Ok(Arc::new(take_run(values, indices)?)),
307 t => unimplemented!("Take not supported for run type {:?}", t)
308 }
309 DataType::Binary => {
310 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
311 }
312 DataType::LargeBinary => {
313 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
314 }
315 DataType::BinaryView => {
316 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
317 }
318 DataType::FixedSizeBinary(size) => {
319 let values = values
320 .as_any()
321 .downcast_ref::<FixedSizeBinaryArray>()
322 .unwrap();
323 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
324 }
325 DataType::Null => {
326 if values.len() >= indices.len() {
328 Ok(values.slice(0, indices.len()))
331 } else {
332 Ok(new_null_array(&DataType::Null, indices.len()))
334 }
335 }
336 DataType::Union(fields, UnionMode::Sparse) => {
337 let mut children = Vec::with_capacity(fields.len());
338 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
339 let type_ids = take_native(values.type_ids(), indices);
340 for (type_id, _field) in fields.iter() {
341 let values = values.child(type_id);
342 let values = take_impl(values, indices)?;
343 children.push(values);
344 }
345 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
346 Ok(Arc::new(array))
347 }
348 DataType::Union(fields, UnionMode::Dense) => {
349 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
350
351 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
352 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
353
354 let children = fields.iter()
355 .map(|(field_type_id, _)| {
356 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
357
358 let indices = crate::filter::filter(&offsets, &mask)?;
359
360 let values = values.child(field_type_id);
361
362 take_impl(values, indices.as_primitive::<Int32Type>())
363 })
364 .collect::<Result<_, _>>()?;
365
366 let mut child_offsets = [0; 128];
367
368 let offsets = type_ids.values()
369 .iter()
370 .map(|&i| {
371 let offset = child_offsets[i as usize];
372
373 child_offsets[i as usize] += 1;
374
375 offset
376 })
377 .collect();
378
379 let (_, type_ids, _) = type_ids.into_parts();
380
381 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
382
383 Ok(Arc::new(array))
384 }
385 t => unimplemented!("Take not supported for data type {:?}", t)
386 }
387}
388
389#[derive(Clone, Debug, Default)]
391pub struct TakeOptions {
392 pub check_bounds: bool,
396}
397
398fn take_primitive<T, I>(
408 values: &PrimitiveArray<T>,
409 indices: &PrimitiveArray<I>,
410) -> Result<PrimitiveArray<T>, ArrowError>
411where
412 T: ArrowPrimitiveType,
413 I: ArrowPrimitiveType,
414{
415 let values_buf = take_native(values.values(), indices);
416 let nulls = take_nulls(values.nulls(), indices);
417 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
418}
419
420#[inline(never)]
421fn take_nulls<I: ArrowPrimitiveType>(
422 values: Option<&NullBuffer>,
423 indices: &PrimitiveArray<I>,
424) -> Option<NullBuffer> {
425 match values.filter(|n| n.null_count() > 0) {
426 Some(n) => NullBuffer::from_unsliced_buffer(
427 take_bits(n.inner(), indices).into_inner(),
428 indices.len(),
429 ),
430 None => indices.nulls().cloned(),
431 }
432}
433
434#[inline(never)]
435fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
436 values: &[T],
437 indices: &PrimitiveArray<I>,
438) -> ScalarBuffer<T> {
439 match indices.nulls().filter(|n| n.null_count() > 0) {
440 Some(n) => indices
441 .values()
442 .iter()
443 .enumerate()
444 .map(|(idx, index)| match values.get(index.as_usize()) {
445 Some(v) => *v,
446 None => match unsafe { n.inner().value_unchecked(idx) } {
448 false => T::default(),
449 true => panic!("Out-of-bounds index {index:?}"),
450 },
451 })
452 .collect(),
453 None => indices
454 .values()
455 .iter()
456 .map(|index| values[index.as_usize()])
457 .collect(),
458 }
459}
460
461#[inline(never)]
462fn take_bits<I: ArrowPrimitiveType>(
463 values: &BooleanBuffer,
464 indices: &PrimitiveArray<I>,
465) -> BooleanBuffer {
466 let len = indices.len();
467
468 match indices.nulls().filter(|n| n.null_count() > 0) {
469 Some(nulls) => {
470 let mut output_buffer = MutableBuffer::new_null(len);
471 let output_slice = output_buffer.as_slice_mut();
472 nulls.valid_indices().for_each(|idx| {
473 if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
475 unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
477 }
478 });
479 BooleanBuffer::new(output_buffer.into(), 0, len)
480 }
481 None => {
482 BooleanBuffer::collect_bool(len, |idx: usize| {
483 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
485 })
486 }
487 }
488}
489
490fn take_boolean<IndexType: ArrowPrimitiveType>(
492 values: &BooleanArray,
493 indices: &PrimitiveArray<IndexType>,
494) -> BooleanArray {
495 let val_buf = take_bits(values.values(), indices);
496 let null_buf = take_nulls(values.nulls(), indices);
497 BooleanArray::new(val_buf, null_buf)
498}
499
500fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
502 array: &GenericByteArray<T>,
503 indices: &PrimitiveArray<IndexType>,
504) -> Result<GenericByteArray<T>, ArrowError> {
505 let mut values: Vec<u8> = Vec::new();
506 let mut offsets = Vec::with_capacity(indices.len() + 1);
507 offsets.push(T::Offset::default());
508
509 let input_offsets = array.value_offsets();
510 let mut capacity = 0;
511 let nulls = take_nulls(array.nulls(), indices);
512
513 match nulls.as_ref().filter(|n| n.null_count() > 0) {
515 None => {
517 for index in indices.values() {
518 let index = index.as_usize();
519 let start = input_offsets[index].as_usize();
520 let end = input_offsets[index + 1].as_usize();
521 capacity += end - start;
522 offsets.push(
523 T::Offset::from_usize(capacity)
524 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
525 );
526 }
527
528 values.reserve(capacity);
529
530 let dst = values.spare_capacity_mut();
531 debug_assert!(dst.len() >= capacity);
532 let mut offset = 0;
533
534 for index in indices.values() {
535 unsafe {
538 let data: &[u8] = array.value_unchecked(index.as_usize()).as_ref();
539 std::ptr::copy_nonoverlapping(
540 data.as_ptr(),
541 dst.get_unchecked_mut(offset..).as_mut_ptr().cast::<u8>(),
542 data.len(),
543 );
544 offset += data.len();
545 }
546 }
547
548 unsafe {
550 values.set_len(capacity);
551 }
552 }
553 Some(output_nulls) => {
555 let mut source_ranges = Vec::with_capacity(indices.len() - output_nulls.null_count());
556 let mut last_filled = 0;
557
558 offsets.resize(indices.len() + 1, T::Offset::default());
560
561 for i in output_nulls.valid_indices() {
563 let current_offset = T::Offset::from_usize(capacity)
564 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
565 if last_filled < i {
567 offsets[last_filled + 1..=i].fill(current_offset);
568 }
569
570 let index = unsafe { indices.value_unchecked(i) }.as_usize();
572 let start = input_offsets[index].as_usize();
573 let end = input_offsets[index + 1].as_usize();
574 capacity += end - start;
575 offsets[i + 1] = T::Offset::from_usize(capacity)
576 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
577
578 source_ranges.push((start, end));
579 last_filled = i + 1;
580 }
581
582 let final_offset = T::Offset::from_usize(capacity)
584 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
585 offsets[last_filled + 1..].fill(final_offset);
586 values.reserve(capacity);
588 debug_assert_eq!(
589 source_ranges.iter().map(|(s, e)| e - s).sum::<usize>(),
590 capacity,
591 "capacity must equal total bytes across all ranges"
592 );
593
594 let src = array.value_data();
595 let src = src.as_ptr();
596 let dst = values.spare_capacity_mut();
597 debug_assert!(dst.len() >= capacity);
598
599 let mut offset = 0;
600
601 for (start, end) in source_ranges.into_iter() {
602 let value_len = end - start;
603 unsafe {
607 std::ptr::copy_nonoverlapping(
608 src.add(start),
609 dst.get_unchecked_mut(offset..).as_mut_ptr().cast::<u8>(),
610 value_len,
611 );
612 offset += value_len;
613 }
614 }
615 unsafe { values.set_len(capacity) };
618 }
619 };
620
621 let array = unsafe {
624 let offsets = OffsetBuffer::new_unchecked(offsets.into());
625 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
626 };
627
628 Ok(array)
629}
630
631fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
633 array: &GenericByteViewArray<T>,
634 indices: &PrimitiveArray<IndexType>,
635) -> Result<GenericByteViewArray<T>, ArrowError> {
636 let new_views = take_native(array.views(), indices);
637 let new_nulls = take_nulls(array.nulls(), indices);
638 Ok(unsafe {
640 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
641 })
642}
643
644fn take_list<IndexType, OffsetType>(
649 values: &GenericListArray<OffsetType::Native>,
650 indices: &PrimitiveArray<IndexType>,
651) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
652where
653 IndexType: ArrowPrimitiveType,
654 OffsetType: ArrowPrimitiveType,
655 OffsetType::Native: OffsetSizeTrait,
656 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
657{
658 let list_offsets = values.value_offsets();
659 let child_data = values.values().to_data();
660 let nulls = take_nulls(values.nulls(), indices);
661
662 let mut new_offsets = Vec::with_capacity(indices.len() + 1);
663 new_offsets.push(OffsetType::Native::zero());
664
665 let use_nulls = child_data.null_count() > 0;
666
667 let capacity = child_data
668 .len()
669 .checked_div(values.len())
670 .map(|v| v * indices.len())
671 .unwrap_or_default();
672
673 let mut array_data = MutableArrayData::new(vec![&child_data], use_nulls, capacity);
674
675 match nulls.as_ref().filter(|n| n.null_count() > 0) {
676 None => {
677 for index in indices.values() {
678 let ix = index.as_usize();
679 let start = list_offsets[ix].as_usize();
680 let end = list_offsets[ix + 1].as_usize();
681 array_data.try_extend(0, start, end)?;
682 new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
683 }
684 }
685 Some(output_nulls) => {
686 assert_eq!(output_nulls.len(), indices.len());
687
688 let mut last_filled = 0;
689 for i in output_nulls.valid_indices() {
690 let current = OffsetType::Native::from_usize(array_data.len()).unwrap();
691 if last_filled < i {
693 new_offsets.extend(std::iter::repeat_n(current, i - last_filled));
694 }
695
696 let ix = unsafe { indices.value_unchecked(i) }.as_usize();
698 let start = list_offsets[ix].as_usize();
699 let end = list_offsets[ix + 1].as_usize();
700 array_data.try_extend(0, start, end)?;
701 new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
702 last_filled = i + 1;
703 }
704
705 let final_offset = OffsetType::Native::from_usize(array_data.len()).unwrap();
707 new_offsets.extend(std::iter::repeat_n(
708 final_offset,
709 indices.len() - last_filled,
710 ));
711 }
712 };
713
714 assert_eq!(
715 new_offsets.len(),
716 indices.len() + 1,
717 "New offsets was filled under/over the expected capacity"
718 );
719
720 let field = match values.data_type() {
721 DataType::List(field) | DataType::LargeList(field) => field.clone(),
722 d => unreachable!("take_list called with non-list data type {d}"),
723 };
724 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(new_offsets)) };
726 let child = make_array(array_data.freeze());
727
728 GenericListArray::<OffsetType::Native>::try_new(field, offsets, child, nulls)
729}
730
731fn take_list_view<IndexType, OffsetType>(
732 values: &GenericListViewArray<OffsetType::Native>,
733 indices: &PrimitiveArray<IndexType>,
734) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
735where
736 IndexType: ArrowPrimitiveType,
737 OffsetType: ArrowPrimitiveType,
738 OffsetType::Native: OffsetSizeTrait,
739{
740 let taken_offsets = take_native(values.offsets(), indices);
741 let taken_sizes = take_native(values.sizes(), indices);
742 let nulls = take_nulls(values.nulls(), indices);
743
744 let field = match values.data_type() {
745 DataType::ListView(field) | DataType::LargeListView(field) => field.clone(),
746 d => unreachable!("take_list_view called with non-list-view data type {d}"),
747 };
748
749 Ok(unsafe {
752 GenericListViewArray::<OffsetType::Native>::new_unchecked(
753 field,
754 taken_offsets,
755 taken_sizes,
756 Arc::clone(values.values()),
757 nulls,
758 )
759 })
760}
761
762fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
768 values: &FixedSizeListArray,
769 indices: &PrimitiveArray<IndexType>,
770 length: <UInt32Type as ArrowPrimitiveType>::Native,
771) -> Result<FixedSizeListArray, ArrowError> {
772 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
773 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
774
775 let num_bytes = bit_util::ceil(indices.len(), 8);
777 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
778 let null_slice = null_buf.as_slice_mut();
779
780 for i in 0..indices.len() {
781 let index = indices
782 .value(i)
783 .to_usize()
784 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
785 if !indices.is_valid(i) || values.is_null(index) {
786 bit_util::unset_bit(null_slice, i);
787 }
788 }
789
790 let field = match values.data_type() {
791 DataType::FixedSizeList(field, _) => field.clone(),
792 d => unreachable!("take_fixed_size_list called with non-fixed-size-list data type {d}"),
793 };
794 let nulls = NullBuffer::from_unsliced_buffer(null_buf, indices.len());
795
796 FixedSizeListArray::try_new(field, length as i32, taken, nulls)
797}
798
799fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
805 values: &FixedSizeBinaryArray,
806 indices: &PrimitiveArray<IndexType>,
807 size: i32,
808) -> Result<FixedSizeBinaryArray, ArrowError> {
809 let size_usize = usize::try_from(size).map_err(|_| {
810 ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
811 })?;
812
813 let result_buffer = match size_usize {
814 1 => take_fixed_size::<IndexType, 1>(values.values(), indices),
815 2 => take_fixed_size::<IndexType, 2>(values.values(), indices),
816 4 => take_fixed_size::<IndexType, 4>(values.values(), indices),
817 8 => take_fixed_size::<IndexType, 8>(values.values(), indices),
818 16 => take_fixed_size::<IndexType, 16>(values.values(), indices),
819 _ => take_fixed_size_binary_buffer_dynamic_length(values, indices, size_usize),
820 };
821
822 let value_nulls = take_nulls(values.nulls(), indices);
823 let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
824
825 return FixedSizeBinaryArray::try_new(size, result_buffer, final_nulls);
826
827 #[inline(never)]
829 fn take_fixed_size_binary_buffer_dynamic_length<IndexType: ArrowPrimitiveType>(
830 values: &FixedSizeBinaryArray,
831 indices: &PrimitiveArray<IndexType>,
832 size_usize: usize,
833 ) -> Buffer {
834 let values_buffer = values.values().as_slice();
835 let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
836
837 if indices.null_count() == 0 {
838 let array_iter = indices.values().iter().map(|idx| {
839 let offset = idx.as_usize() * size_usize;
840 &values_buffer[offset..offset + size_usize]
841 });
842 for slice in array_iter {
843 values_buffer_builder.append_slice(slice);
844 }
845 } else {
846 let array_iter = indices.iter().map(|idx| {
849 idx.map(|idx| {
850 let offset = idx.as_usize() * size_usize;
851 &values_buffer[offset..offset + size_usize]
852 })
853 });
854 for slice in array_iter {
855 match slice {
856 None => values_buffer_builder.append_n(size_usize, 0),
857 Some(slice) => values_buffer_builder.append_slice(slice),
858 }
859 }
860 }
861
862 values_buffer_builder.finish()
863 }
864}
865
866fn take_fixed_size<IndexType: ArrowPrimitiveType, const N: usize>(
879 buffer: &Buffer,
880 indices: &PrimitiveArray<IndexType>,
881) -> Buffer {
882 assert_eq!(
883 buffer.len() % N,
884 0,
885 "Invalid array length in take_fixed_size"
886 );
887
888 let ptr = buffer.as_ptr();
889 let chunk_ptr = ptr.cast::<[u8; N]>();
890 let chunk_len = buffer.len() / N;
891 let buffer: &[[u8; N]] = unsafe {
892 std::slice::from_raw_parts(chunk_ptr, chunk_len)
895 };
896
897 let result_buffer = match indices.nulls().filter(|n| n.null_count() > 0) {
898 Some(n) => indices
899 .values()
900 .iter()
901 .enumerate()
902 .map(|(idx, index)| match buffer.get(index.as_usize()) {
903 Some(v) => *v,
904 None => match unsafe { n.inner().value_unchecked(idx) } {
906 false => [0u8; N],
907 true => panic!("Out-of-bounds index {index:?}"),
908 },
909 })
910 .collect::<Vec<_>>(),
911 None => indices
912 .values()
913 .iter()
914 .map(|index| buffer[index.as_usize()])
915 .collect::<Vec<_>>(),
916 };
917
918 let mut vec = ManuallyDrop::new(result_buffer); let ptr = vec.as_mut_ptr();
920 let len = vec.len();
921 let cap = vec.capacity();
922 let result_buffer = unsafe {
923 Vec::from_raw_parts(ptr.cast::<u8>(), len * N, cap * N)
925 };
926
927 Buffer::from_vec(result_buffer)
928}
929
930fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
935 values: &DictionaryArray<T>,
936 indices: &PrimitiveArray<I>,
937) -> Result<DictionaryArray<T>, ArrowError> {
938 let new_keys = take_primitive(values.keys(), indices)?;
939 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
940}
941
942fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
951 run_array: &RunArray<T>,
952 logical_indices: &PrimitiveArray<I>,
953) -> Result<RunArray<T>, ArrowError> {
954 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
956
957 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
961 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
962 for ix in 1..physical_indices.len() {
963 if physical_indices[ix] != physical_indices[ix - 1] {
964 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
965 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
966 }
967 }
968 take_value_indices
969 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
970 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
971
972 let run_ends = unsafe {
974 RunEndBuffer::new_unchecked(
975 ScalarBuffer::from(new_run_ends_builder.finish()),
976 0,
977 physical_indices.len(),
978 )
979 };
980
981 let take_value_indices =
982 PrimitiveArray::<I>::new(ScalarBuffer::from(take_value_indices.finish()), None);
983
984 let new_values = take(run_array.values(), &take_value_indices, None)?;
985
986 Ok(
988 unsafe {
989 RunArray::<T>::new_unchecked(run_array.data_type().clone(), run_ends, new_values)
990 },
991 )
992}
993
994fn take_value_indices_from_fixed_size_list<IndexType>(
996 list: &FixedSizeListArray,
997 indices: &PrimitiveArray<IndexType>,
998 length: <UInt32Type as ArrowPrimitiveType>::Native,
999) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
1000where
1001 IndexType: ArrowPrimitiveType,
1002{
1003 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
1004
1005 for i in 0..indices.len() {
1006 if indices.is_valid(i) {
1007 let index = indices
1008 .value(i)
1009 .to_usize()
1010 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
1011 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
1012
1013 unsafe {
1015 values.append_trusted_len_iter(start..start + length);
1016 }
1017 } else {
1018 values.append_nulls(length as usize);
1019 }
1020 }
1021
1022 Ok(values.finish())
1023}
1024
1025trait ToIndices {
1028 type T: ArrowPrimitiveType;
1029
1030 fn to_indices(&self) -> PrimitiveArray<Self::T>;
1031}
1032
1033macro_rules! to_indices_reinterpret {
1034 ($t:ty, $o:ty) => {
1035 impl ToIndices for PrimitiveArray<$t> {
1036 type T = $o;
1037
1038 fn to_indices(&self) -> PrimitiveArray<$o> {
1039 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
1040 PrimitiveArray::new(cast, self.nulls().cloned())
1041 }
1042 }
1043 };
1044}
1045
1046macro_rules! to_indices_identity {
1047 ($t:ty) => {
1048 impl ToIndices for PrimitiveArray<$t> {
1049 type T = $t;
1050
1051 fn to_indices(&self) -> PrimitiveArray<$t> {
1052 self.clone()
1053 }
1054 }
1055 };
1056}
1057
1058macro_rules! to_indices_widening {
1059 ($t:ty, $o:ty) => {
1060 impl ToIndices for PrimitiveArray<$t> {
1061 type T = UInt32Type;
1062
1063 fn to_indices(&self) -> PrimitiveArray<$o> {
1064 let cast = self.values().iter().copied().map(|x| x as _).collect();
1065 PrimitiveArray::new(cast, self.nulls().cloned())
1066 }
1067 }
1068 };
1069}
1070
1071to_indices_widening!(UInt8Type, UInt32Type);
1072to_indices_widening!(Int8Type, UInt32Type);
1073
1074to_indices_widening!(UInt16Type, UInt32Type);
1075to_indices_widening!(Int16Type, UInt32Type);
1076
1077to_indices_identity!(UInt32Type);
1078to_indices_reinterpret!(Int32Type, UInt32Type);
1079
1080to_indices_identity!(UInt64Type);
1081to_indices_reinterpret!(Int64Type, UInt64Type);
1082
1083pub fn take_record_batch(
1122 record_batch: &RecordBatch,
1123 indices: &dyn Array,
1124) -> Result<RecordBatch, ArrowError> {
1125 let columns = record_batch
1126 .columns()
1127 .iter()
1128 .map(|c| take(c, indices, None))
1129 .collect::<Result<Vec<_>, _>>()?;
1130 RecordBatch::try_new(record_batch.schema(), columns)
1131}
1132
1133#[cfg(test)]
1134mod tests {
1135 use super::*;
1136 use arrow_array::builder::*;
1137 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1138 use arrow_data::ArrayData;
1139 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1140 use num_traits::ToPrimitive;
1141
1142 fn test_take_decimal_arrays(
1143 data: Vec<Option<i128>>,
1144 index: &UInt32Array,
1145 options: Option<TakeOptions>,
1146 expected_data: Vec<Option<i128>>,
1147 precision: &u8,
1148 scale: &i8,
1149 ) -> Result<(), ArrowError> {
1150 let output = data
1151 .into_iter()
1152 .collect::<Decimal128Array>()
1153 .with_precision_and_scale(*precision, *scale)
1154 .unwrap();
1155
1156 let expected = expected_data
1157 .into_iter()
1158 .collect::<Decimal128Array>()
1159 .with_precision_and_scale(*precision, *scale)
1160 .unwrap();
1161
1162 let expected = Arc::new(expected) as ArrayRef;
1163 let output = take(&output, index, options).unwrap();
1164 assert_eq!(&output, &expected);
1165 Ok(())
1166 }
1167
1168 fn test_take_boolean_arrays(
1169 data: Vec<Option<bool>>,
1170 index: &UInt32Array,
1171 options: Option<TakeOptions>,
1172 expected_data: Vec<Option<bool>>,
1173 ) {
1174 let output = BooleanArray::from(data);
1175 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1176 let output = take(&output, index, options).unwrap();
1177 assert_eq!(&output, &expected)
1178 }
1179
1180 fn test_take_primitive_arrays<T>(
1181 data: Vec<Option<T::Native>>,
1182 index: &UInt32Array,
1183 options: Option<TakeOptions>,
1184 expected_data: Vec<Option<T::Native>>,
1185 ) -> Result<(), ArrowError>
1186 where
1187 T: ArrowPrimitiveType,
1188 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1189 {
1190 let output = PrimitiveArray::<T>::from(data);
1191 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1192 let output = take(&output, index, options)?;
1193 assert_eq!(&output, &expected);
1194 Ok(())
1195 }
1196
1197 fn test_take_primitive_arrays_non_null<T>(
1198 data: Vec<T::Native>,
1199 index: &UInt32Array,
1200 options: Option<TakeOptions>,
1201 expected_data: Vec<Option<T::Native>>,
1202 ) -> Result<(), ArrowError>
1203 where
1204 T: ArrowPrimitiveType,
1205 PrimitiveArray<T>: From<Vec<T::Native>>,
1206 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1207 {
1208 let output = PrimitiveArray::<T>::from(data);
1209 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1210 let output = take(&output, index, options)?;
1211 assert_eq!(&output, &expected);
1212 Ok(())
1213 }
1214
1215 fn test_take_impl_primitive_arrays<T, I>(
1216 data: Vec<Option<T::Native>>,
1217 index: &PrimitiveArray<I>,
1218 options: Option<TakeOptions>,
1219 expected_data: Vec<Option<T::Native>>,
1220 ) where
1221 T: ArrowPrimitiveType,
1222 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1223 I: ArrowPrimitiveType,
1224 {
1225 let output = PrimitiveArray::<T>::from(data);
1226 let expected = PrimitiveArray::<T>::from(expected_data);
1227 let output = take(&output, index, options).unwrap();
1228 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1229 assert_eq!(output, &expected)
1230 }
1231
1232 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1234 let mut struct_builder = StructBuilder::new(
1235 Fields::from(vec![
1236 Field::new("a", DataType::Boolean, true),
1237 Field::new("b", DataType::Int32, true),
1238 ]),
1239 vec![
1240 Box::new(BooleanBuilder::with_capacity(values.len())),
1241 Box::new(Int32Builder::with_capacity(values.len())),
1242 ],
1243 );
1244
1245 for value in values {
1246 struct_builder
1247 .field_builder::<BooleanBuilder>(0)
1248 .unwrap()
1249 .append_option(value.and_then(|v| v.0));
1250 struct_builder
1251 .field_builder::<Int32Builder>(1)
1252 .unwrap()
1253 .append_option(value.and_then(|v| v.1));
1254 struct_builder.append(value.is_some());
1255 }
1256 struct_builder.finish()
1257 }
1258
1259 #[test]
1260 fn test_take_decimal128_non_null_indices() {
1261 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1262 let precision: u8 = 10;
1263 let scale: i8 = 5;
1264 test_take_decimal_arrays(
1265 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1266 &index,
1267 None,
1268 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1269 &precision,
1270 &scale,
1271 )
1272 .unwrap();
1273 }
1274
1275 #[test]
1276 fn test_take_decimal128() {
1277 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1278 let precision: u8 = 10;
1279 let scale: i8 = 5;
1280 test_take_decimal_arrays(
1281 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1282 &index,
1283 None,
1284 vec![Some(3), None, Some(1), Some(3), Some(2)],
1285 &precision,
1286 &scale,
1287 )
1288 .unwrap();
1289 }
1290
1291 #[test]
1292 fn test_take_primitive_non_null_indices() {
1293 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1294 test_take_primitive_arrays::<Int8Type>(
1295 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1296 &index,
1297 None,
1298 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1299 )
1300 .unwrap();
1301 }
1302
1303 #[test]
1304 fn test_take_primitive_non_null_values() {
1305 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1306 test_take_primitive_arrays::<Int8Type>(
1307 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1308 &index,
1309 None,
1310 vec![Some(3), None, Some(1), Some(3), Some(2)],
1311 )
1312 .unwrap();
1313 }
1314
1315 #[test]
1316 fn test_take_primitive_non_null() {
1317 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1318 test_take_primitive_arrays::<Int8Type>(
1319 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1320 &index,
1321 None,
1322 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1323 )
1324 .unwrap();
1325 }
1326
1327 #[test]
1328 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1329 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1330 let index = index.slice(2, 4);
1331 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1332
1333 assert_eq!(
1334 index,
1335 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1336 );
1337
1338 test_take_primitive_arrays_non_null::<Int64Type>(
1339 vec![0, 10, 20, 30, 40, 50],
1340 index,
1341 None,
1342 vec![Some(20), Some(30), None, None],
1343 )
1344 .unwrap();
1345 }
1346
1347 #[test]
1348 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1349 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1350 let index = index.slice(2, 4);
1351 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1352
1353 assert_eq!(
1354 index,
1355 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1356 );
1357
1358 test_take_primitive_arrays::<Int64Type>(
1359 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1360 index,
1361 None,
1362 vec![Some(20), Some(30), None, None],
1363 )
1364 .unwrap();
1365 }
1366
1367 #[test]
1368 fn test_take_primitive() {
1369 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1370
1371 test_take_primitive_arrays::<Int8Type>(
1373 vec![Some(0), None, Some(2), Some(3), None],
1374 &index,
1375 None,
1376 vec![Some(3), None, None, Some(3), Some(2)],
1377 )
1378 .unwrap();
1379
1380 test_take_primitive_arrays::<Int16Type>(
1382 vec![Some(0), None, Some(2), Some(3), None],
1383 &index,
1384 None,
1385 vec![Some(3), None, None, Some(3), Some(2)],
1386 )
1387 .unwrap();
1388
1389 test_take_primitive_arrays::<Int32Type>(
1391 vec![Some(0), None, Some(2), Some(3), None],
1392 &index,
1393 None,
1394 vec![Some(3), None, None, Some(3), Some(2)],
1395 )
1396 .unwrap();
1397
1398 test_take_primitive_arrays::<Int64Type>(
1400 vec![Some(0), None, Some(2), Some(3), None],
1401 &index,
1402 None,
1403 vec![Some(3), None, None, Some(3), Some(2)],
1404 )
1405 .unwrap();
1406
1407 test_take_primitive_arrays::<UInt8Type>(
1409 vec![Some(0), None, Some(2), Some(3), None],
1410 &index,
1411 None,
1412 vec![Some(3), None, None, Some(3), Some(2)],
1413 )
1414 .unwrap();
1415
1416 test_take_primitive_arrays::<UInt16Type>(
1418 vec![Some(0), None, Some(2), Some(3), None],
1419 &index,
1420 None,
1421 vec![Some(3), None, None, Some(3), Some(2)],
1422 )
1423 .unwrap();
1424
1425 test_take_primitive_arrays::<UInt32Type>(
1427 vec![Some(0), None, Some(2), Some(3), None],
1428 &index,
1429 None,
1430 vec![Some(3), None, None, Some(3), Some(2)],
1431 )
1432 .unwrap();
1433
1434 test_take_primitive_arrays::<Int64Type>(
1436 vec![Some(0), None, Some(2), Some(-15), None],
1437 &index,
1438 None,
1439 vec![Some(-15), None, None, Some(-15), Some(2)],
1440 )
1441 .unwrap();
1442
1443 test_take_primitive_arrays::<IntervalYearMonthType>(
1445 vec![Some(0), None, Some(2), Some(-15), None],
1446 &index,
1447 None,
1448 vec![Some(-15), None, None, Some(-15), Some(2)],
1449 )
1450 .unwrap();
1451
1452 let v1 = IntervalDayTime::new(0, 0);
1454 let v2 = IntervalDayTime::new(2, 0);
1455 let v3 = IntervalDayTime::new(-15, 0);
1456 test_take_primitive_arrays::<IntervalDayTimeType>(
1457 vec![Some(v1), None, Some(v2), Some(v3), None],
1458 &index,
1459 None,
1460 vec![Some(v3), None, None, Some(v3), Some(v2)],
1461 )
1462 .unwrap();
1463
1464 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1466 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1467 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1468 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1469 vec![Some(v1), None, Some(v2), Some(v3), None],
1470 &index,
1471 None,
1472 vec![Some(v3), None, None, Some(v3), Some(v2)],
1473 )
1474 .unwrap();
1475
1476 test_take_primitive_arrays::<DurationSecondType>(
1478 vec![Some(0), None, Some(2), Some(-15), None],
1479 &index,
1480 None,
1481 vec![Some(-15), None, None, Some(-15), Some(2)],
1482 )
1483 .unwrap();
1484
1485 test_take_primitive_arrays::<DurationMillisecondType>(
1487 vec![Some(0), None, Some(2), Some(-15), None],
1488 &index,
1489 None,
1490 vec![Some(-15), None, None, Some(-15), Some(2)],
1491 )
1492 .unwrap();
1493
1494 test_take_primitive_arrays::<DurationMicrosecondType>(
1496 vec![Some(0), None, Some(2), Some(-15), None],
1497 &index,
1498 None,
1499 vec![Some(-15), None, None, Some(-15), Some(2)],
1500 )
1501 .unwrap();
1502
1503 test_take_primitive_arrays::<DurationNanosecondType>(
1505 vec![Some(0), None, Some(2), Some(-15), None],
1506 &index,
1507 None,
1508 vec![Some(-15), None, None, Some(-15), Some(2)],
1509 )
1510 .unwrap();
1511
1512 test_take_primitive_arrays::<Float32Type>(
1514 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1515 &index,
1516 None,
1517 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1518 )
1519 .unwrap();
1520
1521 test_take_primitive_arrays::<Float64Type>(
1523 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1524 &index,
1525 None,
1526 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1527 )
1528 .unwrap();
1529 }
1530
1531 #[test]
1532 fn test_take_preserve_timezone() {
1533 let index = Int64Array::from(vec![Some(0), None]);
1534
1535 let input = TimestampNanosecondArray::from(vec![
1536 1_639_715_368_000_000_000,
1537 1_639_715_368_000_000_000,
1538 ])
1539 .with_timezone("UTC".to_string());
1540 let result = take(&input, &index, None).unwrap();
1541 match result.data_type() {
1542 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1543 assert_eq!(tz.clone(), Some("UTC".into()))
1544 }
1545 _ => panic!(),
1546 }
1547 }
1548
1549 #[test]
1550 fn test_take_impl_primitive_with_int64_indices() {
1551 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1552
1553 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1555 vec![Some(0), None, Some(2), Some(3), None],
1556 &index,
1557 None,
1558 vec![Some(3), None, None, Some(3), Some(2)],
1559 );
1560
1561 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1563 vec![Some(0), None, Some(2), Some(-15), None],
1564 &index,
1565 None,
1566 vec![Some(-15), None, None, Some(-15), Some(2)],
1567 );
1568
1569 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1571 vec![Some(0), None, Some(2), Some(3), None],
1572 &index,
1573 None,
1574 vec![Some(3), None, None, Some(3), Some(2)],
1575 );
1576
1577 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1579 vec![Some(0), None, Some(2), Some(-15), None],
1580 &index,
1581 None,
1582 vec![Some(-15), None, None, Some(-15), Some(2)],
1583 );
1584
1585 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1587 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1588 &index,
1589 None,
1590 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1591 );
1592 }
1593
1594 #[test]
1595 fn test_take_impl_primitive_with_uint8_indices() {
1596 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1597
1598 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1600 vec![Some(0), None, Some(2), Some(3), None],
1601 &index,
1602 None,
1603 vec![Some(3), None, None, Some(3), Some(2)],
1604 );
1605
1606 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1608 vec![Some(0), None, Some(2), Some(-15), None],
1609 &index,
1610 None,
1611 vec![Some(-15), None, None, Some(-15), Some(2)],
1612 );
1613
1614 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1616 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1617 &index,
1618 None,
1619 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1620 );
1621 }
1622
1623 #[test]
1624 fn test_take_bool() {
1625 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1626 test_take_boolean_arrays(
1628 vec![Some(false), None, Some(true), Some(false), None],
1629 &index,
1630 None,
1631 vec![Some(false), None, None, Some(false), Some(true)],
1632 );
1633 }
1634
1635 #[test]
1636 fn test_take_bool_nullable_index() {
1637 let index_data = ArrayData::try_new(
1639 DataType::UInt32,
1640 6,
1641 Some(Buffer::from_iter(vec![
1642 false, true, false, true, false, true,
1643 ])),
1644 0,
1645 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1646 vec![],
1647 )
1648 .unwrap();
1649 let index = UInt32Array::from(index_data);
1650 test_take_boolean_arrays(
1651 vec![Some(true), None, Some(false)],
1652 &index,
1653 None,
1654 vec![None, Some(true), None, None, None, Some(false)],
1655 );
1656 }
1657
1658 #[test]
1659 fn test_take_bool_nullable_index_nonnull_values() {
1660 let index_data = ArrayData::try_new(
1662 DataType::UInt32,
1663 6,
1664 Some(Buffer::from_iter(vec![
1665 false, true, false, true, false, true,
1666 ])),
1667 0,
1668 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1669 vec![],
1670 )
1671 .unwrap();
1672 let index = UInt32Array::from(index_data);
1673 test_take_boolean_arrays(
1674 vec![Some(true), Some(true), Some(false)],
1675 &index,
1676 None,
1677 vec![None, Some(true), None, Some(true), None, Some(false)],
1678 );
1679 }
1680
1681 #[test]
1682 fn test_take_bool_with_offset() {
1683 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1684 let index = index.slice(2, 4);
1685 let index = index
1686 .as_any()
1687 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1688 .unwrap();
1689
1690 test_take_boolean_arrays(
1692 vec![Some(false), None, Some(true), Some(false), None],
1693 index,
1694 None,
1695 vec![None, Some(false), Some(true), None],
1696 );
1697 }
1698
1699 fn _test_take_string<'a, K>()
1700 where
1701 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1702 {
1703 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1704
1705 let array = K::from(vec![
1706 Some("one"),
1707 None,
1708 Some("three"),
1709 Some("four"),
1710 Some("five"),
1711 ]);
1712 let actual = take(&array, &index, None).unwrap();
1713 assert_eq!(actual.len(), index.len());
1714
1715 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1716
1717 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1718
1719 assert_eq!(actual, &expected);
1720 }
1721
1722 #[test]
1723 fn test_take_string() {
1724 _test_take_string::<StringArray>()
1725 }
1726
1727 #[test]
1728 fn test_take_large_string() {
1729 _test_take_string::<LargeStringArray>()
1730 }
1731
1732 #[test]
1733 fn test_take_slice_string() {
1734 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1735 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1736 let indices_slice = indices.slice(1, 4);
1737 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1738 let result = take(&strings, &indices_slice, None).unwrap();
1739 assert_eq!(result.as_ref(), &expected);
1740 }
1741
1742 #[test]
1747 fn test_take_bytes_sliced_values() {
1748 let values = StringArray::from(vec![
1749 Some("aaa"),
1750 Some("bbb"),
1751 None,
1752 Some("ccccc"),
1753 Some("dd"),
1754 None,
1755 Some("eeee"),
1756 ]);
1757 let sliced = values.slice(2, 5);
1760
1761 let indices = Int32Array::from(vec![1, 2, 4, 1]);
1764 let result = take(&sliced, &indices, None).unwrap();
1765 let expected =
1766 StringArray::from(vec![Some("ccccc"), Some("dd"), Some("eeee"), Some("ccccc")]);
1767 assert_eq!(result.as_string::<i32>(), &expected);
1768
1769 let indices = Int32Array::from(vec![Some(1), None, Some(0), Some(4), Some(3)]);
1772 let result = take(&sliced, &indices, None).unwrap();
1773 let expected = StringArray::from(vec![Some("ccccc"), None, None, Some("eeee"), None]);
1774 assert_eq!(result.as_string::<i32>(), &expected);
1775 }
1776
1777 fn _test_byte_view<T>()
1778 where
1779 T: ByteViewType,
1780 str: AsRef<T::Native>,
1781 T::Native: PartialEq,
1782 {
1783 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1784 let array = {
1785 let mut builder = GenericByteViewBuilder::<T>::new();
1787 builder.append_value("hello");
1788 builder.append_value("world");
1789 builder.append_null();
1790 builder.append_value("large payload over 12 bytes");
1791 builder.append_value("lulu");
1792 builder.finish()
1793 };
1794
1795 let actual = take(&array, &index, None).unwrap();
1796
1797 assert_eq!(actual.len(), index.len());
1798
1799 let expected = {
1800 let mut builder = GenericByteViewBuilder::<T>::new();
1802 builder.append_value("large payload over 12 bytes");
1803 builder.append_null();
1804 builder.append_value("world");
1805 builder.append_value("large payload over 12 bytes");
1806 builder.append_value("lulu");
1807 builder.append_null();
1808 builder.finish()
1809 };
1810
1811 assert_eq!(actual.as_ref(), &expected);
1812 }
1813
1814 #[test]
1815 fn test_take_string_view() {
1816 _test_byte_view::<StringViewType>()
1817 }
1818
1819 #[test]
1820 fn test_take_binary_view() {
1821 _test_byte_view::<BinaryViewType>()
1822 }
1823
1824 macro_rules! test_take_list {
1825 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1826 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1828 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1830 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1831 let list_data_type =
1833 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1834 let list_data = ArrayData::builder(list_data_type.clone())
1835 .len(4)
1836 .add_buffer(value_offsets)
1837 .add_child_data(value_data)
1838 .build()
1839 .unwrap();
1840 let list_array = $list_array_type::from(list_data);
1841
1842 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1844
1845 let a = take(&list_array, &index, None).unwrap();
1846 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1847
1848 let expected_data = Int32Array::from(vec![
1851 Some(2),
1852 Some(3),
1853 Some(-1),
1854 Some(-2),
1855 Some(-1),
1856 Some(0),
1857 Some(0),
1858 Some(0),
1859 ])
1860 .into_data();
1861 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1863 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1864 let expected_list_data = ArrayData::builder(list_data_type)
1866 .len(5)
1867 .nulls(index.nulls().cloned())
1869 .add_buffer(expected_offsets)
1870 .add_child_data(expected_data)
1871 .build()
1872 .unwrap();
1873 let expected_list_array = $list_array_type::from(expected_list_data);
1874
1875 assert_eq!(a, &expected_list_array);
1876 }};
1877 }
1878
1879 macro_rules! test_take_list_with_value_nulls {
1880 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1881 let value_data = Int32Array::from(vec![
1883 Some(0),
1884 None,
1885 Some(0),
1886 Some(-1),
1887 Some(-2),
1888 Some(3),
1889 None,
1890 Some(5),
1891 None,
1892 ])
1893 .into_data();
1894 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1896 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1897 let list_data_type =
1899 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1900 let list_data = ArrayData::builder(list_data_type.clone())
1901 .len(4)
1902 .add_buffer(value_offsets)
1903 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1904 .add_child_data(value_data)
1905 .build()
1906 .unwrap();
1907 let list_array = $list_array_type::from(list_data);
1908
1909 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1911
1912 let a = take(&list_array, &index, None).unwrap();
1913 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1914
1915 let expected_data = Int32Array::from(vec![
1918 None,
1919 Some(-1),
1920 Some(-2),
1921 Some(3),
1922 Some(5),
1923 None,
1924 Some(0),
1925 None,
1926 Some(0),
1927 ])
1928 .into_data();
1929 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1931 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1932 let expected_list_data = ArrayData::builder(list_data_type)
1934 .len(5)
1935 .nulls(index.nulls().cloned())
1937 .add_buffer(expected_offsets)
1938 .add_child_data(expected_data)
1939 .build()
1940 .unwrap();
1941 let expected_list_array = $list_array_type::from(expected_list_data);
1942
1943 assert_eq!(a, &expected_list_array);
1944 }};
1945 }
1946
1947 macro_rules! test_take_list_with_nulls {
1948 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1949 let value_data = Int32Array::from(vec![
1951 Some(0),
1952 None,
1953 Some(0),
1954 Some(-1),
1955 Some(-2),
1956 Some(3),
1957 Some(5),
1958 None,
1959 ])
1960 .into_data();
1961 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1963 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1964 let list_data_type =
1966 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1967 let list_data = ArrayData::builder(list_data_type.clone())
1968 .len(4)
1969 .add_buffer(value_offsets)
1970 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1971 .add_child_data(value_data)
1972 .build()
1973 .unwrap();
1974 let list_array = $list_array_type::from(list_data);
1975
1976 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1978
1979 let a = take(&list_array, &index, None).unwrap();
1980 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1981
1982 let expected_data = Int32Array::from(vec![
1985 Some(-1),
1986 Some(-2),
1987 Some(3),
1988 Some(5),
1989 None,
1990 Some(0),
1991 None,
1992 Some(0),
1993 ])
1994 .into_data();
1995 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1997 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1998 let mut null_bits: [u8; 1] = [0; 1];
2000 bit_util::set_bit(&mut null_bits, 2);
2001 bit_util::set_bit(&mut null_bits, 3);
2002 bit_util::set_bit(&mut null_bits, 4);
2003 let expected_list_data = ArrayData::builder(list_data_type)
2004 .len(5)
2005 .null_bit_buffer(Some(Buffer::from(null_bits)))
2007 .add_buffer(expected_offsets)
2008 .add_child_data(expected_data)
2009 .build()
2010 .unwrap();
2011 let expected_list_array = $list_array_type::from(expected_list_data);
2012
2013 assert_eq!(a, &expected_list_array);
2014 }};
2015 }
2016
2017 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
2018 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
2019 take_indices: Vec<Option<usize>>,
2020 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
2021 mapper: F,
2022 ) where
2023 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
2024 {
2025 let mut list_view_array =
2026 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2027
2028 for value in values {
2029 list_view_array.append_option(value);
2030 }
2031 let list_view_array = list_view_array.finish();
2032 let list_view_array = mapper(list_view_array);
2033
2034 let mut indices = UInt64Builder::new();
2035 for idx in take_indices {
2036 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
2037 }
2038 let indices = indices.finish();
2039
2040 let taken = take(&list_view_array, &indices, None)
2041 .unwrap()
2042 .as_list_view()
2043 .clone();
2044
2045 let mut expected_array =
2046 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2047 for value in expected {
2048 expected_array.append_option(value);
2049 }
2050 let expected_array = expected_array.finish();
2051
2052 assert_eq!(taken, expected_array);
2053 }
2054
2055 macro_rules! list_view_test_case {
2056 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
2057 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
2058 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
2059 }};
2060 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
2061 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
2062 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
2063 }};
2064 }
2065
2066 fn do_take_fixed_size_list_test<T>(
2067 length: <Int32Type as ArrowPrimitiveType>::Native,
2068 input_data: Vec<Option<Vec<Option<T::Native>>>>,
2069 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
2070 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
2071 ) where
2072 T: ArrowPrimitiveType,
2073 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2074 {
2075 let indices = UInt32Array::from(indices);
2076
2077 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
2078
2079 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
2080
2081 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
2082
2083 assert_eq!(&output, &expected)
2084 }
2085
2086 #[test]
2087 fn test_take_list() {
2088 test_take_list!(i32, List, ListArray);
2089 }
2090
2091 #[test]
2092 fn test_take_large_list() {
2093 test_take_list!(i64, LargeList, LargeListArray);
2094 }
2095
2096 #[test]
2097 fn test_take_list_with_value_nulls() {
2098 test_take_list_with_value_nulls!(i32, List, ListArray);
2099 }
2100
2101 #[test]
2102 fn test_take_large_list_with_value_nulls() {
2103 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
2104 }
2105
2106 #[test]
2107 fn test_test_take_list_with_nulls() {
2108 test_take_list_with_nulls!(i32, List, ListArray);
2109 }
2110
2111 #[test]
2112 fn test_test_take_large_list_with_nulls() {
2113 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2114 }
2115
2116 #[test]
2117 fn test_test_take_list_view_reversed() {
2118 list_view_test_case! {
2120 values: vec![
2121 Some(vec![Some(1), None, Some(3)]),
2122 None,
2123 Some(vec![Some(7), Some(8), None]),
2124 ],
2125 indices: vec![Some(2), Some(1), Some(0)],
2126 expected: vec![
2127 Some(vec![Some(7), Some(8), None]),
2128 None,
2129 Some(vec![Some(1), None, Some(3)]),
2130 ]
2131 }
2132 }
2133
2134 #[test]
2135 fn test_take_list_view_null_indices() {
2136 list_view_test_case! {
2138 values: vec![
2139 Some(vec![Some(1), None, Some(3)]),
2140 None,
2141 Some(vec![Some(7), Some(8), None]),
2142 ],
2143 indices: vec![None, Some(0), None],
2144 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2145 }
2146 }
2147
2148 #[test]
2149 fn test_take_list_view_null_values() {
2150 list_view_test_case! {
2152 values: vec![
2153 Some(vec![Some(1), None, Some(3)]),
2154 None,
2155 Some(vec![Some(7), Some(8), None]),
2156 ],
2157 indices: vec![Some(1), Some(1), Some(1), None, None],
2158 expected: vec![None; 5]
2159 }
2160 }
2161
2162 #[test]
2163 fn test_take_list_view_sliced() {
2164 list_view_test_case! {
2166 values: vec![
2167 Some(vec![Some(1)]),
2168 None,
2169 None,
2170 Some(vec![Some(2), Some(3)]),
2171 Some(vec![Some(4), Some(5)]),
2172 None,
2173 ],
2174 transform: |l| l.slice(2, 4),
2175 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2176 expected: vec![
2177 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2178 ]
2179 }
2180 }
2181
2182 #[test]
2183 fn test_take_fixed_size_list() {
2184 do_take_fixed_size_list_test::<Int32Type>(
2185 3,
2186 vec![
2187 Some(vec![None, Some(1), Some(2)]),
2188 Some(vec![Some(3), Some(4), None]),
2189 Some(vec![Some(6), Some(7), Some(8)]),
2190 ],
2191 vec![2, 1, 0],
2192 vec![
2193 Some(vec![Some(6), Some(7), Some(8)]),
2194 Some(vec![Some(3), Some(4), None]),
2195 Some(vec![None, Some(1), Some(2)]),
2196 ],
2197 );
2198
2199 do_take_fixed_size_list_test::<UInt8Type>(
2200 1,
2201 vec![
2202 Some(vec![Some(1)]),
2203 Some(vec![Some(2)]),
2204 Some(vec![Some(3)]),
2205 Some(vec![Some(4)]),
2206 Some(vec![Some(5)]),
2207 Some(vec![Some(6)]),
2208 Some(vec![Some(7)]),
2209 Some(vec![Some(8)]),
2210 ],
2211 vec![2, 7, 0],
2212 vec![
2213 Some(vec![Some(3)]),
2214 Some(vec![Some(8)]),
2215 Some(vec![Some(1)]),
2216 ],
2217 );
2218
2219 do_take_fixed_size_list_test::<UInt64Type>(
2220 3,
2221 vec![
2222 Some(vec![Some(10), Some(11), Some(12)]),
2223 Some(vec![Some(13), Some(14), Some(15)]),
2224 None,
2225 Some(vec![Some(16), Some(17), Some(18)]),
2226 ],
2227 vec![3, 2, 1, 2, 0],
2228 vec![
2229 Some(vec![Some(16), Some(17), Some(18)]),
2230 None,
2231 Some(vec![Some(13), Some(14), Some(15)]),
2232 None,
2233 Some(vec![Some(10), Some(11), Some(12)]),
2234 ],
2235 );
2236 }
2237
2238 #[test]
2239 fn test_take_fixed_size_binary_with_nulls_indices() {
2240 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2241 [
2242 Some(vec![0x01, 0x01, 0x01, 0x01]),
2243 Some(vec![0x02, 0x02, 0x02, 0x02]),
2244 Some(vec![0x03, 0x03, 0x03, 0x03]),
2245 Some(vec![0x04, 0x04, 0x04, 0x04]),
2246 ]
2247 .into_iter(),
2248 4,
2249 )
2250 .unwrap();
2251
2252 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2254
2255 let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2256 assert_eq!(result.len(), 4);
2257 assert_eq!(result.null_count(), 2);
2258 assert_eq!(
2259 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2260 vec![true, false, false, true]
2261 );
2262 }
2263
2264 #[test]
2268 fn test_take_fixed_size_binary_with_nulls_indices_not_optimized_length() {
2269 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2270 [
2271 Some(vec![0x01, 0x01, 0x01, 0x01, 0x01]),
2272 Some(vec![0x02, 0x02, 0x02, 0x02, 0x01]),
2273 Some(vec![0x03, 0x03, 0x03, 0x03, 0x01]),
2274 Some(vec![0x04, 0x04, 0x04, 0x04, 0x01]),
2275 ]
2276 .into_iter(),
2277 5,
2278 )
2279 .unwrap();
2280
2281 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2283
2284 let result = take_fixed_size_binary(&fsb, &indices, 5).unwrap();
2285 assert_eq!(result.len(), 4);
2286 assert_eq!(result.null_count(), 2);
2287 assert_eq!(
2288 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2289 vec![true, false, false, true]
2290 );
2291 }
2292
2293 #[test]
2294 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2295 fn test_take_list_out_of_bounds() {
2296 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2298 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2300 let list_data_type =
2302 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2303 let list_data = ArrayData::builder(list_data_type)
2304 .len(3)
2305 .add_buffer(value_offsets)
2306 .add_child_data(value_data)
2307 .build()
2308 .unwrap();
2309 let list_array = ListArray::from(list_data);
2310
2311 let index = UInt32Array::from(vec![1000]);
2312
2313 take(&list_array, &index, None).unwrap();
2316 }
2317
2318 #[test]
2319 fn test_take_map() {
2320 let values = Int32Array::from(vec![1, 2, 3, 4]);
2321 let array =
2322 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2323 .unwrap();
2324
2325 let index = UInt32Array::from(vec![0]);
2326
2327 let result = take(&array, &index, None).unwrap();
2328 let expected: ArrayRef = Arc::new(
2329 MapArray::new_from_strings(
2330 vec!["a", "b", "c"].into_iter(),
2331 &values.slice(0, 3),
2332 &[0, 3],
2333 )
2334 .unwrap(),
2335 );
2336 assert_eq!(&expected, &result);
2337 }
2338
2339 #[test]
2340 fn test_take_struct() {
2341 let array = create_test_struct(vec![
2342 Some((Some(true), Some(42))),
2343 Some((Some(false), Some(28))),
2344 Some((Some(false), Some(19))),
2345 Some((Some(true), Some(31))),
2346 None,
2347 ]);
2348
2349 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2350 let actual = take(&array, &index, None).unwrap();
2351 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2352 assert_eq!(index.len(), actual.len());
2353 assert_eq!(1, actual.null_count());
2354
2355 let expected = create_test_struct(vec![
2356 Some((Some(true), Some(42))),
2357 Some((Some(true), Some(31))),
2358 Some((Some(false), Some(28))),
2359 Some((Some(true), Some(42))),
2360 Some((Some(false), Some(19))),
2361 None,
2362 ]);
2363
2364 assert_eq!(&expected, actual);
2365
2366 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2367 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2368 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2369 let actual = take(&empty_struct_arr, &index, None).unwrap();
2370
2371 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2372 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2373 assert_eq!(&expected_struct_arr, actual.as_struct());
2374 }
2375
2376 #[test]
2377 fn test_take_struct_with_null_indices() {
2378 let array = create_test_struct(vec![
2379 Some((Some(true), Some(42))),
2380 Some((Some(false), Some(28))),
2381 Some((Some(false), Some(19))),
2382 Some((Some(true), Some(31))),
2383 None,
2384 ]);
2385
2386 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2387 let actual = take(&array, &index, None).unwrap();
2388 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2389 assert_eq!(index.len(), actual.len());
2390 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2393 None,
2394 Some((Some(true), Some(31))),
2395 Some((Some(false), Some(28))),
2396 None,
2397 Some((Some(true), Some(42))),
2398 None,
2399 ]);
2400
2401 assert_eq!(&expected, actual);
2402 }
2403
2404 #[test]
2405 fn test_take_out_of_bounds() {
2406 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2407 let take_opt = TakeOptions { check_bounds: true };
2408
2409 let result = test_take_primitive_arrays::<Int64Type>(
2411 vec![Some(0), None, Some(2), Some(3), None],
2412 &index,
2413 Some(take_opt),
2414 vec![None],
2415 );
2416 assert!(result.is_err());
2417 }
2418
2419 #[test]
2420 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2421 fn test_take_out_of_bounds_panic() {
2422 let index = UInt32Array::from(vec![Some(1000)]);
2423
2424 test_take_primitive_arrays::<Int64Type>(
2425 vec![Some(0), Some(1), Some(2), Some(3)],
2426 &index,
2427 None,
2428 vec![None],
2429 )
2430 .unwrap();
2431 }
2432
2433 #[test]
2434 fn test_null_array_smaller_than_indices() {
2435 let values = NullArray::new(2);
2436 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2437
2438 let result = take(&values, &indices, None).unwrap();
2439 let expected: ArrayRef = Arc::new(NullArray::new(3));
2440 assert_eq!(&result, &expected);
2441 }
2442
2443 #[test]
2444 fn test_null_array_larger_than_indices() {
2445 let values = NullArray::new(5);
2446 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2447
2448 let result = take(&values, &indices, None).unwrap();
2449 let expected: ArrayRef = Arc::new(NullArray::new(3));
2450 assert_eq!(&result, &expected);
2451 }
2452
2453 #[test]
2454 fn test_null_array_indices_out_of_bounds() {
2455 let values = NullArray::new(5);
2456 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2457
2458 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2459 assert_eq!(
2460 result.unwrap_err().to_string(),
2461 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2462 );
2463 }
2464
2465 #[test]
2466 fn test_take_dict() {
2467 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2468
2469 dict_builder.append("foo").unwrap();
2470 dict_builder.append("bar").unwrap();
2471 dict_builder.append("").unwrap();
2472 dict_builder.append_null();
2473 dict_builder.append("foo").unwrap();
2474 dict_builder.append("bar").unwrap();
2475 dict_builder.append("bar").unwrap();
2476 dict_builder.append("foo").unwrap();
2477
2478 let array = dict_builder.finish();
2479 let dict_values = array.values().clone();
2480 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2481
2482 let indices = UInt32Array::from(vec![
2483 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2491
2492 let result = take(&array, &indices, None).unwrap();
2493 let result = result
2494 .as_any()
2495 .downcast_ref::<DictionaryArray<Int16Type>>()
2496 .unwrap();
2497
2498 let result_values: StringArray = result.values().to_data().into();
2499
2500 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2502 assert_eq!(&expected_values, dict_values);
2503 assert_eq!(&expected_values, &result_values);
2504
2505 let expected_keys = Int16Array::from(vec![
2506 Some(0),
2507 Some(0),
2508 None,
2509 Some(1),
2510 Some(1),
2511 Some(2),
2512 None,
2513 ]);
2514 assert_eq!(result.keys(), &expected_keys);
2515 }
2516
2517 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2518 where
2519 S: OffsetSizeTrait + 'static,
2520 T: ArrowPrimitiveType,
2521 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2522 {
2523 GenericListArray::from_iter_primitive::<T, _, _>(
2524 data.iter()
2525 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2526 )
2527 }
2528
2529 fn test_take_sliced_list_generic<S: OffsetSizeTrait + 'static>() {
2530 let list = build_generic_list::<S, Int32Type>(vec![
2531 Some(vec![0, 1]),
2532 Some(vec![2, 3, 4]),
2533 None,
2534 Some(vec![]),
2535 Some(vec![5, 6]),
2536 Some(vec![7]),
2537 ]);
2538 let sliced = list.slice(1, 4);
2539 let indices = UInt32Array::from(vec![Some(3), Some(0), None, Some(2), Some(1)]);
2540
2541 let taken = take(&sliced, &indices, None).unwrap();
2542 let taken = taken.as_list::<S>();
2543
2544 let expected = build_generic_list::<S, Int32Type>(vec![
2545 Some(vec![5, 6]),
2546 Some(vec![2, 3, 4]),
2547 None,
2548 Some(vec![]),
2549 None,
2550 ]);
2551
2552 assert_eq!(taken, &expected);
2553 }
2554
2555 fn test_take_sliced_list_with_value_nulls_generic<S: OffsetSizeTrait + 'static>() {
2556 let list = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
2557 Some(vec![Some(10)]),
2558 Some(vec![None, Some(1)]),
2559 None,
2560 Some(vec![Some(2), None]),
2561 Some(vec![]),
2562 Some(vec![Some(3)]),
2563 ]);
2564 let sliced = list.slice(1, 4);
2565 let indices = UInt32Array::from(vec![Some(2), Some(0), None, Some(3), Some(1)]);
2566
2567 let taken = take(&sliced, &indices, None).unwrap();
2568 let taken = taken.as_list::<S>();
2569
2570 let expected = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
2571 Some(vec![Some(2), None]),
2572 Some(vec![None, Some(1)]),
2573 None,
2574 Some(vec![]),
2575 None,
2576 ]);
2577
2578 assert_eq!(taken, &expected);
2579 }
2580
2581 #[test]
2582 fn test_take_sliced_list() {
2583 test_take_sliced_list_generic::<i32>();
2584 }
2585
2586 #[test]
2587 fn test_take_sliced_large_list() {
2588 test_take_sliced_list_generic::<i64>();
2589 }
2590
2591 #[test]
2592 fn test_take_sliced_list_with_value_nulls() {
2593 test_take_sliced_list_with_value_nulls_generic::<i32>();
2594 }
2595
2596 #[test]
2597 fn test_take_sliced_large_list_with_value_nulls() {
2598 test_take_sliced_list_with_value_nulls_generic::<i64>();
2599 }
2600
2601 #[test]
2602 fn test_take_runs() {
2603 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2604
2605 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2606 builder.extend(logical_array.into_iter().map(Some));
2607 let run_array = builder.finish();
2608
2609 let take_indices: PrimitiveArray<Int32Type> =
2610 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2611
2612 let take_out = take_run(&run_array, &take_indices).unwrap();
2613
2614 assert_eq!(take_out.len(), 7);
2615 assert_eq!(take_out.run_ends().len(), 7);
2616 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2617
2618 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2619 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2620 }
2621
2622 #[test]
2623 fn test_take_runs_sliced() {
2624 let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2625
2626 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2627 builder.extend(logical_array.into_iter().map(Some));
2628 let run_array = builder.finish();
2629
2630 let run_array = run_array.slice(4, 6); let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2633
2634 let result = take_run(&run_array, &take_indices).unwrap();
2635 let result = result.downcast::<Int32Array>().unwrap();
2636
2637 let expected = vec![3, 5, 5, 3, 4];
2638 let actual = result.into_iter().flatten().collect::<Vec<_>>();
2639
2640 assert_eq!(expected, actual);
2641 }
2642
2643 #[test]
2644 fn test_take_value_index_from_fixed_list() {
2645 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2646 vec![
2647 Some(vec![Some(1), Some(2), None]),
2648 Some(vec![Some(4), None, Some(6)]),
2649 None,
2650 Some(vec![None, Some(8), Some(9)]),
2651 ],
2652 3,
2653 );
2654
2655 let indices = UInt32Array::from(vec![2, 1, 0]);
2656 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2657
2658 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2659
2660 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2661 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2662
2663 assert_eq!(
2664 indexed,
2665 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2666 );
2667 }
2668
2669 #[test]
2670 fn test_take_null_indices() {
2671 let indices = Int32Array::new(
2673 vec![1, 2, 400, 400].into(),
2674 Some(NullBuffer::from(vec![true, true, false, false])),
2675 );
2676 let values = Int32Array::from(vec![1, 23, 4, 5]);
2677 let r = take(&values, &indices, None).unwrap();
2678 let values = r
2679 .as_primitive::<Int32Type>()
2680 .into_iter()
2681 .collect::<Vec<_>>();
2682 assert_eq!(&values, &[Some(23), Some(4), None, None])
2683 }
2684
2685 #[test]
2686 fn test_take_fixed_size_list_null_indices() {
2687 let indices = Int32Array::from_iter([Some(0), None]);
2688 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2689 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2690 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2691
2692 let r = take(&values, &indices, None).unwrap();
2693 let values = r
2694 .as_fixed_size_list()
2695 .values()
2696 .as_primitive::<Int32Type>()
2697 .into_iter()
2698 .collect::<Vec<_>>();
2699 assert_eq!(values, &[Some(0), Some(1), None, None])
2700 }
2701
2702 #[test]
2703 fn test_take_bytes_null_indices() {
2704 let indices = Int32Array::new(
2705 vec![0, 1, 400, 400].into(),
2706 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2707 );
2708 let values = StringArray::from(vec![Some("foo"), None]);
2709 let r = take(&values, &indices, None).unwrap();
2710 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2711 assert_eq!(&values, &[Some("foo"), None, None, None])
2712 }
2713
2714 #[test]
2715 fn test_take_union_sparse() {
2716 let structs = create_test_struct(vec![
2717 Some((Some(true), Some(42))),
2718 Some((Some(false), Some(28))),
2719 Some((Some(false), Some(19))),
2720 Some((Some(true), Some(31))),
2721 None,
2722 ]);
2723 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2724 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2725
2726 let union_fields = [
2727 (
2728 0,
2729 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2730 ),
2731 (
2732 1,
2733 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2734 ),
2735 ]
2736 .into_iter()
2737 .collect();
2738 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2739 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2740
2741 let indices = vec![0, 3, 1, 0, 2, 4];
2742 let index = UInt32Array::from(indices.clone());
2743 let actual = take(&array, &index, None).unwrap();
2744 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2745 let strings = actual.child(1);
2746 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2747
2748 let actual = strings.iter().collect::<Vec<_>>();
2749 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2750 assert_eq!(expected, actual);
2751 }
2752
2753 #[test]
2754 fn test_take_union_dense() {
2755 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2756 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2757 let ints = vec![10, 20, 30, 40];
2758 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2759
2760 let indices = vec![0, 3, 1, 0, 2, 4];
2761
2762 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2763 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2764 let taken_ints = vec![10, 20, 10, 30];
2765 let taken_strings = vec![Some("a"), None];
2766
2767 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2768 let offsets = <ScalarBuffer<i32>>::from(offsets);
2769 let ints = UInt32Array::from(ints);
2770 let strings = StringArray::from(strings);
2771
2772 let union_fields = [
2773 (
2774 0,
2775 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2776 ),
2777 (
2778 1,
2779 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2780 ),
2781 ]
2782 .into_iter()
2783 .collect();
2784
2785 let array = UnionArray::try_new(
2786 union_fields,
2787 type_ids,
2788 Some(offsets),
2789 vec![Arc::new(ints), Arc::new(strings)],
2790 )
2791 .unwrap();
2792
2793 let index = UInt32Array::from(indices);
2794
2795 let actual = take(&array, &index, None).unwrap();
2796 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2797
2798 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2799 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2800 assert_eq!(
2801 UInt32Array::from(actual.child(0).to_data()),
2802 UInt32Array::from(taken_ints)
2803 );
2804 assert_eq!(
2805 StringArray::from(actual.child(1).to_data()),
2806 StringArray::from(taken_strings)
2807 );
2808 }
2809
2810 #[test]
2811 fn test_take_union_dense_using_builder() {
2812 let mut builder = UnionBuilder::new_dense();
2813
2814 builder.append::<Int32Type>("a", 1).unwrap();
2815 builder.append::<Float64Type>("b", 3.0).unwrap();
2816 builder.append::<Int32Type>("a", 4).unwrap();
2817 builder.append::<Int32Type>("a", 5).unwrap();
2818 builder.append::<Float64Type>("b", 2.0).unwrap();
2819
2820 let union = builder.build().unwrap();
2821
2822 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2823
2824 let mut builder = UnionBuilder::new_dense();
2825
2826 builder.append::<Int32Type>("a", 4).unwrap();
2827 builder.append::<Int32Type>("a", 1).unwrap();
2828 builder.append::<Float64Type>("b", 3.0).unwrap();
2829 builder.append::<Int32Type>("a", 4).unwrap();
2830
2831 let taken = builder.build().unwrap();
2832
2833 assert_eq!(
2834 taken.to_data(),
2835 take(&union, &indices, None).unwrap().to_data()
2836 );
2837 }
2838
2839 #[test]
2840 fn test_take_union_dense_all_match_issue_6206() {
2841 let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2842 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2843
2844 let array = UnionArray::try_new(
2845 fields,
2846 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2847 Some(ScalarBuffer::from_iter(0_i32..5)),
2848 vec![ints],
2849 )
2850 .unwrap();
2851
2852 let indicies = Int64Array::from(vec![0, 2, 4]);
2853 let array = take(&array, &indicies, None).unwrap();
2854 assert_eq!(array.len(), 3);
2855 }
2856
2857 fn offset_overflow_fixture() -> (StringArray, usize) {
2862 let value_len = 1_000_000;
2863 let values = StringArray::from(vec![Some("a".repeat(value_len))]);
2864 let n = i32::MAX as usize / value_len + 1;
2865 (values, n)
2866 }
2867
2868 #[test]
2869 fn test_take_bytes_offset_overflow() {
2870 let (values, n) = offset_overflow_fixture();
2871 let indices = Int32Array::from(vec![0; n]);
2872 assert!(matches!(
2873 take(&values, &indices, None),
2874 Err(ArrowError::OffsetOverflowError(_))
2875 ));
2876 }
2877
2878 #[test]
2881 fn test_take_bytes_offset_overflow_nullable() {
2882 let (values, n) = offset_overflow_fixture();
2883 let validity =
2886 NullBuffer::from_iter(std::iter::once(false).chain(std::iter::repeat_n(true, n)));
2887 let indices = Int32Array::new(vec![0i32; n + 1].into(), Some(validity));
2888
2889 assert!(matches!(
2890 take(&values, &indices, None),
2891 Err(ArrowError::OffsetOverflowError(_))
2892 ));
2893 }
2894
2895 #[test]
2896 fn test_take_run_empty_indices() {
2897 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2898 builder.extend([Some(1), Some(1), Some(2), Some(2)]);
2899 let run_array = builder.finish();
2900
2901 let logical_indices: PrimitiveArray<Int32Type> = PrimitiveArray::from(Vec::<i32>::new());
2902
2903 let result = take_impl(&run_array, &logical_indices).expect("take_run with empty indices");
2904
2905 assert_eq!(result.len(), 0);
2907 assert_eq!(result.null_count(), 0);
2908
2909 let run_result = result
2912 .as_any()
2913 .downcast_ref::<RunArray<Int32Type>>()
2914 .expect("result should be a RunArray");
2915 assert_eq!(run_result.run_ends().len(), 0);
2916 assert_eq!(run_result.values().len(), 0);
2917 }
2918}