1use std::fmt::Display;
21use std::sync::Arc;
22
23use arrow_array::builder::{BufferBuilder, UInt32Builder};
24use arrow_array::cast::AsArray;
25use arrow_array::types::*;
26use arrow_array::*;
27use arrow_buffer::{
28 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
29 bit_util,
30};
31use arrow_data::ArrayDataBuilder;
32use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
33
34use num_traits::{One, Zero};
35
36pub fn take(
88 values: &dyn Array,
89 indices: &dyn Array,
90 options: Option<TakeOptions>,
91) -> Result<ArrayRef, ArrowError> {
92 let options = options.unwrap_or_default();
93 downcast_integer_array!(
94 indices => {
95 if options.check_bounds {
96 check_bounds(values.len(), indices)?;
97 }
98 let indices = indices.to_indices();
99 take_impl(values, &indices)
100 },
101 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
102 )
103}
104
105pub fn take_arrays(
154 arrays: &[ArrayRef],
155 indices: &dyn Array,
156 options: Option<TakeOptions>,
157) -> Result<Vec<ArrayRef>, ArrowError> {
158 arrays
159 .iter()
160 .map(|array| take(array.as_ref(), indices, options.clone()))
161 .collect()
162}
163
164fn check_bounds<T: ArrowPrimitiveType>(
166 len: usize,
167 indices: &PrimitiveArray<T>,
168) -> Result<(), ArrowError>
169where
170 T::Native: Display,
171{
172 let len = match T::Native::from_usize(len) {
173 Some(len) => len,
174 None => {
175 if T::DATA_TYPE.is_integer() {
176 return Ok(());
178 } else {
179 return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
180 }
181 }
182 };
183
184 if indices.null_count() > 0 {
185 indices.iter().flatten().try_for_each(|index| {
186 if index >= len {
187 return Err(ArrowError::ComputeError(format!(
188 "Array index out of bounds, cannot get item at index {index} from {len} entries"
189 )));
190 }
191 Ok(())
192 })
193 } else {
194 let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
195 in_bounds & (i >= T::Native::ZERO) & (i < len)
196 });
197
198 if !in_bounds {
199 for &index in indices.values() {
200 if index < T::Native::ZERO || index >= len {
201 return Err(ArrowError::ComputeError(format!(
202 "Array index out of bounds, cannot get item at index {index} from {len} entries"
203 )));
204 }
205 }
206 }
207
208 Ok(())
209 }
210}
211
212#[inline(never)]
213fn take_impl<IndexType: ArrowPrimitiveType>(
214 values: &dyn Array,
215 indices: &PrimitiveArray<IndexType>,
216) -> Result<ArrayRef, ArrowError> {
217 downcast_primitive_array! {
218 values => Ok(Arc::new(take_primitive(values, indices)?)),
219 DataType::Boolean => {
220 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
221 Ok(Arc::new(take_boolean(values, indices)))
222 }
223 DataType::Utf8 => {
224 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
225 }
226 DataType::LargeUtf8 => {
227 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
228 }
229 DataType::Utf8View => {
230 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
231 }
232 DataType::List(_) => {
233 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
234 }
235 DataType::LargeList(_) => {
236 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
237 }
238 DataType::ListView(_) => {
239 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
240 }
241 DataType::LargeListView(_) => {
242 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
243 }
244 DataType::FixedSizeList(_, length) => {
245 let values = values
246 .as_any()
247 .downcast_ref::<FixedSizeListArray>()
248 .unwrap();
249 Ok(Arc::new(take_fixed_size_list(
250 values,
251 indices,
252 *length as u32,
253 )?))
254 }
255 DataType::Map(_, _) => {
256 let list_arr = ListArray::from(values.as_map().clone());
257 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
258 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
259 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
260 }
261 DataType::Struct(fields) => {
262 let array: &StructArray = values.as_struct();
263 let arrays = array
264 .columns()
265 .iter()
266 .map(|a| take_impl(a.as_ref(), indices))
267 .collect::<Result<Vec<ArrayRef>, _>>()?;
268 let fields: Vec<(FieldRef, ArrayRef)> =
269 fields.iter().cloned().zip(arrays).collect();
270
271 let is_valid: Buffer = indices
273 .iter()
274 .map(|index| {
275 if let Some(index) = index {
276 array.is_valid(index.to_usize().unwrap())
277 } else {
278 false
279 }
280 })
281 .collect();
282
283 if fields.is_empty() {
284 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
285 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
286 } else {
287 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
288 }
289 }
290 DataType::Dictionary(_, _) => downcast_dictionary_array! {
291 values => Ok(Arc::new(take_dict(values, indices)?)),
292 t => unimplemented!("Take not supported for dictionary type {:?}", t)
293 }
294 DataType::RunEndEncoded(_, _) => downcast_run_array! {
295 values => Ok(Arc::new(take_run(values, indices)?)),
296 t => unimplemented!("Take not supported for run type {:?}", t)
297 }
298 DataType::Binary => {
299 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
300 }
301 DataType::LargeBinary => {
302 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
303 }
304 DataType::BinaryView => {
305 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
306 }
307 DataType::FixedSizeBinary(size) => {
308 let values = values
309 .as_any()
310 .downcast_ref::<FixedSizeBinaryArray>()
311 .unwrap();
312 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
313 }
314 DataType::Null => {
315 if values.len() >= indices.len() {
317 Ok(values.slice(0, indices.len()))
320 } else {
321 Ok(new_null_array(&DataType::Null, indices.len()))
323 }
324 }
325 DataType::Union(fields, UnionMode::Sparse) => {
326 let mut children = Vec::with_capacity(fields.len());
327 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
328 let type_ids = take_native(values.type_ids(), indices);
329 for (type_id, _field) in fields.iter() {
330 let values = values.child(type_id);
331 let values = take_impl(values, indices)?;
332 children.push(values);
333 }
334 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
335 Ok(Arc::new(array))
336 }
337 DataType::Union(fields, UnionMode::Dense) => {
338 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
339
340 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
341 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
342
343 let children = fields.iter()
344 .map(|(field_type_id, _)| {
345 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
346
347 let indices = crate::filter::filter(&offsets, &mask)?;
348
349 let values = values.child(field_type_id);
350
351 take_impl(values, indices.as_primitive::<Int32Type>())
352 })
353 .collect::<Result<_, _>>()?;
354
355 let mut child_offsets = [0; 128];
356
357 let offsets = type_ids.values()
358 .iter()
359 .map(|&i| {
360 let offset = child_offsets[i as usize];
361
362 child_offsets[i as usize] += 1;
363
364 offset
365 })
366 .collect();
367
368 let (_, type_ids, _) = type_ids.into_parts();
369
370 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
371
372 Ok(Arc::new(array))
373 }
374 t => unimplemented!("Take not supported for data type {:?}", t)
375 }
376}
377
378#[derive(Clone, Debug, Default)]
380pub struct TakeOptions {
381 pub check_bounds: bool,
385}
386
387fn take_primitive<T, I>(
397 values: &PrimitiveArray<T>,
398 indices: &PrimitiveArray<I>,
399) -> Result<PrimitiveArray<T>, ArrowError>
400where
401 T: ArrowPrimitiveType,
402 I: ArrowPrimitiveType,
403{
404 let values_buf = take_native(values.values(), indices);
405 let nulls = take_nulls(values.nulls(), indices);
406 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
407}
408
409#[inline(never)]
410fn take_nulls<I: ArrowPrimitiveType>(
411 values: Option<&NullBuffer>,
412 indices: &PrimitiveArray<I>,
413) -> Option<NullBuffer> {
414 match values.filter(|n| n.null_count() > 0) {
415 Some(n) => {
416 let buffer = take_bits(n.inner(), indices);
417 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
418 }
419 None => indices.nulls().cloned(),
420 }
421}
422
423#[inline(never)]
424fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
425 values: &[T],
426 indices: &PrimitiveArray<I>,
427) -> ScalarBuffer<T> {
428 match indices.nulls().filter(|n| n.null_count() > 0) {
429 Some(n) => indices
430 .values()
431 .iter()
432 .enumerate()
433 .map(|(idx, index)| match values.get(index.as_usize()) {
434 Some(v) => *v,
435 None => match unsafe { n.inner().value_unchecked(idx) } {
437 false => T::default(),
438 true => panic!("Out-of-bounds index {index:?}"),
439 },
440 })
441 .collect(),
442 None => indices
443 .values()
444 .iter()
445 .map(|index| values[index.as_usize()])
446 .collect(),
447 }
448}
449
450#[inline(never)]
451fn take_bits<I: ArrowPrimitiveType>(
452 values: &BooleanBuffer,
453 indices: &PrimitiveArray<I>,
454) -> BooleanBuffer {
455 let len = indices.len();
456
457 match indices.nulls().filter(|n| n.null_count() > 0) {
458 Some(nulls) => {
459 let mut output_buffer = MutableBuffer::new_null(len);
460 let output_slice = output_buffer.as_slice_mut();
461 nulls.valid_indices().for_each(|idx| {
462 if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
464 unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
466 }
467 });
468 BooleanBuffer::new(output_buffer.into(), 0, len)
469 }
470 None => {
471 BooleanBuffer::collect_bool(len, |idx: usize| {
472 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
474 })
475 }
476 }
477}
478
479fn take_boolean<IndexType: ArrowPrimitiveType>(
481 values: &BooleanArray,
482 indices: &PrimitiveArray<IndexType>,
483) -> BooleanArray {
484 let val_buf = take_bits(values.values(), indices);
485 let null_buf = take_nulls(values.nulls(), indices);
486 BooleanArray::new(val_buf, null_buf)
487}
488
489fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
491 array: &GenericByteArray<T>,
492 indices: &PrimitiveArray<IndexType>,
493) -> Result<GenericByteArray<T>, ArrowError> {
494 let mut offsets = Vec::with_capacity(indices.len() + 1);
495 offsets.push(T::Offset::default());
496
497 let input_offsets = array.value_offsets();
498 let mut capacity = 0;
499 let nulls = take_nulls(array.nulls(), indices);
500
501 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
502 offsets.reserve(indices.len());
503 for index in indices.values() {
504 let index = index.as_usize();
505 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
506 offsets.push(
507 T::Offset::from_usize(capacity)
508 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
509 );
510 }
511 let mut values = Vec::with_capacity(capacity);
512
513 for index in indices.values() {
514 values.extend_from_slice(array.value(index.as_usize()).as_ref());
515 }
516 (offsets, values)
517 } else if indices.null_count() == 0 {
518 offsets.reserve(indices.len());
519 for index in indices.values() {
520 let index = index.as_usize();
521 if array.is_valid(index) {
522 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
523 }
524 offsets.push(
525 T::Offset::from_usize(capacity)
526 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
527 );
528 }
529 let mut values = Vec::with_capacity(capacity);
530
531 for index in indices.values() {
532 let index = index.as_usize();
533 if array.is_valid(index) {
534 values.extend_from_slice(array.value(index).as_ref());
535 }
536 }
537 (offsets, values)
538 } else if array.null_count() == 0 {
539 offsets.reserve(indices.len());
540 for (i, index) in indices.values().iter().enumerate() {
541 let index = index.as_usize();
542 if indices.is_valid(i) {
543 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
544 }
545 offsets.push(
546 T::Offset::from_usize(capacity)
547 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
548 );
549 }
550 let mut values = Vec::with_capacity(capacity);
551
552 for (i, index) in indices.values().iter().enumerate() {
553 if indices.is_valid(i) {
554 values.extend_from_slice(array.value(index.as_usize()).as_ref());
555 }
556 }
557 (offsets, values)
558 } else {
559 let nulls = nulls.as_ref().unwrap();
560 offsets.reserve(indices.len());
561 for (i, index) in indices.values().iter().enumerate() {
562 let index = index.as_usize();
563 if nulls.is_valid(i) {
564 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
565 }
566 offsets.push(
567 T::Offset::from_usize(capacity)
568 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
569 );
570 }
571 let mut values = Vec::with_capacity(capacity);
572
573 for (i, index) in indices.values().iter().enumerate() {
574 let index = index.as_usize();
577 if nulls.is_valid(i) {
578 values.extend_from_slice(array.value(index).as_ref());
579 }
580 }
581 (offsets, values)
582 };
583
584 T::Offset::from_usize(values.len())
585 .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
586
587 let array = unsafe {
588 let offsets = OffsetBuffer::new_unchecked(offsets.into());
589 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
590 };
591
592 Ok(array)
593}
594
595fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
597 array: &GenericByteViewArray<T>,
598 indices: &PrimitiveArray<IndexType>,
599) -> Result<GenericByteViewArray<T>, ArrowError> {
600 let new_views = take_native(array.views(), indices);
601 let new_nulls = take_nulls(array.nulls(), indices);
602 Ok(unsafe {
604 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
605 })
606}
607
608fn take_list<IndexType, OffsetType>(
614 values: &GenericListArray<OffsetType::Native>,
615 indices: &PrimitiveArray<IndexType>,
616) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
617where
618 IndexType: ArrowPrimitiveType,
619 OffsetType: ArrowPrimitiveType,
620 OffsetType::Native: OffsetSizeTrait,
621 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
622{
623 let (list_indices, offsets, null_buf) =
626 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
627
628 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
629 let value_offsets = Buffer::from_vec(offsets);
630 let list_data = ArrayDataBuilder::new(values.data_type().clone())
632 .len(indices.len())
633 .null_bit_buffer(Some(null_buf.into()))
634 .offset(0)
635 .add_child_data(taken.into_data())
636 .add_buffer(value_offsets);
637
638 let list_data = unsafe { list_data.build_unchecked() };
639
640 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
641}
642
643fn take_list_view<IndexType, OffsetType>(
644 values: &GenericListViewArray<OffsetType::Native>,
645 indices: &PrimitiveArray<IndexType>,
646) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
647where
648 IndexType: ArrowPrimitiveType,
649 OffsetType: ArrowPrimitiveType,
650 OffsetType::Native: OffsetSizeTrait,
651{
652 let taken_offsets = take_native(values.offsets(), indices);
653 let taken_sizes = take_native(values.sizes(), indices);
654 let nulls = take_nulls(values.nulls(), indices);
655
656 let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
657 .len(indices.len())
658 .nulls(nulls)
659 .buffers(vec![taken_offsets.into(), taken_sizes.into()])
660 .child_data(vec![values.values().to_data()]);
661
662 let list_view_data = unsafe { list_view_data.build_unchecked() };
664
665 Ok(GenericListViewArray::<OffsetType::Native>::from(
666 list_view_data,
667 ))
668}
669
670fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
676 values: &FixedSizeListArray,
677 indices: &PrimitiveArray<IndexType>,
678 length: <UInt32Type as ArrowPrimitiveType>::Native,
679) -> Result<FixedSizeListArray, ArrowError> {
680 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
681 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
682
683 let num_bytes = bit_util::ceil(indices.len(), 8);
685 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
686 let null_slice = null_buf.as_slice_mut();
687
688 for i in 0..indices.len() {
689 let index = indices
690 .value(i)
691 .to_usize()
692 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
693 if !indices.is_valid(i) || values.is_null(index) {
694 bit_util::unset_bit(null_slice, i);
695 }
696 }
697
698 let list_data = ArrayDataBuilder::new(values.data_type().clone())
699 .len(indices.len())
700 .null_bit_buffer(Some(null_buf.into()))
701 .offset(0)
702 .add_child_data(taken.into_data());
703
704 let list_data = unsafe { list_data.build_unchecked() };
705
706 Ok(FixedSizeListArray::from(list_data))
707}
708
709fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
715 values: &FixedSizeBinaryArray,
716 indices: &PrimitiveArray<IndexType>,
717 size: i32,
718) -> Result<FixedSizeBinaryArray, ArrowError> {
719 let size_usize = usize::try_from(size).map_err(|_| {
720 ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
721 })?;
722
723 let values_buffer = values.values().as_slice();
724 let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
725
726 if indices.null_count() == 0 {
727 let array_iter = indices.values().iter().map(|idx| {
728 let offset = idx.as_usize() * size_usize;
729 &values_buffer[offset..offset + size_usize]
730 });
731 for slice in array_iter {
732 values_buffer_builder.append_slice(slice);
733 }
734 } else {
735 let array_iter = indices.iter().map(|idx| {
738 idx.map(|idx| {
739 let offset = idx.as_usize() * size_usize;
740 &values_buffer[offset..offset + size_usize]
741 })
742 });
743 for slice in array_iter {
744 match slice {
745 None => values_buffer_builder.append_n(size_usize, 0),
746 Some(slice) => values_buffer_builder.append_slice(slice),
747 }
748 }
749 }
750
751 let values_buffer = values_buffer_builder.finish();
752 let value_nulls = take_nulls(values.nulls(), indices);
753 let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
754
755 let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
756 .len(indices.len())
757 .nulls(final_nulls)
758 .offset(0)
759 .add_buffer(values_buffer)
760 .build()?;
761
762 Ok(FixedSizeBinaryArray::from(array_data))
763}
764
765fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
770 values: &DictionaryArray<T>,
771 indices: &PrimitiveArray<I>,
772) -> Result<DictionaryArray<T>, ArrowError> {
773 let new_keys = take_primitive(values.keys(), indices)?;
774 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
775}
776
777fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
786 run_array: &RunArray<T>,
787 logical_indices: &PrimitiveArray<I>,
788) -> Result<RunArray<T>, ArrowError> {
789 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
791
792 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
796 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
797 let mut new_physical_len = 1;
798 for ix in 1..physical_indices.len() {
799 if physical_indices[ix] != physical_indices[ix - 1] {
800 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
801 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
802 new_physical_len += 1;
803 }
804 }
805 take_value_indices
806 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
807 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
808 let new_run_ends = unsafe {
809 ArrayDataBuilder::new(T::DATA_TYPE)
812 .len(new_physical_len)
813 .null_count(0)
814 .add_buffer(new_run_ends_builder.finish())
815 .build_unchecked()
816 };
817
818 let take_value_indices: PrimitiveArray<I> = unsafe {
819 ArrayDataBuilder::new(I::DATA_TYPE)
822 .len(new_physical_len)
823 .null_count(0)
824 .add_buffer(take_value_indices.finish())
825 .build_unchecked()
826 .into()
827 };
828
829 let new_values = take(run_array.values(), &take_value_indices, None)?;
830
831 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
832 .len(physical_indices.len())
833 .add_child_data(new_run_ends)
834 .add_child_data(new_values.into_data());
835 let array_data = unsafe {
836 builder.build_unchecked()
839 };
840 Ok(array_data.into())
841}
842
843#[allow(clippy::type_complexity)]
849fn take_value_indices_from_list<IndexType, OffsetType>(
850 list: &GenericListArray<OffsetType::Native>,
851 indices: &PrimitiveArray<IndexType>,
852) -> Result<
853 (
854 PrimitiveArray<OffsetType>,
855 Vec<OffsetType::Native>,
856 MutableBuffer,
857 ),
858 ArrowError,
859>
860where
861 IndexType: ArrowPrimitiveType,
862 OffsetType: ArrowPrimitiveType,
863 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
864 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
865{
866 let offsets: &[OffsetType::Native] = list.value_offsets();
868
869 let mut new_offsets = Vec::with_capacity(indices.len());
870 let mut values = Vec::new();
871 let mut current_offset = OffsetType::Native::zero();
872 new_offsets.push(OffsetType::Native::zero());
874
875 let num_bytes = bit_util::ceil(indices.len(), 8);
877 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
878 let null_slice = null_buf.as_slice_mut();
879
880 for i in 0..indices.len() {
882 if indices.is_valid(i) {
883 let ix = indices
884 .value(i)
885 .to_usize()
886 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
887 let start = offsets[ix];
888 let end = offsets[ix + 1];
889 current_offset += end - start;
890 new_offsets.push(current_offset);
891
892 let mut curr = start;
893
894 while curr < end {
896 values.push(curr);
897 curr += One::one();
898 }
899 if !list.is_valid(ix) {
900 bit_util::unset_bit(null_slice, i);
901 }
902 } else {
903 bit_util::unset_bit(null_slice, i);
904 new_offsets.push(current_offset);
905 }
906 }
907
908 Ok((
909 PrimitiveArray::<OffsetType>::from(values),
910 new_offsets,
911 null_buf,
912 ))
913}
914
915fn take_value_indices_from_fixed_size_list<IndexType>(
917 list: &FixedSizeListArray,
918 indices: &PrimitiveArray<IndexType>,
919 length: <UInt32Type as ArrowPrimitiveType>::Native,
920) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
921where
922 IndexType: ArrowPrimitiveType,
923{
924 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
925
926 for i in 0..indices.len() {
927 if indices.is_valid(i) {
928 let index = indices
929 .value(i)
930 .to_usize()
931 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
932 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
933
934 unsafe {
936 values.append_trusted_len_iter(start..start + length);
937 }
938 } else {
939 values.append_nulls(length as usize);
940 }
941 }
942
943 Ok(values.finish())
944}
945
946trait ToIndices {
949 type T: ArrowPrimitiveType;
950
951 fn to_indices(&self) -> PrimitiveArray<Self::T>;
952}
953
954macro_rules! to_indices_reinterpret {
955 ($t:ty, $o:ty) => {
956 impl ToIndices for PrimitiveArray<$t> {
957 type T = $o;
958
959 fn to_indices(&self) -> PrimitiveArray<$o> {
960 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
961 PrimitiveArray::new(cast, self.nulls().cloned())
962 }
963 }
964 };
965}
966
967macro_rules! to_indices_identity {
968 ($t:ty) => {
969 impl ToIndices for PrimitiveArray<$t> {
970 type T = $t;
971
972 fn to_indices(&self) -> PrimitiveArray<$t> {
973 self.clone()
974 }
975 }
976 };
977}
978
979macro_rules! to_indices_widening {
980 ($t:ty, $o:ty) => {
981 impl ToIndices for PrimitiveArray<$t> {
982 type T = UInt32Type;
983
984 fn to_indices(&self) -> PrimitiveArray<$o> {
985 let cast = self.values().iter().copied().map(|x| x as _).collect();
986 PrimitiveArray::new(cast, self.nulls().cloned())
987 }
988 }
989 };
990}
991
992to_indices_widening!(UInt8Type, UInt32Type);
993to_indices_widening!(Int8Type, UInt32Type);
994
995to_indices_widening!(UInt16Type, UInt32Type);
996to_indices_widening!(Int16Type, UInt32Type);
997
998to_indices_identity!(UInt32Type);
999to_indices_reinterpret!(Int32Type, UInt32Type);
1000
1001to_indices_identity!(UInt64Type);
1002to_indices_reinterpret!(Int64Type, UInt64Type);
1003
1004pub fn take_record_batch(
1043 record_batch: &RecordBatch,
1044 indices: &dyn Array,
1045) -> Result<RecordBatch, ArrowError> {
1046 let columns = record_batch
1047 .columns()
1048 .iter()
1049 .map(|c| take(c, indices, None))
1050 .collect::<Result<Vec<_>, _>>()?;
1051 RecordBatch::try_new(record_batch.schema(), columns)
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use super::*;
1057 use arrow_array::builder::*;
1058 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1059 use arrow_data::ArrayData;
1060 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1061 use num_traits::ToPrimitive;
1062
1063 fn test_take_decimal_arrays(
1064 data: Vec<Option<i128>>,
1065 index: &UInt32Array,
1066 options: Option<TakeOptions>,
1067 expected_data: Vec<Option<i128>>,
1068 precision: &u8,
1069 scale: &i8,
1070 ) -> Result<(), ArrowError> {
1071 let output = data
1072 .into_iter()
1073 .collect::<Decimal128Array>()
1074 .with_precision_and_scale(*precision, *scale)
1075 .unwrap();
1076
1077 let expected = expected_data
1078 .into_iter()
1079 .collect::<Decimal128Array>()
1080 .with_precision_and_scale(*precision, *scale)
1081 .unwrap();
1082
1083 let expected = Arc::new(expected) as ArrayRef;
1084 let output = take(&output, index, options).unwrap();
1085 assert_eq!(&output, &expected);
1086 Ok(())
1087 }
1088
1089 fn test_take_boolean_arrays(
1090 data: Vec<Option<bool>>,
1091 index: &UInt32Array,
1092 options: Option<TakeOptions>,
1093 expected_data: Vec<Option<bool>>,
1094 ) {
1095 let output = BooleanArray::from(data);
1096 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1097 let output = take(&output, index, options).unwrap();
1098 assert_eq!(&output, &expected)
1099 }
1100
1101 fn test_take_primitive_arrays<T>(
1102 data: Vec<Option<T::Native>>,
1103 index: &UInt32Array,
1104 options: Option<TakeOptions>,
1105 expected_data: Vec<Option<T::Native>>,
1106 ) -> Result<(), ArrowError>
1107 where
1108 T: ArrowPrimitiveType,
1109 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1110 {
1111 let output = PrimitiveArray::<T>::from(data);
1112 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1113 let output = take(&output, index, options)?;
1114 assert_eq!(&output, &expected);
1115 Ok(())
1116 }
1117
1118 fn test_take_primitive_arrays_non_null<T>(
1119 data: Vec<T::Native>,
1120 index: &UInt32Array,
1121 options: Option<TakeOptions>,
1122 expected_data: Vec<Option<T::Native>>,
1123 ) -> Result<(), ArrowError>
1124 where
1125 T: ArrowPrimitiveType,
1126 PrimitiveArray<T>: From<Vec<T::Native>>,
1127 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1128 {
1129 let output = PrimitiveArray::<T>::from(data);
1130 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1131 let output = take(&output, index, options)?;
1132 assert_eq!(&output, &expected);
1133 Ok(())
1134 }
1135
1136 fn test_take_impl_primitive_arrays<T, I>(
1137 data: Vec<Option<T::Native>>,
1138 index: &PrimitiveArray<I>,
1139 options: Option<TakeOptions>,
1140 expected_data: Vec<Option<T::Native>>,
1141 ) where
1142 T: ArrowPrimitiveType,
1143 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1144 I: ArrowPrimitiveType,
1145 {
1146 let output = PrimitiveArray::<T>::from(data);
1147 let expected = PrimitiveArray::<T>::from(expected_data);
1148 let output = take(&output, index, options).unwrap();
1149 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1150 assert_eq!(output, &expected)
1151 }
1152
1153 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1155 let mut struct_builder = StructBuilder::new(
1156 Fields::from(vec![
1157 Field::new("a", DataType::Boolean, true),
1158 Field::new("b", DataType::Int32, true),
1159 ]),
1160 vec![
1161 Box::new(BooleanBuilder::with_capacity(values.len())),
1162 Box::new(Int32Builder::with_capacity(values.len())),
1163 ],
1164 );
1165
1166 for value in values {
1167 struct_builder
1168 .field_builder::<BooleanBuilder>(0)
1169 .unwrap()
1170 .append_option(value.and_then(|v| v.0));
1171 struct_builder
1172 .field_builder::<Int32Builder>(1)
1173 .unwrap()
1174 .append_option(value.and_then(|v| v.1));
1175 struct_builder.append(value.is_some());
1176 }
1177 struct_builder.finish()
1178 }
1179
1180 #[test]
1181 fn test_take_decimal128_non_null_indices() {
1182 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1183 let precision: u8 = 10;
1184 let scale: i8 = 5;
1185 test_take_decimal_arrays(
1186 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1187 &index,
1188 None,
1189 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1190 &precision,
1191 &scale,
1192 )
1193 .unwrap();
1194 }
1195
1196 #[test]
1197 fn test_take_decimal128() {
1198 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1199 let precision: u8 = 10;
1200 let scale: i8 = 5;
1201 test_take_decimal_arrays(
1202 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1203 &index,
1204 None,
1205 vec![Some(3), None, Some(1), Some(3), Some(2)],
1206 &precision,
1207 &scale,
1208 )
1209 .unwrap();
1210 }
1211
1212 #[test]
1213 fn test_take_primitive_non_null_indices() {
1214 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1215 test_take_primitive_arrays::<Int8Type>(
1216 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1217 &index,
1218 None,
1219 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1220 )
1221 .unwrap();
1222 }
1223
1224 #[test]
1225 fn test_take_primitive_non_null_values() {
1226 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1227 test_take_primitive_arrays::<Int8Type>(
1228 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1229 &index,
1230 None,
1231 vec![Some(3), None, Some(1), Some(3), Some(2)],
1232 )
1233 .unwrap();
1234 }
1235
1236 #[test]
1237 fn test_take_primitive_non_null() {
1238 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1239 test_take_primitive_arrays::<Int8Type>(
1240 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1241 &index,
1242 None,
1243 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1244 )
1245 .unwrap();
1246 }
1247
1248 #[test]
1249 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1250 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1251 let index = index.slice(2, 4);
1252 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1253
1254 assert_eq!(
1255 index,
1256 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1257 );
1258
1259 test_take_primitive_arrays_non_null::<Int64Type>(
1260 vec![0, 10, 20, 30, 40, 50],
1261 index,
1262 None,
1263 vec![Some(20), Some(30), None, None],
1264 )
1265 .unwrap();
1266 }
1267
1268 #[test]
1269 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1270 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1271 let index = index.slice(2, 4);
1272 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1273
1274 assert_eq!(
1275 index,
1276 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1277 );
1278
1279 test_take_primitive_arrays::<Int64Type>(
1280 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1281 index,
1282 None,
1283 vec![Some(20), Some(30), None, None],
1284 )
1285 .unwrap();
1286 }
1287
1288 #[test]
1289 fn test_take_primitive() {
1290 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1291
1292 test_take_primitive_arrays::<Int8Type>(
1294 vec![Some(0), None, Some(2), Some(3), None],
1295 &index,
1296 None,
1297 vec![Some(3), None, None, Some(3), Some(2)],
1298 )
1299 .unwrap();
1300
1301 test_take_primitive_arrays::<Int16Type>(
1303 vec![Some(0), None, Some(2), Some(3), None],
1304 &index,
1305 None,
1306 vec![Some(3), None, None, Some(3), Some(2)],
1307 )
1308 .unwrap();
1309
1310 test_take_primitive_arrays::<Int32Type>(
1312 vec![Some(0), None, Some(2), Some(3), None],
1313 &index,
1314 None,
1315 vec![Some(3), None, None, Some(3), Some(2)],
1316 )
1317 .unwrap();
1318
1319 test_take_primitive_arrays::<Int64Type>(
1321 vec![Some(0), None, Some(2), Some(3), None],
1322 &index,
1323 None,
1324 vec![Some(3), None, None, Some(3), Some(2)],
1325 )
1326 .unwrap();
1327
1328 test_take_primitive_arrays::<UInt8Type>(
1330 vec![Some(0), None, Some(2), Some(3), None],
1331 &index,
1332 None,
1333 vec![Some(3), None, None, Some(3), Some(2)],
1334 )
1335 .unwrap();
1336
1337 test_take_primitive_arrays::<UInt16Type>(
1339 vec![Some(0), None, Some(2), Some(3), None],
1340 &index,
1341 None,
1342 vec![Some(3), None, None, Some(3), Some(2)],
1343 )
1344 .unwrap();
1345
1346 test_take_primitive_arrays::<UInt32Type>(
1348 vec![Some(0), None, Some(2), Some(3), None],
1349 &index,
1350 None,
1351 vec![Some(3), None, None, Some(3), Some(2)],
1352 )
1353 .unwrap();
1354
1355 test_take_primitive_arrays::<Int64Type>(
1357 vec![Some(0), None, Some(2), Some(-15), None],
1358 &index,
1359 None,
1360 vec![Some(-15), None, None, Some(-15), Some(2)],
1361 )
1362 .unwrap();
1363
1364 test_take_primitive_arrays::<IntervalYearMonthType>(
1366 vec![Some(0), None, Some(2), Some(-15), None],
1367 &index,
1368 None,
1369 vec![Some(-15), None, None, Some(-15), Some(2)],
1370 )
1371 .unwrap();
1372
1373 let v1 = IntervalDayTime::new(0, 0);
1375 let v2 = IntervalDayTime::new(2, 0);
1376 let v3 = IntervalDayTime::new(-15, 0);
1377 test_take_primitive_arrays::<IntervalDayTimeType>(
1378 vec![Some(v1), None, Some(v2), Some(v3), None],
1379 &index,
1380 None,
1381 vec![Some(v3), None, None, Some(v3), Some(v2)],
1382 )
1383 .unwrap();
1384
1385 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1387 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1388 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1389 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1390 vec![Some(v1), None, Some(v2), Some(v3), None],
1391 &index,
1392 None,
1393 vec![Some(v3), None, None, Some(v3), Some(v2)],
1394 )
1395 .unwrap();
1396
1397 test_take_primitive_arrays::<DurationSecondType>(
1399 vec![Some(0), None, Some(2), Some(-15), None],
1400 &index,
1401 None,
1402 vec![Some(-15), None, None, Some(-15), Some(2)],
1403 )
1404 .unwrap();
1405
1406 test_take_primitive_arrays::<DurationMillisecondType>(
1408 vec![Some(0), None, Some(2), Some(-15), None],
1409 &index,
1410 None,
1411 vec![Some(-15), None, None, Some(-15), Some(2)],
1412 )
1413 .unwrap();
1414
1415 test_take_primitive_arrays::<DurationMicrosecondType>(
1417 vec![Some(0), None, Some(2), Some(-15), None],
1418 &index,
1419 None,
1420 vec![Some(-15), None, None, Some(-15), Some(2)],
1421 )
1422 .unwrap();
1423
1424 test_take_primitive_arrays::<DurationNanosecondType>(
1426 vec![Some(0), None, Some(2), Some(-15), None],
1427 &index,
1428 None,
1429 vec![Some(-15), None, None, Some(-15), Some(2)],
1430 )
1431 .unwrap();
1432
1433 test_take_primitive_arrays::<Float32Type>(
1435 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1436 &index,
1437 None,
1438 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1439 )
1440 .unwrap();
1441
1442 test_take_primitive_arrays::<Float64Type>(
1444 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1445 &index,
1446 None,
1447 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1448 )
1449 .unwrap();
1450 }
1451
1452 #[test]
1453 fn test_take_preserve_timezone() {
1454 let index = Int64Array::from(vec![Some(0), None]);
1455
1456 let input = TimestampNanosecondArray::from(vec![
1457 1_639_715_368_000_000_000,
1458 1_639_715_368_000_000_000,
1459 ])
1460 .with_timezone("UTC".to_string());
1461 let result = take(&input, &index, None).unwrap();
1462 match result.data_type() {
1463 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1464 assert_eq!(tz.clone(), Some("UTC".into()))
1465 }
1466 _ => panic!(),
1467 }
1468 }
1469
1470 #[test]
1471 fn test_take_impl_primitive_with_int64_indices() {
1472 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1473
1474 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1476 vec![Some(0), None, Some(2), Some(3), None],
1477 &index,
1478 None,
1479 vec![Some(3), None, None, Some(3), Some(2)],
1480 );
1481
1482 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1484 vec![Some(0), None, Some(2), Some(-15), None],
1485 &index,
1486 None,
1487 vec![Some(-15), None, None, Some(-15), Some(2)],
1488 );
1489
1490 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1492 vec![Some(0), None, Some(2), Some(3), None],
1493 &index,
1494 None,
1495 vec![Some(3), None, None, Some(3), Some(2)],
1496 );
1497
1498 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1500 vec![Some(0), None, Some(2), Some(-15), None],
1501 &index,
1502 None,
1503 vec![Some(-15), None, None, Some(-15), Some(2)],
1504 );
1505
1506 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1508 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1509 &index,
1510 None,
1511 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1512 );
1513 }
1514
1515 #[test]
1516 fn test_take_impl_primitive_with_uint8_indices() {
1517 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1518
1519 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1521 vec![Some(0), None, Some(2), Some(3), None],
1522 &index,
1523 None,
1524 vec![Some(3), None, None, Some(3), Some(2)],
1525 );
1526
1527 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1529 vec![Some(0), None, Some(2), Some(-15), None],
1530 &index,
1531 None,
1532 vec![Some(-15), None, None, Some(-15), Some(2)],
1533 );
1534
1535 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1537 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1538 &index,
1539 None,
1540 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1541 );
1542 }
1543
1544 #[test]
1545 fn test_take_bool() {
1546 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1547 test_take_boolean_arrays(
1549 vec![Some(false), None, Some(true), Some(false), None],
1550 &index,
1551 None,
1552 vec![Some(false), None, None, Some(false), Some(true)],
1553 );
1554 }
1555
1556 #[test]
1557 fn test_take_bool_nullable_index() {
1558 let index_data = ArrayData::try_new(
1560 DataType::UInt32,
1561 6,
1562 Some(Buffer::from_iter(vec![
1563 false, true, false, true, false, true,
1564 ])),
1565 0,
1566 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1567 vec![],
1568 )
1569 .unwrap();
1570 let index = UInt32Array::from(index_data);
1571 test_take_boolean_arrays(
1572 vec![Some(true), None, Some(false)],
1573 &index,
1574 None,
1575 vec![None, Some(true), None, None, None, Some(false)],
1576 );
1577 }
1578
1579 #[test]
1580 fn test_take_bool_nullable_index_nonnull_values() {
1581 let index_data = ArrayData::try_new(
1583 DataType::UInt32,
1584 6,
1585 Some(Buffer::from_iter(vec![
1586 false, true, false, true, false, true,
1587 ])),
1588 0,
1589 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1590 vec![],
1591 )
1592 .unwrap();
1593 let index = UInt32Array::from(index_data);
1594 test_take_boolean_arrays(
1595 vec![Some(true), Some(true), Some(false)],
1596 &index,
1597 None,
1598 vec![None, Some(true), None, Some(true), None, Some(false)],
1599 );
1600 }
1601
1602 #[test]
1603 fn test_take_bool_with_offset() {
1604 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1605 let index = index.slice(2, 4);
1606 let index = index
1607 .as_any()
1608 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1609 .unwrap();
1610
1611 test_take_boolean_arrays(
1613 vec![Some(false), None, Some(true), Some(false), None],
1614 index,
1615 None,
1616 vec![None, Some(false), Some(true), None],
1617 );
1618 }
1619
1620 fn _test_take_string<'a, K>()
1621 where
1622 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1623 {
1624 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1625
1626 let array = K::from(vec![
1627 Some("one"),
1628 None,
1629 Some("three"),
1630 Some("four"),
1631 Some("five"),
1632 ]);
1633 let actual = take(&array, &index, None).unwrap();
1634 assert_eq!(actual.len(), index.len());
1635
1636 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1637
1638 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1639
1640 assert_eq!(actual, &expected);
1641 }
1642
1643 #[test]
1644 fn test_take_string() {
1645 _test_take_string::<StringArray>()
1646 }
1647
1648 #[test]
1649 fn test_take_large_string() {
1650 _test_take_string::<LargeStringArray>()
1651 }
1652
1653 #[test]
1654 fn test_take_slice_string() {
1655 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1656 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1657 let indices_slice = indices.slice(1, 4);
1658 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1659 let result = take(&strings, &indices_slice, None).unwrap();
1660 assert_eq!(result.as_ref(), &expected);
1661 }
1662
1663 fn _test_byte_view<T>()
1664 where
1665 T: ByteViewType,
1666 str: AsRef<T::Native>,
1667 T::Native: PartialEq,
1668 {
1669 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1670 let array = {
1671 let mut builder = GenericByteViewBuilder::<T>::new();
1673 builder.append_value("hello");
1674 builder.append_value("world");
1675 builder.append_null();
1676 builder.append_value("large payload over 12 bytes");
1677 builder.append_value("lulu");
1678 builder.finish()
1679 };
1680
1681 let actual = take(&array, &index, None).unwrap();
1682
1683 assert_eq!(actual.len(), index.len());
1684
1685 let expected = {
1686 let mut builder = GenericByteViewBuilder::<T>::new();
1688 builder.append_value("large payload over 12 bytes");
1689 builder.append_null();
1690 builder.append_value("world");
1691 builder.append_value("large payload over 12 bytes");
1692 builder.append_value("lulu");
1693 builder.append_null();
1694 builder.finish()
1695 };
1696
1697 assert_eq!(actual.as_ref(), &expected);
1698 }
1699
1700 #[test]
1701 fn test_take_string_view() {
1702 _test_byte_view::<StringViewType>()
1703 }
1704
1705 #[test]
1706 fn test_take_binary_view() {
1707 _test_byte_view::<BinaryViewType>()
1708 }
1709
1710 macro_rules! test_take_list {
1711 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1712 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1714 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1716 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1717 let list_data_type =
1719 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1720 let list_data = ArrayData::builder(list_data_type.clone())
1721 .len(4)
1722 .add_buffer(value_offsets)
1723 .add_child_data(value_data)
1724 .build()
1725 .unwrap();
1726 let list_array = $list_array_type::from(list_data);
1727
1728 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1730
1731 let a = take(&list_array, &index, None).unwrap();
1732 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1733
1734 let expected_data = Int32Array::from(vec![
1737 Some(2),
1738 Some(3),
1739 Some(-1),
1740 Some(-2),
1741 Some(-1),
1742 Some(0),
1743 Some(0),
1744 Some(0),
1745 ])
1746 .into_data();
1747 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1749 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1750 let expected_list_data = ArrayData::builder(list_data_type)
1752 .len(5)
1753 .nulls(index.nulls().cloned())
1755 .add_buffer(expected_offsets)
1756 .add_child_data(expected_data)
1757 .build()
1758 .unwrap();
1759 let expected_list_array = $list_array_type::from(expected_list_data);
1760
1761 assert_eq!(a, &expected_list_array);
1762 }};
1763 }
1764
1765 macro_rules! test_take_list_with_value_nulls {
1766 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1767 let value_data = Int32Array::from(vec![
1769 Some(0),
1770 None,
1771 Some(0),
1772 Some(-1),
1773 Some(-2),
1774 Some(3),
1775 None,
1776 Some(5),
1777 None,
1778 ])
1779 .into_data();
1780 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1782 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1783 let list_data_type =
1785 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1786 let list_data = ArrayData::builder(list_data_type.clone())
1787 .len(4)
1788 .add_buffer(value_offsets)
1789 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1790 .add_child_data(value_data)
1791 .build()
1792 .unwrap();
1793 let list_array = $list_array_type::from(list_data);
1794
1795 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1797
1798 let a = take(&list_array, &index, None).unwrap();
1799 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1800
1801 let expected_data = Int32Array::from(vec![
1804 None,
1805 Some(-1),
1806 Some(-2),
1807 Some(3),
1808 Some(5),
1809 None,
1810 Some(0),
1811 None,
1812 Some(0),
1813 ])
1814 .into_data();
1815 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1817 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1818 let expected_list_data = ArrayData::builder(list_data_type)
1820 .len(5)
1821 .nulls(index.nulls().cloned())
1823 .add_buffer(expected_offsets)
1824 .add_child_data(expected_data)
1825 .build()
1826 .unwrap();
1827 let expected_list_array = $list_array_type::from(expected_list_data);
1828
1829 assert_eq!(a, &expected_list_array);
1830 }};
1831 }
1832
1833 macro_rules! test_take_list_with_nulls {
1834 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1835 let value_data = Int32Array::from(vec![
1837 Some(0),
1838 None,
1839 Some(0),
1840 Some(-1),
1841 Some(-2),
1842 Some(3),
1843 Some(5),
1844 None,
1845 ])
1846 .into_data();
1847 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1849 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1850 let list_data_type =
1852 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1853 let list_data = ArrayData::builder(list_data_type.clone())
1854 .len(4)
1855 .add_buffer(value_offsets)
1856 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1857 .add_child_data(value_data)
1858 .build()
1859 .unwrap();
1860 let list_array = $list_array_type::from(list_data);
1861
1862 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1864
1865 let a = take(&list_array, &index, None).unwrap();
1866 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1867
1868 let expected_data = Int32Array::from(vec![
1871 Some(-1),
1872 Some(-2),
1873 Some(3),
1874 Some(5),
1875 None,
1876 Some(0),
1877 None,
1878 Some(0),
1879 ])
1880 .into_data();
1881 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1883 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1884 let mut null_bits: [u8; 1] = [0; 1];
1886 bit_util::set_bit(&mut null_bits, 2);
1887 bit_util::set_bit(&mut null_bits, 3);
1888 bit_util::set_bit(&mut null_bits, 4);
1889 let expected_list_data = ArrayData::builder(list_data_type)
1890 .len(5)
1891 .null_bit_buffer(Some(Buffer::from(null_bits)))
1893 .add_buffer(expected_offsets)
1894 .add_child_data(expected_data)
1895 .build()
1896 .unwrap();
1897 let expected_list_array = $list_array_type::from(expected_list_data);
1898
1899 assert_eq!(a, &expected_list_array);
1900 }};
1901 }
1902
1903 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1904 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1905 take_indices: Vec<Option<usize>>,
1906 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1907 mapper: F,
1908 ) where
1909 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1910 {
1911 let mut list_view_array =
1912 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1913
1914 for value in values {
1915 list_view_array.append_option(value);
1916 }
1917 let list_view_array = list_view_array.finish();
1918 let list_view_array = mapper(list_view_array);
1919
1920 let mut indices = UInt64Builder::new();
1921 for idx in take_indices {
1922 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
1923 }
1924 let indices = indices.finish();
1925
1926 let taken = take(&list_view_array, &indices, None)
1927 .unwrap()
1928 .as_list_view()
1929 .clone();
1930
1931 let mut expected_array =
1932 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1933 for value in expected {
1934 expected_array.append_option(value);
1935 }
1936 let expected_array = expected_array.finish();
1937
1938 assert_eq!(taken, expected_array);
1939 }
1940
1941 macro_rules! list_view_test_case {
1942 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
1943 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
1944 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
1945 }};
1946 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
1947 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
1948 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
1949 }};
1950 }
1951
1952 fn do_take_fixed_size_list_test<T>(
1953 length: <Int32Type as ArrowPrimitiveType>::Native,
1954 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1955 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1956 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1957 ) where
1958 T: ArrowPrimitiveType,
1959 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1960 {
1961 let indices = UInt32Array::from(indices);
1962
1963 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1964
1965 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1966
1967 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1968
1969 assert_eq!(&output, &expected)
1970 }
1971
1972 #[test]
1973 fn test_take_list() {
1974 test_take_list!(i32, List, ListArray);
1975 }
1976
1977 #[test]
1978 fn test_take_large_list() {
1979 test_take_list!(i64, LargeList, LargeListArray);
1980 }
1981
1982 #[test]
1983 fn test_take_list_with_value_nulls() {
1984 test_take_list_with_value_nulls!(i32, List, ListArray);
1985 }
1986
1987 #[test]
1988 fn test_take_large_list_with_value_nulls() {
1989 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1990 }
1991
1992 #[test]
1993 fn test_test_take_list_with_nulls() {
1994 test_take_list_with_nulls!(i32, List, ListArray);
1995 }
1996
1997 #[test]
1998 fn test_test_take_large_list_with_nulls() {
1999 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2000 }
2001
2002 #[test]
2003 fn test_test_take_list_view_reversed() {
2004 list_view_test_case! {
2006 values: vec![
2007 Some(vec![Some(1), None, Some(3)]),
2008 None,
2009 Some(vec![Some(7), Some(8), None]),
2010 ],
2011 indices: vec![Some(2), Some(1), Some(0)],
2012 expected: vec![
2013 Some(vec![Some(7), Some(8), None]),
2014 None,
2015 Some(vec![Some(1), None, Some(3)]),
2016 ]
2017 }
2018 }
2019
2020 #[test]
2021 fn test_take_list_view_null_indices() {
2022 list_view_test_case! {
2024 values: vec![
2025 Some(vec![Some(1), None, Some(3)]),
2026 None,
2027 Some(vec![Some(7), Some(8), None]),
2028 ],
2029 indices: vec![None, Some(0), None],
2030 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2031 }
2032 }
2033
2034 #[test]
2035 fn test_take_list_view_null_values() {
2036 list_view_test_case! {
2038 values: vec![
2039 Some(vec![Some(1), None, Some(3)]),
2040 None,
2041 Some(vec![Some(7), Some(8), None]),
2042 ],
2043 indices: vec![Some(1), Some(1), Some(1), None, None],
2044 expected: vec![None; 5]
2045 }
2046 }
2047
2048 #[test]
2049 fn test_take_list_view_sliced() {
2050 list_view_test_case! {
2052 values: vec![
2053 Some(vec![Some(1)]),
2054 None,
2055 None,
2056 Some(vec![Some(2), Some(3)]),
2057 Some(vec![Some(4), Some(5)]),
2058 None,
2059 ],
2060 transform: |l| l.slice(2, 4),
2061 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2062 expected: vec![
2063 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2064 ]
2065 }
2066 }
2067
2068 #[test]
2069 fn test_take_fixed_size_list() {
2070 do_take_fixed_size_list_test::<Int32Type>(
2071 3,
2072 vec![
2073 Some(vec![None, Some(1), Some(2)]),
2074 Some(vec![Some(3), Some(4), None]),
2075 Some(vec![Some(6), Some(7), Some(8)]),
2076 ],
2077 vec![2, 1, 0],
2078 vec![
2079 Some(vec![Some(6), Some(7), Some(8)]),
2080 Some(vec![Some(3), Some(4), None]),
2081 Some(vec![None, Some(1), Some(2)]),
2082 ],
2083 );
2084
2085 do_take_fixed_size_list_test::<UInt8Type>(
2086 1,
2087 vec![
2088 Some(vec![Some(1)]),
2089 Some(vec![Some(2)]),
2090 Some(vec![Some(3)]),
2091 Some(vec![Some(4)]),
2092 Some(vec![Some(5)]),
2093 Some(vec![Some(6)]),
2094 Some(vec![Some(7)]),
2095 Some(vec![Some(8)]),
2096 ],
2097 vec![2, 7, 0],
2098 vec![
2099 Some(vec![Some(3)]),
2100 Some(vec![Some(8)]),
2101 Some(vec![Some(1)]),
2102 ],
2103 );
2104
2105 do_take_fixed_size_list_test::<UInt64Type>(
2106 3,
2107 vec![
2108 Some(vec![Some(10), Some(11), Some(12)]),
2109 Some(vec![Some(13), Some(14), Some(15)]),
2110 None,
2111 Some(vec![Some(16), Some(17), Some(18)]),
2112 ],
2113 vec![3, 2, 1, 2, 0],
2114 vec![
2115 Some(vec![Some(16), Some(17), Some(18)]),
2116 None,
2117 Some(vec![Some(13), Some(14), Some(15)]),
2118 None,
2119 Some(vec![Some(10), Some(11), Some(12)]),
2120 ],
2121 );
2122 }
2123
2124 #[test]
2125 fn test_take_fixed_size_binary_with_nulls_indices() {
2126 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2127 [
2128 Some(vec![0x01, 0x01, 0x01, 0x01]),
2129 Some(vec![0x02, 0x02, 0x02, 0x02]),
2130 Some(vec![0x03, 0x03, 0x03, 0x03]),
2131 Some(vec![0x04, 0x04, 0x04, 0x04]),
2132 ]
2133 .into_iter(),
2134 4,
2135 )
2136 .unwrap();
2137
2138 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2140
2141 let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2142 assert_eq!(result.len(), 4);
2143 assert_eq!(result.null_count(), 2);
2144 assert_eq!(
2145 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2146 vec![true, false, false, true]
2147 );
2148 }
2149
2150 #[test]
2151 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2152 fn test_take_list_out_of_bounds() {
2153 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2155 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2157 let list_data_type =
2159 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2160 let list_data = ArrayData::builder(list_data_type)
2161 .len(3)
2162 .add_buffer(value_offsets)
2163 .add_child_data(value_data)
2164 .build()
2165 .unwrap();
2166 let list_array = ListArray::from(list_data);
2167
2168 let index = UInt32Array::from(vec![1000]);
2169
2170 take(&list_array, &index, None).unwrap();
2173 }
2174
2175 #[test]
2176 fn test_take_map() {
2177 let values = Int32Array::from(vec![1, 2, 3, 4]);
2178 let array =
2179 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2180 .unwrap();
2181
2182 let index = UInt32Array::from(vec![0]);
2183
2184 let result = take(&array, &index, None).unwrap();
2185 let expected: ArrayRef = Arc::new(
2186 MapArray::new_from_strings(
2187 vec!["a", "b", "c"].into_iter(),
2188 &values.slice(0, 3),
2189 &[0, 3],
2190 )
2191 .unwrap(),
2192 );
2193 assert_eq!(&expected, &result);
2194 }
2195
2196 #[test]
2197 fn test_take_struct() {
2198 let array = create_test_struct(vec![
2199 Some((Some(true), Some(42))),
2200 Some((Some(false), Some(28))),
2201 Some((Some(false), Some(19))),
2202 Some((Some(true), Some(31))),
2203 None,
2204 ]);
2205
2206 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2207 let actual = take(&array, &index, None).unwrap();
2208 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2209 assert_eq!(index.len(), actual.len());
2210 assert_eq!(1, actual.null_count());
2211
2212 let expected = create_test_struct(vec![
2213 Some((Some(true), Some(42))),
2214 Some((Some(true), Some(31))),
2215 Some((Some(false), Some(28))),
2216 Some((Some(true), Some(42))),
2217 Some((Some(false), Some(19))),
2218 None,
2219 ]);
2220
2221 assert_eq!(&expected, actual);
2222
2223 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2224 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2225 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2226 let actual = take(&empty_struct_arr, &index, None).unwrap();
2227
2228 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2229 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2230 assert_eq!(&expected_struct_arr, actual.as_struct());
2231 }
2232
2233 #[test]
2234 fn test_take_struct_with_null_indices() {
2235 let array = create_test_struct(vec![
2236 Some((Some(true), Some(42))),
2237 Some((Some(false), Some(28))),
2238 Some((Some(false), Some(19))),
2239 Some((Some(true), Some(31))),
2240 None,
2241 ]);
2242
2243 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2244 let actual = take(&array, &index, None).unwrap();
2245 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2246 assert_eq!(index.len(), actual.len());
2247 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2250 None,
2251 Some((Some(true), Some(31))),
2252 Some((Some(false), Some(28))),
2253 None,
2254 Some((Some(true), Some(42))),
2255 None,
2256 ]);
2257
2258 assert_eq!(&expected, actual);
2259 }
2260
2261 #[test]
2262 fn test_take_out_of_bounds() {
2263 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2264 let take_opt = TakeOptions { check_bounds: true };
2265
2266 let result = test_take_primitive_arrays::<Int64Type>(
2268 vec![Some(0), None, Some(2), Some(3), None],
2269 &index,
2270 Some(take_opt),
2271 vec![None],
2272 );
2273 assert!(result.is_err());
2274 }
2275
2276 #[test]
2277 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2278 fn test_take_out_of_bounds_panic() {
2279 let index = UInt32Array::from(vec![Some(1000)]);
2280
2281 test_take_primitive_arrays::<Int64Type>(
2282 vec![Some(0), Some(1), Some(2), Some(3)],
2283 &index,
2284 None,
2285 vec![None],
2286 )
2287 .unwrap();
2288 }
2289
2290 #[test]
2291 fn test_null_array_smaller_than_indices() {
2292 let values = NullArray::new(2);
2293 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2294
2295 let result = take(&values, &indices, None).unwrap();
2296 let expected: ArrayRef = Arc::new(NullArray::new(3));
2297 assert_eq!(&result, &expected);
2298 }
2299
2300 #[test]
2301 fn test_null_array_larger_than_indices() {
2302 let values = NullArray::new(5);
2303 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2304
2305 let result = take(&values, &indices, None).unwrap();
2306 let expected: ArrayRef = Arc::new(NullArray::new(3));
2307 assert_eq!(&result, &expected);
2308 }
2309
2310 #[test]
2311 fn test_null_array_indices_out_of_bounds() {
2312 let values = NullArray::new(5);
2313 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2314
2315 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2316 assert_eq!(
2317 result.unwrap_err().to_string(),
2318 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2319 );
2320 }
2321
2322 #[test]
2323 fn test_take_dict() {
2324 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2325
2326 dict_builder.append("foo").unwrap();
2327 dict_builder.append("bar").unwrap();
2328 dict_builder.append("").unwrap();
2329 dict_builder.append_null();
2330 dict_builder.append("foo").unwrap();
2331 dict_builder.append("bar").unwrap();
2332 dict_builder.append("bar").unwrap();
2333 dict_builder.append("foo").unwrap();
2334
2335 let array = dict_builder.finish();
2336 let dict_values = array.values().clone();
2337 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2338
2339 let indices = UInt32Array::from(vec![
2340 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2348
2349 let result = take(&array, &indices, None).unwrap();
2350 let result = result
2351 .as_any()
2352 .downcast_ref::<DictionaryArray<Int16Type>>()
2353 .unwrap();
2354
2355 let result_values: StringArray = result.values().to_data().into();
2356
2357 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2359 assert_eq!(&expected_values, dict_values);
2360 assert_eq!(&expected_values, &result_values);
2361
2362 let expected_keys = Int16Array::from(vec![
2363 Some(0),
2364 Some(0),
2365 None,
2366 Some(1),
2367 Some(1),
2368 Some(2),
2369 None,
2370 ]);
2371 assert_eq!(result.keys(), &expected_keys);
2372 }
2373
2374 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2375 where
2376 S: OffsetSizeTrait + 'static,
2377 T: ArrowPrimitiveType,
2378 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2379 {
2380 GenericListArray::from_iter_primitive::<T, _, _>(
2381 data.iter()
2382 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2383 )
2384 }
2385
2386 #[test]
2387 fn test_take_value_index_from_list() {
2388 let list = build_generic_list::<i32, Int32Type>(vec![
2389 Some(vec![0, 1]),
2390 Some(vec![2, 3, 4]),
2391 Some(vec![5, 6, 7, 8, 9]),
2392 ]);
2393 let indices = UInt32Array::from(vec![2, 0]);
2394
2395 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2396
2397 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2398 assert_eq!(offsets, vec![0, 5, 7]);
2399 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2400 }
2401
2402 #[test]
2403 fn test_take_value_index_from_large_list() {
2404 let list = build_generic_list::<i64, Int32Type>(vec![
2405 Some(vec![0, 1]),
2406 Some(vec![2, 3, 4]),
2407 Some(vec![5, 6, 7, 8, 9]),
2408 ]);
2409 let indices = UInt32Array::from(vec![2, 0]);
2410
2411 let (indexed, offsets, null_buf) =
2412 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2413
2414 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2415 assert_eq!(offsets, vec![0, 5, 7]);
2416 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2417 }
2418
2419 #[test]
2420 fn test_take_runs() {
2421 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2422
2423 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2424 builder.extend(logical_array.into_iter().map(Some));
2425 let run_array = builder.finish();
2426
2427 let take_indices: PrimitiveArray<Int32Type> =
2428 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2429
2430 let take_out = take_run(&run_array, &take_indices).unwrap();
2431
2432 assert_eq!(take_out.len(), 7);
2433 assert_eq!(take_out.run_ends().len(), 7);
2434 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2435
2436 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2437 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2438 }
2439
2440 #[test]
2441 fn test_take_runs_sliced() {
2442 let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2443
2444 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2445 builder.extend(logical_array.into_iter().map(Some));
2446 let run_array = builder.finish();
2447
2448 let run_array = run_array.slice(4, 6); let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2451
2452 let result = take_run(&run_array, &take_indices).unwrap();
2453 let result = result.downcast::<Int32Array>().unwrap();
2454
2455 let expected = vec![3, 5, 5, 3, 4];
2456 let actual = result.into_iter().flatten().collect::<Vec<_>>();
2457
2458 assert_eq!(expected, actual);
2459 }
2460
2461 #[test]
2462 fn test_take_value_index_from_fixed_list() {
2463 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2464 vec![
2465 Some(vec![Some(1), Some(2), None]),
2466 Some(vec![Some(4), None, Some(6)]),
2467 None,
2468 Some(vec![None, Some(8), Some(9)]),
2469 ],
2470 3,
2471 );
2472
2473 let indices = UInt32Array::from(vec![2, 1, 0]);
2474 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2475
2476 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2477
2478 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2479 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2480
2481 assert_eq!(
2482 indexed,
2483 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2484 );
2485 }
2486
2487 #[test]
2488 fn test_take_null_indices() {
2489 let indices = Int32Array::new(
2491 vec![1, 2, 400, 400].into(),
2492 Some(NullBuffer::from(vec![true, true, false, false])),
2493 );
2494 let values = Int32Array::from(vec![1, 23, 4, 5]);
2495 let r = take(&values, &indices, None).unwrap();
2496 let values = r
2497 .as_primitive::<Int32Type>()
2498 .into_iter()
2499 .collect::<Vec<_>>();
2500 assert_eq!(&values, &[Some(23), Some(4), None, None])
2501 }
2502
2503 #[test]
2504 fn test_take_fixed_size_list_null_indices() {
2505 let indices = Int32Array::from_iter([Some(0), None]);
2506 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2507 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2508 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2509
2510 let r = take(&values, &indices, None).unwrap();
2511 let values = r
2512 .as_fixed_size_list()
2513 .values()
2514 .as_primitive::<Int32Type>()
2515 .into_iter()
2516 .collect::<Vec<_>>();
2517 assert_eq!(values, &[Some(0), Some(1), None, None])
2518 }
2519
2520 #[test]
2521 fn test_take_bytes_null_indices() {
2522 let indices = Int32Array::new(
2523 vec![0, 1, 400, 400].into(),
2524 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2525 );
2526 let values = StringArray::from(vec![Some("foo"), None]);
2527 let r = take(&values, &indices, None).unwrap();
2528 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2529 assert_eq!(&values, &[Some("foo"), None, None, None])
2530 }
2531
2532 #[test]
2533 fn test_take_union_sparse() {
2534 let structs = create_test_struct(vec![
2535 Some((Some(true), Some(42))),
2536 Some((Some(false), Some(28))),
2537 Some((Some(false), Some(19))),
2538 Some((Some(true), Some(31))),
2539 None,
2540 ]);
2541 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2542 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2543
2544 let union_fields = [
2545 (
2546 0,
2547 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2548 ),
2549 (
2550 1,
2551 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2552 ),
2553 ]
2554 .into_iter()
2555 .collect();
2556 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2557 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2558
2559 let indices = vec![0, 3, 1, 0, 2, 4];
2560 let index = UInt32Array::from(indices.clone());
2561 let actual = take(&array, &index, None).unwrap();
2562 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2563 let strings = actual.child(1);
2564 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2565
2566 let actual = strings.iter().collect::<Vec<_>>();
2567 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2568 assert_eq!(expected, actual);
2569 }
2570
2571 #[test]
2572 fn test_take_union_dense() {
2573 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2574 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2575 let ints = vec![10, 20, 30, 40];
2576 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2577
2578 let indices = vec![0, 3, 1, 0, 2, 4];
2579
2580 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2581 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2582 let taken_ints = vec![10, 20, 10, 30];
2583 let taken_strings = vec![Some("a"), None];
2584
2585 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2586 let offsets = <ScalarBuffer<i32>>::from(offsets);
2587 let ints = UInt32Array::from(ints);
2588 let strings = StringArray::from(strings);
2589
2590 let union_fields = [
2591 (
2592 0,
2593 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2594 ),
2595 (
2596 1,
2597 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2598 ),
2599 ]
2600 .into_iter()
2601 .collect();
2602
2603 let array = UnionArray::try_new(
2604 union_fields,
2605 type_ids,
2606 Some(offsets),
2607 vec![Arc::new(ints), Arc::new(strings)],
2608 )
2609 .unwrap();
2610
2611 let index = UInt32Array::from(indices);
2612
2613 let actual = take(&array, &index, None).unwrap();
2614 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2615
2616 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2617 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2618 assert_eq!(
2619 UInt32Array::from(actual.child(0).to_data()),
2620 UInt32Array::from(taken_ints)
2621 );
2622 assert_eq!(
2623 StringArray::from(actual.child(1).to_data()),
2624 StringArray::from(taken_strings)
2625 );
2626 }
2627
2628 #[test]
2629 fn test_take_union_dense_using_builder() {
2630 let mut builder = UnionBuilder::new_dense();
2631
2632 builder.append::<Int32Type>("a", 1).unwrap();
2633 builder.append::<Float64Type>("b", 3.0).unwrap();
2634 builder.append::<Int32Type>("a", 4).unwrap();
2635 builder.append::<Int32Type>("a", 5).unwrap();
2636 builder.append::<Float64Type>("b", 2.0).unwrap();
2637
2638 let union = builder.build().unwrap();
2639
2640 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2641
2642 let mut builder = UnionBuilder::new_dense();
2643
2644 builder.append::<Int32Type>("a", 4).unwrap();
2645 builder.append::<Int32Type>("a", 1).unwrap();
2646 builder.append::<Float64Type>("b", 3.0).unwrap();
2647 builder.append::<Int32Type>("a", 4).unwrap();
2648
2649 let taken = builder.build().unwrap();
2650
2651 assert_eq!(
2652 taken.to_data(),
2653 take(&union, &indices, None).unwrap().to_data()
2654 );
2655 }
2656
2657 #[test]
2658 fn test_take_union_dense_all_match_issue_6206() {
2659 let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2660 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2661
2662 let array = UnionArray::try_new(
2663 fields,
2664 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2665 Some(ScalarBuffer::from_iter(0_i32..5)),
2666 vec![ints],
2667 )
2668 .unwrap();
2669
2670 let indicies = Int64Array::from(vec![0, 2, 4]);
2671 let array = take(&array, &indicies, None).unwrap();
2672 assert_eq!(array.len(), 3);
2673 }
2674
2675 #[test]
2676 fn test_take_bytes_offset_overflow() {
2677 let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2678 let text = ('a'..='z').collect::<String>();
2679 let values = StringArray::from(vec![Some(text.clone())]);
2680 assert!(matches!(
2681 take(&values, &indices, None),
2682 Err(ArrowError::OffsetOverflowError(_))
2683 ));
2684 }
2685}