1use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
28 bit_util,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num_traits::{One, Zero};
34
35pub fn take(
87 values: &dyn Array,
88 indices: &dyn Array,
89 options: Option<TakeOptions>,
90) -> Result<ArrayRef, ArrowError> {
91 let options = options.unwrap_or_default();
92 downcast_integer_array!(
93 indices => {
94 if options.check_bounds {
95 check_bounds(values.len(), indices)?;
96 }
97 let indices = indices.to_indices();
98 take_impl(values, &indices)
99 },
100 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
101 )
102}
103
104pub fn take_arrays(
153 arrays: &[ArrayRef],
154 indices: &dyn Array,
155 options: Option<TakeOptions>,
156) -> Result<Vec<ArrayRef>, ArrowError> {
157 arrays
158 .iter()
159 .map(|array| take(array.as_ref(), indices, options.clone()))
160 .collect()
161}
162
163fn check_bounds<T: ArrowPrimitiveType>(
165 len: usize,
166 indices: &PrimitiveArray<T>,
167) -> Result<(), ArrowError> {
168 if indices.null_count() > 0 {
169 indices.iter().flatten().try_for_each(|index| {
170 let ix = index
171 .to_usize()
172 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
173 if ix >= len {
174 return Err(ArrowError::ComputeError(format!(
175 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
176 )));
177 }
178 Ok(())
179 })
180 } else {
181 indices.values().iter().try_for_each(|index| {
182 let ix = index
183 .to_usize()
184 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
185 if ix >= len {
186 return Err(ArrowError::ComputeError(format!(
187 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
188 )));
189 }
190 Ok(())
191 })
192 }
193}
194
195#[inline(never)]
196fn take_impl<IndexType: ArrowPrimitiveType>(
197 values: &dyn Array,
198 indices: &PrimitiveArray<IndexType>,
199) -> Result<ArrayRef, ArrowError> {
200 downcast_primitive_array! {
201 values => Ok(Arc::new(take_primitive(values, indices)?)),
202 DataType::Boolean => {
203 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
204 Ok(Arc::new(take_boolean(values, indices)))
205 }
206 DataType::Utf8 => {
207 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
208 }
209 DataType::LargeUtf8 => {
210 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
211 }
212 DataType::Utf8View => {
213 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
214 }
215 DataType::List(_) => {
216 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
217 }
218 DataType::LargeList(_) => {
219 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
220 }
221 DataType::ListView(_) => {
222 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
223 }
224 DataType::LargeListView(_) => {
225 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
226 }
227 DataType::FixedSizeList(_, length) => {
228 let values = values
229 .as_any()
230 .downcast_ref::<FixedSizeListArray>()
231 .unwrap();
232 Ok(Arc::new(take_fixed_size_list(
233 values,
234 indices,
235 *length as u32,
236 )?))
237 }
238 DataType::Map(_, _) => {
239 let list_arr = ListArray::from(values.as_map().clone());
240 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
241 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
242 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
243 }
244 DataType::Struct(fields) => {
245 let array: &StructArray = values.as_struct();
246 let arrays = array
247 .columns()
248 .iter()
249 .map(|a| take_impl(a.as_ref(), indices))
250 .collect::<Result<Vec<ArrayRef>, _>>()?;
251 let fields: Vec<(FieldRef, ArrayRef)> =
252 fields.iter().cloned().zip(arrays).collect();
253
254 let is_valid: Buffer = indices
256 .iter()
257 .map(|index| {
258 if let Some(index) = index {
259 array.is_valid(index.to_usize().unwrap())
260 } else {
261 false
262 }
263 })
264 .collect();
265
266 if fields.is_empty() {
267 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
268 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
269 } else {
270 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
271 }
272 }
273 DataType::Dictionary(_, _) => downcast_dictionary_array! {
274 values => Ok(Arc::new(take_dict(values, indices)?)),
275 t => unimplemented!("Take not supported for dictionary type {:?}", t)
276 }
277 DataType::RunEndEncoded(_, _) => downcast_run_array! {
278 values => Ok(Arc::new(take_run(values, indices)?)),
279 t => unimplemented!("Take not supported for run type {:?}", t)
280 }
281 DataType::Binary => {
282 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
283 }
284 DataType::LargeBinary => {
285 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
286 }
287 DataType::BinaryView => {
288 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
289 }
290 DataType::FixedSizeBinary(size) => {
291 let values = values
292 .as_any()
293 .downcast_ref::<FixedSizeBinaryArray>()
294 .unwrap();
295 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
296 }
297 DataType::Null => {
298 if values.len() >= indices.len() {
300 Ok(values.slice(0, indices.len()))
303 } else {
304 Ok(new_null_array(&DataType::Null, indices.len()))
306 }
307 }
308 DataType::Union(fields, UnionMode::Sparse) => {
309 let mut children = Vec::with_capacity(fields.len());
310 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
311 let type_ids = take_native(values.type_ids(), indices);
312 for (type_id, _field) in fields.iter() {
313 let values = values.child(type_id);
314 let values = take_impl(values, indices)?;
315 children.push(values);
316 }
317 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
318 Ok(Arc::new(array))
319 }
320 DataType::Union(fields, UnionMode::Dense) => {
321 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
322
323 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
324 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
325
326 let children = fields.iter()
327 .map(|(field_type_id, _)| {
328 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
329
330 let indices = crate::filter::filter(&offsets, &mask)?;
331
332 let values = values.child(field_type_id);
333
334 take_impl(values, indices.as_primitive::<Int32Type>())
335 })
336 .collect::<Result<_, _>>()?;
337
338 let mut child_offsets = [0; 128];
339
340 let offsets = type_ids.values()
341 .iter()
342 .map(|&i| {
343 let offset = child_offsets[i as usize];
344
345 child_offsets[i as usize] += 1;
346
347 offset
348 })
349 .collect();
350
351 let (_, type_ids, _) = type_ids.into_parts();
352
353 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
354
355 Ok(Arc::new(array))
356 }
357 t => unimplemented!("Take not supported for data type {:?}", t)
358 }
359}
360
361#[derive(Clone, Debug, Default)]
363pub struct TakeOptions {
364 pub check_bounds: bool,
368}
369
370#[inline(always)]
371fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
372 index
373 .to_usize()
374 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
375}
376
377fn take_primitive<T, I>(
387 values: &PrimitiveArray<T>,
388 indices: &PrimitiveArray<I>,
389) -> Result<PrimitiveArray<T>, ArrowError>
390where
391 T: ArrowPrimitiveType,
392 I: ArrowPrimitiveType,
393{
394 let values_buf = take_native(values.values(), indices);
395 let nulls = take_nulls(values.nulls(), indices);
396 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
397}
398
399#[inline(never)]
400fn take_nulls<I: ArrowPrimitiveType>(
401 values: Option<&NullBuffer>,
402 indices: &PrimitiveArray<I>,
403) -> Option<NullBuffer> {
404 match values.filter(|n| n.null_count() > 0) {
405 Some(n) => {
406 let buffer = take_bits(n.inner(), indices);
407 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
408 }
409 None => indices.nulls().cloned(),
410 }
411}
412
413#[inline(never)]
414fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
415 values: &[T],
416 indices: &PrimitiveArray<I>,
417) -> ScalarBuffer<T> {
418 match indices.nulls().filter(|n| n.null_count() > 0) {
419 Some(n) => indices
420 .values()
421 .iter()
422 .enumerate()
423 .map(|(idx, index)| match values.get(index.as_usize()) {
424 Some(v) => *v,
425 None => match n.is_null(idx) {
426 true => T::default(),
427 false => panic!("Out-of-bounds index {index:?}"),
428 },
429 })
430 .collect(),
431 None => indices
432 .values()
433 .iter()
434 .map(|index| values[index.as_usize()])
435 .collect(),
436 }
437}
438
439#[inline(never)]
440fn take_bits<I: ArrowPrimitiveType>(
441 values: &BooleanBuffer,
442 indices: &PrimitiveArray<I>,
443) -> BooleanBuffer {
444 let len = indices.len();
445
446 match indices.nulls().filter(|n| n.null_count() > 0) {
447 Some(nulls) => {
448 let mut output_buffer = MutableBuffer::new_null(len);
449 let output_slice = output_buffer.as_slice_mut();
450 nulls.valid_indices().for_each(|idx| {
451 if values.value(indices.value(idx).as_usize()) {
452 bit_util::set_bit(output_slice, idx);
453 }
454 });
455 BooleanBuffer::new(output_buffer.into(), 0, len)
456 }
457 None => {
458 BooleanBuffer::collect_bool(len, |idx: usize| {
459 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
461 })
462 }
463 }
464}
465
466fn take_boolean<IndexType: ArrowPrimitiveType>(
468 values: &BooleanArray,
469 indices: &PrimitiveArray<IndexType>,
470) -> BooleanArray {
471 let val_buf = take_bits(values.values(), indices);
472 let null_buf = take_nulls(values.nulls(), indices);
473 BooleanArray::new(val_buf, null_buf)
474}
475
476fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
478 array: &GenericByteArray<T>,
479 indices: &PrimitiveArray<IndexType>,
480) -> Result<GenericByteArray<T>, ArrowError> {
481 let mut offsets = Vec::with_capacity(indices.len() + 1);
482 offsets.push(T::Offset::default());
483
484 let input_offsets = array.value_offsets();
485 let mut capacity = 0;
486 let nulls = take_nulls(array.nulls(), indices);
487
488 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
489 offsets.reserve(indices.len());
490 for index in indices.values() {
491 let index = index.as_usize();
492 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
493 offsets.push(
494 T::Offset::from_usize(capacity)
495 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
496 );
497 }
498 let mut values = Vec::with_capacity(capacity);
499
500 for index in indices.values() {
501 values.extend_from_slice(array.value(index.as_usize()).as_ref());
502 }
503 (offsets, values)
504 } else if indices.null_count() == 0 {
505 offsets.reserve(indices.len());
506 for index in indices.values() {
507 let index = index.as_usize();
508 if array.is_valid(index) {
509 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
510 }
511 offsets.push(
512 T::Offset::from_usize(capacity)
513 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
514 );
515 }
516 let mut values = Vec::with_capacity(capacity);
517
518 for index in indices.values() {
519 let index = index.as_usize();
520 if array.is_valid(index) {
521 values.extend_from_slice(array.value(index).as_ref());
522 }
523 }
524 (offsets, values)
525 } else if array.null_count() == 0 {
526 offsets.reserve(indices.len());
527 for (i, index) in indices.values().iter().enumerate() {
528 let index = index.as_usize();
529 if indices.is_valid(i) {
530 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
531 }
532 offsets.push(
533 T::Offset::from_usize(capacity)
534 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
535 );
536 }
537 let mut values = Vec::with_capacity(capacity);
538
539 for (i, index) in indices.values().iter().enumerate() {
540 if indices.is_valid(i) {
541 values.extend_from_slice(array.value(index.as_usize()).as_ref());
542 }
543 }
544 (offsets, values)
545 } else {
546 let nulls = nulls.as_ref().unwrap();
547 offsets.reserve(indices.len());
548 for (i, index) in indices.values().iter().enumerate() {
549 let index = index.as_usize();
550 if nulls.is_valid(i) {
551 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
552 }
553 offsets.push(
554 T::Offset::from_usize(capacity)
555 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
556 );
557 }
558 let mut values = Vec::with_capacity(capacity);
559
560 for (i, index) in indices.values().iter().enumerate() {
561 let index = index.as_usize();
564 if nulls.is_valid(i) {
565 values.extend_from_slice(array.value(index).as_ref());
566 }
567 }
568 (offsets, values)
569 };
570
571 T::Offset::from_usize(values.len())
572 .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
573
574 let array = unsafe {
575 let offsets = OffsetBuffer::new_unchecked(offsets.into());
576 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
577 };
578
579 Ok(array)
580}
581
582fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
584 array: &GenericByteViewArray<T>,
585 indices: &PrimitiveArray<IndexType>,
586) -> Result<GenericByteViewArray<T>, ArrowError> {
587 let new_views = take_native(array.views(), indices);
588 let new_nulls = take_nulls(array.nulls(), indices);
589 Ok(unsafe {
591 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
592 })
593}
594
595fn take_list<IndexType, OffsetType>(
601 values: &GenericListArray<OffsetType::Native>,
602 indices: &PrimitiveArray<IndexType>,
603) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
604where
605 IndexType: ArrowPrimitiveType,
606 OffsetType: ArrowPrimitiveType,
607 OffsetType::Native: OffsetSizeTrait,
608 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
609{
610 let (list_indices, offsets, null_buf) =
613 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
614
615 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
616 let value_offsets = Buffer::from_vec(offsets);
617 let list_data = ArrayDataBuilder::new(values.data_type().clone())
619 .len(indices.len())
620 .null_bit_buffer(Some(null_buf.into()))
621 .offset(0)
622 .add_child_data(taken.into_data())
623 .add_buffer(value_offsets);
624
625 let list_data = unsafe { list_data.build_unchecked() };
626
627 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
628}
629
630fn take_list_view<IndexType, OffsetType>(
631 values: &GenericListViewArray<OffsetType::Native>,
632 indices: &PrimitiveArray<IndexType>,
633) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
634where
635 IndexType: ArrowPrimitiveType,
636 OffsetType: ArrowPrimitiveType,
637 OffsetType::Native: OffsetSizeTrait,
638{
639 let taken_offsets = take_native(values.offsets(), indices);
640 let taken_sizes = take_native(values.sizes(), indices);
641 let nulls = take_nulls(values.nulls(), indices);
642
643 let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
644 .len(indices.len())
645 .nulls(nulls)
646 .buffers(vec![taken_offsets.into(), taken_sizes.into()])
647 .child_data(vec![values.values().to_data()]);
648
649 let list_view_data = unsafe { list_view_data.build_unchecked() };
651
652 Ok(GenericListViewArray::<OffsetType::Native>::from(
653 list_view_data,
654 ))
655}
656
657fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
663 values: &FixedSizeListArray,
664 indices: &PrimitiveArray<IndexType>,
665 length: <UInt32Type as ArrowPrimitiveType>::Native,
666) -> Result<FixedSizeListArray, ArrowError> {
667 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
668 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
669
670 let num_bytes = bit_util::ceil(indices.len(), 8);
672 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
673 let null_slice = null_buf.as_slice_mut();
674
675 for i in 0..indices.len() {
676 let index = indices
677 .value(i)
678 .to_usize()
679 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
680 if !indices.is_valid(i) || values.is_null(index) {
681 bit_util::unset_bit(null_slice, i);
682 }
683 }
684
685 let list_data = ArrayDataBuilder::new(values.data_type().clone())
686 .len(indices.len())
687 .null_bit_buffer(Some(null_buf.into()))
688 .offset(0)
689 .add_child_data(taken.into_data());
690
691 let list_data = unsafe { list_data.build_unchecked() };
692
693 Ok(FixedSizeListArray::from(list_data))
694}
695
696fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
697 values: &FixedSizeBinaryArray,
698 indices: &PrimitiveArray<IndexType>,
699 size: i32,
700) -> Result<FixedSizeBinaryArray, ArrowError> {
701 let nulls = values.nulls();
702 let array_iter = indices
703 .values()
704 .iter()
705 .map(|idx| {
706 let idx = maybe_usize::<IndexType::Native>(*idx)?;
707 if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
708 Ok(Some(values.value(idx)))
709 } else {
710 Ok(None)
711 }
712 })
713 .collect::<Result<Vec<_>, ArrowError>>()?
714 .into_iter();
715
716 FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
717}
718
719fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
724 values: &DictionaryArray<T>,
725 indices: &PrimitiveArray<I>,
726) -> Result<DictionaryArray<T>, ArrowError> {
727 let new_keys = take_primitive(values.keys(), indices)?;
728 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
729}
730
731fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
740 run_array: &RunArray<T>,
741 logical_indices: &PrimitiveArray<I>,
742) -> Result<RunArray<T>, ArrowError> {
743 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
745
746 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
750 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
751 let mut new_physical_len = 1;
752 for ix in 1..physical_indices.len() {
753 if physical_indices[ix] != physical_indices[ix - 1] {
754 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
755 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
756 new_physical_len += 1;
757 }
758 }
759 take_value_indices
760 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
761 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
762 let new_run_ends = unsafe {
763 ArrayDataBuilder::new(T::DATA_TYPE)
766 .len(new_physical_len)
767 .null_count(0)
768 .add_buffer(new_run_ends_builder.finish())
769 .build_unchecked()
770 };
771
772 let take_value_indices: PrimitiveArray<I> = unsafe {
773 ArrayDataBuilder::new(I::DATA_TYPE)
776 .len(new_physical_len)
777 .null_count(0)
778 .add_buffer(take_value_indices.finish())
779 .build_unchecked()
780 .into()
781 };
782
783 let new_values = take(run_array.values(), &take_value_indices, None)?;
784
785 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
786 .len(physical_indices.len())
787 .add_child_data(new_run_ends)
788 .add_child_data(new_values.into_data());
789 let array_data = unsafe {
790 builder.build_unchecked()
793 };
794 Ok(array_data.into())
795}
796
797#[allow(clippy::type_complexity)]
803fn take_value_indices_from_list<IndexType, OffsetType>(
804 list: &GenericListArray<OffsetType::Native>,
805 indices: &PrimitiveArray<IndexType>,
806) -> Result<
807 (
808 PrimitiveArray<OffsetType>,
809 Vec<OffsetType::Native>,
810 MutableBuffer,
811 ),
812 ArrowError,
813>
814where
815 IndexType: ArrowPrimitiveType,
816 OffsetType: ArrowPrimitiveType,
817 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
818 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
819{
820 let offsets: &[OffsetType::Native] = list.value_offsets();
822
823 let mut new_offsets = Vec::with_capacity(indices.len());
824 let mut values = Vec::new();
825 let mut current_offset = OffsetType::Native::zero();
826 new_offsets.push(OffsetType::Native::zero());
828
829 let num_bytes = bit_util::ceil(indices.len(), 8);
831 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
832 let null_slice = null_buf.as_slice_mut();
833
834 for i in 0..indices.len() {
836 if indices.is_valid(i) {
837 let ix = indices
838 .value(i)
839 .to_usize()
840 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
841 let start = offsets[ix];
842 let end = offsets[ix + 1];
843 current_offset += end - start;
844 new_offsets.push(current_offset);
845
846 let mut curr = start;
847
848 while curr < end {
850 values.push(curr);
851 curr += One::one();
852 }
853 if !list.is_valid(ix) {
854 bit_util::unset_bit(null_slice, i);
855 }
856 } else {
857 bit_util::unset_bit(null_slice, i);
858 new_offsets.push(current_offset);
859 }
860 }
861
862 Ok((
863 PrimitiveArray::<OffsetType>::from(values),
864 new_offsets,
865 null_buf,
866 ))
867}
868
869fn take_value_indices_from_fixed_size_list<IndexType>(
871 list: &FixedSizeListArray,
872 indices: &PrimitiveArray<IndexType>,
873 length: <UInt32Type as ArrowPrimitiveType>::Native,
874) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
875where
876 IndexType: ArrowPrimitiveType,
877{
878 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
879
880 for i in 0..indices.len() {
881 if indices.is_valid(i) {
882 let index = indices
883 .value(i)
884 .to_usize()
885 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
886 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
887
888 unsafe {
890 values.append_trusted_len_iter(start..start + length);
891 }
892 } else {
893 values.append_nulls(length as usize);
894 }
895 }
896
897 Ok(values.finish())
898}
899
900trait ToIndices {
903 type T: ArrowPrimitiveType;
904
905 fn to_indices(&self) -> PrimitiveArray<Self::T>;
906}
907
908macro_rules! to_indices_reinterpret {
909 ($t:ty, $o:ty) => {
910 impl ToIndices for PrimitiveArray<$t> {
911 type T = $o;
912
913 fn to_indices(&self) -> PrimitiveArray<$o> {
914 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
915 PrimitiveArray::new(cast, self.nulls().cloned())
916 }
917 }
918 };
919}
920
921macro_rules! to_indices_identity {
922 ($t:ty) => {
923 impl ToIndices for PrimitiveArray<$t> {
924 type T = $t;
925
926 fn to_indices(&self) -> PrimitiveArray<$t> {
927 self.clone()
928 }
929 }
930 };
931}
932
933macro_rules! to_indices_widening {
934 ($t:ty, $o:ty) => {
935 impl ToIndices for PrimitiveArray<$t> {
936 type T = UInt32Type;
937
938 fn to_indices(&self) -> PrimitiveArray<$o> {
939 let cast = self.values().iter().copied().map(|x| x as _).collect();
940 PrimitiveArray::new(cast, self.nulls().cloned())
941 }
942 }
943 };
944}
945
946to_indices_widening!(UInt8Type, UInt32Type);
947to_indices_widening!(Int8Type, UInt32Type);
948
949to_indices_widening!(UInt16Type, UInt32Type);
950to_indices_widening!(Int16Type, UInt32Type);
951
952to_indices_identity!(UInt32Type);
953to_indices_reinterpret!(Int32Type, UInt32Type);
954
955to_indices_identity!(UInt64Type);
956to_indices_reinterpret!(Int64Type, UInt64Type);
957
958pub fn take_record_batch(
998 record_batch: &RecordBatch,
999 indices: &dyn Array,
1000) -> Result<RecordBatch, ArrowError> {
1001 let columns = record_batch
1002 .columns()
1003 .iter()
1004 .map(|c| take(c, indices, None))
1005 .collect::<Result<Vec<_>, _>>()?;
1006 RecordBatch::try_new(record_batch.schema(), columns)
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011 use super::*;
1012 use arrow_array::builder::*;
1013 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1014 use arrow_data::ArrayData;
1015 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1016 use num_traits::ToPrimitive;
1017
1018 fn test_take_decimal_arrays(
1019 data: Vec<Option<i128>>,
1020 index: &UInt32Array,
1021 options: Option<TakeOptions>,
1022 expected_data: Vec<Option<i128>>,
1023 precision: &u8,
1024 scale: &i8,
1025 ) -> Result<(), ArrowError> {
1026 let output = data
1027 .into_iter()
1028 .collect::<Decimal128Array>()
1029 .with_precision_and_scale(*precision, *scale)
1030 .unwrap();
1031
1032 let expected = expected_data
1033 .into_iter()
1034 .collect::<Decimal128Array>()
1035 .with_precision_and_scale(*precision, *scale)
1036 .unwrap();
1037
1038 let expected = Arc::new(expected) as ArrayRef;
1039 let output = take(&output, index, options).unwrap();
1040 assert_eq!(&output, &expected);
1041 Ok(())
1042 }
1043
1044 fn test_take_boolean_arrays(
1045 data: Vec<Option<bool>>,
1046 index: &UInt32Array,
1047 options: Option<TakeOptions>,
1048 expected_data: Vec<Option<bool>>,
1049 ) {
1050 let output = BooleanArray::from(data);
1051 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1052 let output = take(&output, index, options).unwrap();
1053 assert_eq!(&output, &expected)
1054 }
1055
1056 fn test_take_primitive_arrays<T>(
1057 data: Vec<Option<T::Native>>,
1058 index: &UInt32Array,
1059 options: Option<TakeOptions>,
1060 expected_data: Vec<Option<T::Native>>,
1061 ) -> Result<(), ArrowError>
1062 where
1063 T: ArrowPrimitiveType,
1064 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1065 {
1066 let output = PrimitiveArray::<T>::from(data);
1067 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1068 let output = take(&output, index, options)?;
1069 assert_eq!(&output, &expected);
1070 Ok(())
1071 }
1072
1073 fn test_take_primitive_arrays_non_null<T>(
1074 data: Vec<T::Native>,
1075 index: &UInt32Array,
1076 options: Option<TakeOptions>,
1077 expected_data: Vec<Option<T::Native>>,
1078 ) -> Result<(), ArrowError>
1079 where
1080 T: ArrowPrimitiveType,
1081 PrimitiveArray<T>: From<Vec<T::Native>>,
1082 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1083 {
1084 let output = PrimitiveArray::<T>::from(data);
1085 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1086 let output = take(&output, index, options)?;
1087 assert_eq!(&output, &expected);
1088 Ok(())
1089 }
1090
1091 fn test_take_impl_primitive_arrays<T, I>(
1092 data: Vec<Option<T::Native>>,
1093 index: &PrimitiveArray<I>,
1094 options: Option<TakeOptions>,
1095 expected_data: Vec<Option<T::Native>>,
1096 ) where
1097 T: ArrowPrimitiveType,
1098 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1099 I: ArrowPrimitiveType,
1100 {
1101 let output = PrimitiveArray::<T>::from(data);
1102 let expected = PrimitiveArray::<T>::from(expected_data);
1103 let output = take(&output, index, options).unwrap();
1104 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1105 assert_eq!(output, &expected)
1106 }
1107
1108 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1110 let mut struct_builder = StructBuilder::new(
1111 Fields::from(vec![
1112 Field::new("a", DataType::Boolean, true),
1113 Field::new("b", DataType::Int32, true),
1114 ]),
1115 vec![
1116 Box::new(BooleanBuilder::with_capacity(values.len())),
1117 Box::new(Int32Builder::with_capacity(values.len())),
1118 ],
1119 );
1120
1121 for value in values {
1122 struct_builder
1123 .field_builder::<BooleanBuilder>(0)
1124 .unwrap()
1125 .append_option(value.and_then(|v| v.0));
1126 struct_builder
1127 .field_builder::<Int32Builder>(1)
1128 .unwrap()
1129 .append_option(value.and_then(|v| v.1));
1130 struct_builder.append(value.is_some());
1131 }
1132 struct_builder.finish()
1133 }
1134
1135 #[test]
1136 fn test_take_decimal128_non_null_indices() {
1137 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1138 let precision: u8 = 10;
1139 let scale: i8 = 5;
1140 test_take_decimal_arrays(
1141 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1142 &index,
1143 None,
1144 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1145 &precision,
1146 &scale,
1147 )
1148 .unwrap();
1149 }
1150
1151 #[test]
1152 fn test_take_decimal128() {
1153 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1154 let precision: u8 = 10;
1155 let scale: i8 = 5;
1156 test_take_decimal_arrays(
1157 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1158 &index,
1159 None,
1160 vec![Some(3), None, Some(1), Some(3), Some(2)],
1161 &precision,
1162 &scale,
1163 )
1164 .unwrap();
1165 }
1166
1167 #[test]
1168 fn test_take_primitive_non_null_indices() {
1169 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1170 test_take_primitive_arrays::<Int8Type>(
1171 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1172 &index,
1173 None,
1174 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1175 )
1176 .unwrap();
1177 }
1178
1179 #[test]
1180 fn test_take_primitive_non_null_values() {
1181 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1182 test_take_primitive_arrays::<Int8Type>(
1183 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1184 &index,
1185 None,
1186 vec![Some(3), None, Some(1), Some(3), Some(2)],
1187 )
1188 .unwrap();
1189 }
1190
1191 #[test]
1192 fn test_take_primitive_non_null() {
1193 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1194 test_take_primitive_arrays::<Int8Type>(
1195 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1196 &index,
1197 None,
1198 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1199 )
1200 .unwrap();
1201 }
1202
1203 #[test]
1204 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1205 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1206 let index = index.slice(2, 4);
1207 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1208
1209 assert_eq!(
1210 index,
1211 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1212 );
1213
1214 test_take_primitive_arrays_non_null::<Int64Type>(
1215 vec![0, 10, 20, 30, 40, 50],
1216 index,
1217 None,
1218 vec![Some(20), Some(30), None, None],
1219 )
1220 .unwrap();
1221 }
1222
1223 #[test]
1224 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1225 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1226 let index = index.slice(2, 4);
1227 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1228
1229 assert_eq!(
1230 index,
1231 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1232 );
1233
1234 test_take_primitive_arrays::<Int64Type>(
1235 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1236 index,
1237 None,
1238 vec![Some(20), Some(30), None, None],
1239 )
1240 .unwrap();
1241 }
1242
1243 #[test]
1244 fn test_take_primitive() {
1245 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1246
1247 test_take_primitive_arrays::<Int8Type>(
1249 vec![Some(0), None, Some(2), Some(3), None],
1250 &index,
1251 None,
1252 vec![Some(3), None, None, Some(3), Some(2)],
1253 )
1254 .unwrap();
1255
1256 test_take_primitive_arrays::<Int16Type>(
1258 vec![Some(0), None, Some(2), Some(3), None],
1259 &index,
1260 None,
1261 vec![Some(3), None, None, Some(3), Some(2)],
1262 )
1263 .unwrap();
1264
1265 test_take_primitive_arrays::<Int32Type>(
1267 vec![Some(0), None, Some(2), Some(3), None],
1268 &index,
1269 None,
1270 vec![Some(3), None, None, Some(3), Some(2)],
1271 )
1272 .unwrap();
1273
1274 test_take_primitive_arrays::<Int64Type>(
1276 vec![Some(0), None, Some(2), Some(3), None],
1277 &index,
1278 None,
1279 vec![Some(3), None, None, Some(3), Some(2)],
1280 )
1281 .unwrap();
1282
1283 test_take_primitive_arrays::<UInt8Type>(
1285 vec![Some(0), None, Some(2), Some(3), None],
1286 &index,
1287 None,
1288 vec![Some(3), None, None, Some(3), Some(2)],
1289 )
1290 .unwrap();
1291
1292 test_take_primitive_arrays::<UInt16Type>(
1294 vec![Some(0), None, Some(2), Some(3), None],
1295 &index,
1296 None,
1297 vec![Some(3), None, None, Some(3), Some(2)],
1298 )
1299 .unwrap();
1300
1301 test_take_primitive_arrays::<UInt32Type>(
1303 vec![Some(0), None, Some(2), Some(3), None],
1304 &index,
1305 None,
1306 vec![Some(3), None, None, Some(3), Some(2)],
1307 )
1308 .unwrap();
1309
1310 test_take_primitive_arrays::<Int64Type>(
1312 vec![Some(0), None, Some(2), Some(-15), None],
1313 &index,
1314 None,
1315 vec![Some(-15), None, None, Some(-15), Some(2)],
1316 )
1317 .unwrap();
1318
1319 test_take_primitive_arrays::<IntervalYearMonthType>(
1321 vec![Some(0), None, Some(2), Some(-15), None],
1322 &index,
1323 None,
1324 vec![Some(-15), None, None, Some(-15), Some(2)],
1325 )
1326 .unwrap();
1327
1328 let v1 = IntervalDayTime::new(0, 0);
1330 let v2 = IntervalDayTime::new(2, 0);
1331 let v3 = IntervalDayTime::new(-15, 0);
1332 test_take_primitive_arrays::<IntervalDayTimeType>(
1333 vec![Some(v1), None, Some(v2), Some(v3), None],
1334 &index,
1335 None,
1336 vec![Some(v3), None, None, Some(v3), Some(v2)],
1337 )
1338 .unwrap();
1339
1340 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1342 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1343 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1344 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1345 vec![Some(v1), None, Some(v2), Some(v3), None],
1346 &index,
1347 None,
1348 vec![Some(v3), None, None, Some(v3), Some(v2)],
1349 )
1350 .unwrap();
1351
1352 test_take_primitive_arrays::<DurationSecondType>(
1354 vec![Some(0), None, Some(2), Some(-15), None],
1355 &index,
1356 None,
1357 vec![Some(-15), None, None, Some(-15), Some(2)],
1358 )
1359 .unwrap();
1360
1361 test_take_primitive_arrays::<DurationMillisecondType>(
1363 vec![Some(0), None, Some(2), Some(-15), None],
1364 &index,
1365 None,
1366 vec![Some(-15), None, None, Some(-15), Some(2)],
1367 )
1368 .unwrap();
1369
1370 test_take_primitive_arrays::<DurationMicrosecondType>(
1372 vec![Some(0), None, Some(2), Some(-15), None],
1373 &index,
1374 None,
1375 vec![Some(-15), None, None, Some(-15), Some(2)],
1376 )
1377 .unwrap();
1378
1379 test_take_primitive_arrays::<DurationNanosecondType>(
1381 vec![Some(0), None, Some(2), Some(-15), None],
1382 &index,
1383 None,
1384 vec![Some(-15), None, None, Some(-15), Some(2)],
1385 )
1386 .unwrap();
1387
1388 test_take_primitive_arrays::<Float32Type>(
1390 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1391 &index,
1392 None,
1393 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1394 )
1395 .unwrap();
1396
1397 test_take_primitive_arrays::<Float64Type>(
1399 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1400 &index,
1401 None,
1402 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1403 )
1404 .unwrap();
1405 }
1406
1407 #[test]
1408 fn test_take_preserve_timezone() {
1409 let index = Int64Array::from(vec![Some(0), None]);
1410
1411 let input = TimestampNanosecondArray::from(vec![
1412 1_639_715_368_000_000_000,
1413 1_639_715_368_000_000_000,
1414 ])
1415 .with_timezone("UTC".to_string());
1416 let result = take(&input, &index, None).unwrap();
1417 match result.data_type() {
1418 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1419 assert_eq!(tz.clone(), Some("UTC".into()))
1420 }
1421 _ => panic!(),
1422 }
1423 }
1424
1425 #[test]
1426 fn test_take_impl_primitive_with_int64_indices() {
1427 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1428
1429 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1431 vec![Some(0), None, Some(2), Some(3), None],
1432 &index,
1433 None,
1434 vec![Some(3), None, None, Some(3), Some(2)],
1435 );
1436
1437 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1439 vec![Some(0), None, Some(2), Some(-15), None],
1440 &index,
1441 None,
1442 vec![Some(-15), None, None, Some(-15), Some(2)],
1443 );
1444
1445 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1447 vec![Some(0), None, Some(2), Some(3), None],
1448 &index,
1449 None,
1450 vec![Some(3), None, None, Some(3), Some(2)],
1451 );
1452
1453 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1455 vec![Some(0), None, Some(2), Some(-15), None],
1456 &index,
1457 None,
1458 vec![Some(-15), None, None, Some(-15), Some(2)],
1459 );
1460
1461 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1463 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1464 &index,
1465 None,
1466 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1467 );
1468 }
1469
1470 #[test]
1471 fn test_take_impl_primitive_with_uint8_indices() {
1472 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1473
1474 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1476 vec![Some(0), None, Some(2), Some(3), None],
1477 &index,
1478 None,
1479 vec![Some(3), None, None, Some(3), Some(2)],
1480 );
1481
1482 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1484 vec![Some(0), None, Some(2), Some(-15), None],
1485 &index,
1486 None,
1487 vec![Some(-15), None, None, Some(-15), Some(2)],
1488 );
1489
1490 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1492 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1493 &index,
1494 None,
1495 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1496 );
1497 }
1498
1499 #[test]
1500 fn test_take_bool() {
1501 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1502 test_take_boolean_arrays(
1504 vec![Some(false), None, Some(true), Some(false), None],
1505 &index,
1506 None,
1507 vec![Some(false), None, None, Some(false), Some(true)],
1508 );
1509 }
1510
1511 #[test]
1512 fn test_take_bool_nullable_index() {
1513 let index_data = ArrayData::try_new(
1515 DataType::UInt32,
1516 6,
1517 Some(Buffer::from_iter(vec![
1518 false, true, false, true, false, true,
1519 ])),
1520 0,
1521 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1522 vec![],
1523 )
1524 .unwrap();
1525 let index = UInt32Array::from(index_data);
1526 test_take_boolean_arrays(
1527 vec![Some(true), None, Some(false)],
1528 &index,
1529 None,
1530 vec![None, Some(true), None, None, None, Some(false)],
1531 );
1532 }
1533
1534 #[test]
1535 fn test_take_bool_nullable_index_nonnull_values() {
1536 let index_data = ArrayData::try_new(
1538 DataType::UInt32,
1539 6,
1540 Some(Buffer::from_iter(vec![
1541 false, true, false, true, false, true,
1542 ])),
1543 0,
1544 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1545 vec![],
1546 )
1547 .unwrap();
1548 let index = UInt32Array::from(index_data);
1549 test_take_boolean_arrays(
1550 vec![Some(true), Some(true), Some(false)],
1551 &index,
1552 None,
1553 vec![None, Some(true), None, Some(true), None, Some(false)],
1554 );
1555 }
1556
1557 #[test]
1558 fn test_take_bool_with_offset() {
1559 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1560 let index = index.slice(2, 4);
1561 let index = index
1562 .as_any()
1563 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1564 .unwrap();
1565
1566 test_take_boolean_arrays(
1568 vec![Some(false), None, Some(true), Some(false), None],
1569 index,
1570 None,
1571 vec![None, Some(false), Some(true), None],
1572 );
1573 }
1574
1575 fn _test_take_string<'a, K>()
1576 where
1577 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1578 {
1579 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1580
1581 let array = K::from(vec![
1582 Some("one"),
1583 None,
1584 Some("three"),
1585 Some("four"),
1586 Some("five"),
1587 ]);
1588 let actual = take(&array, &index, None).unwrap();
1589 assert_eq!(actual.len(), index.len());
1590
1591 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1592
1593 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1594
1595 assert_eq!(actual, &expected);
1596 }
1597
1598 #[test]
1599 fn test_take_string() {
1600 _test_take_string::<StringArray>()
1601 }
1602
1603 #[test]
1604 fn test_take_large_string() {
1605 _test_take_string::<LargeStringArray>()
1606 }
1607
1608 #[test]
1609 fn test_take_slice_string() {
1610 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1611 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1612 let indices_slice = indices.slice(1, 4);
1613 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1614 let result = take(&strings, &indices_slice, None).unwrap();
1615 assert_eq!(result.as_ref(), &expected);
1616 }
1617
1618 fn _test_byte_view<T>()
1619 where
1620 T: ByteViewType,
1621 str: AsRef<T::Native>,
1622 T::Native: PartialEq,
1623 {
1624 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1625 let array = {
1626 let mut builder = GenericByteViewBuilder::<T>::new();
1628 builder.append_value("hello");
1629 builder.append_value("world");
1630 builder.append_null();
1631 builder.append_value("large payload over 12 bytes");
1632 builder.append_value("lulu");
1633 builder.finish()
1634 };
1635
1636 let actual = take(&array, &index, None).unwrap();
1637
1638 assert_eq!(actual.len(), index.len());
1639
1640 let expected = {
1641 let mut builder = GenericByteViewBuilder::<T>::new();
1643 builder.append_value("large payload over 12 bytes");
1644 builder.append_null();
1645 builder.append_value("world");
1646 builder.append_value("large payload over 12 bytes");
1647 builder.append_value("lulu");
1648 builder.append_null();
1649 builder.finish()
1650 };
1651
1652 assert_eq!(actual.as_ref(), &expected);
1653 }
1654
1655 #[test]
1656 fn test_take_string_view() {
1657 _test_byte_view::<StringViewType>()
1658 }
1659
1660 #[test]
1661 fn test_take_binary_view() {
1662 _test_byte_view::<BinaryViewType>()
1663 }
1664
1665 macro_rules! test_take_list {
1666 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1667 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1669 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1671 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1672 let list_data_type =
1674 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1675 let list_data = ArrayData::builder(list_data_type.clone())
1676 .len(4)
1677 .add_buffer(value_offsets)
1678 .add_child_data(value_data)
1679 .build()
1680 .unwrap();
1681 let list_array = $list_array_type::from(list_data);
1682
1683 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1685
1686 let a = take(&list_array, &index, None).unwrap();
1687 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1688
1689 let expected_data = Int32Array::from(vec![
1692 Some(2),
1693 Some(3),
1694 Some(-1),
1695 Some(-2),
1696 Some(-1),
1697 Some(0),
1698 Some(0),
1699 Some(0),
1700 ])
1701 .into_data();
1702 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1704 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1705 let expected_list_data = ArrayData::builder(list_data_type)
1707 .len(5)
1708 .nulls(index.nulls().cloned())
1710 .add_buffer(expected_offsets)
1711 .add_child_data(expected_data)
1712 .build()
1713 .unwrap();
1714 let expected_list_array = $list_array_type::from(expected_list_data);
1715
1716 assert_eq!(a, &expected_list_array);
1717 }};
1718 }
1719
1720 macro_rules! test_take_list_with_value_nulls {
1721 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1722 let value_data = Int32Array::from(vec![
1724 Some(0),
1725 None,
1726 Some(0),
1727 Some(-1),
1728 Some(-2),
1729 Some(3),
1730 None,
1731 Some(5),
1732 None,
1733 ])
1734 .into_data();
1735 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1737 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1738 let list_data_type =
1740 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1741 let list_data = ArrayData::builder(list_data_type.clone())
1742 .len(4)
1743 .add_buffer(value_offsets)
1744 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1745 .add_child_data(value_data)
1746 .build()
1747 .unwrap();
1748 let list_array = $list_array_type::from(list_data);
1749
1750 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1752
1753 let a = take(&list_array, &index, None).unwrap();
1754 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1755
1756 let expected_data = Int32Array::from(vec![
1759 None,
1760 Some(-1),
1761 Some(-2),
1762 Some(3),
1763 Some(5),
1764 None,
1765 Some(0),
1766 None,
1767 Some(0),
1768 ])
1769 .into_data();
1770 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1772 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1773 let expected_list_data = ArrayData::builder(list_data_type)
1775 .len(5)
1776 .nulls(index.nulls().cloned())
1778 .add_buffer(expected_offsets)
1779 .add_child_data(expected_data)
1780 .build()
1781 .unwrap();
1782 let expected_list_array = $list_array_type::from(expected_list_data);
1783
1784 assert_eq!(a, &expected_list_array);
1785 }};
1786 }
1787
1788 macro_rules! test_take_list_with_nulls {
1789 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1790 let value_data = Int32Array::from(vec![
1792 Some(0),
1793 None,
1794 Some(0),
1795 Some(-1),
1796 Some(-2),
1797 Some(3),
1798 Some(5),
1799 None,
1800 ])
1801 .into_data();
1802 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1804 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1805 let list_data_type =
1807 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1808 let list_data = ArrayData::builder(list_data_type.clone())
1809 .len(4)
1810 .add_buffer(value_offsets)
1811 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1812 .add_child_data(value_data)
1813 .build()
1814 .unwrap();
1815 let list_array = $list_array_type::from(list_data);
1816
1817 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1819
1820 let a = take(&list_array, &index, None).unwrap();
1821 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1822
1823 let expected_data = Int32Array::from(vec![
1826 Some(-1),
1827 Some(-2),
1828 Some(3),
1829 Some(5),
1830 None,
1831 Some(0),
1832 None,
1833 Some(0),
1834 ])
1835 .into_data();
1836 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1838 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1839 let mut null_bits: [u8; 1] = [0; 1];
1841 bit_util::set_bit(&mut null_bits, 2);
1842 bit_util::set_bit(&mut null_bits, 3);
1843 bit_util::set_bit(&mut null_bits, 4);
1844 let expected_list_data = ArrayData::builder(list_data_type)
1845 .len(5)
1846 .null_bit_buffer(Some(Buffer::from(null_bits)))
1848 .add_buffer(expected_offsets)
1849 .add_child_data(expected_data)
1850 .build()
1851 .unwrap();
1852 let expected_list_array = $list_array_type::from(expected_list_data);
1853
1854 assert_eq!(a, &expected_list_array);
1855 }};
1856 }
1857
1858 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1859 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1860 take_indices: Vec<Option<usize>>,
1861 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1862 mapper: F,
1863 ) where
1864 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1865 {
1866 let mut list_view_array =
1867 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1868
1869 for value in values {
1870 list_view_array.append_option(value);
1871 }
1872 let list_view_array = list_view_array.finish();
1873 let list_view_array = mapper(list_view_array);
1874
1875 let mut indices = UInt64Builder::new();
1876 for idx in take_indices {
1877 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
1878 }
1879 let indices = indices.finish();
1880
1881 let taken = take(&list_view_array, &indices, None)
1882 .unwrap()
1883 .as_list_view()
1884 .clone();
1885
1886 let mut expected_array =
1887 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1888 for value in expected {
1889 expected_array.append_option(value);
1890 }
1891 let expected_array = expected_array.finish();
1892
1893 assert_eq!(taken, expected_array);
1894 }
1895
1896 macro_rules! list_view_test_case {
1897 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
1898 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
1899 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
1900 }};
1901 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
1902 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
1903 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
1904 }};
1905 }
1906
1907 fn do_take_fixed_size_list_test<T>(
1908 length: <Int32Type as ArrowPrimitiveType>::Native,
1909 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1910 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1911 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1912 ) where
1913 T: ArrowPrimitiveType,
1914 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1915 {
1916 let indices = UInt32Array::from(indices);
1917
1918 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1919
1920 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1921
1922 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1923
1924 assert_eq!(&output, &expected)
1925 }
1926
1927 #[test]
1928 fn test_take_list() {
1929 test_take_list!(i32, List, ListArray);
1930 }
1931
1932 #[test]
1933 fn test_take_large_list() {
1934 test_take_list!(i64, LargeList, LargeListArray);
1935 }
1936
1937 #[test]
1938 fn test_take_list_with_value_nulls() {
1939 test_take_list_with_value_nulls!(i32, List, ListArray);
1940 }
1941
1942 #[test]
1943 fn test_take_large_list_with_value_nulls() {
1944 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1945 }
1946
1947 #[test]
1948 fn test_test_take_list_with_nulls() {
1949 test_take_list_with_nulls!(i32, List, ListArray);
1950 }
1951
1952 #[test]
1953 fn test_test_take_large_list_with_nulls() {
1954 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1955 }
1956
1957 #[test]
1958 fn test_test_take_list_view_reversed() {
1959 list_view_test_case! {
1961 values: vec![
1962 Some(vec![Some(1), None, Some(3)]),
1963 None,
1964 Some(vec![Some(7), Some(8), None]),
1965 ],
1966 indices: vec![Some(2), Some(1), Some(0)],
1967 expected: vec![
1968 Some(vec![Some(7), Some(8), None]),
1969 None,
1970 Some(vec![Some(1), None, Some(3)]),
1971 ]
1972 }
1973 }
1974
1975 #[test]
1976 fn test_take_list_view_null_indices() {
1977 list_view_test_case! {
1979 values: vec![
1980 Some(vec![Some(1), None, Some(3)]),
1981 None,
1982 Some(vec![Some(7), Some(8), None]),
1983 ],
1984 indices: vec![None, Some(0), None],
1985 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
1986 }
1987 }
1988
1989 #[test]
1990 fn test_take_list_view_null_values() {
1991 list_view_test_case! {
1993 values: vec![
1994 Some(vec![Some(1), None, Some(3)]),
1995 None,
1996 Some(vec![Some(7), Some(8), None]),
1997 ],
1998 indices: vec![Some(1), Some(1), Some(1), None, None],
1999 expected: vec![None; 5]
2000 }
2001 }
2002
2003 #[test]
2004 fn test_take_list_view_sliced() {
2005 list_view_test_case! {
2007 values: vec![
2008 Some(vec![Some(1)]),
2009 None,
2010 None,
2011 Some(vec![Some(2), Some(3)]),
2012 Some(vec![Some(4), Some(5)]),
2013 None,
2014 ],
2015 transform: |l| l.slice(2, 4),
2016 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2017 expected: vec![
2018 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2019 ]
2020 }
2021 }
2022
2023 #[test]
2024 fn test_take_fixed_size_list() {
2025 do_take_fixed_size_list_test::<Int32Type>(
2026 3,
2027 vec![
2028 Some(vec![None, Some(1), Some(2)]),
2029 Some(vec![Some(3), Some(4), None]),
2030 Some(vec![Some(6), Some(7), Some(8)]),
2031 ],
2032 vec![2, 1, 0],
2033 vec![
2034 Some(vec![Some(6), Some(7), Some(8)]),
2035 Some(vec![Some(3), Some(4), None]),
2036 Some(vec![None, Some(1), Some(2)]),
2037 ],
2038 );
2039
2040 do_take_fixed_size_list_test::<UInt8Type>(
2041 1,
2042 vec![
2043 Some(vec![Some(1)]),
2044 Some(vec![Some(2)]),
2045 Some(vec![Some(3)]),
2046 Some(vec![Some(4)]),
2047 Some(vec![Some(5)]),
2048 Some(vec![Some(6)]),
2049 Some(vec![Some(7)]),
2050 Some(vec![Some(8)]),
2051 ],
2052 vec![2, 7, 0],
2053 vec![
2054 Some(vec![Some(3)]),
2055 Some(vec![Some(8)]),
2056 Some(vec![Some(1)]),
2057 ],
2058 );
2059
2060 do_take_fixed_size_list_test::<UInt64Type>(
2061 3,
2062 vec![
2063 Some(vec![Some(10), Some(11), Some(12)]),
2064 Some(vec![Some(13), Some(14), Some(15)]),
2065 None,
2066 Some(vec![Some(16), Some(17), Some(18)]),
2067 ],
2068 vec![3, 2, 1, 2, 0],
2069 vec![
2070 Some(vec![Some(16), Some(17), Some(18)]),
2071 None,
2072 Some(vec![Some(13), Some(14), Some(15)]),
2073 None,
2074 Some(vec![Some(10), Some(11), Some(12)]),
2075 ],
2076 );
2077 }
2078
2079 #[test]
2080 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2081 fn test_take_list_out_of_bounds() {
2082 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2084 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2086 let list_data_type =
2088 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2089 let list_data = ArrayData::builder(list_data_type)
2090 .len(3)
2091 .add_buffer(value_offsets)
2092 .add_child_data(value_data)
2093 .build()
2094 .unwrap();
2095 let list_array = ListArray::from(list_data);
2096
2097 let index = UInt32Array::from(vec![1000]);
2098
2099 take(&list_array, &index, None).unwrap();
2102 }
2103
2104 #[test]
2105 fn test_take_map() {
2106 let values = Int32Array::from(vec![1, 2, 3, 4]);
2107 let array =
2108 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2109 .unwrap();
2110
2111 let index = UInt32Array::from(vec![0]);
2112
2113 let result = take(&array, &index, None).unwrap();
2114 let expected: ArrayRef = Arc::new(
2115 MapArray::new_from_strings(
2116 vec!["a", "b", "c"].into_iter(),
2117 &values.slice(0, 3),
2118 &[0, 3],
2119 )
2120 .unwrap(),
2121 );
2122 assert_eq!(&expected, &result);
2123 }
2124
2125 #[test]
2126 fn test_take_struct() {
2127 let array = create_test_struct(vec![
2128 Some((Some(true), Some(42))),
2129 Some((Some(false), Some(28))),
2130 Some((Some(false), Some(19))),
2131 Some((Some(true), Some(31))),
2132 None,
2133 ]);
2134
2135 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2136 let actual = take(&array, &index, None).unwrap();
2137 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2138 assert_eq!(index.len(), actual.len());
2139 assert_eq!(1, actual.null_count());
2140
2141 let expected = create_test_struct(vec![
2142 Some((Some(true), Some(42))),
2143 Some((Some(true), Some(31))),
2144 Some((Some(false), Some(28))),
2145 Some((Some(true), Some(42))),
2146 Some((Some(false), Some(19))),
2147 None,
2148 ]);
2149
2150 assert_eq!(&expected, actual);
2151
2152 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2153 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2154 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2155 let actual = take(&empty_struct_arr, &index, None).unwrap();
2156
2157 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2158 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2159 assert_eq!(&expected_struct_arr, actual.as_struct());
2160 }
2161
2162 #[test]
2163 fn test_take_struct_with_null_indices() {
2164 let array = create_test_struct(vec![
2165 Some((Some(true), Some(42))),
2166 Some((Some(false), Some(28))),
2167 Some((Some(false), Some(19))),
2168 Some((Some(true), Some(31))),
2169 None,
2170 ]);
2171
2172 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2173 let actual = take(&array, &index, None).unwrap();
2174 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2175 assert_eq!(index.len(), actual.len());
2176 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2179 None,
2180 Some((Some(true), Some(31))),
2181 Some((Some(false), Some(28))),
2182 None,
2183 Some((Some(true), Some(42))),
2184 None,
2185 ]);
2186
2187 assert_eq!(&expected, actual);
2188 }
2189
2190 #[test]
2191 fn test_take_out_of_bounds() {
2192 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2193 let take_opt = TakeOptions { check_bounds: true };
2194
2195 let result = test_take_primitive_arrays::<Int64Type>(
2197 vec![Some(0), None, Some(2), Some(3), None],
2198 &index,
2199 Some(take_opt),
2200 vec![None],
2201 );
2202 assert!(result.is_err());
2203 }
2204
2205 #[test]
2206 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2207 fn test_take_out_of_bounds_panic() {
2208 let index = UInt32Array::from(vec![Some(1000)]);
2209
2210 test_take_primitive_arrays::<Int64Type>(
2211 vec![Some(0), Some(1), Some(2), Some(3)],
2212 &index,
2213 None,
2214 vec![None],
2215 )
2216 .unwrap();
2217 }
2218
2219 #[test]
2220 fn test_null_array_smaller_than_indices() {
2221 let values = NullArray::new(2);
2222 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2223
2224 let result = take(&values, &indices, None).unwrap();
2225 let expected: ArrayRef = Arc::new(NullArray::new(3));
2226 assert_eq!(&result, &expected);
2227 }
2228
2229 #[test]
2230 fn test_null_array_larger_than_indices() {
2231 let values = NullArray::new(5);
2232 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2233
2234 let result = take(&values, &indices, None).unwrap();
2235 let expected: ArrayRef = Arc::new(NullArray::new(3));
2236 assert_eq!(&result, &expected);
2237 }
2238
2239 #[test]
2240 fn test_null_array_indices_out_of_bounds() {
2241 let values = NullArray::new(5);
2242 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2243
2244 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2245 assert_eq!(
2246 result.unwrap_err().to_string(),
2247 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2248 );
2249 }
2250
2251 #[test]
2252 fn test_take_dict() {
2253 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2254
2255 dict_builder.append("foo").unwrap();
2256 dict_builder.append("bar").unwrap();
2257 dict_builder.append("").unwrap();
2258 dict_builder.append_null();
2259 dict_builder.append("foo").unwrap();
2260 dict_builder.append("bar").unwrap();
2261 dict_builder.append("bar").unwrap();
2262 dict_builder.append("foo").unwrap();
2263
2264 let array = dict_builder.finish();
2265 let dict_values = array.values().clone();
2266 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2267
2268 let indices = UInt32Array::from(vec![
2269 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2277
2278 let result = take(&array, &indices, None).unwrap();
2279 let result = result
2280 .as_any()
2281 .downcast_ref::<DictionaryArray<Int16Type>>()
2282 .unwrap();
2283
2284 let result_values: StringArray = result.values().to_data().into();
2285
2286 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2288 assert_eq!(&expected_values, dict_values);
2289 assert_eq!(&expected_values, &result_values);
2290
2291 let expected_keys = Int16Array::from(vec![
2292 Some(0),
2293 Some(0),
2294 None,
2295 Some(1),
2296 Some(1),
2297 Some(2),
2298 None,
2299 ]);
2300 assert_eq!(result.keys(), &expected_keys);
2301 }
2302
2303 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2304 where
2305 S: OffsetSizeTrait + 'static,
2306 T: ArrowPrimitiveType,
2307 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2308 {
2309 GenericListArray::from_iter_primitive::<T, _, _>(
2310 data.iter()
2311 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2312 )
2313 }
2314
2315 #[test]
2316 fn test_take_value_index_from_list() {
2317 let list = build_generic_list::<i32, Int32Type>(vec![
2318 Some(vec![0, 1]),
2319 Some(vec![2, 3, 4]),
2320 Some(vec![5, 6, 7, 8, 9]),
2321 ]);
2322 let indices = UInt32Array::from(vec![2, 0]);
2323
2324 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2325
2326 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2327 assert_eq!(offsets, vec![0, 5, 7]);
2328 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2329 }
2330
2331 #[test]
2332 fn test_take_value_index_from_large_list() {
2333 let list = build_generic_list::<i64, Int32Type>(vec![
2334 Some(vec![0, 1]),
2335 Some(vec![2, 3, 4]),
2336 Some(vec![5, 6, 7, 8, 9]),
2337 ]);
2338 let indices = UInt32Array::from(vec![2, 0]);
2339
2340 let (indexed, offsets, null_buf) =
2341 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2342
2343 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2344 assert_eq!(offsets, vec![0, 5, 7]);
2345 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2346 }
2347
2348 #[test]
2349 fn test_take_runs() {
2350 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2351
2352 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2353 builder.extend(logical_array.into_iter().map(Some));
2354 let run_array = builder.finish();
2355
2356 let take_indices: PrimitiveArray<Int32Type> =
2357 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2358
2359 let take_out = take_run(&run_array, &take_indices).unwrap();
2360
2361 assert_eq!(take_out.len(), 7);
2362 assert_eq!(take_out.run_ends().len(), 7);
2363 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2364
2365 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2366 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2367 }
2368
2369 #[test]
2370 fn test_take_value_index_from_fixed_list() {
2371 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2372 vec![
2373 Some(vec![Some(1), Some(2), None]),
2374 Some(vec![Some(4), None, Some(6)]),
2375 None,
2376 Some(vec![None, Some(8), Some(9)]),
2377 ],
2378 3,
2379 );
2380
2381 let indices = UInt32Array::from(vec![2, 1, 0]);
2382 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2383
2384 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2385
2386 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2387 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2388
2389 assert_eq!(
2390 indexed,
2391 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2392 );
2393 }
2394
2395 #[test]
2396 fn test_take_null_indices() {
2397 let indices = Int32Array::new(
2399 vec![1, 2, 400, 400].into(),
2400 Some(NullBuffer::from(vec![true, true, false, false])),
2401 );
2402 let values = Int32Array::from(vec![1, 23, 4, 5]);
2403 let r = take(&values, &indices, None).unwrap();
2404 let values = r
2405 .as_primitive::<Int32Type>()
2406 .into_iter()
2407 .collect::<Vec<_>>();
2408 assert_eq!(&values, &[Some(23), Some(4), None, None])
2409 }
2410
2411 #[test]
2412 fn test_take_fixed_size_list_null_indices() {
2413 let indices = Int32Array::from_iter([Some(0), None]);
2414 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2415 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2416 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2417
2418 let r = take(&values, &indices, None).unwrap();
2419 let values = r
2420 .as_fixed_size_list()
2421 .values()
2422 .as_primitive::<Int32Type>()
2423 .into_iter()
2424 .collect::<Vec<_>>();
2425 assert_eq!(values, &[Some(0), Some(1), None, None])
2426 }
2427
2428 #[test]
2429 fn test_take_bytes_null_indices() {
2430 let indices = Int32Array::new(
2431 vec![0, 1, 400, 400].into(),
2432 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2433 );
2434 let values = StringArray::from(vec![Some("foo"), None]);
2435 let r = take(&values, &indices, None).unwrap();
2436 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2437 assert_eq!(&values, &[Some("foo"), None, None, None])
2438 }
2439
2440 #[test]
2441 fn test_take_union_sparse() {
2442 let structs = create_test_struct(vec![
2443 Some((Some(true), Some(42))),
2444 Some((Some(false), Some(28))),
2445 Some((Some(false), Some(19))),
2446 Some((Some(true), Some(31))),
2447 None,
2448 ]);
2449 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2450 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2451
2452 let union_fields = [
2453 (
2454 0,
2455 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2456 ),
2457 (
2458 1,
2459 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2460 ),
2461 ]
2462 .into_iter()
2463 .collect();
2464 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2465 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2466
2467 let indices = vec![0, 3, 1, 0, 2, 4];
2468 let index = UInt32Array::from(indices.clone());
2469 let actual = take(&array, &index, None).unwrap();
2470 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2471 let strings = actual.child(1);
2472 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2473
2474 let actual = strings.iter().collect::<Vec<_>>();
2475 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2476 assert_eq!(expected, actual);
2477 }
2478
2479 #[test]
2480 fn test_take_union_dense() {
2481 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2482 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2483 let ints = vec![10, 20, 30, 40];
2484 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2485
2486 let indices = vec![0, 3, 1, 0, 2, 4];
2487
2488 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2489 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2490 let taken_ints = vec![10, 20, 10, 30];
2491 let taken_strings = vec![Some("a"), None];
2492
2493 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2494 let offsets = <ScalarBuffer<i32>>::from(offsets);
2495 let ints = UInt32Array::from(ints);
2496 let strings = StringArray::from(strings);
2497
2498 let union_fields = [
2499 (
2500 0,
2501 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2502 ),
2503 (
2504 1,
2505 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2506 ),
2507 ]
2508 .into_iter()
2509 .collect();
2510
2511 let array = UnionArray::try_new(
2512 union_fields,
2513 type_ids,
2514 Some(offsets),
2515 vec![Arc::new(ints), Arc::new(strings)],
2516 )
2517 .unwrap();
2518
2519 let index = UInt32Array::from(indices);
2520
2521 let actual = take(&array, &index, None).unwrap();
2522 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2523
2524 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2525 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2526 assert_eq!(
2527 UInt32Array::from(actual.child(0).to_data()),
2528 UInt32Array::from(taken_ints)
2529 );
2530 assert_eq!(
2531 StringArray::from(actual.child(1).to_data()),
2532 StringArray::from(taken_strings)
2533 );
2534 }
2535
2536 #[test]
2537 fn test_take_union_dense_using_builder() {
2538 let mut builder = UnionBuilder::new_dense();
2539
2540 builder.append::<Int32Type>("a", 1).unwrap();
2541 builder.append::<Float64Type>("b", 3.0).unwrap();
2542 builder.append::<Int32Type>("a", 4).unwrap();
2543 builder.append::<Int32Type>("a", 5).unwrap();
2544 builder.append::<Float64Type>("b", 2.0).unwrap();
2545
2546 let union = builder.build().unwrap();
2547
2548 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2549
2550 let mut builder = UnionBuilder::new_dense();
2551
2552 builder.append::<Int32Type>("a", 4).unwrap();
2553 builder.append::<Int32Type>("a", 1).unwrap();
2554 builder.append::<Float64Type>("b", 3.0).unwrap();
2555 builder.append::<Int32Type>("a", 4).unwrap();
2556
2557 let taken = builder.build().unwrap();
2558
2559 assert_eq!(
2560 taken.to_data(),
2561 take(&union, &indices, None).unwrap().to_data()
2562 );
2563 }
2564
2565 #[test]
2566 fn test_take_union_dense_all_match_issue_6206() {
2567 let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2568 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2569
2570 let array = UnionArray::try_new(
2571 fields,
2572 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2573 Some(ScalarBuffer::from_iter(0_i32..5)),
2574 vec![ints],
2575 )
2576 .unwrap();
2577
2578 let indicies = Int64Array::from(vec![0, 2, 4]);
2579 let array = take(&array, &indicies, None).unwrap();
2580 assert_eq!(array.len(), 3);
2581 }
2582
2583 #[test]
2584 fn test_take_bytes_offset_overflow() {
2585 let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2586 let text = ('a'..='z').collect::<String>();
2587 let values = StringArray::from(vec![Some(text.clone())]);
2588 assert!(matches!(
2589 take(&values, &indices, None),
2590 Err(ArrowError::OffsetOverflowError(_))
2591 ));
2592 }
2593}