1use std::fmt::Display;
21use std::mem::ManuallyDrop;
22use std::sync::Arc;
23
24use arrow_array::builder::{BufferBuilder, UInt32Builder};
25use arrow_array::cast::AsArray;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::{
29 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
30 bit_util,
31};
32use arrow_data::ArrayDataBuilder;
33use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
34
35use num_traits::{One, Zero};
36
37pub fn take(
89 values: &dyn Array,
90 indices: &dyn Array,
91 options: Option<TakeOptions>,
92) -> Result<ArrayRef, ArrowError> {
93 let options = options.unwrap_or_default();
94 downcast_integer_array!(
95 indices => {
96 if options.check_bounds {
97 check_bounds(values.len(), indices)?;
98 }
99 let indices = indices.to_indices();
100 take_impl(values, &indices)
101 },
102 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
103 )
104}
105
106pub fn take_arrays(
155 arrays: &[ArrayRef],
156 indices: &dyn Array,
157 options: Option<TakeOptions>,
158) -> Result<Vec<ArrayRef>, ArrowError> {
159 arrays
160 .iter()
161 .map(|array| take(array.as_ref(), indices, options.clone()))
162 .collect()
163}
164
165fn check_bounds<T: ArrowPrimitiveType>(
167 len: usize,
168 indices: &PrimitiveArray<T>,
169) -> Result<(), ArrowError>
170where
171 T::Native: Display,
172{
173 let len = match T::Native::from_usize(len) {
174 Some(len) => len,
175 None => {
176 if T::DATA_TYPE.is_integer() {
177 return Ok(());
179 } else {
180 return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
181 }
182 }
183 };
184
185 if indices.null_count() > 0 {
186 indices.iter().flatten().try_for_each(|index| {
187 if index >= len {
188 return Err(ArrowError::ComputeError(format!(
189 "Array index out of bounds, cannot get item at index {index} from {len} entries"
190 )));
191 }
192 Ok(())
193 })
194 } else {
195 let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
196 in_bounds & (i >= T::Native::ZERO) & (i < len)
197 });
198
199 if !in_bounds {
200 for &index in indices.values() {
201 if index < T::Native::ZERO || index >= len {
202 return Err(ArrowError::ComputeError(format!(
203 "Array index out of bounds, cannot get item at index {index} from {len} entries"
204 )));
205 }
206 }
207 }
208
209 Ok(())
210 }
211}
212
213#[inline(never)]
214fn take_impl<IndexType: ArrowPrimitiveType>(
215 values: &dyn Array,
216 indices: &PrimitiveArray<IndexType>,
217) -> Result<ArrayRef, ArrowError> {
218 if indices.is_empty() {
219 return Ok(new_empty_array(values.data_type()));
220 }
221 downcast_primitive_array! {
222 values => Ok(Arc::new(take_primitive(values, indices)?)),
223 DataType::Boolean => {
224 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
225 Ok(Arc::new(take_boolean(values, indices)))
226 }
227 DataType::Utf8 => {
228 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
229 }
230 DataType::LargeUtf8 => {
231 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
232 }
233 DataType::Utf8View => {
234 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
235 }
236 DataType::List(_) => {
237 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
238 }
239 DataType::LargeList(_) => {
240 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
241 }
242 DataType::ListView(_) => {
243 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
244 }
245 DataType::LargeListView(_) => {
246 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
247 }
248 DataType::FixedSizeList(_, length) => {
249 let values = values
250 .as_any()
251 .downcast_ref::<FixedSizeListArray>()
252 .unwrap();
253 Ok(Arc::new(take_fixed_size_list(
254 values,
255 indices,
256 *length as u32,
257 )?))
258 }
259 DataType::Map(_, _) => {
260 let list_arr = ListArray::from(values.as_map().clone());
261 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
262 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
263 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
264 }
265 DataType::Struct(fields) => {
266 let array: &StructArray = values.as_struct();
267 let arrays = array
268 .columns()
269 .iter()
270 .map(|a| take_impl(a.as_ref(), indices))
271 .collect::<Result<Vec<ArrayRef>, _>>()?;
272 let fields: Vec<(FieldRef, ArrayRef)> =
273 fields.iter().cloned().zip(arrays).collect();
274
275 let is_valid: Buffer = indices
277 .iter()
278 .map(|index| {
279 if let Some(index) = index {
280 array.is_valid(index.to_usize().unwrap())
281 } else {
282 false
283 }
284 })
285 .collect();
286
287 if fields.is_empty() {
288 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
289 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
290 } else {
291 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
292 }
293 }
294 DataType::Dictionary(_, _) => downcast_dictionary_array! {
295 values => Ok(Arc::new(take_dict(values, indices)?)),
296 t => unimplemented!("Take not supported for dictionary type {:?}", t)
297 }
298 DataType::RunEndEncoded(_, _) => downcast_run_array! {
299 values => Ok(Arc::new(take_run(values, indices)?)),
300 t => unimplemented!("Take not supported for run type {:?}", t)
301 }
302 DataType::Binary => {
303 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
304 }
305 DataType::LargeBinary => {
306 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
307 }
308 DataType::BinaryView => {
309 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
310 }
311 DataType::FixedSizeBinary(size) => {
312 let values = values
313 .as_any()
314 .downcast_ref::<FixedSizeBinaryArray>()
315 .unwrap();
316 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
317 }
318 DataType::Null => {
319 if values.len() >= indices.len() {
321 Ok(values.slice(0, indices.len()))
324 } else {
325 Ok(new_null_array(&DataType::Null, indices.len()))
327 }
328 }
329 DataType::Union(fields, UnionMode::Sparse) => {
330 let mut children = Vec::with_capacity(fields.len());
331 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
332 let type_ids = take_native(values.type_ids(), indices);
333 for (type_id, _field) in fields.iter() {
334 let values = values.child(type_id);
335 let values = take_impl(values, indices)?;
336 children.push(values);
337 }
338 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
339 Ok(Arc::new(array))
340 }
341 DataType::Union(fields, UnionMode::Dense) => {
342 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
343
344 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
345 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
346
347 let children = fields.iter()
348 .map(|(field_type_id, _)| {
349 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
350
351 let indices = crate::filter::filter(&offsets, &mask)?;
352
353 let values = values.child(field_type_id);
354
355 take_impl(values, indices.as_primitive::<Int32Type>())
356 })
357 .collect::<Result<_, _>>()?;
358
359 let mut child_offsets = [0; 128];
360
361 let offsets = type_ids.values()
362 .iter()
363 .map(|&i| {
364 let offset = child_offsets[i as usize];
365
366 child_offsets[i as usize] += 1;
367
368 offset
369 })
370 .collect();
371
372 let (_, type_ids, _) = type_ids.into_parts();
373
374 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
375
376 Ok(Arc::new(array))
377 }
378 t => unimplemented!("Take not supported for data type {:?}", t)
379 }
380}
381
382#[derive(Clone, Debug, Default)]
384pub struct TakeOptions {
385 pub check_bounds: bool,
389}
390
391fn take_primitive<T, I>(
401 values: &PrimitiveArray<T>,
402 indices: &PrimitiveArray<I>,
403) -> Result<PrimitiveArray<T>, ArrowError>
404where
405 T: ArrowPrimitiveType,
406 I: ArrowPrimitiveType,
407{
408 let values_buf = take_native(values.values(), indices);
409 let nulls = take_nulls(values.nulls(), indices);
410 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
411}
412
413#[inline(never)]
414fn take_nulls<I: ArrowPrimitiveType>(
415 values: Option<&NullBuffer>,
416 indices: &PrimitiveArray<I>,
417) -> Option<NullBuffer> {
418 match values.filter(|n| n.null_count() > 0) {
419 Some(n) => NullBuffer::from_unsliced_buffer(
420 take_bits(n.inner(), indices).into_inner(),
421 indices.len(),
422 ),
423 None => indices.nulls().cloned(),
424 }
425}
426
427#[inline(never)]
428fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
429 values: &[T],
430 indices: &PrimitiveArray<I>,
431) -> ScalarBuffer<T> {
432 match indices.nulls().filter(|n| n.null_count() > 0) {
433 Some(n) => indices
434 .values()
435 .iter()
436 .enumerate()
437 .map(|(idx, index)| match values.get(index.as_usize()) {
438 Some(v) => *v,
439 None => match unsafe { n.inner().value_unchecked(idx) } {
441 false => T::default(),
442 true => panic!("Out-of-bounds index {index:?}"),
443 },
444 })
445 .collect(),
446 None => indices
447 .values()
448 .iter()
449 .map(|index| values[index.as_usize()])
450 .collect(),
451 }
452}
453
454#[inline(never)]
455fn take_bits<I: ArrowPrimitiveType>(
456 values: &BooleanBuffer,
457 indices: &PrimitiveArray<I>,
458) -> BooleanBuffer {
459 let len = indices.len();
460
461 match indices.nulls().filter(|n| n.null_count() > 0) {
462 Some(nulls) => {
463 let mut output_buffer = MutableBuffer::new_null(len);
464 let output_slice = output_buffer.as_slice_mut();
465 nulls.valid_indices().for_each(|idx| {
466 if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
468 unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
470 }
471 });
472 BooleanBuffer::new(output_buffer.into(), 0, len)
473 }
474 None => {
475 BooleanBuffer::collect_bool(len, |idx: usize| {
476 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
478 })
479 }
480 }
481}
482
483fn take_boolean<IndexType: ArrowPrimitiveType>(
485 values: &BooleanArray,
486 indices: &PrimitiveArray<IndexType>,
487) -> BooleanArray {
488 let val_buf = take_bits(values.values(), indices);
489 let null_buf = take_nulls(values.nulls(), indices);
490 BooleanArray::new(val_buf, null_buf)
491}
492
493fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
495 array: &GenericByteArray<T>,
496 indices: &PrimitiveArray<IndexType>,
497) -> Result<GenericByteArray<T>, ArrowError> {
498 let mut offsets = Vec::with_capacity(indices.len() + 1);
499 offsets.push(T::Offset::default());
500
501 let input_offsets = array.value_offsets();
502 let mut capacity = 0;
503 let nulls = take_nulls(array.nulls(), indices);
504
505 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
506 offsets.reserve(indices.len());
507 for index in indices.values() {
508 let index = index.as_usize();
509 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
510 offsets.push(
511 T::Offset::from_usize(capacity)
512 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
513 );
514 }
515 let mut values = Vec::with_capacity(capacity);
516
517 for index in indices.values() {
518 values.extend_from_slice(array.value(index.as_usize()).as_ref());
519 }
520 (offsets, values)
521 } else if indices.null_count() == 0 {
522 offsets.reserve(indices.len());
523 for index in indices.values() {
524 let index = index.as_usize();
525 if array.is_valid(index) {
526 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
527 }
528 offsets.push(
529 T::Offset::from_usize(capacity)
530 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
531 );
532 }
533 let mut values = Vec::with_capacity(capacity);
534
535 for index in indices.values() {
536 let index = index.as_usize();
537 if array.is_valid(index) {
538 values.extend_from_slice(array.value(index).as_ref());
539 }
540 }
541 (offsets, values)
542 } else if array.null_count() == 0 {
543 offsets.reserve(indices.len());
544 for (i, index) in indices.values().iter().enumerate() {
545 let index = index.as_usize();
546 if indices.is_valid(i) {
547 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
548 }
549 offsets.push(
550 T::Offset::from_usize(capacity)
551 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
552 );
553 }
554 let mut values = Vec::with_capacity(capacity);
555
556 for (i, index) in indices.values().iter().enumerate() {
557 if indices.is_valid(i) {
558 values.extend_from_slice(array.value(index.as_usize()).as_ref());
559 }
560 }
561 (offsets, values)
562 } else {
563 let nulls = nulls.as_ref().unwrap();
564 offsets.reserve(indices.len());
565 for (i, index) in indices.values().iter().enumerate() {
566 let index = index.as_usize();
567 if nulls.is_valid(i) {
568 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
569 }
570 offsets.push(
571 T::Offset::from_usize(capacity)
572 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
573 );
574 }
575 let mut values = Vec::with_capacity(capacity);
576
577 for (i, index) in indices.values().iter().enumerate() {
578 let index = index.as_usize();
581 if nulls.is_valid(i) {
582 values.extend_from_slice(array.value(index).as_ref());
583 }
584 }
585 (offsets, values)
586 };
587
588 T::Offset::from_usize(values.len())
589 .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
590
591 let array = unsafe {
592 let offsets = OffsetBuffer::new_unchecked(offsets.into());
593 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
594 };
595
596 Ok(array)
597}
598
599fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
601 array: &GenericByteViewArray<T>,
602 indices: &PrimitiveArray<IndexType>,
603) -> Result<GenericByteViewArray<T>, ArrowError> {
604 let new_views = take_native(array.views(), indices);
605 let new_nulls = take_nulls(array.nulls(), indices);
606 Ok(unsafe {
608 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
609 })
610}
611
612fn take_list<IndexType, OffsetType>(
618 values: &GenericListArray<OffsetType::Native>,
619 indices: &PrimitiveArray<IndexType>,
620) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
621where
622 IndexType: ArrowPrimitiveType,
623 OffsetType: ArrowPrimitiveType,
624 OffsetType::Native: OffsetSizeTrait,
625 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
626{
627 let (list_indices, offsets, null_buf) =
630 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
631
632 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
633 let value_offsets = Buffer::from_vec(offsets);
634 let list_data = ArrayDataBuilder::new(values.data_type().clone())
636 .len(indices.len())
637 .null_bit_buffer(Some(null_buf.into()))
638 .offset(0)
639 .add_child_data(taken.into_data())
640 .add_buffer(value_offsets);
641
642 let list_data = unsafe { list_data.build_unchecked() };
643
644 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
645}
646
647fn take_list_view<IndexType, OffsetType>(
648 values: &GenericListViewArray<OffsetType::Native>,
649 indices: &PrimitiveArray<IndexType>,
650) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
651where
652 IndexType: ArrowPrimitiveType,
653 OffsetType: ArrowPrimitiveType,
654 OffsetType::Native: OffsetSizeTrait,
655{
656 let taken_offsets = take_native(values.offsets(), indices);
657 let taken_sizes = take_native(values.sizes(), indices);
658 let nulls = take_nulls(values.nulls(), indices);
659
660 let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
661 .len(indices.len())
662 .nulls(nulls)
663 .buffers(vec![taken_offsets.into(), taken_sizes.into()])
664 .child_data(vec![values.values().to_data()]);
665
666 let list_view_data = unsafe { list_view_data.build_unchecked() };
668
669 Ok(GenericListViewArray::<OffsetType::Native>::from(
670 list_view_data,
671 ))
672}
673
674fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
680 values: &FixedSizeListArray,
681 indices: &PrimitiveArray<IndexType>,
682 length: <UInt32Type as ArrowPrimitiveType>::Native,
683) -> Result<FixedSizeListArray, ArrowError> {
684 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
685 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
686
687 let num_bytes = bit_util::ceil(indices.len(), 8);
689 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
690 let null_slice = null_buf.as_slice_mut();
691
692 for i in 0..indices.len() {
693 let index = indices
694 .value(i)
695 .to_usize()
696 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
697 if !indices.is_valid(i) || values.is_null(index) {
698 bit_util::unset_bit(null_slice, i);
699 }
700 }
701
702 let list_data = ArrayDataBuilder::new(values.data_type().clone())
703 .len(indices.len())
704 .null_bit_buffer(Some(null_buf.into()))
705 .offset(0)
706 .add_child_data(taken.into_data());
707
708 let list_data = unsafe { list_data.build_unchecked() };
709
710 Ok(FixedSizeListArray::from(list_data))
711}
712
713fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
719 values: &FixedSizeBinaryArray,
720 indices: &PrimitiveArray<IndexType>,
721 size: i32,
722) -> Result<FixedSizeBinaryArray, ArrowError> {
723 let size_usize = usize::try_from(size).map_err(|_| {
724 ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
725 })?;
726
727 let result_buffer = match size_usize {
728 1 => take_fixed_size::<IndexType, 1>(values.values(), indices),
729 2 => take_fixed_size::<IndexType, 2>(values.values(), indices),
730 4 => take_fixed_size::<IndexType, 4>(values.values(), indices),
731 8 => take_fixed_size::<IndexType, 8>(values.values(), indices),
732 16 => take_fixed_size::<IndexType, 16>(values.values(), indices),
733 _ => take_fixed_size_binary_buffer_dynamic_length(values, indices, size_usize),
734 };
735
736 let value_nulls = take_nulls(values.nulls(), indices);
737 let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
738 let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
739 .len(indices.len())
740 .nulls(final_nulls)
741 .offset(0)
742 .add_buffer(result_buffer)
743 .build()?;
744
745 return Ok(FixedSizeBinaryArray::from(array_data));
746
747 #[inline(never)]
749 fn take_fixed_size_binary_buffer_dynamic_length<IndexType: ArrowPrimitiveType>(
750 values: &FixedSizeBinaryArray,
751 indices: &PrimitiveArray<IndexType>,
752 size_usize: usize,
753 ) -> Buffer {
754 let values_buffer = values.values().as_slice();
755 let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
756
757 if indices.null_count() == 0 {
758 let array_iter = indices.values().iter().map(|idx| {
759 let offset = idx.as_usize() * size_usize;
760 &values_buffer[offset..offset + size_usize]
761 });
762 for slice in array_iter {
763 values_buffer_builder.append_slice(slice);
764 }
765 } else {
766 let array_iter = indices.iter().map(|idx| {
769 idx.map(|idx| {
770 let offset = idx.as_usize() * size_usize;
771 &values_buffer[offset..offset + size_usize]
772 })
773 });
774 for slice in array_iter {
775 match slice {
776 None => values_buffer_builder.append_n(size_usize, 0),
777 Some(slice) => values_buffer_builder.append_slice(slice),
778 }
779 }
780 }
781
782 values_buffer_builder.finish()
783 }
784}
785
786fn take_fixed_size<IndexType: ArrowPrimitiveType, const N: usize>(
799 buffer: &Buffer,
800 indices: &PrimitiveArray<IndexType>,
801) -> Buffer {
802 assert_eq!(
803 buffer.len() % N,
804 0,
805 "Invalid array length in take_fixed_size"
806 );
807
808 let ptr = buffer.as_ptr();
809 let chunk_ptr = ptr.cast::<[u8; N]>();
810 let chunk_len = buffer.len() / N;
811 let buffer: &[[u8; N]] = unsafe {
812 std::slice::from_raw_parts(chunk_ptr, chunk_len)
815 };
816
817 let result_buffer = match indices.nulls().filter(|n| n.null_count() > 0) {
818 Some(n) => indices
819 .values()
820 .iter()
821 .enumerate()
822 .map(|(idx, index)| match buffer.get(index.as_usize()) {
823 Some(v) => *v,
824 None => match unsafe { n.inner().value_unchecked(idx) } {
826 false => [0u8; N],
827 true => panic!("Out-of-bounds index {index:?}"),
828 },
829 })
830 .collect::<Vec<_>>(),
831 None => indices
832 .values()
833 .iter()
834 .map(|index| buffer[index.as_usize()])
835 .collect::<Vec<_>>(),
836 };
837
838 let mut vec = ManuallyDrop::new(result_buffer); let ptr = vec.as_mut_ptr();
840 let len = vec.len();
841 let cap = vec.capacity();
842 let result_buffer = unsafe {
843 Vec::from_raw_parts(ptr.cast::<u8>(), len * N, cap * N)
845 };
846
847 Buffer::from_vec(result_buffer)
848}
849
850fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
855 values: &DictionaryArray<T>,
856 indices: &PrimitiveArray<I>,
857) -> Result<DictionaryArray<T>, ArrowError> {
858 let new_keys = take_primitive(values.keys(), indices)?;
859 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
860}
861
862fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
871 run_array: &RunArray<T>,
872 logical_indices: &PrimitiveArray<I>,
873) -> Result<RunArray<T>, ArrowError> {
874 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
876
877 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
881 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
882 let mut new_physical_len = 1;
883 for ix in 1..physical_indices.len() {
884 if physical_indices[ix] != physical_indices[ix - 1] {
885 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
886 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
887 new_physical_len += 1;
888 }
889 }
890 take_value_indices
891 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
892 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
893 let new_run_ends = unsafe {
894 ArrayDataBuilder::new(T::DATA_TYPE)
897 .len(new_physical_len)
898 .null_count(0)
899 .add_buffer(new_run_ends_builder.finish())
900 .build_unchecked()
901 };
902
903 let take_value_indices: PrimitiveArray<I> = unsafe {
904 ArrayDataBuilder::new(I::DATA_TYPE)
907 .len(new_physical_len)
908 .null_count(0)
909 .add_buffer(take_value_indices.finish())
910 .build_unchecked()
911 .into()
912 };
913
914 let new_values = take(run_array.values(), &take_value_indices, None)?;
915
916 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
917 .len(physical_indices.len())
918 .add_child_data(new_run_ends)
919 .add_child_data(new_values.into_data());
920 let array_data = unsafe {
921 builder.build_unchecked()
924 };
925 Ok(array_data.into())
926}
927
928#[allow(clippy::type_complexity)]
934fn take_value_indices_from_list<IndexType, OffsetType>(
935 list: &GenericListArray<OffsetType::Native>,
936 indices: &PrimitiveArray<IndexType>,
937) -> Result<
938 (
939 PrimitiveArray<OffsetType>,
940 Vec<OffsetType::Native>,
941 MutableBuffer,
942 ),
943 ArrowError,
944>
945where
946 IndexType: ArrowPrimitiveType,
947 OffsetType: ArrowPrimitiveType,
948 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
949 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
950{
951 let offsets: &[OffsetType::Native] = list.value_offsets();
953
954 let mut new_offsets = Vec::with_capacity(indices.len());
955 let mut values = Vec::new();
956 let mut current_offset = OffsetType::Native::zero();
957 new_offsets.push(OffsetType::Native::zero());
959
960 let num_bytes = bit_util::ceil(indices.len(), 8);
962 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
963 let null_slice = null_buf.as_slice_mut();
964
965 for i in 0..indices.len() {
967 if indices.is_valid(i) {
968 let ix = indices
969 .value(i)
970 .to_usize()
971 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
972 let start = offsets[ix];
973 let end = offsets[ix + 1];
974 current_offset += end - start;
975 new_offsets.push(current_offset);
976
977 let mut curr = start;
978
979 while curr < end {
981 values.push(curr);
982 curr += One::one();
983 }
984 if !list.is_valid(ix) {
985 bit_util::unset_bit(null_slice, i);
986 }
987 } else {
988 bit_util::unset_bit(null_slice, i);
989 new_offsets.push(current_offset);
990 }
991 }
992
993 Ok((
994 PrimitiveArray::<OffsetType>::from(values),
995 new_offsets,
996 null_buf,
997 ))
998}
999
1000fn take_value_indices_from_fixed_size_list<IndexType>(
1002 list: &FixedSizeListArray,
1003 indices: &PrimitiveArray<IndexType>,
1004 length: <UInt32Type as ArrowPrimitiveType>::Native,
1005) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
1006where
1007 IndexType: ArrowPrimitiveType,
1008{
1009 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
1010
1011 for i in 0..indices.len() {
1012 if indices.is_valid(i) {
1013 let index = indices
1014 .value(i)
1015 .to_usize()
1016 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
1017 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
1018
1019 unsafe {
1021 values.append_trusted_len_iter(start..start + length);
1022 }
1023 } else {
1024 values.append_nulls(length as usize);
1025 }
1026 }
1027
1028 Ok(values.finish())
1029}
1030
1031trait ToIndices {
1034 type T: ArrowPrimitiveType;
1035
1036 fn to_indices(&self) -> PrimitiveArray<Self::T>;
1037}
1038
1039macro_rules! to_indices_reinterpret {
1040 ($t:ty, $o:ty) => {
1041 impl ToIndices for PrimitiveArray<$t> {
1042 type T = $o;
1043
1044 fn to_indices(&self) -> PrimitiveArray<$o> {
1045 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
1046 PrimitiveArray::new(cast, self.nulls().cloned())
1047 }
1048 }
1049 };
1050}
1051
1052macro_rules! to_indices_identity {
1053 ($t:ty) => {
1054 impl ToIndices for PrimitiveArray<$t> {
1055 type T = $t;
1056
1057 fn to_indices(&self) -> PrimitiveArray<$t> {
1058 self.clone()
1059 }
1060 }
1061 };
1062}
1063
1064macro_rules! to_indices_widening {
1065 ($t:ty, $o:ty) => {
1066 impl ToIndices for PrimitiveArray<$t> {
1067 type T = UInt32Type;
1068
1069 fn to_indices(&self) -> PrimitiveArray<$o> {
1070 let cast = self.values().iter().copied().map(|x| x as _).collect();
1071 PrimitiveArray::new(cast, self.nulls().cloned())
1072 }
1073 }
1074 };
1075}
1076
1077to_indices_widening!(UInt8Type, UInt32Type);
1078to_indices_widening!(Int8Type, UInt32Type);
1079
1080to_indices_widening!(UInt16Type, UInt32Type);
1081to_indices_widening!(Int16Type, UInt32Type);
1082
1083to_indices_identity!(UInt32Type);
1084to_indices_reinterpret!(Int32Type, UInt32Type);
1085
1086to_indices_identity!(UInt64Type);
1087to_indices_reinterpret!(Int64Type, UInt64Type);
1088
1089pub fn take_record_batch(
1128 record_batch: &RecordBatch,
1129 indices: &dyn Array,
1130) -> Result<RecordBatch, ArrowError> {
1131 let columns = record_batch
1132 .columns()
1133 .iter()
1134 .map(|c| take(c, indices, None))
1135 .collect::<Result<Vec<_>, _>>()?;
1136 RecordBatch::try_new(record_batch.schema(), columns)
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141 use super::*;
1142 use arrow_array::builder::*;
1143 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1144 use arrow_data::ArrayData;
1145 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1146 use num_traits::ToPrimitive;
1147
1148 fn test_take_decimal_arrays(
1149 data: Vec<Option<i128>>,
1150 index: &UInt32Array,
1151 options: Option<TakeOptions>,
1152 expected_data: Vec<Option<i128>>,
1153 precision: &u8,
1154 scale: &i8,
1155 ) -> Result<(), ArrowError> {
1156 let output = data
1157 .into_iter()
1158 .collect::<Decimal128Array>()
1159 .with_precision_and_scale(*precision, *scale)
1160 .unwrap();
1161
1162 let expected = expected_data
1163 .into_iter()
1164 .collect::<Decimal128Array>()
1165 .with_precision_and_scale(*precision, *scale)
1166 .unwrap();
1167
1168 let expected = Arc::new(expected) as ArrayRef;
1169 let output = take(&output, index, options).unwrap();
1170 assert_eq!(&output, &expected);
1171 Ok(())
1172 }
1173
1174 fn test_take_boolean_arrays(
1175 data: Vec<Option<bool>>,
1176 index: &UInt32Array,
1177 options: Option<TakeOptions>,
1178 expected_data: Vec<Option<bool>>,
1179 ) {
1180 let output = BooleanArray::from(data);
1181 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1182 let output = take(&output, index, options).unwrap();
1183 assert_eq!(&output, &expected)
1184 }
1185
1186 fn test_take_primitive_arrays<T>(
1187 data: Vec<Option<T::Native>>,
1188 index: &UInt32Array,
1189 options: Option<TakeOptions>,
1190 expected_data: Vec<Option<T::Native>>,
1191 ) -> Result<(), ArrowError>
1192 where
1193 T: ArrowPrimitiveType,
1194 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1195 {
1196 let output = PrimitiveArray::<T>::from(data);
1197 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1198 let output = take(&output, index, options)?;
1199 assert_eq!(&output, &expected);
1200 Ok(())
1201 }
1202
1203 fn test_take_primitive_arrays_non_null<T>(
1204 data: Vec<T::Native>,
1205 index: &UInt32Array,
1206 options: Option<TakeOptions>,
1207 expected_data: Vec<Option<T::Native>>,
1208 ) -> Result<(), ArrowError>
1209 where
1210 T: ArrowPrimitiveType,
1211 PrimitiveArray<T>: From<Vec<T::Native>>,
1212 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1213 {
1214 let output = PrimitiveArray::<T>::from(data);
1215 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1216 let output = take(&output, index, options)?;
1217 assert_eq!(&output, &expected);
1218 Ok(())
1219 }
1220
1221 fn test_take_impl_primitive_arrays<T, I>(
1222 data: Vec<Option<T::Native>>,
1223 index: &PrimitiveArray<I>,
1224 options: Option<TakeOptions>,
1225 expected_data: Vec<Option<T::Native>>,
1226 ) where
1227 T: ArrowPrimitiveType,
1228 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1229 I: ArrowPrimitiveType,
1230 {
1231 let output = PrimitiveArray::<T>::from(data);
1232 let expected = PrimitiveArray::<T>::from(expected_data);
1233 let output = take(&output, index, options).unwrap();
1234 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1235 assert_eq!(output, &expected)
1236 }
1237
1238 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1240 let mut struct_builder = StructBuilder::new(
1241 Fields::from(vec![
1242 Field::new("a", DataType::Boolean, true),
1243 Field::new("b", DataType::Int32, true),
1244 ]),
1245 vec![
1246 Box::new(BooleanBuilder::with_capacity(values.len())),
1247 Box::new(Int32Builder::with_capacity(values.len())),
1248 ],
1249 );
1250
1251 for value in values {
1252 struct_builder
1253 .field_builder::<BooleanBuilder>(0)
1254 .unwrap()
1255 .append_option(value.and_then(|v| v.0));
1256 struct_builder
1257 .field_builder::<Int32Builder>(1)
1258 .unwrap()
1259 .append_option(value.and_then(|v| v.1));
1260 struct_builder.append(value.is_some());
1261 }
1262 struct_builder.finish()
1263 }
1264
1265 #[test]
1266 fn test_take_decimal128_non_null_indices() {
1267 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1268 let precision: u8 = 10;
1269 let scale: i8 = 5;
1270 test_take_decimal_arrays(
1271 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1272 &index,
1273 None,
1274 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1275 &precision,
1276 &scale,
1277 )
1278 .unwrap();
1279 }
1280
1281 #[test]
1282 fn test_take_decimal128() {
1283 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1284 let precision: u8 = 10;
1285 let scale: i8 = 5;
1286 test_take_decimal_arrays(
1287 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1288 &index,
1289 None,
1290 vec![Some(3), None, Some(1), Some(3), Some(2)],
1291 &precision,
1292 &scale,
1293 )
1294 .unwrap();
1295 }
1296
1297 #[test]
1298 fn test_take_primitive_non_null_indices() {
1299 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1300 test_take_primitive_arrays::<Int8Type>(
1301 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1302 &index,
1303 None,
1304 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1305 )
1306 .unwrap();
1307 }
1308
1309 #[test]
1310 fn test_take_primitive_non_null_values() {
1311 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1312 test_take_primitive_arrays::<Int8Type>(
1313 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1314 &index,
1315 None,
1316 vec![Some(3), None, Some(1), Some(3), Some(2)],
1317 )
1318 .unwrap();
1319 }
1320
1321 #[test]
1322 fn test_take_primitive_non_null() {
1323 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1324 test_take_primitive_arrays::<Int8Type>(
1325 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1326 &index,
1327 None,
1328 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1329 )
1330 .unwrap();
1331 }
1332
1333 #[test]
1334 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1335 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1336 let index = index.slice(2, 4);
1337 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1338
1339 assert_eq!(
1340 index,
1341 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1342 );
1343
1344 test_take_primitive_arrays_non_null::<Int64Type>(
1345 vec![0, 10, 20, 30, 40, 50],
1346 index,
1347 None,
1348 vec![Some(20), Some(30), None, None],
1349 )
1350 .unwrap();
1351 }
1352
1353 #[test]
1354 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1355 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1356 let index = index.slice(2, 4);
1357 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1358
1359 assert_eq!(
1360 index,
1361 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1362 );
1363
1364 test_take_primitive_arrays::<Int64Type>(
1365 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1366 index,
1367 None,
1368 vec![Some(20), Some(30), None, None],
1369 )
1370 .unwrap();
1371 }
1372
1373 #[test]
1374 fn test_take_primitive() {
1375 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1376
1377 test_take_primitive_arrays::<Int8Type>(
1379 vec![Some(0), None, Some(2), Some(3), None],
1380 &index,
1381 None,
1382 vec![Some(3), None, None, Some(3), Some(2)],
1383 )
1384 .unwrap();
1385
1386 test_take_primitive_arrays::<Int16Type>(
1388 vec![Some(0), None, Some(2), Some(3), None],
1389 &index,
1390 None,
1391 vec![Some(3), None, None, Some(3), Some(2)],
1392 )
1393 .unwrap();
1394
1395 test_take_primitive_arrays::<Int32Type>(
1397 vec![Some(0), None, Some(2), Some(3), None],
1398 &index,
1399 None,
1400 vec![Some(3), None, None, Some(3), Some(2)],
1401 )
1402 .unwrap();
1403
1404 test_take_primitive_arrays::<Int64Type>(
1406 vec![Some(0), None, Some(2), Some(3), None],
1407 &index,
1408 None,
1409 vec![Some(3), None, None, Some(3), Some(2)],
1410 )
1411 .unwrap();
1412
1413 test_take_primitive_arrays::<UInt8Type>(
1415 vec![Some(0), None, Some(2), Some(3), None],
1416 &index,
1417 None,
1418 vec![Some(3), None, None, Some(3), Some(2)],
1419 )
1420 .unwrap();
1421
1422 test_take_primitive_arrays::<UInt16Type>(
1424 vec![Some(0), None, Some(2), Some(3), None],
1425 &index,
1426 None,
1427 vec![Some(3), None, None, Some(3), Some(2)],
1428 )
1429 .unwrap();
1430
1431 test_take_primitive_arrays::<UInt32Type>(
1433 vec![Some(0), None, Some(2), Some(3), None],
1434 &index,
1435 None,
1436 vec![Some(3), None, None, Some(3), Some(2)],
1437 )
1438 .unwrap();
1439
1440 test_take_primitive_arrays::<Int64Type>(
1442 vec![Some(0), None, Some(2), Some(-15), None],
1443 &index,
1444 None,
1445 vec![Some(-15), None, None, Some(-15), Some(2)],
1446 )
1447 .unwrap();
1448
1449 test_take_primitive_arrays::<IntervalYearMonthType>(
1451 vec![Some(0), None, Some(2), Some(-15), None],
1452 &index,
1453 None,
1454 vec![Some(-15), None, None, Some(-15), Some(2)],
1455 )
1456 .unwrap();
1457
1458 let v1 = IntervalDayTime::new(0, 0);
1460 let v2 = IntervalDayTime::new(2, 0);
1461 let v3 = IntervalDayTime::new(-15, 0);
1462 test_take_primitive_arrays::<IntervalDayTimeType>(
1463 vec![Some(v1), None, Some(v2), Some(v3), None],
1464 &index,
1465 None,
1466 vec![Some(v3), None, None, Some(v3), Some(v2)],
1467 )
1468 .unwrap();
1469
1470 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1472 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1473 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1474 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1475 vec![Some(v1), None, Some(v2), Some(v3), None],
1476 &index,
1477 None,
1478 vec![Some(v3), None, None, Some(v3), Some(v2)],
1479 )
1480 .unwrap();
1481
1482 test_take_primitive_arrays::<DurationSecondType>(
1484 vec![Some(0), None, Some(2), Some(-15), None],
1485 &index,
1486 None,
1487 vec![Some(-15), None, None, Some(-15), Some(2)],
1488 )
1489 .unwrap();
1490
1491 test_take_primitive_arrays::<DurationMillisecondType>(
1493 vec![Some(0), None, Some(2), Some(-15), None],
1494 &index,
1495 None,
1496 vec![Some(-15), None, None, Some(-15), Some(2)],
1497 )
1498 .unwrap();
1499
1500 test_take_primitive_arrays::<DurationMicrosecondType>(
1502 vec![Some(0), None, Some(2), Some(-15), None],
1503 &index,
1504 None,
1505 vec![Some(-15), None, None, Some(-15), Some(2)],
1506 )
1507 .unwrap();
1508
1509 test_take_primitive_arrays::<DurationNanosecondType>(
1511 vec![Some(0), None, Some(2), Some(-15), None],
1512 &index,
1513 None,
1514 vec![Some(-15), None, None, Some(-15), Some(2)],
1515 )
1516 .unwrap();
1517
1518 test_take_primitive_arrays::<Float32Type>(
1520 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1521 &index,
1522 None,
1523 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1524 )
1525 .unwrap();
1526
1527 test_take_primitive_arrays::<Float64Type>(
1529 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1530 &index,
1531 None,
1532 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1533 )
1534 .unwrap();
1535 }
1536
1537 #[test]
1538 fn test_take_preserve_timezone() {
1539 let index = Int64Array::from(vec![Some(0), None]);
1540
1541 let input = TimestampNanosecondArray::from(vec![
1542 1_639_715_368_000_000_000,
1543 1_639_715_368_000_000_000,
1544 ])
1545 .with_timezone("UTC".to_string());
1546 let result = take(&input, &index, None).unwrap();
1547 match result.data_type() {
1548 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1549 assert_eq!(tz.clone(), Some("UTC".into()))
1550 }
1551 _ => panic!(),
1552 }
1553 }
1554
1555 #[test]
1556 fn test_take_impl_primitive_with_int64_indices() {
1557 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1558
1559 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1561 vec![Some(0), None, Some(2), Some(3), None],
1562 &index,
1563 None,
1564 vec![Some(3), None, None, Some(3), Some(2)],
1565 );
1566
1567 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1569 vec![Some(0), None, Some(2), Some(-15), None],
1570 &index,
1571 None,
1572 vec![Some(-15), None, None, Some(-15), Some(2)],
1573 );
1574
1575 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1577 vec![Some(0), None, Some(2), Some(3), None],
1578 &index,
1579 None,
1580 vec![Some(3), None, None, Some(3), Some(2)],
1581 );
1582
1583 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1585 vec![Some(0), None, Some(2), Some(-15), None],
1586 &index,
1587 None,
1588 vec![Some(-15), None, None, Some(-15), Some(2)],
1589 );
1590
1591 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1593 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1594 &index,
1595 None,
1596 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1597 );
1598 }
1599
1600 #[test]
1601 fn test_take_impl_primitive_with_uint8_indices() {
1602 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1603
1604 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1606 vec![Some(0), None, Some(2), Some(3), None],
1607 &index,
1608 None,
1609 vec![Some(3), None, None, Some(3), Some(2)],
1610 );
1611
1612 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1614 vec![Some(0), None, Some(2), Some(-15), None],
1615 &index,
1616 None,
1617 vec![Some(-15), None, None, Some(-15), Some(2)],
1618 );
1619
1620 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1622 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1623 &index,
1624 None,
1625 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1626 );
1627 }
1628
1629 #[test]
1630 fn test_take_bool() {
1631 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1632 test_take_boolean_arrays(
1634 vec![Some(false), None, Some(true), Some(false), None],
1635 &index,
1636 None,
1637 vec![Some(false), None, None, Some(false), Some(true)],
1638 );
1639 }
1640
1641 #[test]
1642 fn test_take_bool_nullable_index() {
1643 let index_data = ArrayData::try_new(
1645 DataType::UInt32,
1646 6,
1647 Some(Buffer::from_iter(vec![
1648 false, true, false, true, false, true,
1649 ])),
1650 0,
1651 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1652 vec![],
1653 )
1654 .unwrap();
1655 let index = UInt32Array::from(index_data);
1656 test_take_boolean_arrays(
1657 vec![Some(true), None, Some(false)],
1658 &index,
1659 None,
1660 vec![None, Some(true), None, None, None, Some(false)],
1661 );
1662 }
1663
1664 #[test]
1665 fn test_take_bool_nullable_index_nonnull_values() {
1666 let index_data = ArrayData::try_new(
1668 DataType::UInt32,
1669 6,
1670 Some(Buffer::from_iter(vec![
1671 false, true, false, true, false, true,
1672 ])),
1673 0,
1674 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1675 vec![],
1676 )
1677 .unwrap();
1678 let index = UInt32Array::from(index_data);
1679 test_take_boolean_arrays(
1680 vec![Some(true), Some(true), Some(false)],
1681 &index,
1682 None,
1683 vec![None, Some(true), None, Some(true), None, Some(false)],
1684 );
1685 }
1686
1687 #[test]
1688 fn test_take_bool_with_offset() {
1689 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1690 let index = index.slice(2, 4);
1691 let index = index
1692 .as_any()
1693 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1694 .unwrap();
1695
1696 test_take_boolean_arrays(
1698 vec![Some(false), None, Some(true), Some(false), None],
1699 index,
1700 None,
1701 vec![None, Some(false), Some(true), None],
1702 );
1703 }
1704
1705 fn _test_take_string<'a, K>()
1706 where
1707 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1708 {
1709 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1710
1711 let array = K::from(vec![
1712 Some("one"),
1713 None,
1714 Some("three"),
1715 Some("four"),
1716 Some("five"),
1717 ]);
1718 let actual = take(&array, &index, None).unwrap();
1719 assert_eq!(actual.len(), index.len());
1720
1721 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1722
1723 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1724
1725 assert_eq!(actual, &expected);
1726 }
1727
1728 #[test]
1729 fn test_take_string() {
1730 _test_take_string::<StringArray>()
1731 }
1732
1733 #[test]
1734 fn test_take_large_string() {
1735 _test_take_string::<LargeStringArray>()
1736 }
1737
1738 #[test]
1739 fn test_take_slice_string() {
1740 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1741 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1742 let indices_slice = indices.slice(1, 4);
1743 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1744 let result = take(&strings, &indices_slice, None).unwrap();
1745 assert_eq!(result.as_ref(), &expected);
1746 }
1747
1748 fn _test_byte_view<T>()
1749 where
1750 T: ByteViewType,
1751 str: AsRef<T::Native>,
1752 T::Native: PartialEq,
1753 {
1754 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1755 let array = {
1756 let mut builder = GenericByteViewBuilder::<T>::new();
1758 builder.append_value("hello");
1759 builder.append_value("world");
1760 builder.append_null();
1761 builder.append_value("large payload over 12 bytes");
1762 builder.append_value("lulu");
1763 builder.finish()
1764 };
1765
1766 let actual = take(&array, &index, None).unwrap();
1767
1768 assert_eq!(actual.len(), index.len());
1769
1770 let expected = {
1771 let mut builder = GenericByteViewBuilder::<T>::new();
1773 builder.append_value("large payload over 12 bytes");
1774 builder.append_null();
1775 builder.append_value("world");
1776 builder.append_value("large payload over 12 bytes");
1777 builder.append_value("lulu");
1778 builder.append_null();
1779 builder.finish()
1780 };
1781
1782 assert_eq!(actual.as_ref(), &expected);
1783 }
1784
1785 #[test]
1786 fn test_take_string_view() {
1787 _test_byte_view::<StringViewType>()
1788 }
1789
1790 #[test]
1791 fn test_take_binary_view() {
1792 _test_byte_view::<BinaryViewType>()
1793 }
1794
1795 macro_rules! test_take_list {
1796 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1797 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1799 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1801 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1802 let list_data_type =
1804 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1805 let list_data = ArrayData::builder(list_data_type.clone())
1806 .len(4)
1807 .add_buffer(value_offsets)
1808 .add_child_data(value_data)
1809 .build()
1810 .unwrap();
1811 let list_array = $list_array_type::from(list_data);
1812
1813 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1815
1816 let a = take(&list_array, &index, None).unwrap();
1817 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1818
1819 let expected_data = Int32Array::from(vec![
1822 Some(2),
1823 Some(3),
1824 Some(-1),
1825 Some(-2),
1826 Some(-1),
1827 Some(0),
1828 Some(0),
1829 Some(0),
1830 ])
1831 .into_data();
1832 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1834 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1835 let expected_list_data = ArrayData::builder(list_data_type)
1837 .len(5)
1838 .nulls(index.nulls().cloned())
1840 .add_buffer(expected_offsets)
1841 .add_child_data(expected_data)
1842 .build()
1843 .unwrap();
1844 let expected_list_array = $list_array_type::from(expected_list_data);
1845
1846 assert_eq!(a, &expected_list_array);
1847 }};
1848 }
1849
1850 macro_rules! test_take_list_with_value_nulls {
1851 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1852 let value_data = Int32Array::from(vec![
1854 Some(0),
1855 None,
1856 Some(0),
1857 Some(-1),
1858 Some(-2),
1859 Some(3),
1860 None,
1861 Some(5),
1862 None,
1863 ])
1864 .into_data();
1865 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1867 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1868 let list_data_type =
1870 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1871 let list_data = ArrayData::builder(list_data_type.clone())
1872 .len(4)
1873 .add_buffer(value_offsets)
1874 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1875 .add_child_data(value_data)
1876 .build()
1877 .unwrap();
1878 let list_array = $list_array_type::from(list_data);
1879
1880 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1882
1883 let a = take(&list_array, &index, None).unwrap();
1884 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1885
1886 let expected_data = Int32Array::from(vec![
1889 None,
1890 Some(-1),
1891 Some(-2),
1892 Some(3),
1893 Some(5),
1894 None,
1895 Some(0),
1896 None,
1897 Some(0),
1898 ])
1899 .into_data();
1900 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1902 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1903 let expected_list_data = ArrayData::builder(list_data_type)
1905 .len(5)
1906 .nulls(index.nulls().cloned())
1908 .add_buffer(expected_offsets)
1909 .add_child_data(expected_data)
1910 .build()
1911 .unwrap();
1912 let expected_list_array = $list_array_type::from(expected_list_data);
1913
1914 assert_eq!(a, &expected_list_array);
1915 }};
1916 }
1917
1918 macro_rules! test_take_list_with_nulls {
1919 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1920 let value_data = Int32Array::from(vec![
1922 Some(0),
1923 None,
1924 Some(0),
1925 Some(-1),
1926 Some(-2),
1927 Some(3),
1928 Some(5),
1929 None,
1930 ])
1931 .into_data();
1932 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1934 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1935 let list_data_type =
1937 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1938 let list_data = ArrayData::builder(list_data_type.clone())
1939 .len(4)
1940 .add_buffer(value_offsets)
1941 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1942 .add_child_data(value_data)
1943 .build()
1944 .unwrap();
1945 let list_array = $list_array_type::from(list_data);
1946
1947 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1949
1950 let a = take(&list_array, &index, None).unwrap();
1951 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1952
1953 let expected_data = Int32Array::from(vec![
1956 Some(-1),
1957 Some(-2),
1958 Some(3),
1959 Some(5),
1960 None,
1961 Some(0),
1962 None,
1963 Some(0),
1964 ])
1965 .into_data();
1966 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1968 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1969 let mut null_bits: [u8; 1] = [0; 1];
1971 bit_util::set_bit(&mut null_bits, 2);
1972 bit_util::set_bit(&mut null_bits, 3);
1973 bit_util::set_bit(&mut null_bits, 4);
1974 let expected_list_data = ArrayData::builder(list_data_type)
1975 .len(5)
1976 .null_bit_buffer(Some(Buffer::from(null_bits)))
1978 .add_buffer(expected_offsets)
1979 .add_child_data(expected_data)
1980 .build()
1981 .unwrap();
1982 let expected_list_array = $list_array_type::from(expected_list_data);
1983
1984 assert_eq!(a, &expected_list_array);
1985 }};
1986 }
1987
1988 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1989 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1990 take_indices: Vec<Option<usize>>,
1991 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1992 mapper: F,
1993 ) where
1994 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1995 {
1996 let mut list_view_array =
1997 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1998
1999 for value in values {
2000 list_view_array.append_option(value);
2001 }
2002 let list_view_array = list_view_array.finish();
2003 let list_view_array = mapper(list_view_array);
2004
2005 let mut indices = UInt64Builder::new();
2006 for idx in take_indices {
2007 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
2008 }
2009 let indices = indices.finish();
2010
2011 let taken = take(&list_view_array, &indices, None)
2012 .unwrap()
2013 .as_list_view()
2014 .clone();
2015
2016 let mut expected_array =
2017 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2018 for value in expected {
2019 expected_array.append_option(value);
2020 }
2021 let expected_array = expected_array.finish();
2022
2023 assert_eq!(taken, expected_array);
2024 }
2025
2026 macro_rules! list_view_test_case {
2027 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
2028 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
2029 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
2030 }};
2031 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
2032 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
2033 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
2034 }};
2035 }
2036
2037 fn do_take_fixed_size_list_test<T>(
2038 length: <Int32Type as ArrowPrimitiveType>::Native,
2039 input_data: Vec<Option<Vec<Option<T::Native>>>>,
2040 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
2041 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
2042 ) where
2043 T: ArrowPrimitiveType,
2044 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2045 {
2046 let indices = UInt32Array::from(indices);
2047
2048 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
2049
2050 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
2051
2052 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
2053
2054 assert_eq!(&output, &expected)
2055 }
2056
2057 #[test]
2058 fn test_take_list() {
2059 test_take_list!(i32, List, ListArray);
2060 }
2061
2062 #[test]
2063 fn test_take_large_list() {
2064 test_take_list!(i64, LargeList, LargeListArray);
2065 }
2066
2067 #[test]
2068 fn test_take_list_with_value_nulls() {
2069 test_take_list_with_value_nulls!(i32, List, ListArray);
2070 }
2071
2072 #[test]
2073 fn test_take_large_list_with_value_nulls() {
2074 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
2075 }
2076
2077 #[test]
2078 fn test_test_take_list_with_nulls() {
2079 test_take_list_with_nulls!(i32, List, ListArray);
2080 }
2081
2082 #[test]
2083 fn test_test_take_large_list_with_nulls() {
2084 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2085 }
2086
2087 #[test]
2088 fn test_test_take_list_view_reversed() {
2089 list_view_test_case! {
2091 values: vec![
2092 Some(vec![Some(1), None, Some(3)]),
2093 None,
2094 Some(vec![Some(7), Some(8), None]),
2095 ],
2096 indices: vec![Some(2), Some(1), Some(0)],
2097 expected: vec![
2098 Some(vec![Some(7), Some(8), None]),
2099 None,
2100 Some(vec![Some(1), None, Some(3)]),
2101 ]
2102 }
2103 }
2104
2105 #[test]
2106 fn test_take_list_view_null_indices() {
2107 list_view_test_case! {
2109 values: vec![
2110 Some(vec![Some(1), None, Some(3)]),
2111 None,
2112 Some(vec![Some(7), Some(8), None]),
2113 ],
2114 indices: vec![None, Some(0), None],
2115 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2116 }
2117 }
2118
2119 #[test]
2120 fn test_take_list_view_null_values() {
2121 list_view_test_case! {
2123 values: vec![
2124 Some(vec![Some(1), None, Some(3)]),
2125 None,
2126 Some(vec![Some(7), Some(8), None]),
2127 ],
2128 indices: vec![Some(1), Some(1), Some(1), None, None],
2129 expected: vec![None; 5]
2130 }
2131 }
2132
2133 #[test]
2134 fn test_take_list_view_sliced() {
2135 list_view_test_case! {
2137 values: vec![
2138 Some(vec![Some(1)]),
2139 None,
2140 None,
2141 Some(vec![Some(2), Some(3)]),
2142 Some(vec![Some(4), Some(5)]),
2143 None,
2144 ],
2145 transform: |l| l.slice(2, 4),
2146 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2147 expected: vec![
2148 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2149 ]
2150 }
2151 }
2152
2153 #[test]
2154 fn test_take_fixed_size_list() {
2155 do_take_fixed_size_list_test::<Int32Type>(
2156 3,
2157 vec![
2158 Some(vec![None, Some(1), Some(2)]),
2159 Some(vec![Some(3), Some(4), None]),
2160 Some(vec![Some(6), Some(7), Some(8)]),
2161 ],
2162 vec![2, 1, 0],
2163 vec![
2164 Some(vec![Some(6), Some(7), Some(8)]),
2165 Some(vec![Some(3), Some(4), None]),
2166 Some(vec![None, Some(1), Some(2)]),
2167 ],
2168 );
2169
2170 do_take_fixed_size_list_test::<UInt8Type>(
2171 1,
2172 vec![
2173 Some(vec![Some(1)]),
2174 Some(vec![Some(2)]),
2175 Some(vec![Some(3)]),
2176 Some(vec![Some(4)]),
2177 Some(vec![Some(5)]),
2178 Some(vec![Some(6)]),
2179 Some(vec![Some(7)]),
2180 Some(vec![Some(8)]),
2181 ],
2182 vec![2, 7, 0],
2183 vec![
2184 Some(vec![Some(3)]),
2185 Some(vec![Some(8)]),
2186 Some(vec![Some(1)]),
2187 ],
2188 );
2189
2190 do_take_fixed_size_list_test::<UInt64Type>(
2191 3,
2192 vec![
2193 Some(vec![Some(10), Some(11), Some(12)]),
2194 Some(vec![Some(13), Some(14), Some(15)]),
2195 None,
2196 Some(vec![Some(16), Some(17), Some(18)]),
2197 ],
2198 vec![3, 2, 1, 2, 0],
2199 vec![
2200 Some(vec![Some(16), Some(17), Some(18)]),
2201 None,
2202 Some(vec![Some(13), Some(14), Some(15)]),
2203 None,
2204 Some(vec![Some(10), Some(11), Some(12)]),
2205 ],
2206 );
2207 }
2208
2209 #[test]
2210 fn test_take_fixed_size_binary_with_nulls_indices() {
2211 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2212 [
2213 Some(vec![0x01, 0x01, 0x01, 0x01]),
2214 Some(vec![0x02, 0x02, 0x02, 0x02]),
2215 Some(vec![0x03, 0x03, 0x03, 0x03]),
2216 Some(vec![0x04, 0x04, 0x04, 0x04]),
2217 ]
2218 .into_iter(),
2219 4,
2220 )
2221 .unwrap();
2222
2223 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2225
2226 let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2227 assert_eq!(result.len(), 4);
2228 assert_eq!(result.null_count(), 2);
2229 assert_eq!(
2230 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2231 vec![true, false, false, true]
2232 );
2233 }
2234
2235 #[test]
2239 fn test_take_fixed_size_binary_with_nulls_indices_not_optimized_length() {
2240 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2241 [
2242 Some(vec![0x01, 0x01, 0x01, 0x01, 0x01]),
2243 Some(vec![0x02, 0x02, 0x02, 0x02, 0x01]),
2244 Some(vec![0x03, 0x03, 0x03, 0x03, 0x01]),
2245 Some(vec![0x04, 0x04, 0x04, 0x04, 0x01]),
2246 ]
2247 .into_iter(),
2248 5,
2249 )
2250 .unwrap();
2251
2252 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2254
2255 let result = take_fixed_size_binary(&fsb, &indices, 5).unwrap();
2256 assert_eq!(result.len(), 4);
2257 assert_eq!(result.null_count(), 2);
2258 assert_eq!(
2259 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2260 vec![true, false, false, true]
2261 );
2262 }
2263
2264 #[test]
2265 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2266 fn test_take_list_out_of_bounds() {
2267 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2269 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2271 let list_data_type =
2273 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2274 let list_data = ArrayData::builder(list_data_type)
2275 .len(3)
2276 .add_buffer(value_offsets)
2277 .add_child_data(value_data)
2278 .build()
2279 .unwrap();
2280 let list_array = ListArray::from(list_data);
2281
2282 let index = UInt32Array::from(vec![1000]);
2283
2284 take(&list_array, &index, None).unwrap();
2287 }
2288
2289 #[test]
2290 fn test_take_map() {
2291 let values = Int32Array::from(vec![1, 2, 3, 4]);
2292 let array =
2293 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2294 .unwrap();
2295
2296 let index = UInt32Array::from(vec![0]);
2297
2298 let result = take(&array, &index, None).unwrap();
2299 let expected: ArrayRef = Arc::new(
2300 MapArray::new_from_strings(
2301 vec!["a", "b", "c"].into_iter(),
2302 &values.slice(0, 3),
2303 &[0, 3],
2304 )
2305 .unwrap(),
2306 );
2307 assert_eq!(&expected, &result);
2308 }
2309
2310 #[test]
2311 fn test_take_struct() {
2312 let array = create_test_struct(vec![
2313 Some((Some(true), Some(42))),
2314 Some((Some(false), Some(28))),
2315 Some((Some(false), Some(19))),
2316 Some((Some(true), Some(31))),
2317 None,
2318 ]);
2319
2320 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2321 let actual = take(&array, &index, None).unwrap();
2322 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2323 assert_eq!(index.len(), actual.len());
2324 assert_eq!(1, actual.null_count());
2325
2326 let expected = create_test_struct(vec![
2327 Some((Some(true), Some(42))),
2328 Some((Some(true), Some(31))),
2329 Some((Some(false), Some(28))),
2330 Some((Some(true), Some(42))),
2331 Some((Some(false), Some(19))),
2332 None,
2333 ]);
2334
2335 assert_eq!(&expected, actual);
2336
2337 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2338 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2339 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2340 let actual = take(&empty_struct_arr, &index, None).unwrap();
2341
2342 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2343 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2344 assert_eq!(&expected_struct_arr, actual.as_struct());
2345 }
2346
2347 #[test]
2348 fn test_take_struct_with_null_indices() {
2349 let array = create_test_struct(vec![
2350 Some((Some(true), Some(42))),
2351 Some((Some(false), Some(28))),
2352 Some((Some(false), Some(19))),
2353 Some((Some(true), Some(31))),
2354 None,
2355 ]);
2356
2357 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2358 let actual = take(&array, &index, None).unwrap();
2359 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2360 assert_eq!(index.len(), actual.len());
2361 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2364 None,
2365 Some((Some(true), Some(31))),
2366 Some((Some(false), Some(28))),
2367 None,
2368 Some((Some(true), Some(42))),
2369 None,
2370 ]);
2371
2372 assert_eq!(&expected, actual);
2373 }
2374
2375 #[test]
2376 fn test_take_out_of_bounds() {
2377 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2378 let take_opt = TakeOptions { check_bounds: true };
2379
2380 let result = test_take_primitive_arrays::<Int64Type>(
2382 vec![Some(0), None, Some(2), Some(3), None],
2383 &index,
2384 Some(take_opt),
2385 vec![None],
2386 );
2387 assert!(result.is_err());
2388 }
2389
2390 #[test]
2391 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2392 fn test_take_out_of_bounds_panic() {
2393 let index = UInt32Array::from(vec![Some(1000)]);
2394
2395 test_take_primitive_arrays::<Int64Type>(
2396 vec![Some(0), Some(1), Some(2), Some(3)],
2397 &index,
2398 None,
2399 vec![None],
2400 )
2401 .unwrap();
2402 }
2403
2404 #[test]
2405 fn test_null_array_smaller_than_indices() {
2406 let values = NullArray::new(2);
2407 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2408
2409 let result = take(&values, &indices, None).unwrap();
2410 let expected: ArrayRef = Arc::new(NullArray::new(3));
2411 assert_eq!(&result, &expected);
2412 }
2413
2414 #[test]
2415 fn test_null_array_larger_than_indices() {
2416 let values = NullArray::new(5);
2417 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2418
2419 let result = take(&values, &indices, None).unwrap();
2420 let expected: ArrayRef = Arc::new(NullArray::new(3));
2421 assert_eq!(&result, &expected);
2422 }
2423
2424 #[test]
2425 fn test_null_array_indices_out_of_bounds() {
2426 let values = NullArray::new(5);
2427 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2428
2429 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2430 assert_eq!(
2431 result.unwrap_err().to_string(),
2432 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2433 );
2434 }
2435
2436 #[test]
2437 fn test_take_dict() {
2438 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2439
2440 dict_builder.append("foo").unwrap();
2441 dict_builder.append("bar").unwrap();
2442 dict_builder.append("").unwrap();
2443 dict_builder.append_null();
2444 dict_builder.append("foo").unwrap();
2445 dict_builder.append("bar").unwrap();
2446 dict_builder.append("bar").unwrap();
2447 dict_builder.append("foo").unwrap();
2448
2449 let array = dict_builder.finish();
2450 let dict_values = array.values().clone();
2451 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2452
2453 let indices = UInt32Array::from(vec![
2454 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2462
2463 let result = take(&array, &indices, None).unwrap();
2464 let result = result
2465 .as_any()
2466 .downcast_ref::<DictionaryArray<Int16Type>>()
2467 .unwrap();
2468
2469 let result_values: StringArray = result.values().to_data().into();
2470
2471 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2473 assert_eq!(&expected_values, dict_values);
2474 assert_eq!(&expected_values, &result_values);
2475
2476 let expected_keys = Int16Array::from(vec![
2477 Some(0),
2478 Some(0),
2479 None,
2480 Some(1),
2481 Some(1),
2482 Some(2),
2483 None,
2484 ]);
2485 assert_eq!(result.keys(), &expected_keys);
2486 }
2487
2488 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2489 where
2490 S: OffsetSizeTrait + 'static,
2491 T: ArrowPrimitiveType,
2492 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2493 {
2494 GenericListArray::from_iter_primitive::<T, _, _>(
2495 data.iter()
2496 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2497 )
2498 }
2499
2500 #[test]
2501 fn test_take_value_index_from_list() {
2502 let list = build_generic_list::<i32, Int32Type>(vec![
2503 Some(vec![0, 1]),
2504 Some(vec![2, 3, 4]),
2505 Some(vec![5, 6, 7, 8, 9]),
2506 ]);
2507 let indices = UInt32Array::from(vec![2, 0]);
2508
2509 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2510
2511 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2512 assert_eq!(offsets, vec![0, 5, 7]);
2513 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2514 }
2515
2516 #[test]
2517 fn test_take_value_index_from_large_list() {
2518 let list = build_generic_list::<i64, Int32Type>(vec![
2519 Some(vec![0, 1]),
2520 Some(vec![2, 3, 4]),
2521 Some(vec![5, 6, 7, 8, 9]),
2522 ]);
2523 let indices = UInt32Array::from(vec![2, 0]);
2524
2525 let (indexed, offsets, null_buf) =
2526 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2527
2528 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2529 assert_eq!(offsets, vec![0, 5, 7]);
2530 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2531 }
2532
2533 #[test]
2534 fn test_take_runs() {
2535 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2536
2537 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2538 builder.extend(logical_array.into_iter().map(Some));
2539 let run_array = builder.finish();
2540
2541 let take_indices: PrimitiveArray<Int32Type> =
2542 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2543
2544 let take_out = take_run(&run_array, &take_indices).unwrap();
2545
2546 assert_eq!(take_out.len(), 7);
2547 assert_eq!(take_out.run_ends().len(), 7);
2548 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2549
2550 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2551 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2552 }
2553
2554 #[test]
2555 fn test_take_runs_sliced() {
2556 let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2557
2558 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2559 builder.extend(logical_array.into_iter().map(Some));
2560 let run_array = builder.finish();
2561
2562 let run_array = run_array.slice(4, 6); let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2565
2566 let result = take_run(&run_array, &take_indices).unwrap();
2567 let result = result.downcast::<Int32Array>().unwrap();
2568
2569 let expected = vec![3, 5, 5, 3, 4];
2570 let actual = result.into_iter().flatten().collect::<Vec<_>>();
2571
2572 assert_eq!(expected, actual);
2573 }
2574
2575 #[test]
2576 fn test_take_value_index_from_fixed_list() {
2577 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2578 vec![
2579 Some(vec![Some(1), Some(2), None]),
2580 Some(vec![Some(4), None, Some(6)]),
2581 None,
2582 Some(vec![None, Some(8), Some(9)]),
2583 ],
2584 3,
2585 );
2586
2587 let indices = UInt32Array::from(vec![2, 1, 0]);
2588 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2589
2590 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2591
2592 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2593 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2594
2595 assert_eq!(
2596 indexed,
2597 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2598 );
2599 }
2600
2601 #[test]
2602 fn test_take_null_indices() {
2603 let indices = Int32Array::new(
2605 vec![1, 2, 400, 400].into(),
2606 Some(NullBuffer::from(vec![true, true, false, false])),
2607 );
2608 let values = Int32Array::from(vec![1, 23, 4, 5]);
2609 let r = take(&values, &indices, None).unwrap();
2610 let values = r
2611 .as_primitive::<Int32Type>()
2612 .into_iter()
2613 .collect::<Vec<_>>();
2614 assert_eq!(&values, &[Some(23), Some(4), None, None])
2615 }
2616
2617 #[test]
2618 fn test_take_fixed_size_list_null_indices() {
2619 let indices = Int32Array::from_iter([Some(0), None]);
2620 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2621 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2622 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2623
2624 let r = take(&values, &indices, None).unwrap();
2625 let values = r
2626 .as_fixed_size_list()
2627 .values()
2628 .as_primitive::<Int32Type>()
2629 .into_iter()
2630 .collect::<Vec<_>>();
2631 assert_eq!(values, &[Some(0), Some(1), None, None])
2632 }
2633
2634 #[test]
2635 fn test_take_bytes_null_indices() {
2636 let indices = Int32Array::new(
2637 vec![0, 1, 400, 400].into(),
2638 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2639 );
2640 let values = StringArray::from(vec![Some("foo"), None]);
2641 let r = take(&values, &indices, None).unwrap();
2642 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2643 assert_eq!(&values, &[Some("foo"), None, None, None])
2644 }
2645
2646 #[test]
2647 fn test_take_union_sparse() {
2648 let structs = create_test_struct(vec![
2649 Some((Some(true), Some(42))),
2650 Some((Some(false), Some(28))),
2651 Some((Some(false), Some(19))),
2652 Some((Some(true), Some(31))),
2653 None,
2654 ]);
2655 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2656 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2657
2658 let union_fields = [
2659 (
2660 0,
2661 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2662 ),
2663 (
2664 1,
2665 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2666 ),
2667 ]
2668 .into_iter()
2669 .collect();
2670 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2671 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2672
2673 let indices = vec![0, 3, 1, 0, 2, 4];
2674 let index = UInt32Array::from(indices.clone());
2675 let actual = take(&array, &index, None).unwrap();
2676 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2677 let strings = actual.child(1);
2678 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2679
2680 let actual = strings.iter().collect::<Vec<_>>();
2681 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2682 assert_eq!(expected, actual);
2683 }
2684
2685 #[test]
2686 fn test_take_union_dense() {
2687 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2688 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2689 let ints = vec![10, 20, 30, 40];
2690 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2691
2692 let indices = vec![0, 3, 1, 0, 2, 4];
2693
2694 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2695 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2696 let taken_ints = vec![10, 20, 10, 30];
2697 let taken_strings = vec![Some("a"), None];
2698
2699 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2700 let offsets = <ScalarBuffer<i32>>::from(offsets);
2701 let ints = UInt32Array::from(ints);
2702 let strings = StringArray::from(strings);
2703
2704 let union_fields = [
2705 (
2706 0,
2707 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2708 ),
2709 (
2710 1,
2711 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2712 ),
2713 ]
2714 .into_iter()
2715 .collect();
2716
2717 let array = UnionArray::try_new(
2718 union_fields,
2719 type_ids,
2720 Some(offsets),
2721 vec![Arc::new(ints), Arc::new(strings)],
2722 )
2723 .unwrap();
2724
2725 let index = UInt32Array::from(indices);
2726
2727 let actual = take(&array, &index, None).unwrap();
2728 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2729
2730 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2731 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2732 assert_eq!(
2733 UInt32Array::from(actual.child(0).to_data()),
2734 UInt32Array::from(taken_ints)
2735 );
2736 assert_eq!(
2737 StringArray::from(actual.child(1).to_data()),
2738 StringArray::from(taken_strings)
2739 );
2740 }
2741
2742 #[test]
2743 fn test_take_union_dense_using_builder() {
2744 let mut builder = UnionBuilder::new_dense();
2745
2746 builder.append::<Int32Type>("a", 1).unwrap();
2747 builder.append::<Float64Type>("b", 3.0).unwrap();
2748 builder.append::<Int32Type>("a", 4).unwrap();
2749 builder.append::<Int32Type>("a", 5).unwrap();
2750 builder.append::<Float64Type>("b", 2.0).unwrap();
2751
2752 let union = builder.build().unwrap();
2753
2754 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2755
2756 let mut builder = UnionBuilder::new_dense();
2757
2758 builder.append::<Int32Type>("a", 4).unwrap();
2759 builder.append::<Int32Type>("a", 1).unwrap();
2760 builder.append::<Float64Type>("b", 3.0).unwrap();
2761 builder.append::<Int32Type>("a", 4).unwrap();
2762
2763 let taken = builder.build().unwrap();
2764
2765 assert_eq!(
2766 taken.to_data(),
2767 take(&union, &indices, None).unwrap().to_data()
2768 );
2769 }
2770
2771 #[test]
2772 fn test_take_union_dense_all_match_issue_6206() {
2773 let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2774 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2775
2776 let array = UnionArray::try_new(
2777 fields,
2778 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2779 Some(ScalarBuffer::from_iter(0_i32..5)),
2780 vec![ints],
2781 )
2782 .unwrap();
2783
2784 let indicies = Int64Array::from(vec![0, 2, 4]);
2785 let array = take(&array, &indicies, None).unwrap();
2786 assert_eq!(array.len(), 3);
2787 }
2788
2789 #[test]
2790 fn test_take_bytes_offset_overflow() {
2791 let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2792 let text = ('a'..='z').collect::<String>();
2793 let values = StringArray::from(vec![Some(text.clone())]);
2794 assert!(matches!(
2795 take(&values, &indices, None),
2796 Err(ArrowError::OffsetOverflowError(_))
2797 ));
2798 }
2799
2800 #[test]
2801 fn test_take_run_empty_indices() {
2802 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2803 builder.extend([Some(1), Some(1), Some(2), Some(2)]);
2804 let run_array = builder.finish();
2805
2806 let logical_indices: PrimitiveArray<Int32Type> = PrimitiveArray::from(Vec::<i32>::new());
2807
2808 let result = take_impl(&run_array, &logical_indices).expect("take_run with empty indices");
2809
2810 assert_eq!(result.len(), 0);
2812 assert_eq!(result.null_count(), 0);
2813
2814 let run_result = result
2817 .as_any()
2818 .downcast_ref::<RunArray<Int32Type>>()
2819 .expect("result should be a RunArray");
2820 assert_eq!(run_result.run_ends().len(), 0);
2821 assert_eq!(run_result.values().len(), 0);
2822 }
2823}