1use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
28 bit_util,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num_traits::{One, Zero};
34
35pub fn take(
87 values: &dyn Array,
88 indices: &dyn Array,
89 options: Option<TakeOptions>,
90) -> Result<ArrayRef, ArrowError> {
91 let options = options.unwrap_or_default();
92 downcast_integer_array!(
93 indices => {
94 if options.check_bounds {
95 check_bounds(values.len(), indices)?;
96 }
97 let indices = indices.to_indices();
98 take_impl(values, &indices)
99 },
100 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
101 )
102}
103
104pub fn take_arrays(
153 arrays: &[ArrayRef],
154 indices: &dyn Array,
155 options: Option<TakeOptions>,
156) -> Result<Vec<ArrayRef>, ArrowError> {
157 arrays
158 .iter()
159 .map(|array| take(array.as_ref(), indices, options.clone()))
160 .collect()
161}
162
163fn check_bounds<T: ArrowPrimitiveType>(
165 len: usize,
166 indices: &PrimitiveArray<T>,
167) -> Result<(), ArrowError> {
168 if indices.null_count() > 0 {
169 indices.iter().flatten().try_for_each(|index| {
170 let ix = index
171 .to_usize()
172 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
173 if ix >= len {
174 return Err(ArrowError::ComputeError(format!(
175 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
176 )));
177 }
178 Ok(())
179 })
180 } else {
181 indices.values().iter().try_for_each(|index| {
182 let ix = index
183 .to_usize()
184 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
185 if ix >= len {
186 return Err(ArrowError::ComputeError(format!(
187 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
188 )));
189 }
190 Ok(())
191 })
192 }
193}
194
195#[inline(never)]
196fn take_impl<IndexType: ArrowPrimitiveType>(
197 values: &dyn Array,
198 indices: &PrimitiveArray<IndexType>,
199) -> Result<ArrayRef, ArrowError> {
200 downcast_primitive_array! {
201 values => Ok(Arc::new(take_primitive(values, indices)?)),
202 DataType::Boolean => {
203 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
204 Ok(Arc::new(take_boolean(values, indices)))
205 }
206 DataType::Utf8 => {
207 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
208 }
209 DataType::LargeUtf8 => {
210 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
211 }
212 DataType::Utf8View => {
213 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
214 }
215 DataType::List(_) => {
216 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
217 }
218 DataType::LargeList(_) => {
219 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
220 }
221 DataType::ListView(_) => {
222 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
223 }
224 DataType::LargeListView(_) => {
225 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
226 }
227 DataType::FixedSizeList(_, length) => {
228 let values = values
229 .as_any()
230 .downcast_ref::<FixedSizeListArray>()
231 .unwrap();
232 Ok(Arc::new(take_fixed_size_list(
233 values,
234 indices,
235 *length as u32,
236 )?))
237 }
238 DataType::Map(_, _) => {
239 let list_arr = ListArray::from(values.as_map().clone());
240 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
241 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
242 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
243 }
244 DataType::Struct(fields) => {
245 let array: &StructArray = values.as_struct();
246 let arrays = array
247 .columns()
248 .iter()
249 .map(|a| take_impl(a.as_ref(), indices))
250 .collect::<Result<Vec<ArrayRef>, _>>()?;
251 let fields: Vec<(FieldRef, ArrayRef)> =
252 fields.iter().cloned().zip(arrays).collect();
253
254 let is_valid: Buffer = indices
256 .iter()
257 .map(|index| {
258 if let Some(index) = index {
259 array.is_valid(index.to_usize().unwrap())
260 } else {
261 false
262 }
263 })
264 .collect();
265
266 if fields.is_empty() {
267 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
268 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
269 } else {
270 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
271 }
272 }
273 DataType::Dictionary(_, _) => downcast_dictionary_array! {
274 values => Ok(Arc::new(take_dict(values, indices)?)),
275 t => unimplemented!("Take not supported for dictionary type {:?}", t)
276 }
277 DataType::RunEndEncoded(_, _) => downcast_run_array! {
278 values => Ok(Arc::new(take_run(values, indices)?)),
279 t => unimplemented!("Take not supported for run type {:?}", t)
280 }
281 DataType::Binary => {
282 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
283 }
284 DataType::LargeBinary => {
285 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
286 }
287 DataType::BinaryView => {
288 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
289 }
290 DataType::FixedSizeBinary(size) => {
291 let values = values
292 .as_any()
293 .downcast_ref::<FixedSizeBinaryArray>()
294 .unwrap();
295 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
296 }
297 DataType::Null => {
298 if values.len() >= indices.len() {
300 Ok(values.slice(0, indices.len()))
303 } else {
304 Ok(new_null_array(&DataType::Null, indices.len()))
306 }
307 }
308 DataType::Union(fields, UnionMode::Sparse) => {
309 let mut children = Vec::with_capacity(fields.len());
310 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
311 let type_ids = take_native(values.type_ids(), indices);
312 for (type_id, _field) in fields.iter() {
313 let values = values.child(type_id);
314 let values = take_impl(values, indices)?;
315 children.push(values);
316 }
317 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
318 Ok(Arc::new(array))
319 }
320 DataType::Union(fields, UnionMode::Dense) => {
321 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
322
323 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
324 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
325
326 let children = fields.iter()
327 .map(|(field_type_id, _)| {
328 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
329
330 let indices = crate::filter::filter(&offsets, &mask)?;
331
332 let values = values.child(field_type_id);
333
334 take_impl(values, indices.as_primitive::<Int32Type>())
335 })
336 .collect::<Result<_, _>>()?;
337
338 let mut child_offsets = [0; 128];
339
340 let offsets = type_ids.values()
341 .iter()
342 .map(|&i| {
343 let offset = child_offsets[i as usize];
344
345 child_offsets[i as usize] += 1;
346
347 offset
348 })
349 .collect();
350
351 let (_, type_ids, _) = type_ids.into_parts();
352
353 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
354
355 Ok(Arc::new(array))
356 }
357 t => unimplemented!("Take not supported for data type {:?}", t)
358 }
359}
360
361#[derive(Clone, Debug, Default)]
363pub struct TakeOptions {
364 pub check_bounds: bool,
368}
369
370#[inline(always)]
371fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
372 index
373 .to_usize()
374 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
375}
376
377fn take_primitive<T, I>(
387 values: &PrimitiveArray<T>,
388 indices: &PrimitiveArray<I>,
389) -> Result<PrimitiveArray<T>, ArrowError>
390where
391 T: ArrowPrimitiveType,
392 I: ArrowPrimitiveType,
393{
394 let values_buf = take_native(values.values(), indices);
395 let nulls = take_nulls(values.nulls(), indices);
396 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
397}
398
399#[inline(never)]
400fn take_nulls<I: ArrowPrimitiveType>(
401 values: Option<&NullBuffer>,
402 indices: &PrimitiveArray<I>,
403) -> Option<NullBuffer> {
404 match values.filter(|n| n.null_count() > 0) {
405 Some(n) => {
406 let buffer = take_bits(n.inner(), indices);
407 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
408 }
409 None => indices.nulls().cloned(),
410 }
411}
412
413#[inline(never)]
414fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
415 values: &[T],
416 indices: &PrimitiveArray<I>,
417) -> ScalarBuffer<T> {
418 match indices.nulls().filter(|n| n.null_count() > 0) {
419 Some(n) => indices
420 .values()
421 .iter()
422 .enumerate()
423 .map(|(idx, index)| match values.get(index.as_usize()) {
424 Some(v) => *v,
425 None => match unsafe { n.inner().value_unchecked(idx) } {
427 false => T::default(),
428 true => panic!("Out-of-bounds index {index:?}"),
429 },
430 })
431 .collect(),
432 None => indices
433 .values()
434 .iter()
435 .map(|index| values[index.as_usize()])
436 .collect(),
437 }
438}
439
440#[inline(never)]
441fn take_bits<I: ArrowPrimitiveType>(
442 values: &BooleanBuffer,
443 indices: &PrimitiveArray<I>,
444) -> BooleanBuffer {
445 let len = indices.len();
446
447 match indices.nulls().filter(|n| n.null_count() > 0) {
448 Some(nulls) => {
449 let mut output_buffer = MutableBuffer::new_null(len);
450 let output_slice = output_buffer.as_slice_mut();
451 nulls.valid_indices().for_each(|idx| {
452 if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
454 unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
456 }
457 });
458 BooleanBuffer::new(output_buffer.into(), 0, len)
459 }
460 None => {
461 BooleanBuffer::collect_bool(len, |idx: usize| {
462 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
464 })
465 }
466 }
467}
468
469fn take_boolean<IndexType: ArrowPrimitiveType>(
471 values: &BooleanArray,
472 indices: &PrimitiveArray<IndexType>,
473) -> BooleanArray {
474 let val_buf = take_bits(values.values(), indices);
475 let null_buf = take_nulls(values.nulls(), indices);
476 BooleanArray::new(val_buf, null_buf)
477}
478
479fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
481 array: &GenericByteArray<T>,
482 indices: &PrimitiveArray<IndexType>,
483) -> Result<GenericByteArray<T>, ArrowError> {
484 let mut offsets = Vec::with_capacity(indices.len() + 1);
485 offsets.push(T::Offset::default());
486
487 let input_offsets = array.value_offsets();
488 let mut capacity = 0;
489 let nulls = take_nulls(array.nulls(), indices);
490
491 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
492 offsets.reserve(indices.len());
493 for index in indices.values() {
494 let index = index.as_usize();
495 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
496 offsets.push(
497 T::Offset::from_usize(capacity)
498 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
499 );
500 }
501 let mut values = Vec::with_capacity(capacity);
502
503 for index in indices.values() {
504 values.extend_from_slice(array.value(index.as_usize()).as_ref());
505 }
506 (offsets, values)
507 } else if indices.null_count() == 0 {
508 offsets.reserve(indices.len());
509 for index in indices.values() {
510 let index = index.as_usize();
511 if array.is_valid(index) {
512 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
513 }
514 offsets.push(
515 T::Offset::from_usize(capacity)
516 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
517 );
518 }
519 let mut values = Vec::with_capacity(capacity);
520
521 for index in indices.values() {
522 let index = index.as_usize();
523 if array.is_valid(index) {
524 values.extend_from_slice(array.value(index).as_ref());
525 }
526 }
527 (offsets, values)
528 } else if array.null_count() == 0 {
529 offsets.reserve(indices.len());
530 for (i, index) in indices.values().iter().enumerate() {
531 let index = index.as_usize();
532 if indices.is_valid(i) {
533 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
534 }
535 offsets.push(
536 T::Offset::from_usize(capacity)
537 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
538 );
539 }
540 let mut values = Vec::with_capacity(capacity);
541
542 for (i, index) in indices.values().iter().enumerate() {
543 if indices.is_valid(i) {
544 values.extend_from_slice(array.value(index.as_usize()).as_ref());
545 }
546 }
547 (offsets, values)
548 } else {
549 let nulls = nulls.as_ref().unwrap();
550 offsets.reserve(indices.len());
551 for (i, index) in indices.values().iter().enumerate() {
552 let index = index.as_usize();
553 if nulls.is_valid(i) {
554 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
555 }
556 offsets.push(
557 T::Offset::from_usize(capacity)
558 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
559 );
560 }
561 let mut values = Vec::with_capacity(capacity);
562
563 for (i, index) in indices.values().iter().enumerate() {
564 let index = index.as_usize();
567 if nulls.is_valid(i) {
568 values.extend_from_slice(array.value(index).as_ref());
569 }
570 }
571 (offsets, values)
572 };
573
574 T::Offset::from_usize(values.len())
575 .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
576
577 let array = unsafe {
578 let offsets = OffsetBuffer::new_unchecked(offsets.into());
579 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
580 };
581
582 Ok(array)
583}
584
585fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
587 array: &GenericByteViewArray<T>,
588 indices: &PrimitiveArray<IndexType>,
589) -> Result<GenericByteViewArray<T>, ArrowError> {
590 let new_views = take_native(array.views(), indices);
591 let new_nulls = take_nulls(array.nulls(), indices);
592 Ok(unsafe {
594 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
595 })
596}
597
598fn take_list<IndexType, OffsetType>(
604 values: &GenericListArray<OffsetType::Native>,
605 indices: &PrimitiveArray<IndexType>,
606) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
607where
608 IndexType: ArrowPrimitiveType,
609 OffsetType: ArrowPrimitiveType,
610 OffsetType::Native: OffsetSizeTrait,
611 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
612{
613 let (list_indices, offsets, null_buf) =
616 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
617
618 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
619 let value_offsets = Buffer::from_vec(offsets);
620 let list_data = ArrayDataBuilder::new(values.data_type().clone())
622 .len(indices.len())
623 .null_bit_buffer(Some(null_buf.into()))
624 .offset(0)
625 .add_child_data(taken.into_data())
626 .add_buffer(value_offsets);
627
628 let list_data = unsafe { list_data.build_unchecked() };
629
630 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
631}
632
633fn take_list_view<IndexType, OffsetType>(
634 values: &GenericListViewArray<OffsetType::Native>,
635 indices: &PrimitiveArray<IndexType>,
636) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
637where
638 IndexType: ArrowPrimitiveType,
639 OffsetType: ArrowPrimitiveType,
640 OffsetType::Native: OffsetSizeTrait,
641{
642 let taken_offsets = take_native(values.offsets(), indices);
643 let taken_sizes = take_native(values.sizes(), indices);
644 let nulls = take_nulls(values.nulls(), indices);
645
646 let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
647 .len(indices.len())
648 .nulls(nulls)
649 .buffers(vec![taken_offsets.into(), taken_sizes.into()])
650 .child_data(vec![values.values().to_data()]);
651
652 let list_view_data = unsafe { list_view_data.build_unchecked() };
654
655 Ok(GenericListViewArray::<OffsetType::Native>::from(
656 list_view_data,
657 ))
658}
659
660fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
666 values: &FixedSizeListArray,
667 indices: &PrimitiveArray<IndexType>,
668 length: <UInt32Type as ArrowPrimitiveType>::Native,
669) -> Result<FixedSizeListArray, ArrowError> {
670 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
671 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
672
673 let num_bytes = bit_util::ceil(indices.len(), 8);
675 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
676 let null_slice = null_buf.as_slice_mut();
677
678 for i in 0..indices.len() {
679 let index = indices
680 .value(i)
681 .to_usize()
682 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
683 if !indices.is_valid(i) || values.is_null(index) {
684 bit_util::unset_bit(null_slice, i);
685 }
686 }
687
688 let list_data = ArrayDataBuilder::new(values.data_type().clone())
689 .len(indices.len())
690 .null_bit_buffer(Some(null_buf.into()))
691 .offset(0)
692 .add_child_data(taken.into_data());
693
694 let list_data = unsafe { list_data.build_unchecked() };
695
696 Ok(FixedSizeListArray::from(list_data))
697}
698
699fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
700 values: &FixedSizeBinaryArray,
701 indices: &PrimitiveArray<IndexType>,
702 size: i32,
703) -> Result<FixedSizeBinaryArray, ArrowError> {
704 let nulls = values.nulls();
705 let array_iter = indices
706 .values()
707 .iter()
708 .map(|idx| {
709 let idx = maybe_usize::<IndexType::Native>(*idx)?;
710 if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
711 Ok(Some(values.value(idx)))
712 } else {
713 Ok(None)
714 }
715 })
716 .collect::<Result<Vec<_>, ArrowError>>()?
717 .into_iter();
718
719 FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
720}
721
722fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
727 values: &DictionaryArray<T>,
728 indices: &PrimitiveArray<I>,
729) -> Result<DictionaryArray<T>, ArrowError> {
730 let new_keys = take_primitive(values.keys(), indices)?;
731 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
732}
733
734fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
743 run_array: &RunArray<T>,
744 logical_indices: &PrimitiveArray<I>,
745) -> Result<RunArray<T>, ArrowError> {
746 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
748
749 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
753 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
754 let mut new_physical_len = 1;
755 for ix in 1..physical_indices.len() {
756 if physical_indices[ix] != physical_indices[ix - 1] {
757 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
758 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
759 new_physical_len += 1;
760 }
761 }
762 take_value_indices
763 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
764 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
765 let new_run_ends = unsafe {
766 ArrayDataBuilder::new(T::DATA_TYPE)
769 .len(new_physical_len)
770 .null_count(0)
771 .add_buffer(new_run_ends_builder.finish())
772 .build_unchecked()
773 };
774
775 let take_value_indices: PrimitiveArray<I> = unsafe {
776 ArrayDataBuilder::new(I::DATA_TYPE)
779 .len(new_physical_len)
780 .null_count(0)
781 .add_buffer(take_value_indices.finish())
782 .build_unchecked()
783 .into()
784 };
785
786 let new_values = take(run_array.values(), &take_value_indices, None)?;
787
788 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
789 .len(physical_indices.len())
790 .add_child_data(new_run_ends)
791 .add_child_data(new_values.into_data());
792 let array_data = unsafe {
793 builder.build_unchecked()
796 };
797 Ok(array_data.into())
798}
799
800#[allow(clippy::type_complexity)]
806fn take_value_indices_from_list<IndexType, OffsetType>(
807 list: &GenericListArray<OffsetType::Native>,
808 indices: &PrimitiveArray<IndexType>,
809) -> Result<
810 (
811 PrimitiveArray<OffsetType>,
812 Vec<OffsetType::Native>,
813 MutableBuffer,
814 ),
815 ArrowError,
816>
817where
818 IndexType: ArrowPrimitiveType,
819 OffsetType: ArrowPrimitiveType,
820 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
821 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
822{
823 let offsets: &[OffsetType::Native] = list.value_offsets();
825
826 let mut new_offsets = Vec::with_capacity(indices.len());
827 let mut values = Vec::new();
828 let mut current_offset = OffsetType::Native::zero();
829 new_offsets.push(OffsetType::Native::zero());
831
832 let num_bytes = bit_util::ceil(indices.len(), 8);
834 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
835 let null_slice = null_buf.as_slice_mut();
836
837 for i in 0..indices.len() {
839 if indices.is_valid(i) {
840 let ix = indices
841 .value(i)
842 .to_usize()
843 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
844 let start = offsets[ix];
845 let end = offsets[ix + 1];
846 current_offset += end - start;
847 new_offsets.push(current_offset);
848
849 let mut curr = start;
850
851 while curr < end {
853 values.push(curr);
854 curr += One::one();
855 }
856 if !list.is_valid(ix) {
857 bit_util::unset_bit(null_slice, i);
858 }
859 } else {
860 bit_util::unset_bit(null_slice, i);
861 new_offsets.push(current_offset);
862 }
863 }
864
865 Ok((
866 PrimitiveArray::<OffsetType>::from(values),
867 new_offsets,
868 null_buf,
869 ))
870}
871
872fn take_value_indices_from_fixed_size_list<IndexType>(
874 list: &FixedSizeListArray,
875 indices: &PrimitiveArray<IndexType>,
876 length: <UInt32Type as ArrowPrimitiveType>::Native,
877) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
878where
879 IndexType: ArrowPrimitiveType,
880{
881 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
882
883 for i in 0..indices.len() {
884 if indices.is_valid(i) {
885 let index = indices
886 .value(i)
887 .to_usize()
888 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
889 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
890
891 unsafe {
893 values.append_trusted_len_iter(start..start + length);
894 }
895 } else {
896 values.append_nulls(length as usize);
897 }
898 }
899
900 Ok(values.finish())
901}
902
903trait ToIndices {
906 type T: ArrowPrimitiveType;
907
908 fn to_indices(&self) -> PrimitiveArray<Self::T>;
909}
910
911macro_rules! to_indices_reinterpret {
912 ($t:ty, $o:ty) => {
913 impl ToIndices for PrimitiveArray<$t> {
914 type T = $o;
915
916 fn to_indices(&self) -> PrimitiveArray<$o> {
917 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
918 PrimitiveArray::new(cast, self.nulls().cloned())
919 }
920 }
921 };
922}
923
924macro_rules! to_indices_identity {
925 ($t:ty) => {
926 impl ToIndices for PrimitiveArray<$t> {
927 type T = $t;
928
929 fn to_indices(&self) -> PrimitiveArray<$t> {
930 self.clone()
931 }
932 }
933 };
934}
935
936macro_rules! to_indices_widening {
937 ($t:ty, $o:ty) => {
938 impl ToIndices for PrimitiveArray<$t> {
939 type T = UInt32Type;
940
941 fn to_indices(&self) -> PrimitiveArray<$o> {
942 let cast = self.values().iter().copied().map(|x| x as _).collect();
943 PrimitiveArray::new(cast, self.nulls().cloned())
944 }
945 }
946 };
947}
948
949to_indices_widening!(UInt8Type, UInt32Type);
950to_indices_widening!(Int8Type, UInt32Type);
951
952to_indices_widening!(UInt16Type, UInt32Type);
953to_indices_widening!(Int16Type, UInt32Type);
954
955to_indices_identity!(UInt32Type);
956to_indices_reinterpret!(Int32Type, UInt32Type);
957
958to_indices_identity!(UInt64Type);
959to_indices_reinterpret!(Int64Type, UInt64Type);
960
961pub fn take_record_batch(
1001 record_batch: &RecordBatch,
1002 indices: &dyn Array,
1003) -> Result<RecordBatch, ArrowError> {
1004 let columns = record_batch
1005 .columns()
1006 .iter()
1007 .map(|c| take(c, indices, None))
1008 .collect::<Result<Vec<_>, _>>()?;
1009 RecordBatch::try_new(record_batch.schema(), columns)
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015 use arrow_array::builder::*;
1016 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1017 use arrow_data::ArrayData;
1018 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1019 use num_traits::ToPrimitive;
1020
1021 fn test_take_decimal_arrays(
1022 data: Vec<Option<i128>>,
1023 index: &UInt32Array,
1024 options: Option<TakeOptions>,
1025 expected_data: Vec<Option<i128>>,
1026 precision: &u8,
1027 scale: &i8,
1028 ) -> Result<(), ArrowError> {
1029 let output = data
1030 .into_iter()
1031 .collect::<Decimal128Array>()
1032 .with_precision_and_scale(*precision, *scale)
1033 .unwrap();
1034
1035 let expected = expected_data
1036 .into_iter()
1037 .collect::<Decimal128Array>()
1038 .with_precision_and_scale(*precision, *scale)
1039 .unwrap();
1040
1041 let expected = Arc::new(expected) as ArrayRef;
1042 let output = take(&output, index, options).unwrap();
1043 assert_eq!(&output, &expected);
1044 Ok(())
1045 }
1046
1047 fn test_take_boolean_arrays(
1048 data: Vec<Option<bool>>,
1049 index: &UInt32Array,
1050 options: Option<TakeOptions>,
1051 expected_data: Vec<Option<bool>>,
1052 ) {
1053 let output = BooleanArray::from(data);
1054 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1055 let output = take(&output, index, options).unwrap();
1056 assert_eq!(&output, &expected)
1057 }
1058
1059 fn test_take_primitive_arrays<T>(
1060 data: Vec<Option<T::Native>>,
1061 index: &UInt32Array,
1062 options: Option<TakeOptions>,
1063 expected_data: Vec<Option<T::Native>>,
1064 ) -> Result<(), ArrowError>
1065 where
1066 T: ArrowPrimitiveType,
1067 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1068 {
1069 let output = PrimitiveArray::<T>::from(data);
1070 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1071 let output = take(&output, index, options)?;
1072 assert_eq!(&output, &expected);
1073 Ok(())
1074 }
1075
1076 fn test_take_primitive_arrays_non_null<T>(
1077 data: Vec<T::Native>,
1078 index: &UInt32Array,
1079 options: Option<TakeOptions>,
1080 expected_data: Vec<Option<T::Native>>,
1081 ) -> Result<(), ArrowError>
1082 where
1083 T: ArrowPrimitiveType,
1084 PrimitiveArray<T>: From<Vec<T::Native>>,
1085 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1086 {
1087 let output = PrimitiveArray::<T>::from(data);
1088 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1089 let output = take(&output, index, options)?;
1090 assert_eq!(&output, &expected);
1091 Ok(())
1092 }
1093
1094 fn test_take_impl_primitive_arrays<T, I>(
1095 data: Vec<Option<T::Native>>,
1096 index: &PrimitiveArray<I>,
1097 options: Option<TakeOptions>,
1098 expected_data: Vec<Option<T::Native>>,
1099 ) where
1100 T: ArrowPrimitiveType,
1101 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1102 I: ArrowPrimitiveType,
1103 {
1104 let output = PrimitiveArray::<T>::from(data);
1105 let expected = PrimitiveArray::<T>::from(expected_data);
1106 let output = take(&output, index, options).unwrap();
1107 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1108 assert_eq!(output, &expected)
1109 }
1110
1111 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1113 let mut struct_builder = StructBuilder::new(
1114 Fields::from(vec![
1115 Field::new("a", DataType::Boolean, true),
1116 Field::new("b", DataType::Int32, true),
1117 ]),
1118 vec![
1119 Box::new(BooleanBuilder::with_capacity(values.len())),
1120 Box::new(Int32Builder::with_capacity(values.len())),
1121 ],
1122 );
1123
1124 for value in values {
1125 struct_builder
1126 .field_builder::<BooleanBuilder>(0)
1127 .unwrap()
1128 .append_option(value.and_then(|v| v.0));
1129 struct_builder
1130 .field_builder::<Int32Builder>(1)
1131 .unwrap()
1132 .append_option(value.and_then(|v| v.1));
1133 struct_builder.append(value.is_some());
1134 }
1135 struct_builder.finish()
1136 }
1137
1138 #[test]
1139 fn test_take_decimal128_non_null_indices() {
1140 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1141 let precision: u8 = 10;
1142 let scale: i8 = 5;
1143 test_take_decimal_arrays(
1144 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1145 &index,
1146 None,
1147 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1148 &precision,
1149 &scale,
1150 )
1151 .unwrap();
1152 }
1153
1154 #[test]
1155 fn test_take_decimal128() {
1156 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1157 let precision: u8 = 10;
1158 let scale: i8 = 5;
1159 test_take_decimal_arrays(
1160 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1161 &index,
1162 None,
1163 vec![Some(3), None, Some(1), Some(3), Some(2)],
1164 &precision,
1165 &scale,
1166 )
1167 .unwrap();
1168 }
1169
1170 #[test]
1171 fn test_take_primitive_non_null_indices() {
1172 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1173 test_take_primitive_arrays::<Int8Type>(
1174 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1175 &index,
1176 None,
1177 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1178 )
1179 .unwrap();
1180 }
1181
1182 #[test]
1183 fn test_take_primitive_non_null_values() {
1184 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1185 test_take_primitive_arrays::<Int8Type>(
1186 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1187 &index,
1188 None,
1189 vec![Some(3), None, Some(1), Some(3), Some(2)],
1190 )
1191 .unwrap();
1192 }
1193
1194 #[test]
1195 fn test_take_primitive_non_null() {
1196 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1197 test_take_primitive_arrays::<Int8Type>(
1198 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1199 &index,
1200 None,
1201 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1202 )
1203 .unwrap();
1204 }
1205
1206 #[test]
1207 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1208 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1209 let index = index.slice(2, 4);
1210 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1211
1212 assert_eq!(
1213 index,
1214 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1215 );
1216
1217 test_take_primitive_arrays_non_null::<Int64Type>(
1218 vec![0, 10, 20, 30, 40, 50],
1219 index,
1220 None,
1221 vec![Some(20), Some(30), None, None],
1222 )
1223 .unwrap();
1224 }
1225
1226 #[test]
1227 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1228 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1229 let index = index.slice(2, 4);
1230 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1231
1232 assert_eq!(
1233 index,
1234 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1235 );
1236
1237 test_take_primitive_arrays::<Int64Type>(
1238 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1239 index,
1240 None,
1241 vec![Some(20), Some(30), None, None],
1242 )
1243 .unwrap();
1244 }
1245
1246 #[test]
1247 fn test_take_primitive() {
1248 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1249
1250 test_take_primitive_arrays::<Int8Type>(
1252 vec![Some(0), None, Some(2), Some(3), None],
1253 &index,
1254 None,
1255 vec![Some(3), None, None, Some(3), Some(2)],
1256 )
1257 .unwrap();
1258
1259 test_take_primitive_arrays::<Int16Type>(
1261 vec![Some(0), None, Some(2), Some(3), None],
1262 &index,
1263 None,
1264 vec![Some(3), None, None, Some(3), Some(2)],
1265 )
1266 .unwrap();
1267
1268 test_take_primitive_arrays::<Int32Type>(
1270 vec![Some(0), None, Some(2), Some(3), None],
1271 &index,
1272 None,
1273 vec![Some(3), None, None, Some(3), Some(2)],
1274 )
1275 .unwrap();
1276
1277 test_take_primitive_arrays::<Int64Type>(
1279 vec![Some(0), None, Some(2), Some(3), None],
1280 &index,
1281 None,
1282 vec![Some(3), None, None, Some(3), Some(2)],
1283 )
1284 .unwrap();
1285
1286 test_take_primitive_arrays::<UInt8Type>(
1288 vec![Some(0), None, Some(2), Some(3), None],
1289 &index,
1290 None,
1291 vec![Some(3), None, None, Some(3), Some(2)],
1292 )
1293 .unwrap();
1294
1295 test_take_primitive_arrays::<UInt16Type>(
1297 vec![Some(0), None, Some(2), Some(3), None],
1298 &index,
1299 None,
1300 vec![Some(3), None, None, Some(3), Some(2)],
1301 )
1302 .unwrap();
1303
1304 test_take_primitive_arrays::<UInt32Type>(
1306 vec![Some(0), None, Some(2), Some(3), None],
1307 &index,
1308 None,
1309 vec![Some(3), None, None, Some(3), Some(2)],
1310 )
1311 .unwrap();
1312
1313 test_take_primitive_arrays::<Int64Type>(
1315 vec![Some(0), None, Some(2), Some(-15), None],
1316 &index,
1317 None,
1318 vec![Some(-15), None, None, Some(-15), Some(2)],
1319 )
1320 .unwrap();
1321
1322 test_take_primitive_arrays::<IntervalYearMonthType>(
1324 vec![Some(0), None, Some(2), Some(-15), None],
1325 &index,
1326 None,
1327 vec![Some(-15), None, None, Some(-15), Some(2)],
1328 )
1329 .unwrap();
1330
1331 let v1 = IntervalDayTime::new(0, 0);
1333 let v2 = IntervalDayTime::new(2, 0);
1334 let v3 = IntervalDayTime::new(-15, 0);
1335 test_take_primitive_arrays::<IntervalDayTimeType>(
1336 vec![Some(v1), None, Some(v2), Some(v3), None],
1337 &index,
1338 None,
1339 vec![Some(v3), None, None, Some(v3), Some(v2)],
1340 )
1341 .unwrap();
1342
1343 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1345 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1346 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1347 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1348 vec![Some(v1), None, Some(v2), Some(v3), None],
1349 &index,
1350 None,
1351 vec![Some(v3), None, None, Some(v3), Some(v2)],
1352 )
1353 .unwrap();
1354
1355 test_take_primitive_arrays::<DurationSecondType>(
1357 vec![Some(0), None, Some(2), Some(-15), None],
1358 &index,
1359 None,
1360 vec![Some(-15), None, None, Some(-15), Some(2)],
1361 )
1362 .unwrap();
1363
1364 test_take_primitive_arrays::<DurationMillisecondType>(
1366 vec![Some(0), None, Some(2), Some(-15), None],
1367 &index,
1368 None,
1369 vec![Some(-15), None, None, Some(-15), Some(2)],
1370 )
1371 .unwrap();
1372
1373 test_take_primitive_arrays::<DurationMicrosecondType>(
1375 vec![Some(0), None, Some(2), Some(-15), None],
1376 &index,
1377 None,
1378 vec![Some(-15), None, None, Some(-15), Some(2)],
1379 )
1380 .unwrap();
1381
1382 test_take_primitive_arrays::<DurationNanosecondType>(
1384 vec![Some(0), None, Some(2), Some(-15), None],
1385 &index,
1386 None,
1387 vec![Some(-15), None, None, Some(-15), Some(2)],
1388 )
1389 .unwrap();
1390
1391 test_take_primitive_arrays::<Float32Type>(
1393 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1394 &index,
1395 None,
1396 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1397 )
1398 .unwrap();
1399
1400 test_take_primitive_arrays::<Float64Type>(
1402 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1403 &index,
1404 None,
1405 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1406 )
1407 .unwrap();
1408 }
1409
1410 #[test]
1411 fn test_take_preserve_timezone() {
1412 let index = Int64Array::from(vec![Some(0), None]);
1413
1414 let input = TimestampNanosecondArray::from(vec![
1415 1_639_715_368_000_000_000,
1416 1_639_715_368_000_000_000,
1417 ])
1418 .with_timezone("UTC".to_string());
1419 let result = take(&input, &index, None).unwrap();
1420 match result.data_type() {
1421 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1422 assert_eq!(tz.clone(), Some("UTC".into()))
1423 }
1424 _ => panic!(),
1425 }
1426 }
1427
1428 #[test]
1429 fn test_take_impl_primitive_with_int64_indices() {
1430 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1431
1432 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1434 vec![Some(0), None, Some(2), Some(3), None],
1435 &index,
1436 None,
1437 vec![Some(3), None, None, Some(3), Some(2)],
1438 );
1439
1440 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1442 vec![Some(0), None, Some(2), Some(-15), None],
1443 &index,
1444 None,
1445 vec![Some(-15), None, None, Some(-15), Some(2)],
1446 );
1447
1448 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1450 vec![Some(0), None, Some(2), Some(3), None],
1451 &index,
1452 None,
1453 vec![Some(3), None, None, Some(3), Some(2)],
1454 );
1455
1456 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1458 vec![Some(0), None, Some(2), Some(-15), None],
1459 &index,
1460 None,
1461 vec![Some(-15), None, None, Some(-15), Some(2)],
1462 );
1463
1464 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1466 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1467 &index,
1468 None,
1469 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1470 );
1471 }
1472
1473 #[test]
1474 fn test_take_impl_primitive_with_uint8_indices() {
1475 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1476
1477 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1479 vec![Some(0), None, Some(2), Some(3), None],
1480 &index,
1481 None,
1482 vec![Some(3), None, None, Some(3), Some(2)],
1483 );
1484
1485 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1487 vec![Some(0), None, Some(2), Some(-15), None],
1488 &index,
1489 None,
1490 vec![Some(-15), None, None, Some(-15), Some(2)],
1491 );
1492
1493 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1495 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1496 &index,
1497 None,
1498 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1499 );
1500 }
1501
1502 #[test]
1503 fn test_take_bool() {
1504 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1505 test_take_boolean_arrays(
1507 vec![Some(false), None, Some(true), Some(false), None],
1508 &index,
1509 None,
1510 vec![Some(false), None, None, Some(false), Some(true)],
1511 );
1512 }
1513
1514 #[test]
1515 fn test_take_bool_nullable_index() {
1516 let index_data = ArrayData::try_new(
1518 DataType::UInt32,
1519 6,
1520 Some(Buffer::from_iter(vec![
1521 false, true, false, true, false, true,
1522 ])),
1523 0,
1524 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1525 vec![],
1526 )
1527 .unwrap();
1528 let index = UInt32Array::from(index_data);
1529 test_take_boolean_arrays(
1530 vec![Some(true), None, Some(false)],
1531 &index,
1532 None,
1533 vec![None, Some(true), None, None, None, Some(false)],
1534 );
1535 }
1536
1537 #[test]
1538 fn test_take_bool_nullable_index_nonnull_values() {
1539 let index_data = ArrayData::try_new(
1541 DataType::UInt32,
1542 6,
1543 Some(Buffer::from_iter(vec![
1544 false, true, false, true, false, true,
1545 ])),
1546 0,
1547 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1548 vec![],
1549 )
1550 .unwrap();
1551 let index = UInt32Array::from(index_data);
1552 test_take_boolean_arrays(
1553 vec![Some(true), Some(true), Some(false)],
1554 &index,
1555 None,
1556 vec![None, Some(true), None, Some(true), None, Some(false)],
1557 );
1558 }
1559
1560 #[test]
1561 fn test_take_bool_with_offset() {
1562 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1563 let index = index.slice(2, 4);
1564 let index = index
1565 .as_any()
1566 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1567 .unwrap();
1568
1569 test_take_boolean_arrays(
1571 vec![Some(false), None, Some(true), Some(false), None],
1572 index,
1573 None,
1574 vec![None, Some(false), Some(true), None],
1575 );
1576 }
1577
1578 fn _test_take_string<'a, K>()
1579 where
1580 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1581 {
1582 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1583
1584 let array = K::from(vec![
1585 Some("one"),
1586 None,
1587 Some("three"),
1588 Some("four"),
1589 Some("five"),
1590 ]);
1591 let actual = take(&array, &index, None).unwrap();
1592 assert_eq!(actual.len(), index.len());
1593
1594 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1595
1596 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1597
1598 assert_eq!(actual, &expected);
1599 }
1600
1601 #[test]
1602 fn test_take_string() {
1603 _test_take_string::<StringArray>()
1604 }
1605
1606 #[test]
1607 fn test_take_large_string() {
1608 _test_take_string::<LargeStringArray>()
1609 }
1610
1611 #[test]
1612 fn test_take_slice_string() {
1613 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1614 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1615 let indices_slice = indices.slice(1, 4);
1616 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1617 let result = take(&strings, &indices_slice, None).unwrap();
1618 assert_eq!(result.as_ref(), &expected);
1619 }
1620
1621 fn _test_byte_view<T>()
1622 where
1623 T: ByteViewType,
1624 str: AsRef<T::Native>,
1625 T::Native: PartialEq,
1626 {
1627 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1628 let array = {
1629 let mut builder = GenericByteViewBuilder::<T>::new();
1631 builder.append_value("hello");
1632 builder.append_value("world");
1633 builder.append_null();
1634 builder.append_value("large payload over 12 bytes");
1635 builder.append_value("lulu");
1636 builder.finish()
1637 };
1638
1639 let actual = take(&array, &index, None).unwrap();
1640
1641 assert_eq!(actual.len(), index.len());
1642
1643 let expected = {
1644 let mut builder = GenericByteViewBuilder::<T>::new();
1646 builder.append_value("large payload over 12 bytes");
1647 builder.append_null();
1648 builder.append_value("world");
1649 builder.append_value("large payload over 12 bytes");
1650 builder.append_value("lulu");
1651 builder.append_null();
1652 builder.finish()
1653 };
1654
1655 assert_eq!(actual.as_ref(), &expected);
1656 }
1657
1658 #[test]
1659 fn test_take_string_view() {
1660 _test_byte_view::<StringViewType>()
1661 }
1662
1663 #[test]
1664 fn test_take_binary_view() {
1665 _test_byte_view::<BinaryViewType>()
1666 }
1667
1668 macro_rules! test_take_list {
1669 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1670 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1672 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1674 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1675 let list_data_type =
1677 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1678 let list_data = ArrayData::builder(list_data_type.clone())
1679 .len(4)
1680 .add_buffer(value_offsets)
1681 .add_child_data(value_data)
1682 .build()
1683 .unwrap();
1684 let list_array = $list_array_type::from(list_data);
1685
1686 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1688
1689 let a = take(&list_array, &index, None).unwrap();
1690 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1691
1692 let expected_data = Int32Array::from(vec![
1695 Some(2),
1696 Some(3),
1697 Some(-1),
1698 Some(-2),
1699 Some(-1),
1700 Some(0),
1701 Some(0),
1702 Some(0),
1703 ])
1704 .into_data();
1705 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1707 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1708 let expected_list_data = ArrayData::builder(list_data_type)
1710 .len(5)
1711 .nulls(index.nulls().cloned())
1713 .add_buffer(expected_offsets)
1714 .add_child_data(expected_data)
1715 .build()
1716 .unwrap();
1717 let expected_list_array = $list_array_type::from(expected_list_data);
1718
1719 assert_eq!(a, &expected_list_array);
1720 }};
1721 }
1722
1723 macro_rules! test_take_list_with_value_nulls {
1724 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1725 let value_data = Int32Array::from(vec![
1727 Some(0),
1728 None,
1729 Some(0),
1730 Some(-1),
1731 Some(-2),
1732 Some(3),
1733 None,
1734 Some(5),
1735 None,
1736 ])
1737 .into_data();
1738 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1740 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1741 let list_data_type =
1743 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1744 let list_data = ArrayData::builder(list_data_type.clone())
1745 .len(4)
1746 .add_buffer(value_offsets)
1747 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1748 .add_child_data(value_data)
1749 .build()
1750 .unwrap();
1751 let list_array = $list_array_type::from(list_data);
1752
1753 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1755
1756 let a = take(&list_array, &index, None).unwrap();
1757 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1758
1759 let expected_data = Int32Array::from(vec![
1762 None,
1763 Some(-1),
1764 Some(-2),
1765 Some(3),
1766 Some(5),
1767 None,
1768 Some(0),
1769 None,
1770 Some(0),
1771 ])
1772 .into_data();
1773 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1775 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1776 let expected_list_data = ArrayData::builder(list_data_type)
1778 .len(5)
1779 .nulls(index.nulls().cloned())
1781 .add_buffer(expected_offsets)
1782 .add_child_data(expected_data)
1783 .build()
1784 .unwrap();
1785 let expected_list_array = $list_array_type::from(expected_list_data);
1786
1787 assert_eq!(a, &expected_list_array);
1788 }};
1789 }
1790
1791 macro_rules! test_take_list_with_nulls {
1792 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1793 let value_data = Int32Array::from(vec![
1795 Some(0),
1796 None,
1797 Some(0),
1798 Some(-1),
1799 Some(-2),
1800 Some(3),
1801 Some(5),
1802 None,
1803 ])
1804 .into_data();
1805 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1807 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1808 let list_data_type =
1810 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1811 let list_data = ArrayData::builder(list_data_type.clone())
1812 .len(4)
1813 .add_buffer(value_offsets)
1814 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1815 .add_child_data(value_data)
1816 .build()
1817 .unwrap();
1818 let list_array = $list_array_type::from(list_data);
1819
1820 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1822
1823 let a = take(&list_array, &index, None).unwrap();
1824 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1825
1826 let expected_data = Int32Array::from(vec![
1829 Some(-1),
1830 Some(-2),
1831 Some(3),
1832 Some(5),
1833 None,
1834 Some(0),
1835 None,
1836 Some(0),
1837 ])
1838 .into_data();
1839 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1841 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1842 let mut null_bits: [u8; 1] = [0; 1];
1844 bit_util::set_bit(&mut null_bits, 2);
1845 bit_util::set_bit(&mut null_bits, 3);
1846 bit_util::set_bit(&mut null_bits, 4);
1847 let expected_list_data = ArrayData::builder(list_data_type)
1848 .len(5)
1849 .null_bit_buffer(Some(Buffer::from(null_bits)))
1851 .add_buffer(expected_offsets)
1852 .add_child_data(expected_data)
1853 .build()
1854 .unwrap();
1855 let expected_list_array = $list_array_type::from(expected_list_data);
1856
1857 assert_eq!(a, &expected_list_array);
1858 }};
1859 }
1860
1861 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1862 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1863 take_indices: Vec<Option<usize>>,
1864 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1865 mapper: F,
1866 ) where
1867 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1868 {
1869 let mut list_view_array =
1870 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1871
1872 for value in values {
1873 list_view_array.append_option(value);
1874 }
1875 let list_view_array = list_view_array.finish();
1876 let list_view_array = mapper(list_view_array);
1877
1878 let mut indices = UInt64Builder::new();
1879 for idx in take_indices {
1880 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
1881 }
1882 let indices = indices.finish();
1883
1884 let taken = take(&list_view_array, &indices, None)
1885 .unwrap()
1886 .as_list_view()
1887 .clone();
1888
1889 let mut expected_array =
1890 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1891 for value in expected {
1892 expected_array.append_option(value);
1893 }
1894 let expected_array = expected_array.finish();
1895
1896 assert_eq!(taken, expected_array);
1897 }
1898
1899 macro_rules! list_view_test_case {
1900 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
1901 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
1902 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
1903 }};
1904 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
1905 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
1906 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
1907 }};
1908 }
1909
1910 fn do_take_fixed_size_list_test<T>(
1911 length: <Int32Type as ArrowPrimitiveType>::Native,
1912 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1913 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1914 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1915 ) where
1916 T: ArrowPrimitiveType,
1917 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1918 {
1919 let indices = UInt32Array::from(indices);
1920
1921 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1922
1923 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1924
1925 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1926
1927 assert_eq!(&output, &expected)
1928 }
1929
1930 #[test]
1931 fn test_take_list() {
1932 test_take_list!(i32, List, ListArray);
1933 }
1934
1935 #[test]
1936 fn test_take_large_list() {
1937 test_take_list!(i64, LargeList, LargeListArray);
1938 }
1939
1940 #[test]
1941 fn test_take_list_with_value_nulls() {
1942 test_take_list_with_value_nulls!(i32, List, ListArray);
1943 }
1944
1945 #[test]
1946 fn test_take_large_list_with_value_nulls() {
1947 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1948 }
1949
1950 #[test]
1951 fn test_test_take_list_with_nulls() {
1952 test_take_list_with_nulls!(i32, List, ListArray);
1953 }
1954
1955 #[test]
1956 fn test_test_take_large_list_with_nulls() {
1957 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1958 }
1959
1960 #[test]
1961 fn test_test_take_list_view_reversed() {
1962 list_view_test_case! {
1964 values: vec![
1965 Some(vec![Some(1), None, Some(3)]),
1966 None,
1967 Some(vec![Some(7), Some(8), None]),
1968 ],
1969 indices: vec![Some(2), Some(1), Some(0)],
1970 expected: vec![
1971 Some(vec![Some(7), Some(8), None]),
1972 None,
1973 Some(vec![Some(1), None, Some(3)]),
1974 ]
1975 }
1976 }
1977
1978 #[test]
1979 fn test_take_list_view_null_indices() {
1980 list_view_test_case! {
1982 values: vec![
1983 Some(vec![Some(1), None, Some(3)]),
1984 None,
1985 Some(vec![Some(7), Some(8), None]),
1986 ],
1987 indices: vec![None, Some(0), None],
1988 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
1989 }
1990 }
1991
1992 #[test]
1993 fn test_take_list_view_null_values() {
1994 list_view_test_case! {
1996 values: vec![
1997 Some(vec![Some(1), None, Some(3)]),
1998 None,
1999 Some(vec![Some(7), Some(8), None]),
2000 ],
2001 indices: vec![Some(1), Some(1), Some(1), None, None],
2002 expected: vec![None; 5]
2003 }
2004 }
2005
2006 #[test]
2007 fn test_take_list_view_sliced() {
2008 list_view_test_case! {
2010 values: vec![
2011 Some(vec![Some(1)]),
2012 None,
2013 None,
2014 Some(vec![Some(2), Some(3)]),
2015 Some(vec![Some(4), Some(5)]),
2016 None,
2017 ],
2018 transform: |l| l.slice(2, 4),
2019 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2020 expected: vec![
2021 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2022 ]
2023 }
2024 }
2025
2026 #[test]
2027 fn test_take_fixed_size_list() {
2028 do_take_fixed_size_list_test::<Int32Type>(
2029 3,
2030 vec![
2031 Some(vec![None, Some(1), Some(2)]),
2032 Some(vec![Some(3), Some(4), None]),
2033 Some(vec![Some(6), Some(7), Some(8)]),
2034 ],
2035 vec![2, 1, 0],
2036 vec![
2037 Some(vec![Some(6), Some(7), Some(8)]),
2038 Some(vec![Some(3), Some(4), None]),
2039 Some(vec![None, Some(1), Some(2)]),
2040 ],
2041 );
2042
2043 do_take_fixed_size_list_test::<UInt8Type>(
2044 1,
2045 vec![
2046 Some(vec![Some(1)]),
2047 Some(vec![Some(2)]),
2048 Some(vec![Some(3)]),
2049 Some(vec![Some(4)]),
2050 Some(vec![Some(5)]),
2051 Some(vec![Some(6)]),
2052 Some(vec![Some(7)]),
2053 Some(vec![Some(8)]),
2054 ],
2055 vec![2, 7, 0],
2056 vec![
2057 Some(vec![Some(3)]),
2058 Some(vec![Some(8)]),
2059 Some(vec![Some(1)]),
2060 ],
2061 );
2062
2063 do_take_fixed_size_list_test::<UInt64Type>(
2064 3,
2065 vec![
2066 Some(vec![Some(10), Some(11), Some(12)]),
2067 Some(vec![Some(13), Some(14), Some(15)]),
2068 None,
2069 Some(vec![Some(16), Some(17), Some(18)]),
2070 ],
2071 vec![3, 2, 1, 2, 0],
2072 vec![
2073 Some(vec![Some(16), Some(17), Some(18)]),
2074 None,
2075 Some(vec![Some(13), Some(14), Some(15)]),
2076 None,
2077 Some(vec![Some(10), Some(11), Some(12)]),
2078 ],
2079 );
2080 }
2081
2082 #[test]
2083 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2084 fn test_take_list_out_of_bounds() {
2085 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2087 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2089 let list_data_type =
2091 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2092 let list_data = ArrayData::builder(list_data_type)
2093 .len(3)
2094 .add_buffer(value_offsets)
2095 .add_child_data(value_data)
2096 .build()
2097 .unwrap();
2098 let list_array = ListArray::from(list_data);
2099
2100 let index = UInt32Array::from(vec![1000]);
2101
2102 take(&list_array, &index, None).unwrap();
2105 }
2106
2107 #[test]
2108 fn test_take_map() {
2109 let values = Int32Array::from(vec![1, 2, 3, 4]);
2110 let array =
2111 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2112 .unwrap();
2113
2114 let index = UInt32Array::from(vec![0]);
2115
2116 let result = take(&array, &index, None).unwrap();
2117 let expected: ArrayRef = Arc::new(
2118 MapArray::new_from_strings(
2119 vec!["a", "b", "c"].into_iter(),
2120 &values.slice(0, 3),
2121 &[0, 3],
2122 )
2123 .unwrap(),
2124 );
2125 assert_eq!(&expected, &result);
2126 }
2127
2128 #[test]
2129 fn test_take_struct() {
2130 let array = create_test_struct(vec![
2131 Some((Some(true), Some(42))),
2132 Some((Some(false), Some(28))),
2133 Some((Some(false), Some(19))),
2134 Some((Some(true), Some(31))),
2135 None,
2136 ]);
2137
2138 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2139 let actual = take(&array, &index, None).unwrap();
2140 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2141 assert_eq!(index.len(), actual.len());
2142 assert_eq!(1, actual.null_count());
2143
2144 let expected = create_test_struct(vec![
2145 Some((Some(true), Some(42))),
2146 Some((Some(true), Some(31))),
2147 Some((Some(false), Some(28))),
2148 Some((Some(true), Some(42))),
2149 Some((Some(false), Some(19))),
2150 None,
2151 ]);
2152
2153 assert_eq!(&expected, actual);
2154
2155 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2156 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2157 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2158 let actual = take(&empty_struct_arr, &index, None).unwrap();
2159
2160 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2161 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2162 assert_eq!(&expected_struct_arr, actual.as_struct());
2163 }
2164
2165 #[test]
2166 fn test_take_struct_with_null_indices() {
2167 let array = create_test_struct(vec![
2168 Some((Some(true), Some(42))),
2169 Some((Some(false), Some(28))),
2170 Some((Some(false), Some(19))),
2171 Some((Some(true), Some(31))),
2172 None,
2173 ]);
2174
2175 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2176 let actual = take(&array, &index, None).unwrap();
2177 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2178 assert_eq!(index.len(), actual.len());
2179 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2182 None,
2183 Some((Some(true), Some(31))),
2184 Some((Some(false), Some(28))),
2185 None,
2186 Some((Some(true), Some(42))),
2187 None,
2188 ]);
2189
2190 assert_eq!(&expected, actual);
2191 }
2192
2193 #[test]
2194 fn test_take_out_of_bounds() {
2195 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2196 let take_opt = TakeOptions { check_bounds: true };
2197
2198 let result = test_take_primitive_arrays::<Int64Type>(
2200 vec![Some(0), None, Some(2), Some(3), None],
2201 &index,
2202 Some(take_opt),
2203 vec![None],
2204 );
2205 assert!(result.is_err());
2206 }
2207
2208 #[test]
2209 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2210 fn test_take_out_of_bounds_panic() {
2211 let index = UInt32Array::from(vec![Some(1000)]);
2212
2213 test_take_primitive_arrays::<Int64Type>(
2214 vec![Some(0), Some(1), Some(2), Some(3)],
2215 &index,
2216 None,
2217 vec![None],
2218 )
2219 .unwrap();
2220 }
2221
2222 #[test]
2223 fn test_null_array_smaller_than_indices() {
2224 let values = NullArray::new(2);
2225 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2226
2227 let result = take(&values, &indices, None).unwrap();
2228 let expected: ArrayRef = Arc::new(NullArray::new(3));
2229 assert_eq!(&result, &expected);
2230 }
2231
2232 #[test]
2233 fn test_null_array_larger_than_indices() {
2234 let values = NullArray::new(5);
2235 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2236
2237 let result = take(&values, &indices, None).unwrap();
2238 let expected: ArrayRef = Arc::new(NullArray::new(3));
2239 assert_eq!(&result, &expected);
2240 }
2241
2242 #[test]
2243 fn test_null_array_indices_out_of_bounds() {
2244 let values = NullArray::new(5);
2245 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2246
2247 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2248 assert_eq!(
2249 result.unwrap_err().to_string(),
2250 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2251 );
2252 }
2253
2254 #[test]
2255 fn test_take_dict() {
2256 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2257
2258 dict_builder.append("foo").unwrap();
2259 dict_builder.append("bar").unwrap();
2260 dict_builder.append("").unwrap();
2261 dict_builder.append_null();
2262 dict_builder.append("foo").unwrap();
2263 dict_builder.append("bar").unwrap();
2264 dict_builder.append("bar").unwrap();
2265 dict_builder.append("foo").unwrap();
2266
2267 let array = dict_builder.finish();
2268 let dict_values = array.values().clone();
2269 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2270
2271 let indices = UInt32Array::from(vec![
2272 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2280
2281 let result = take(&array, &indices, None).unwrap();
2282 let result = result
2283 .as_any()
2284 .downcast_ref::<DictionaryArray<Int16Type>>()
2285 .unwrap();
2286
2287 let result_values: StringArray = result.values().to_data().into();
2288
2289 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2291 assert_eq!(&expected_values, dict_values);
2292 assert_eq!(&expected_values, &result_values);
2293
2294 let expected_keys = Int16Array::from(vec![
2295 Some(0),
2296 Some(0),
2297 None,
2298 Some(1),
2299 Some(1),
2300 Some(2),
2301 None,
2302 ]);
2303 assert_eq!(result.keys(), &expected_keys);
2304 }
2305
2306 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2307 where
2308 S: OffsetSizeTrait + 'static,
2309 T: ArrowPrimitiveType,
2310 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2311 {
2312 GenericListArray::from_iter_primitive::<T, _, _>(
2313 data.iter()
2314 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2315 )
2316 }
2317
2318 #[test]
2319 fn test_take_value_index_from_list() {
2320 let list = build_generic_list::<i32, Int32Type>(vec![
2321 Some(vec![0, 1]),
2322 Some(vec![2, 3, 4]),
2323 Some(vec![5, 6, 7, 8, 9]),
2324 ]);
2325 let indices = UInt32Array::from(vec![2, 0]);
2326
2327 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2328
2329 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2330 assert_eq!(offsets, vec![0, 5, 7]);
2331 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2332 }
2333
2334 #[test]
2335 fn test_take_value_index_from_large_list() {
2336 let list = build_generic_list::<i64, Int32Type>(vec![
2337 Some(vec![0, 1]),
2338 Some(vec![2, 3, 4]),
2339 Some(vec![5, 6, 7, 8, 9]),
2340 ]);
2341 let indices = UInt32Array::from(vec![2, 0]);
2342
2343 let (indexed, offsets, null_buf) =
2344 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2345
2346 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2347 assert_eq!(offsets, vec![0, 5, 7]);
2348 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2349 }
2350
2351 #[test]
2352 fn test_take_runs() {
2353 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2354
2355 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2356 builder.extend(logical_array.into_iter().map(Some));
2357 let run_array = builder.finish();
2358
2359 let take_indices: PrimitiveArray<Int32Type> =
2360 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2361
2362 let take_out = take_run(&run_array, &take_indices).unwrap();
2363
2364 assert_eq!(take_out.len(), 7);
2365 assert_eq!(take_out.run_ends().len(), 7);
2366 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2367
2368 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2369 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2370 }
2371
2372 #[test]
2373 fn test_take_value_index_from_fixed_list() {
2374 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2375 vec![
2376 Some(vec![Some(1), Some(2), None]),
2377 Some(vec![Some(4), None, Some(6)]),
2378 None,
2379 Some(vec![None, Some(8), Some(9)]),
2380 ],
2381 3,
2382 );
2383
2384 let indices = UInt32Array::from(vec![2, 1, 0]);
2385 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2386
2387 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2388
2389 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2390 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2391
2392 assert_eq!(
2393 indexed,
2394 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2395 );
2396 }
2397
2398 #[test]
2399 fn test_take_null_indices() {
2400 let indices = Int32Array::new(
2402 vec![1, 2, 400, 400].into(),
2403 Some(NullBuffer::from(vec![true, true, false, false])),
2404 );
2405 let values = Int32Array::from(vec![1, 23, 4, 5]);
2406 let r = take(&values, &indices, None).unwrap();
2407 let values = r
2408 .as_primitive::<Int32Type>()
2409 .into_iter()
2410 .collect::<Vec<_>>();
2411 assert_eq!(&values, &[Some(23), Some(4), None, None])
2412 }
2413
2414 #[test]
2415 fn test_take_fixed_size_list_null_indices() {
2416 let indices = Int32Array::from_iter([Some(0), None]);
2417 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2418 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2419 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2420
2421 let r = take(&values, &indices, None).unwrap();
2422 let values = r
2423 .as_fixed_size_list()
2424 .values()
2425 .as_primitive::<Int32Type>()
2426 .into_iter()
2427 .collect::<Vec<_>>();
2428 assert_eq!(values, &[Some(0), Some(1), None, None])
2429 }
2430
2431 #[test]
2432 fn test_take_bytes_null_indices() {
2433 let indices = Int32Array::new(
2434 vec![0, 1, 400, 400].into(),
2435 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2436 );
2437 let values = StringArray::from(vec![Some("foo"), None]);
2438 let r = take(&values, &indices, None).unwrap();
2439 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2440 assert_eq!(&values, &[Some("foo"), None, None, None])
2441 }
2442
2443 #[test]
2444 fn test_take_union_sparse() {
2445 let structs = create_test_struct(vec![
2446 Some((Some(true), Some(42))),
2447 Some((Some(false), Some(28))),
2448 Some((Some(false), Some(19))),
2449 Some((Some(true), Some(31))),
2450 None,
2451 ]);
2452 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2453 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2454
2455 let union_fields = [
2456 (
2457 0,
2458 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2459 ),
2460 (
2461 1,
2462 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2463 ),
2464 ]
2465 .into_iter()
2466 .collect();
2467 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2468 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2469
2470 let indices = vec![0, 3, 1, 0, 2, 4];
2471 let index = UInt32Array::from(indices.clone());
2472 let actual = take(&array, &index, None).unwrap();
2473 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2474 let strings = actual.child(1);
2475 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2476
2477 let actual = strings.iter().collect::<Vec<_>>();
2478 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2479 assert_eq!(expected, actual);
2480 }
2481
2482 #[test]
2483 fn test_take_union_dense() {
2484 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2485 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2486 let ints = vec![10, 20, 30, 40];
2487 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2488
2489 let indices = vec![0, 3, 1, 0, 2, 4];
2490
2491 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2492 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2493 let taken_ints = vec![10, 20, 10, 30];
2494 let taken_strings = vec![Some("a"), None];
2495
2496 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2497 let offsets = <ScalarBuffer<i32>>::from(offsets);
2498 let ints = UInt32Array::from(ints);
2499 let strings = StringArray::from(strings);
2500
2501 let union_fields = [
2502 (
2503 0,
2504 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2505 ),
2506 (
2507 1,
2508 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2509 ),
2510 ]
2511 .into_iter()
2512 .collect();
2513
2514 let array = UnionArray::try_new(
2515 union_fields,
2516 type_ids,
2517 Some(offsets),
2518 vec![Arc::new(ints), Arc::new(strings)],
2519 )
2520 .unwrap();
2521
2522 let index = UInt32Array::from(indices);
2523
2524 let actual = take(&array, &index, None).unwrap();
2525 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2526
2527 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2528 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2529 assert_eq!(
2530 UInt32Array::from(actual.child(0).to_data()),
2531 UInt32Array::from(taken_ints)
2532 );
2533 assert_eq!(
2534 StringArray::from(actual.child(1).to_data()),
2535 StringArray::from(taken_strings)
2536 );
2537 }
2538
2539 #[test]
2540 fn test_take_union_dense_using_builder() {
2541 let mut builder = UnionBuilder::new_dense();
2542
2543 builder.append::<Int32Type>("a", 1).unwrap();
2544 builder.append::<Float64Type>("b", 3.0).unwrap();
2545 builder.append::<Int32Type>("a", 4).unwrap();
2546 builder.append::<Int32Type>("a", 5).unwrap();
2547 builder.append::<Float64Type>("b", 2.0).unwrap();
2548
2549 let union = builder.build().unwrap();
2550
2551 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2552
2553 let mut builder = UnionBuilder::new_dense();
2554
2555 builder.append::<Int32Type>("a", 4).unwrap();
2556 builder.append::<Int32Type>("a", 1).unwrap();
2557 builder.append::<Float64Type>("b", 3.0).unwrap();
2558 builder.append::<Int32Type>("a", 4).unwrap();
2559
2560 let taken = builder.build().unwrap();
2561
2562 assert_eq!(
2563 taken.to_data(),
2564 take(&union, &indices, None).unwrap().to_data()
2565 );
2566 }
2567
2568 #[test]
2569 fn test_take_union_dense_all_match_issue_6206() {
2570 let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2571 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2572
2573 let array = UnionArray::try_new(
2574 fields,
2575 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2576 Some(ScalarBuffer::from_iter(0_i32..5)),
2577 vec![ints],
2578 )
2579 .unwrap();
2580
2581 let indicies = Int64Array::from(vec![0, 2, 4]);
2582 let array = take(&array, &indicies, None).unwrap();
2583 assert_eq!(array.len(), 3);
2584 }
2585
2586 #[test]
2587 fn test_take_bytes_offset_overflow() {
2588 let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2589 let text = ('a'..='z').collect::<String>();
2590 let values = StringArray::from(vec![Some(text.clone())]);
2591 assert!(matches!(
2592 take(&values, &indices, None),
2593 Err(ArrowError::OffsetOverflowError(_))
2594 ));
2595 }
2596}