1use std::fmt::Display;
21use std::sync::Arc;
22
23use arrow_array::builder::{BufferBuilder, UInt32Builder};
24use arrow_array::cast::AsArray;
25use arrow_array::types::*;
26use arrow_array::*;
27use arrow_buffer::{
28 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
29 bit_util,
30};
31use arrow_data::ArrayDataBuilder;
32use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
33
34use num_traits::{One, Zero};
35
36pub fn take(
88 values: &dyn Array,
89 indices: &dyn Array,
90 options: Option<TakeOptions>,
91) -> Result<ArrayRef, ArrowError> {
92 let options = options.unwrap_or_default();
93 downcast_integer_array!(
94 indices => {
95 if options.check_bounds {
96 check_bounds(values.len(), indices)?;
97 }
98 let indices = indices.to_indices();
99 take_impl(values, &indices)
100 },
101 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
102 )
103}
104
105pub fn take_arrays(
154 arrays: &[ArrayRef],
155 indices: &dyn Array,
156 options: Option<TakeOptions>,
157) -> Result<Vec<ArrayRef>, ArrowError> {
158 arrays
159 .iter()
160 .map(|array| take(array.as_ref(), indices, options.clone()))
161 .collect()
162}
163
164fn check_bounds<T: ArrowPrimitiveType>(
166 len: usize,
167 indices: &PrimitiveArray<T>,
168) -> Result<(), ArrowError>
169where
170 T::Native: Display,
171{
172 let len = match T::Native::from_usize(len) {
173 Some(len) => len,
174 None => {
175 if T::DATA_TYPE.is_integer() {
176 return Ok(());
178 } else {
179 return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
180 }
181 }
182 };
183
184 if indices.null_count() > 0 {
185 indices.iter().flatten().try_for_each(|index| {
186 if index >= len {
187 return Err(ArrowError::ComputeError(format!(
188 "Array index out of bounds, cannot get item at index {index} from {len} entries"
189 )));
190 }
191 Ok(())
192 })
193 } else {
194 let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
195 in_bounds & (i >= T::Native::ZERO) & (i < len)
196 });
197
198 if !in_bounds {
199 for &index in indices.values() {
200 if index < T::Native::ZERO || index >= len {
201 return Err(ArrowError::ComputeError(format!(
202 "Array index out of bounds, cannot get item at index {index} from {len} entries"
203 )));
204 }
205 }
206 }
207
208 Ok(())
209 }
210}
211
212#[inline(never)]
213fn take_impl<IndexType: ArrowPrimitiveType>(
214 values: &dyn Array,
215 indices: &PrimitiveArray<IndexType>,
216) -> Result<ArrayRef, ArrowError> {
217 downcast_primitive_array! {
218 values => Ok(Arc::new(take_primitive(values, indices)?)),
219 DataType::Boolean => {
220 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
221 Ok(Arc::new(take_boolean(values, indices)))
222 }
223 DataType::Utf8 => {
224 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
225 }
226 DataType::LargeUtf8 => {
227 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
228 }
229 DataType::Utf8View => {
230 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
231 }
232 DataType::List(_) => {
233 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
234 }
235 DataType::LargeList(_) => {
236 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
237 }
238 DataType::ListView(_) => {
239 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
240 }
241 DataType::LargeListView(_) => {
242 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
243 }
244 DataType::FixedSizeList(_, length) => {
245 let values = values
246 .as_any()
247 .downcast_ref::<FixedSizeListArray>()
248 .unwrap();
249 Ok(Arc::new(take_fixed_size_list(
250 values,
251 indices,
252 *length as u32,
253 )?))
254 }
255 DataType::Map(_, _) => {
256 let list_arr = ListArray::from(values.as_map().clone());
257 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
258 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
259 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
260 }
261 DataType::Struct(fields) => {
262 let array: &StructArray = values.as_struct();
263 let arrays = array
264 .columns()
265 .iter()
266 .map(|a| take_impl(a.as_ref(), indices))
267 .collect::<Result<Vec<ArrayRef>, _>>()?;
268 let fields: Vec<(FieldRef, ArrayRef)> =
269 fields.iter().cloned().zip(arrays).collect();
270
271 let is_valid: Buffer = indices
273 .iter()
274 .map(|index| {
275 if let Some(index) = index {
276 array.is_valid(index.to_usize().unwrap())
277 } else {
278 false
279 }
280 })
281 .collect();
282
283 if fields.is_empty() {
284 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
285 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
286 } else {
287 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
288 }
289 }
290 DataType::Dictionary(_, _) => downcast_dictionary_array! {
291 values => Ok(Arc::new(take_dict(values, indices)?)),
292 t => unimplemented!("Take not supported for dictionary type {:?}", t)
293 }
294 DataType::RunEndEncoded(_, _) => downcast_run_array! {
295 values => Ok(Arc::new(take_run(values, indices)?)),
296 t => unimplemented!("Take not supported for run type {:?}", t)
297 }
298 DataType::Binary => {
299 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
300 }
301 DataType::LargeBinary => {
302 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
303 }
304 DataType::BinaryView => {
305 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
306 }
307 DataType::FixedSizeBinary(size) => {
308 let values = values
309 .as_any()
310 .downcast_ref::<FixedSizeBinaryArray>()
311 .unwrap();
312 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
313 }
314 DataType::Null => {
315 if values.len() >= indices.len() {
317 Ok(values.slice(0, indices.len()))
320 } else {
321 Ok(new_null_array(&DataType::Null, indices.len()))
323 }
324 }
325 DataType::Union(fields, UnionMode::Sparse) => {
326 let mut children = Vec::with_capacity(fields.len());
327 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
328 let type_ids = take_native(values.type_ids(), indices);
329 for (type_id, _field) in fields.iter() {
330 let values = values.child(type_id);
331 let values = take_impl(values, indices)?;
332 children.push(values);
333 }
334 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
335 Ok(Arc::new(array))
336 }
337 DataType::Union(fields, UnionMode::Dense) => {
338 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
339
340 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
341 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
342
343 let children = fields.iter()
344 .map(|(field_type_id, _)| {
345 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
346
347 let indices = crate::filter::filter(&offsets, &mask)?;
348
349 let values = values.child(field_type_id);
350
351 take_impl(values, indices.as_primitive::<Int32Type>())
352 })
353 .collect::<Result<_, _>>()?;
354
355 let mut child_offsets = [0; 128];
356
357 let offsets = type_ids.values()
358 .iter()
359 .map(|&i| {
360 let offset = child_offsets[i as usize];
361
362 child_offsets[i as usize] += 1;
363
364 offset
365 })
366 .collect();
367
368 let (_, type_ids, _) = type_ids.into_parts();
369
370 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
371
372 Ok(Arc::new(array))
373 }
374 t => unimplemented!("Take not supported for data type {:?}", t)
375 }
376}
377
378#[derive(Clone, Debug, Default)]
380pub struct TakeOptions {
381 pub check_bounds: bool,
385}
386
387fn take_primitive<T, I>(
397 values: &PrimitiveArray<T>,
398 indices: &PrimitiveArray<I>,
399) -> Result<PrimitiveArray<T>, ArrowError>
400where
401 T: ArrowPrimitiveType,
402 I: ArrowPrimitiveType,
403{
404 let values_buf = take_native(values.values(), indices);
405 let nulls = take_nulls(values.nulls(), indices);
406 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
407}
408
409#[inline(never)]
410fn take_nulls<I: ArrowPrimitiveType>(
411 values: Option<&NullBuffer>,
412 indices: &PrimitiveArray<I>,
413) -> Option<NullBuffer> {
414 match values.filter(|n| n.null_count() > 0) {
415 Some(n) => {
416 let buffer = take_bits(n.inner(), indices);
417 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
418 }
419 None => indices.nulls().cloned(),
420 }
421}
422
423#[inline(never)]
424fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
425 values: &[T],
426 indices: &PrimitiveArray<I>,
427) -> ScalarBuffer<T> {
428 match indices.nulls().filter(|n| n.null_count() > 0) {
429 Some(n) => indices
430 .values()
431 .iter()
432 .enumerate()
433 .map(|(idx, index)| match values.get(index.as_usize()) {
434 Some(v) => *v,
435 None => match unsafe { n.inner().value_unchecked(idx) } {
437 false => T::default(),
438 true => panic!("Out-of-bounds index {index:?}"),
439 },
440 })
441 .collect(),
442 None => indices
443 .values()
444 .iter()
445 .map(|index| values[index.as_usize()])
446 .collect(),
447 }
448}
449
450#[inline(never)]
451fn take_bits<I: ArrowPrimitiveType>(
452 values: &BooleanBuffer,
453 indices: &PrimitiveArray<I>,
454) -> BooleanBuffer {
455 let len = indices.len();
456
457 match indices.nulls().filter(|n| n.null_count() > 0) {
458 Some(nulls) => {
459 let mut output_buffer = MutableBuffer::new_null(len);
460 let output_slice = output_buffer.as_slice_mut();
461 nulls.valid_indices().for_each(|idx| {
462 if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
464 unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
466 }
467 });
468 BooleanBuffer::new(output_buffer.into(), 0, len)
469 }
470 None => {
471 BooleanBuffer::collect_bool(len, |idx: usize| {
472 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
474 })
475 }
476 }
477}
478
479fn take_boolean<IndexType: ArrowPrimitiveType>(
481 values: &BooleanArray,
482 indices: &PrimitiveArray<IndexType>,
483) -> BooleanArray {
484 let val_buf = take_bits(values.values(), indices);
485 let null_buf = take_nulls(values.nulls(), indices);
486 BooleanArray::new(val_buf, null_buf)
487}
488
489fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
491 array: &GenericByteArray<T>,
492 indices: &PrimitiveArray<IndexType>,
493) -> Result<GenericByteArray<T>, ArrowError> {
494 let mut offsets = Vec::with_capacity(indices.len() + 1);
495 offsets.push(T::Offset::default());
496
497 let input_offsets = array.value_offsets();
498 let mut capacity = 0;
499 let nulls = take_nulls(array.nulls(), indices);
500
501 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
502 offsets.reserve(indices.len());
503 for index in indices.values() {
504 let index = index.as_usize();
505 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
506 offsets.push(
507 T::Offset::from_usize(capacity)
508 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
509 );
510 }
511 let mut values = Vec::with_capacity(capacity);
512
513 for index in indices.values() {
514 values.extend_from_slice(array.value(index.as_usize()).as_ref());
515 }
516 (offsets, values)
517 } else if indices.null_count() == 0 {
518 offsets.reserve(indices.len());
519 for index in indices.values() {
520 let index = index.as_usize();
521 if array.is_valid(index) {
522 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
523 }
524 offsets.push(
525 T::Offset::from_usize(capacity)
526 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
527 );
528 }
529 let mut values = Vec::with_capacity(capacity);
530
531 for index in indices.values() {
532 let index = index.as_usize();
533 if array.is_valid(index) {
534 values.extend_from_slice(array.value(index).as_ref());
535 }
536 }
537 (offsets, values)
538 } else if array.null_count() == 0 {
539 offsets.reserve(indices.len());
540 for (i, index) in indices.values().iter().enumerate() {
541 let index = index.as_usize();
542 if indices.is_valid(i) {
543 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
544 }
545 offsets.push(
546 T::Offset::from_usize(capacity)
547 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
548 );
549 }
550 let mut values = Vec::with_capacity(capacity);
551
552 for (i, index) in indices.values().iter().enumerate() {
553 if indices.is_valid(i) {
554 values.extend_from_slice(array.value(index.as_usize()).as_ref());
555 }
556 }
557 (offsets, values)
558 } else {
559 let nulls = nulls.as_ref().unwrap();
560 offsets.reserve(indices.len());
561 for (i, index) in indices.values().iter().enumerate() {
562 let index = index.as_usize();
563 if nulls.is_valid(i) {
564 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
565 }
566 offsets.push(
567 T::Offset::from_usize(capacity)
568 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
569 );
570 }
571 let mut values = Vec::with_capacity(capacity);
572
573 for (i, index) in indices.values().iter().enumerate() {
574 let index = index.as_usize();
577 if nulls.is_valid(i) {
578 values.extend_from_slice(array.value(index).as_ref());
579 }
580 }
581 (offsets, values)
582 };
583
584 T::Offset::from_usize(values.len())
585 .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?;
586
587 let array = unsafe {
588 let offsets = OffsetBuffer::new_unchecked(offsets.into());
589 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
590 };
591
592 Ok(array)
593}
594
595fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
597 array: &GenericByteViewArray<T>,
598 indices: &PrimitiveArray<IndexType>,
599) -> Result<GenericByteViewArray<T>, ArrowError> {
600 let new_views = take_native(array.views(), indices);
601 let new_nulls = take_nulls(array.nulls(), indices);
602 Ok(unsafe {
604 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
605 })
606}
607
608fn take_list<IndexType, OffsetType>(
614 values: &GenericListArray<OffsetType::Native>,
615 indices: &PrimitiveArray<IndexType>,
616) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
617where
618 IndexType: ArrowPrimitiveType,
619 OffsetType: ArrowPrimitiveType,
620 OffsetType::Native: OffsetSizeTrait,
621 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
622{
623 let (list_indices, offsets, null_buf) =
626 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
627
628 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
629 let value_offsets = Buffer::from_vec(offsets);
630 let list_data = ArrayDataBuilder::new(values.data_type().clone())
632 .len(indices.len())
633 .null_bit_buffer(Some(null_buf.into()))
634 .offset(0)
635 .add_child_data(taken.into_data())
636 .add_buffer(value_offsets);
637
638 let list_data = unsafe { list_data.build_unchecked() };
639
640 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
641}
642
643fn take_list_view<IndexType, OffsetType>(
644 values: &GenericListViewArray<OffsetType::Native>,
645 indices: &PrimitiveArray<IndexType>,
646) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
647where
648 IndexType: ArrowPrimitiveType,
649 OffsetType: ArrowPrimitiveType,
650 OffsetType::Native: OffsetSizeTrait,
651{
652 let taken_offsets = take_native(values.offsets(), indices);
653 let taken_sizes = take_native(values.sizes(), indices);
654 let nulls = take_nulls(values.nulls(), indices);
655
656 let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
657 .len(indices.len())
658 .nulls(nulls)
659 .buffers(vec![taken_offsets.into(), taken_sizes.into()])
660 .child_data(vec![values.values().to_data()]);
661
662 let list_view_data = unsafe { list_view_data.build_unchecked() };
664
665 Ok(GenericListViewArray::<OffsetType::Native>::from(
666 list_view_data,
667 ))
668}
669
670fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
676 values: &FixedSizeListArray,
677 indices: &PrimitiveArray<IndexType>,
678 length: <UInt32Type as ArrowPrimitiveType>::Native,
679) -> Result<FixedSizeListArray, ArrowError> {
680 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
681 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
682
683 let num_bytes = bit_util::ceil(indices.len(), 8);
685 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
686 let null_slice = null_buf.as_slice_mut();
687
688 for i in 0..indices.len() {
689 let index = indices
690 .value(i)
691 .to_usize()
692 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
693 if !indices.is_valid(i) || values.is_null(index) {
694 bit_util::unset_bit(null_slice, i);
695 }
696 }
697
698 let list_data = ArrayDataBuilder::new(values.data_type().clone())
699 .len(indices.len())
700 .null_bit_buffer(Some(null_buf.into()))
701 .offset(0)
702 .add_child_data(taken.into_data());
703
704 let list_data = unsafe { list_data.build_unchecked() };
705
706 Ok(FixedSizeListArray::from(list_data))
707}
708
709fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
715 values: &FixedSizeBinaryArray,
716 indices: &PrimitiveArray<IndexType>,
717 size: i32,
718) -> Result<FixedSizeBinaryArray, ArrowError> {
719 let size_usize = usize::try_from(size).map_err(|_| {
720 ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
721 })?;
722
723 let values_buffer = values.values().as_slice();
724 let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
725
726 if indices.null_count() == 0 {
727 let array_iter = indices.values().iter().map(|idx| {
728 let offset = idx.as_usize() * size_usize;
729 &values_buffer[offset..offset + size_usize]
730 });
731 for slice in array_iter {
732 values_buffer_builder.append_slice(slice);
733 }
734 } else {
735 let array_iter = indices.iter().map(|idx| {
738 idx.map(|idx| {
739 let offset = idx.as_usize() * size_usize;
740 &values_buffer[offset..offset + size_usize]
741 })
742 });
743 for slice in array_iter {
744 match slice {
745 None => values_buffer_builder.append_n(size_usize, 0),
746 Some(slice) => values_buffer_builder.append_slice(slice),
747 }
748 }
749 }
750
751 let values_buffer = values_buffer_builder.finish();
752 let value_nulls = take_nulls(values.nulls(), indices);
753 let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
754
755 let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
756 .len(indices.len())
757 .nulls(final_nulls)
758 .offset(0)
759 .add_buffer(values_buffer)
760 .build()?;
761
762 Ok(FixedSizeBinaryArray::from(array_data))
763}
764
765fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
770 values: &DictionaryArray<T>,
771 indices: &PrimitiveArray<I>,
772) -> Result<DictionaryArray<T>, ArrowError> {
773 let new_keys = take_primitive(values.keys(), indices)?;
774 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
775}
776
777fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
786 run_array: &RunArray<T>,
787 logical_indices: &PrimitiveArray<I>,
788) -> Result<RunArray<T>, ArrowError> {
789 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
791
792 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
796 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
797 let mut new_physical_len = 1;
798 for ix in 1..physical_indices.len() {
799 if physical_indices[ix] != physical_indices[ix - 1] {
800 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
801 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
802 new_physical_len += 1;
803 }
804 }
805 take_value_indices
806 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
807 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
808 let new_run_ends = unsafe {
809 ArrayDataBuilder::new(T::DATA_TYPE)
812 .len(new_physical_len)
813 .null_count(0)
814 .add_buffer(new_run_ends_builder.finish())
815 .build_unchecked()
816 };
817
818 let take_value_indices: PrimitiveArray<I> = unsafe {
819 ArrayDataBuilder::new(I::DATA_TYPE)
822 .len(new_physical_len)
823 .null_count(0)
824 .add_buffer(take_value_indices.finish())
825 .build_unchecked()
826 .into()
827 };
828
829 let new_values = take(run_array.values(), &take_value_indices, None)?;
830
831 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
832 .len(physical_indices.len())
833 .add_child_data(new_run_ends)
834 .add_child_data(new_values.into_data());
835 let array_data = unsafe {
836 builder.build_unchecked()
839 };
840 Ok(array_data.into())
841}
842
843#[allow(clippy::type_complexity)]
849fn take_value_indices_from_list<IndexType, OffsetType>(
850 list: &GenericListArray<OffsetType::Native>,
851 indices: &PrimitiveArray<IndexType>,
852) -> Result<
853 (
854 PrimitiveArray<OffsetType>,
855 Vec<OffsetType::Native>,
856 MutableBuffer,
857 ),
858 ArrowError,
859>
860where
861 IndexType: ArrowPrimitiveType,
862 OffsetType: ArrowPrimitiveType,
863 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
864 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
865{
866 let offsets: &[OffsetType::Native] = list.value_offsets();
868
869 let mut new_offsets = Vec::with_capacity(indices.len());
870 let mut values = Vec::new();
871 let mut current_offset = OffsetType::Native::zero();
872 new_offsets.push(OffsetType::Native::zero());
874
875 let num_bytes = bit_util::ceil(indices.len(), 8);
877 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
878 let null_slice = null_buf.as_slice_mut();
879
880 for i in 0..indices.len() {
882 if indices.is_valid(i) {
883 let ix = indices
884 .value(i)
885 .to_usize()
886 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
887 let start = offsets[ix];
888 let end = offsets[ix + 1];
889 current_offset += end - start;
890 new_offsets.push(current_offset);
891
892 let mut curr = start;
893
894 while curr < end {
896 values.push(curr);
897 curr += One::one();
898 }
899 if !list.is_valid(ix) {
900 bit_util::unset_bit(null_slice, i);
901 }
902 } else {
903 bit_util::unset_bit(null_slice, i);
904 new_offsets.push(current_offset);
905 }
906 }
907
908 Ok((
909 PrimitiveArray::<OffsetType>::from(values),
910 new_offsets,
911 null_buf,
912 ))
913}
914
915fn take_value_indices_from_fixed_size_list<IndexType>(
917 list: &FixedSizeListArray,
918 indices: &PrimitiveArray<IndexType>,
919 length: <UInt32Type as ArrowPrimitiveType>::Native,
920) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
921where
922 IndexType: ArrowPrimitiveType,
923{
924 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
925
926 for i in 0..indices.len() {
927 if indices.is_valid(i) {
928 let index = indices
929 .value(i)
930 .to_usize()
931 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
932 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
933
934 unsafe {
936 values.append_trusted_len_iter(start..start + length);
937 }
938 } else {
939 values.append_nulls(length as usize);
940 }
941 }
942
943 Ok(values.finish())
944}
945
946trait ToIndices {
949 type T: ArrowPrimitiveType;
950
951 fn to_indices(&self) -> PrimitiveArray<Self::T>;
952}
953
954macro_rules! to_indices_reinterpret {
955 ($t:ty, $o:ty) => {
956 impl ToIndices for PrimitiveArray<$t> {
957 type T = $o;
958
959 fn to_indices(&self) -> PrimitiveArray<$o> {
960 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
961 PrimitiveArray::new(cast, self.nulls().cloned())
962 }
963 }
964 };
965}
966
967macro_rules! to_indices_identity {
968 ($t:ty) => {
969 impl ToIndices for PrimitiveArray<$t> {
970 type T = $t;
971
972 fn to_indices(&self) -> PrimitiveArray<$t> {
973 self.clone()
974 }
975 }
976 };
977}
978
979macro_rules! to_indices_widening {
980 ($t:ty, $o:ty) => {
981 impl ToIndices for PrimitiveArray<$t> {
982 type T = UInt32Type;
983
984 fn to_indices(&self) -> PrimitiveArray<$o> {
985 let cast = self.values().iter().copied().map(|x| x as _).collect();
986 PrimitiveArray::new(cast, self.nulls().cloned())
987 }
988 }
989 };
990}
991
992to_indices_widening!(UInt8Type, UInt32Type);
993to_indices_widening!(Int8Type, UInt32Type);
994
995to_indices_widening!(UInt16Type, UInt32Type);
996to_indices_widening!(Int16Type, UInt32Type);
997
998to_indices_identity!(UInt32Type);
999to_indices_reinterpret!(Int32Type, UInt32Type);
1000
1001to_indices_identity!(UInt64Type);
1002to_indices_reinterpret!(Int64Type, UInt64Type);
1003
1004pub fn take_record_batch(
1044 record_batch: &RecordBatch,
1045 indices: &dyn Array,
1046) -> Result<RecordBatch, ArrowError> {
1047 let columns = record_batch
1048 .columns()
1049 .iter()
1050 .map(|c| take(c, indices, None))
1051 .collect::<Result<Vec<_>, _>>()?;
1052 RecordBatch::try_new(record_batch.schema(), columns)
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use super::*;
1058 use arrow_array::builder::*;
1059 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1060 use arrow_data::ArrayData;
1061 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1062 use num_traits::ToPrimitive;
1063
1064 fn test_take_decimal_arrays(
1065 data: Vec<Option<i128>>,
1066 index: &UInt32Array,
1067 options: Option<TakeOptions>,
1068 expected_data: Vec<Option<i128>>,
1069 precision: &u8,
1070 scale: &i8,
1071 ) -> Result<(), ArrowError> {
1072 let output = data
1073 .into_iter()
1074 .collect::<Decimal128Array>()
1075 .with_precision_and_scale(*precision, *scale)
1076 .unwrap();
1077
1078 let expected = expected_data
1079 .into_iter()
1080 .collect::<Decimal128Array>()
1081 .with_precision_and_scale(*precision, *scale)
1082 .unwrap();
1083
1084 let expected = Arc::new(expected) as ArrayRef;
1085 let output = take(&output, index, options).unwrap();
1086 assert_eq!(&output, &expected);
1087 Ok(())
1088 }
1089
1090 fn test_take_boolean_arrays(
1091 data: Vec<Option<bool>>,
1092 index: &UInt32Array,
1093 options: Option<TakeOptions>,
1094 expected_data: Vec<Option<bool>>,
1095 ) {
1096 let output = BooleanArray::from(data);
1097 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1098 let output = take(&output, index, options).unwrap();
1099 assert_eq!(&output, &expected)
1100 }
1101
1102 fn test_take_primitive_arrays<T>(
1103 data: Vec<Option<T::Native>>,
1104 index: &UInt32Array,
1105 options: Option<TakeOptions>,
1106 expected_data: Vec<Option<T::Native>>,
1107 ) -> Result<(), ArrowError>
1108 where
1109 T: ArrowPrimitiveType,
1110 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1111 {
1112 let output = PrimitiveArray::<T>::from(data);
1113 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1114 let output = take(&output, index, options)?;
1115 assert_eq!(&output, &expected);
1116 Ok(())
1117 }
1118
1119 fn test_take_primitive_arrays_non_null<T>(
1120 data: Vec<T::Native>,
1121 index: &UInt32Array,
1122 options: Option<TakeOptions>,
1123 expected_data: Vec<Option<T::Native>>,
1124 ) -> Result<(), ArrowError>
1125 where
1126 T: ArrowPrimitiveType,
1127 PrimitiveArray<T>: From<Vec<T::Native>>,
1128 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1129 {
1130 let output = PrimitiveArray::<T>::from(data);
1131 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1132 let output = take(&output, index, options)?;
1133 assert_eq!(&output, &expected);
1134 Ok(())
1135 }
1136
1137 fn test_take_impl_primitive_arrays<T, I>(
1138 data: Vec<Option<T::Native>>,
1139 index: &PrimitiveArray<I>,
1140 options: Option<TakeOptions>,
1141 expected_data: Vec<Option<T::Native>>,
1142 ) where
1143 T: ArrowPrimitiveType,
1144 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1145 I: ArrowPrimitiveType,
1146 {
1147 let output = PrimitiveArray::<T>::from(data);
1148 let expected = PrimitiveArray::<T>::from(expected_data);
1149 let output = take(&output, index, options).unwrap();
1150 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1151 assert_eq!(output, &expected)
1152 }
1153
1154 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1156 let mut struct_builder = StructBuilder::new(
1157 Fields::from(vec![
1158 Field::new("a", DataType::Boolean, true),
1159 Field::new("b", DataType::Int32, true),
1160 ]),
1161 vec![
1162 Box::new(BooleanBuilder::with_capacity(values.len())),
1163 Box::new(Int32Builder::with_capacity(values.len())),
1164 ],
1165 );
1166
1167 for value in values {
1168 struct_builder
1169 .field_builder::<BooleanBuilder>(0)
1170 .unwrap()
1171 .append_option(value.and_then(|v| v.0));
1172 struct_builder
1173 .field_builder::<Int32Builder>(1)
1174 .unwrap()
1175 .append_option(value.and_then(|v| v.1));
1176 struct_builder.append(value.is_some());
1177 }
1178 struct_builder.finish()
1179 }
1180
1181 #[test]
1182 fn test_take_decimal128_non_null_indices() {
1183 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1184 let precision: u8 = 10;
1185 let scale: i8 = 5;
1186 test_take_decimal_arrays(
1187 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1188 &index,
1189 None,
1190 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1191 &precision,
1192 &scale,
1193 )
1194 .unwrap();
1195 }
1196
1197 #[test]
1198 fn test_take_decimal128() {
1199 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1200 let precision: u8 = 10;
1201 let scale: i8 = 5;
1202 test_take_decimal_arrays(
1203 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1204 &index,
1205 None,
1206 vec![Some(3), None, Some(1), Some(3), Some(2)],
1207 &precision,
1208 &scale,
1209 )
1210 .unwrap();
1211 }
1212
1213 #[test]
1214 fn test_take_primitive_non_null_indices() {
1215 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1216 test_take_primitive_arrays::<Int8Type>(
1217 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1218 &index,
1219 None,
1220 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1221 )
1222 .unwrap();
1223 }
1224
1225 #[test]
1226 fn test_take_primitive_non_null_values() {
1227 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1228 test_take_primitive_arrays::<Int8Type>(
1229 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1230 &index,
1231 None,
1232 vec![Some(3), None, Some(1), Some(3), Some(2)],
1233 )
1234 .unwrap();
1235 }
1236
1237 #[test]
1238 fn test_take_primitive_non_null() {
1239 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1240 test_take_primitive_arrays::<Int8Type>(
1241 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1242 &index,
1243 None,
1244 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1245 )
1246 .unwrap();
1247 }
1248
1249 #[test]
1250 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1251 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1252 let index = index.slice(2, 4);
1253 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1254
1255 assert_eq!(
1256 index,
1257 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1258 );
1259
1260 test_take_primitive_arrays_non_null::<Int64Type>(
1261 vec![0, 10, 20, 30, 40, 50],
1262 index,
1263 None,
1264 vec![Some(20), Some(30), None, None],
1265 )
1266 .unwrap();
1267 }
1268
1269 #[test]
1270 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1271 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1272 let index = index.slice(2, 4);
1273 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1274
1275 assert_eq!(
1276 index,
1277 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1278 );
1279
1280 test_take_primitive_arrays::<Int64Type>(
1281 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1282 index,
1283 None,
1284 vec![Some(20), Some(30), None, None],
1285 )
1286 .unwrap();
1287 }
1288
1289 #[test]
1290 fn test_take_primitive() {
1291 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1292
1293 test_take_primitive_arrays::<Int8Type>(
1295 vec![Some(0), None, Some(2), Some(3), None],
1296 &index,
1297 None,
1298 vec![Some(3), None, None, Some(3), Some(2)],
1299 )
1300 .unwrap();
1301
1302 test_take_primitive_arrays::<Int16Type>(
1304 vec![Some(0), None, Some(2), Some(3), None],
1305 &index,
1306 None,
1307 vec![Some(3), None, None, Some(3), Some(2)],
1308 )
1309 .unwrap();
1310
1311 test_take_primitive_arrays::<Int32Type>(
1313 vec![Some(0), None, Some(2), Some(3), None],
1314 &index,
1315 None,
1316 vec![Some(3), None, None, Some(3), Some(2)],
1317 )
1318 .unwrap();
1319
1320 test_take_primitive_arrays::<Int64Type>(
1322 vec![Some(0), None, Some(2), Some(3), None],
1323 &index,
1324 None,
1325 vec![Some(3), None, None, Some(3), Some(2)],
1326 )
1327 .unwrap();
1328
1329 test_take_primitive_arrays::<UInt8Type>(
1331 vec![Some(0), None, Some(2), Some(3), None],
1332 &index,
1333 None,
1334 vec![Some(3), None, None, Some(3), Some(2)],
1335 )
1336 .unwrap();
1337
1338 test_take_primitive_arrays::<UInt16Type>(
1340 vec![Some(0), None, Some(2), Some(3), None],
1341 &index,
1342 None,
1343 vec![Some(3), None, None, Some(3), Some(2)],
1344 )
1345 .unwrap();
1346
1347 test_take_primitive_arrays::<UInt32Type>(
1349 vec![Some(0), None, Some(2), Some(3), None],
1350 &index,
1351 None,
1352 vec![Some(3), None, None, Some(3), Some(2)],
1353 )
1354 .unwrap();
1355
1356 test_take_primitive_arrays::<Int64Type>(
1358 vec![Some(0), None, Some(2), Some(-15), None],
1359 &index,
1360 None,
1361 vec![Some(-15), None, None, Some(-15), Some(2)],
1362 )
1363 .unwrap();
1364
1365 test_take_primitive_arrays::<IntervalYearMonthType>(
1367 vec![Some(0), None, Some(2), Some(-15), None],
1368 &index,
1369 None,
1370 vec![Some(-15), None, None, Some(-15), Some(2)],
1371 )
1372 .unwrap();
1373
1374 let v1 = IntervalDayTime::new(0, 0);
1376 let v2 = IntervalDayTime::new(2, 0);
1377 let v3 = IntervalDayTime::new(-15, 0);
1378 test_take_primitive_arrays::<IntervalDayTimeType>(
1379 vec![Some(v1), None, Some(v2), Some(v3), None],
1380 &index,
1381 None,
1382 vec![Some(v3), None, None, Some(v3), Some(v2)],
1383 )
1384 .unwrap();
1385
1386 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1388 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1389 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1390 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1391 vec![Some(v1), None, Some(v2), Some(v3), None],
1392 &index,
1393 None,
1394 vec![Some(v3), None, None, Some(v3), Some(v2)],
1395 )
1396 .unwrap();
1397
1398 test_take_primitive_arrays::<DurationSecondType>(
1400 vec![Some(0), None, Some(2), Some(-15), None],
1401 &index,
1402 None,
1403 vec![Some(-15), None, None, Some(-15), Some(2)],
1404 )
1405 .unwrap();
1406
1407 test_take_primitive_arrays::<DurationMillisecondType>(
1409 vec![Some(0), None, Some(2), Some(-15), None],
1410 &index,
1411 None,
1412 vec![Some(-15), None, None, Some(-15), Some(2)],
1413 )
1414 .unwrap();
1415
1416 test_take_primitive_arrays::<DurationMicrosecondType>(
1418 vec![Some(0), None, Some(2), Some(-15), None],
1419 &index,
1420 None,
1421 vec![Some(-15), None, None, Some(-15), Some(2)],
1422 )
1423 .unwrap();
1424
1425 test_take_primitive_arrays::<DurationNanosecondType>(
1427 vec![Some(0), None, Some(2), Some(-15), None],
1428 &index,
1429 None,
1430 vec![Some(-15), None, None, Some(-15), Some(2)],
1431 )
1432 .unwrap();
1433
1434 test_take_primitive_arrays::<Float32Type>(
1436 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1437 &index,
1438 None,
1439 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1440 )
1441 .unwrap();
1442
1443 test_take_primitive_arrays::<Float64Type>(
1445 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1446 &index,
1447 None,
1448 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1449 )
1450 .unwrap();
1451 }
1452
1453 #[test]
1454 fn test_take_preserve_timezone() {
1455 let index = Int64Array::from(vec![Some(0), None]);
1456
1457 let input = TimestampNanosecondArray::from(vec![
1458 1_639_715_368_000_000_000,
1459 1_639_715_368_000_000_000,
1460 ])
1461 .with_timezone("UTC".to_string());
1462 let result = take(&input, &index, None).unwrap();
1463 match result.data_type() {
1464 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1465 assert_eq!(tz.clone(), Some("UTC".into()))
1466 }
1467 _ => panic!(),
1468 }
1469 }
1470
1471 #[test]
1472 fn test_take_impl_primitive_with_int64_indices() {
1473 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1474
1475 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1477 vec![Some(0), None, Some(2), Some(3), None],
1478 &index,
1479 None,
1480 vec![Some(3), None, None, Some(3), Some(2)],
1481 );
1482
1483 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1485 vec![Some(0), None, Some(2), Some(-15), None],
1486 &index,
1487 None,
1488 vec![Some(-15), None, None, Some(-15), Some(2)],
1489 );
1490
1491 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1493 vec![Some(0), None, Some(2), Some(3), None],
1494 &index,
1495 None,
1496 vec![Some(3), None, None, Some(3), Some(2)],
1497 );
1498
1499 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1501 vec![Some(0), None, Some(2), Some(-15), None],
1502 &index,
1503 None,
1504 vec![Some(-15), None, None, Some(-15), Some(2)],
1505 );
1506
1507 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1509 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1510 &index,
1511 None,
1512 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1513 );
1514 }
1515
1516 #[test]
1517 fn test_take_impl_primitive_with_uint8_indices() {
1518 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1519
1520 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1522 vec![Some(0), None, Some(2), Some(3), None],
1523 &index,
1524 None,
1525 vec![Some(3), None, None, Some(3), Some(2)],
1526 );
1527
1528 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1530 vec![Some(0), None, Some(2), Some(-15), None],
1531 &index,
1532 None,
1533 vec![Some(-15), None, None, Some(-15), Some(2)],
1534 );
1535
1536 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1538 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1539 &index,
1540 None,
1541 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1542 );
1543 }
1544
1545 #[test]
1546 fn test_take_bool() {
1547 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1548 test_take_boolean_arrays(
1550 vec![Some(false), None, Some(true), Some(false), None],
1551 &index,
1552 None,
1553 vec![Some(false), None, None, Some(false), Some(true)],
1554 );
1555 }
1556
1557 #[test]
1558 fn test_take_bool_nullable_index() {
1559 let index_data = ArrayData::try_new(
1561 DataType::UInt32,
1562 6,
1563 Some(Buffer::from_iter(vec![
1564 false, true, false, true, false, true,
1565 ])),
1566 0,
1567 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1568 vec![],
1569 )
1570 .unwrap();
1571 let index = UInt32Array::from(index_data);
1572 test_take_boolean_arrays(
1573 vec![Some(true), None, Some(false)],
1574 &index,
1575 None,
1576 vec![None, Some(true), None, None, None, Some(false)],
1577 );
1578 }
1579
1580 #[test]
1581 fn test_take_bool_nullable_index_nonnull_values() {
1582 let index_data = ArrayData::try_new(
1584 DataType::UInt32,
1585 6,
1586 Some(Buffer::from_iter(vec![
1587 false, true, false, true, false, true,
1588 ])),
1589 0,
1590 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1591 vec![],
1592 )
1593 .unwrap();
1594 let index = UInt32Array::from(index_data);
1595 test_take_boolean_arrays(
1596 vec![Some(true), Some(true), Some(false)],
1597 &index,
1598 None,
1599 vec![None, Some(true), None, Some(true), None, Some(false)],
1600 );
1601 }
1602
1603 #[test]
1604 fn test_take_bool_with_offset() {
1605 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1606 let index = index.slice(2, 4);
1607 let index = index
1608 .as_any()
1609 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1610 .unwrap();
1611
1612 test_take_boolean_arrays(
1614 vec![Some(false), None, Some(true), Some(false), None],
1615 index,
1616 None,
1617 vec![None, Some(false), Some(true), None],
1618 );
1619 }
1620
1621 fn _test_take_string<'a, K>()
1622 where
1623 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1624 {
1625 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1626
1627 let array = K::from(vec![
1628 Some("one"),
1629 None,
1630 Some("three"),
1631 Some("four"),
1632 Some("five"),
1633 ]);
1634 let actual = take(&array, &index, None).unwrap();
1635 assert_eq!(actual.len(), index.len());
1636
1637 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1638
1639 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1640
1641 assert_eq!(actual, &expected);
1642 }
1643
1644 #[test]
1645 fn test_take_string() {
1646 _test_take_string::<StringArray>()
1647 }
1648
1649 #[test]
1650 fn test_take_large_string() {
1651 _test_take_string::<LargeStringArray>()
1652 }
1653
1654 #[test]
1655 fn test_take_slice_string() {
1656 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1657 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1658 let indices_slice = indices.slice(1, 4);
1659 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1660 let result = take(&strings, &indices_slice, None).unwrap();
1661 assert_eq!(result.as_ref(), &expected);
1662 }
1663
1664 fn _test_byte_view<T>()
1665 where
1666 T: ByteViewType,
1667 str: AsRef<T::Native>,
1668 T::Native: PartialEq,
1669 {
1670 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1671 let array = {
1672 let mut builder = GenericByteViewBuilder::<T>::new();
1674 builder.append_value("hello");
1675 builder.append_value("world");
1676 builder.append_null();
1677 builder.append_value("large payload over 12 bytes");
1678 builder.append_value("lulu");
1679 builder.finish()
1680 };
1681
1682 let actual = take(&array, &index, None).unwrap();
1683
1684 assert_eq!(actual.len(), index.len());
1685
1686 let expected = {
1687 let mut builder = GenericByteViewBuilder::<T>::new();
1689 builder.append_value("large payload over 12 bytes");
1690 builder.append_null();
1691 builder.append_value("world");
1692 builder.append_value("large payload over 12 bytes");
1693 builder.append_value("lulu");
1694 builder.append_null();
1695 builder.finish()
1696 };
1697
1698 assert_eq!(actual.as_ref(), &expected);
1699 }
1700
1701 #[test]
1702 fn test_take_string_view() {
1703 _test_byte_view::<StringViewType>()
1704 }
1705
1706 #[test]
1707 fn test_take_binary_view() {
1708 _test_byte_view::<BinaryViewType>()
1709 }
1710
1711 macro_rules! test_take_list {
1712 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1713 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1715 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1717 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1718 let list_data_type =
1720 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1721 let list_data = ArrayData::builder(list_data_type.clone())
1722 .len(4)
1723 .add_buffer(value_offsets)
1724 .add_child_data(value_data)
1725 .build()
1726 .unwrap();
1727 let list_array = $list_array_type::from(list_data);
1728
1729 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1731
1732 let a = take(&list_array, &index, None).unwrap();
1733 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1734
1735 let expected_data = Int32Array::from(vec![
1738 Some(2),
1739 Some(3),
1740 Some(-1),
1741 Some(-2),
1742 Some(-1),
1743 Some(0),
1744 Some(0),
1745 Some(0),
1746 ])
1747 .into_data();
1748 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1750 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1751 let expected_list_data = ArrayData::builder(list_data_type)
1753 .len(5)
1754 .nulls(index.nulls().cloned())
1756 .add_buffer(expected_offsets)
1757 .add_child_data(expected_data)
1758 .build()
1759 .unwrap();
1760 let expected_list_array = $list_array_type::from(expected_list_data);
1761
1762 assert_eq!(a, &expected_list_array);
1763 }};
1764 }
1765
1766 macro_rules! test_take_list_with_value_nulls {
1767 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1768 let value_data = Int32Array::from(vec![
1770 Some(0),
1771 None,
1772 Some(0),
1773 Some(-1),
1774 Some(-2),
1775 Some(3),
1776 None,
1777 Some(5),
1778 None,
1779 ])
1780 .into_data();
1781 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1783 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1784 let list_data_type =
1786 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1787 let list_data = ArrayData::builder(list_data_type.clone())
1788 .len(4)
1789 .add_buffer(value_offsets)
1790 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1791 .add_child_data(value_data)
1792 .build()
1793 .unwrap();
1794 let list_array = $list_array_type::from(list_data);
1795
1796 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1798
1799 let a = take(&list_array, &index, None).unwrap();
1800 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1801
1802 let expected_data = Int32Array::from(vec![
1805 None,
1806 Some(-1),
1807 Some(-2),
1808 Some(3),
1809 Some(5),
1810 None,
1811 Some(0),
1812 None,
1813 Some(0),
1814 ])
1815 .into_data();
1816 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1818 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1819 let expected_list_data = ArrayData::builder(list_data_type)
1821 .len(5)
1822 .nulls(index.nulls().cloned())
1824 .add_buffer(expected_offsets)
1825 .add_child_data(expected_data)
1826 .build()
1827 .unwrap();
1828 let expected_list_array = $list_array_type::from(expected_list_data);
1829
1830 assert_eq!(a, &expected_list_array);
1831 }};
1832 }
1833
1834 macro_rules! test_take_list_with_nulls {
1835 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1836 let value_data = Int32Array::from(vec![
1838 Some(0),
1839 None,
1840 Some(0),
1841 Some(-1),
1842 Some(-2),
1843 Some(3),
1844 Some(5),
1845 None,
1846 ])
1847 .into_data();
1848 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1850 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1851 let list_data_type =
1853 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1854 let list_data = ArrayData::builder(list_data_type.clone())
1855 .len(4)
1856 .add_buffer(value_offsets)
1857 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1858 .add_child_data(value_data)
1859 .build()
1860 .unwrap();
1861 let list_array = $list_array_type::from(list_data);
1862
1863 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1865
1866 let a = take(&list_array, &index, None).unwrap();
1867 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1868
1869 let expected_data = Int32Array::from(vec![
1872 Some(-1),
1873 Some(-2),
1874 Some(3),
1875 Some(5),
1876 None,
1877 Some(0),
1878 None,
1879 Some(0),
1880 ])
1881 .into_data();
1882 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1884 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1885 let mut null_bits: [u8; 1] = [0; 1];
1887 bit_util::set_bit(&mut null_bits, 2);
1888 bit_util::set_bit(&mut null_bits, 3);
1889 bit_util::set_bit(&mut null_bits, 4);
1890 let expected_list_data = ArrayData::builder(list_data_type)
1891 .len(5)
1892 .null_bit_buffer(Some(Buffer::from(null_bits)))
1894 .add_buffer(expected_offsets)
1895 .add_child_data(expected_data)
1896 .build()
1897 .unwrap();
1898 let expected_list_array = $list_array_type::from(expected_list_data);
1899
1900 assert_eq!(a, &expected_list_array);
1901 }};
1902 }
1903
1904 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1905 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1906 take_indices: Vec<Option<usize>>,
1907 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1908 mapper: F,
1909 ) where
1910 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1911 {
1912 let mut list_view_array =
1913 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1914
1915 for value in values {
1916 list_view_array.append_option(value);
1917 }
1918 let list_view_array = list_view_array.finish();
1919 let list_view_array = mapper(list_view_array);
1920
1921 let mut indices = UInt64Builder::new();
1922 for idx in take_indices {
1923 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
1924 }
1925 let indices = indices.finish();
1926
1927 let taken = take(&list_view_array, &indices, None)
1928 .unwrap()
1929 .as_list_view()
1930 .clone();
1931
1932 let mut expected_array =
1933 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1934 for value in expected {
1935 expected_array.append_option(value);
1936 }
1937 let expected_array = expected_array.finish();
1938
1939 assert_eq!(taken, expected_array);
1940 }
1941
1942 macro_rules! list_view_test_case {
1943 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
1944 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
1945 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
1946 }};
1947 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
1948 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
1949 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
1950 }};
1951 }
1952
1953 fn do_take_fixed_size_list_test<T>(
1954 length: <Int32Type as ArrowPrimitiveType>::Native,
1955 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1956 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1957 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1958 ) where
1959 T: ArrowPrimitiveType,
1960 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1961 {
1962 let indices = UInt32Array::from(indices);
1963
1964 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1965
1966 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1967
1968 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1969
1970 assert_eq!(&output, &expected)
1971 }
1972
1973 #[test]
1974 fn test_take_list() {
1975 test_take_list!(i32, List, ListArray);
1976 }
1977
1978 #[test]
1979 fn test_take_large_list() {
1980 test_take_list!(i64, LargeList, LargeListArray);
1981 }
1982
1983 #[test]
1984 fn test_take_list_with_value_nulls() {
1985 test_take_list_with_value_nulls!(i32, List, ListArray);
1986 }
1987
1988 #[test]
1989 fn test_take_large_list_with_value_nulls() {
1990 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1991 }
1992
1993 #[test]
1994 fn test_test_take_list_with_nulls() {
1995 test_take_list_with_nulls!(i32, List, ListArray);
1996 }
1997
1998 #[test]
1999 fn test_test_take_large_list_with_nulls() {
2000 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2001 }
2002
2003 #[test]
2004 fn test_test_take_list_view_reversed() {
2005 list_view_test_case! {
2007 values: vec![
2008 Some(vec![Some(1), None, Some(3)]),
2009 None,
2010 Some(vec![Some(7), Some(8), None]),
2011 ],
2012 indices: vec![Some(2), Some(1), Some(0)],
2013 expected: vec![
2014 Some(vec![Some(7), Some(8), None]),
2015 None,
2016 Some(vec![Some(1), None, Some(3)]),
2017 ]
2018 }
2019 }
2020
2021 #[test]
2022 fn test_take_list_view_null_indices() {
2023 list_view_test_case! {
2025 values: vec![
2026 Some(vec![Some(1), None, Some(3)]),
2027 None,
2028 Some(vec![Some(7), Some(8), None]),
2029 ],
2030 indices: vec![None, Some(0), None],
2031 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2032 }
2033 }
2034
2035 #[test]
2036 fn test_take_list_view_null_values() {
2037 list_view_test_case! {
2039 values: vec![
2040 Some(vec![Some(1), None, Some(3)]),
2041 None,
2042 Some(vec![Some(7), Some(8), None]),
2043 ],
2044 indices: vec![Some(1), Some(1), Some(1), None, None],
2045 expected: vec![None; 5]
2046 }
2047 }
2048
2049 #[test]
2050 fn test_take_list_view_sliced() {
2051 list_view_test_case! {
2053 values: vec![
2054 Some(vec![Some(1)]),
2055 None,
2056 None,
2057 Some(vec![Some(2), Some(3)]),
2058 Some(vec![Some(4), Some(5)]),
2059 None,
2060 ],
2061 transform: |l| l.slice(2, 4),
2062 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2063 expected: vec![
2064 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2065 ]
2066 }
2067 }
2068
2069 #[test]
2070 fn test_take_fixed_size_list() {
2071 do_take_fixed_size_list_test::<Int32Type>(
2072 3,
2073 vec![
2074 Some(vec![None, Some(1), Some(2)]),
2075 Some(vec![Some(3), Some(4), None]),
2076 Some(vec![Some(6), Some(7), Some(8)]),
2077 ],
2078 vec![2, 1, 0],
2079 vec![
2080 Some(vec![Some(6), Some(7), Some(8)]),
2081 Some(vec![Some(3), Some(4), None]),
2082 Some(vec![None, Some(1), Some(2)]),
2083 ],
2084 );
2085
2086 do_take_fixed_size_list_test::<UInt8Type>(
2087 1,
2088 vec![
2089 Some(vec![Some(1)]),
2090 Some(vec![Some(2)]),
2091 Some(vec![Some(3)]),
2092 Some(vec![Some(4)]),
2093 Some(vec![Some(5)]),
2094 Some(vec![Some(6)]),
2095 Some(vec![Some(7)]),
2096 Some(vec![Some(8)]),
2097 ],
2098 vec![2, 7, 0],
2099 vec![
2100 Some(vec![Some(3)]),
2101 Some(vec![Some(8)]),
2102 Some(vec![Some(1)]),
2103 ],
2104 );
2105
2106 do_take_fixed_size_list_test::<UInt64Type>(
2107 3,
2108 vec![
2109 Some(vec![Some(10), Some(11), Some(12)]),
2110 Some(vec![Some(13), Some(14), Some(15)]),
2111 None,
2112 Some(vec![Some(16), Some(17), Some(18)]),
2113 ],
2114 vec![3, 2, 1, 2, 0],
2115 vec![
2116 Some(vec![Some(16), Some(17), Some(18)]),
2117 None,
2118 Some(vec![Some(13), Some(14), Some(15)]),
2119 None,
2120 Some(vec![Some(10), Some(11), Some(12)]),
2121 ],
2122 );
2123 }
2124
2125 #[test]
2126 fn test_take_fixed_size_binary_with_nulls_indices() {
2127 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2128 [
2129 Some(vec![0x01, 0x01, 0x01, 0x01]),
2130 Some(vec![0x02, 0x02, 0x02, 0x02]),
2131 Some(vec![0x03, 0x03, 0x03, 0x03]),
2132 Some(vec![0x04, 0x04, 0x04, 0x04]),
2133 ]
2134 .into_iter(),
2135 4,
2136 )
2137 .unwrap();
2138
2139 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2141
2142 let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2143 assert_eq!(result.len(), 4);
2144 assert_eq!(result.null_count(), 2);
2145 assert_eq!(
2146 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2147 vec![true, false, false, true]
2148 );
2149 }
2150
2151 #[test]
2152 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2153 fn test_take_list_out_of_bounds() {
2154 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2156 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2158 let list_data_type =
2160 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2161 let list_data = ArrayData::builder(list_data_type)
2162 .len(3)
2163 .add_buffer(value_offsets)
2164 .add_child_data(value_data)
2165 .build()
2166 .unwrap();
2167 let list_array = ListArray::from(list_data);
2168
2169 let index = UInt32Array::from(vec![1000]);
2170
2171 take(&list_array, &index, None).unwrap();
2174 }
2175
2176 #[test]
2177 fn test_take_map() {
2178 let values = Int32Array::from(vec![1, 2, 3, 4]);
2179 let array =
2180 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2181 .unwrap();
2182
2183 let index = UInt32Array::from(vec![0]);
2184
2185 let result = take(&array, &index, None).unwrap();
2186 let expected: ArrayRef = Arc::new(
2187 MapArray::new_from_strings(
2188 vec!["a", "b", "c"].into_iter(),
2189 &values.slice(0, 3),
2190 &[0, 3],
2191 )
2192 .unwrap(),
2193 );
2194 assert_eq!(&expected, &result);
2195 }
2196
2197 #[test]
2198 fn test_take_struct() {
2199 let array = create_test_struct(vec![
2200 Some((Some(true), Some(42))),
2201 Some((Some(false), Some(28))),
2202 Some((Some(false), Some(19))),
2203 Some((Some(true), Some(31))),
2204 None,
2205 ]);
2206
2207 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2208 let actual = take(&array, &index, None).unwrap();
2209 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2210 assert_eq!(index.len(), actual.len());
2211 assert_eq!(1, actual.null_count());
2212
2213 let expected = create_test_struct(vec![
2214 Some((Some(true), Some(42))),
2215 Some((Some(true), Some(31))),
2216 Some((Some(false), Some(28))),
2217 Some((Some(true), Some(42))),
2218 Some((Some(false), Some(19))),
2219 None,
2220 ]);
2221
2222 assert_eq!(&expected, actual);
2223
2224 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2225 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2226 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2227 let actual = take(&empty_struct_arr, &index, None).unwrap();
2228
2229 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2230 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2231 assert_eq!(&expected_struct_arr, actual.as_struct());
2232 }
2233
2234 #[test]
2235 fn test_take_struct_with_null_indices() {
2236 let array = create_test_struct(vec![
2237 Some((Some(true), Some(42))),
2238 Some((Some(false), Some(28))),
2239 Some((Some(false), Some(19))),
2240 Some((Some(true), Some(31))),
2241 None,
2242 ]);
2243
2244 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2245 let actual = take(&array, &index, None).unwrap();
2246 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2247 assert_eq!(index.len(), actual.len());
2248 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2251 None,
2252 Some((Some(true), Some(31))),
2253 Some((Some(false), Some(28))),
2254 None,
2255 Some((Some(true), Some(42))),
2256 None,
2257 ]);
2258
2259 assert_eq!(&expected, actual);
2260 }
2261
2262 #[test]
2263 fn test_take_out_of_bounds() {
2264 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2265 let take_opt = TakeOptions { check_bounds: true };
2266
2267 let result = test_take_primitive_arrays::<Int64Type>(
2269 vec![Some(0), None, Some(2), Some(3), None],
2270 &index,
2271 Some(take_opt),
2272 vec![None],
2273 );
2274 assert!(result.is_err());
2275 }
2276
2277 #[test]
2278 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2279 fn test_take_out_of_bounds_panic() {
2280 let index = UInt32Array::from(vec![Some(1000)]);
2281
2282 test_take_primitive_arrays::<Int64Type>(
2283 vec![Some(0), Some(1), Some(2), Some(3)],
2284 &index,
2285 None,
2286 vec![None],
2287 )
2288 .unwrap();
2289 }
2290
2291 #[test]
2292 fn test_null_array_smaller_than_indices() {
2293 let values = NullArray::new(2);
2294 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2295
2296 let result = take(&values, &indices, None).unwrap();
2297 let expected: ArrayRef = Arc::new(NullArray::new(3));
2298 assert_eq!(&result, &expected);
2299 }
2300
2301 #[test]
2302 fn test_null_array_larger_than_indices() {
2303 let values = NullArray::new(5);
2304 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2305
2306 let result = take(&values, &indices, None).unwrap();
2307 let expected: ArrayRef = Arc::new(NullArray::new(3));
2308 assert_eq!(&result, &expected);
2309 }
2310
2311 #[test]
2312 fn test_null_array_indices_out_of_bounds() {
2313 let values = NullArray::new(5);
2314 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2315
2316 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2317 assert_eq!(
2318 result.unwrap_err().to_string(),
2319 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2320 );
2321 }
2322
2323 #[test]
2324 fn test_take_dict() {
2325 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2326
2327 dict_builder.append("foo").unwrap();
2328 dict_builder.append("bar").unwrap();
2329 dict_builder.append("").unwrap();
2330 dict_builder.append_null();
2331 dict_builder.append("foo").unwrap();
2332 dict_builder.append("bar").unwrap();
2333 dict_builder.append("bar").unwrap();
2334 dict_builder.append("foo").unwrap();
2335
2336 let array = dict_builder.finish();
2337 let dict_values = array.values().clone();
2338 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2339
2340 let indices = UInt32Array::from(vec![
2341 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2349
2350 let result = take(&array, &indices, None).unwrap();
2351 let result = result
2352 .as_any()
2353 .downcast_ref::<DictionaryArray<Int16Type>>()
2354 .unwrap();
2355
2356 let result_values: StringArray = result.values().to_data().into();
2357
2358 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2360 assert_eq!(&expected_values, dict_values);
2361 assert_eq!(&expected_values, &result_values);
2362
2363 let expected_keys = Int16Array::from(vec![
2364 Some(0),
2365 Some(0),
2366 None,
2367 Some(1),
2368 Some(1),
2369 Some(2),
2370 None,
2371 ]);
2372 assert_eq!(result.keys(), &expected_keys);
2373 }
2374
2375 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2376 where
2377 S: OffsetSizeTrait + 'static,
2378 T: ArrowPrimitiveType,
2379 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2380 {
2381 GenericListArray::from_iter_primitive::<T, _, _>(
2382 data.iter()
2383 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2384 )
2385 }
2386
2387 #[test]
2388 fn test_take_value_index_from_list() {
2389 let list = build_generic_list::<i32, Int32Type>(vec![
2390 Some(vec![0, 1]),
2391 Some(vec![2, 3, 4]),
2392 Some(vec![5, 6, 7, 8, 9]),
2393 ]);
2394 let indices = UInt32Array::from(vec![2, 0]);
2395
2396 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2397
2398 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2399 assert_eq!(offsets, vec![0, 5, 7]);
2400 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2401 }
2402
2403 #[test]
2404 fn test_take_value_index_from_large_list() {
2405 let list = build_generic_list::<i64, Int32Type>(vec![
2406 Some(vec![0, 1]),
2407 Some(vec![2, 3, 4]),
2408 Some(vec![5, 6, 7, 8, 9]),
2409 ]);
2410 let indices = UInt32Array::from(vec![2, 0]);
2411
2412 let (indexed, offsets, null_buf) =
2413 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2414
2415 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2416 assert_eq!(offsets, vec![0, 5, 7]);
2417 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2418 }
2419
2420 #[test]
2421 fn test_take_runs() {
2422 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2423
2424 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2425 builder.extend(logical_array.into_iter().map(Some));
2426 let run_array = builder.finish();
2427
2428 let take_indices: PrimitiveArray<Int32Type> =
2429 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2430
2431 let take_out = take_run(&run_array, &take_indices).unwrap();
2432
2433 assert_eq!(take_out.len(), 7);
2434 assert_eq!(take_out.run_ends().len(), 7);
2435 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2436
2437 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2438 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2439 }
2440
2441 #[test]
2442 fn test_take_runs_sliced() {
2443 let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2444
2445 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2446 builder.extend(logical_array.into_iter().map(Some));
2447 let run_array = builder.finish();
2448
2449 let run_array = run_array.slice(4, 6); let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2452
2453 let result = take_run(&run_array, &take_indices).unwrap();
2454 let result = result.downcast::<Int32Array>().unwrap();
2455
2456 let expected = vec![3, 5, 5, 3, 4];
2457 let actual = result.into_iter().flatten().collect::<Vec<_>>();
2458
2459 assert_eq!(expected, actual);
2460 }
2461
2462 #[test]
2463 fn test_take_value_index_from_fixed_list() {
2464 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2465 vec![
2466 Some(vec![Some(1), Some(2), None]),
2467 Some(vec![Some(4), None, Some(6)]),
2468 None,
2469 Some(vec![None, Some(8), Some(9)]),
2470 ],
2471 3,
2472 );
2473
2474 let indices = UInt32Array::from(vec![2, 1, 0]);
2475 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2476
2477 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2478
2479 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2480 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2481
2482 assert_eq!(
2483 indexed,
2484 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2485 );
2486 }
2487
2488 #[test]
2489 fn test_take_null_indices() {
2490 let indices = Int32Array::new(
2492 vec![1, 2, 400, 400].into(),
2493 Some(NullBuffer::from(vec![true, true, false, false])),
2494 );
2495 let values = Int32Array::from(vec![1, 23, 4, 5]);
2496 let r = take(&values, &indices, None).unwrap();
2497 let values = r
2498 .as_primitive::<Int32Type>()
2499 .into_iter()
2500 .collect::<Vec<_>>();
2501 assert_eq!(&values, &[Some(23), Some(4), None, None])
2502 }
2503
2504 #[test]
2505 fn test_take_fixed_size_list_null_indices() {
2506 let indices = Int32Array::from_iter([Some(0), None]);
2507 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2508 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2509 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2510
2511 let r = take(&values, &indices, None).unwrap();
2512 let values = r
2513 .as_fixed_size_list()
2514 .values()
2515 .as_primitive::<Int32Type>()
2516 .into_iter()
2517 .collect::<Vec<_>>();
2518 assert_eq!(values, &[Some(0), Some(1), None, None])
2519 }
2520
2521 #[test]
2522 fn test_take_bytes_null_indices() {
2523 let indices = Int32Array::new(
2524 vec![0, 1, 400, 400].into(),
2525 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2526 );
2527 let values = StringArray::from(vec![Some("foo"), None]);
2528 let r = take(&values, &indices, None).unwrap();
2529 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2530 assert_eq!(&values, &[Some("foo"), None, None, None])
2531 }
2532
2533 #[test]
2534 fn test_take_union_sparse() {
2535 let structs = create_test_struct(vec![
2536 Some((Some(true), Some(42))),
2537 Some((Some(false), Some(28))),
2538 Some((Some(false), Some(19))),
2539 Some((Some(true), Some(31))),
2540 None,
2541 ]);
2542 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2543 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2544
2545 let union_fields = [
2546 (
2547 0,
2548 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2549 ),
2550 (
2551 1,
2552 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2553 ),
2554 ]
2555 .into_iter()
2556 .collect();
2557 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2558 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2559
2560 let indices = vec![0, 3, 1, 0, 2, 4];
2561 let index = UInt32Array::from(indices.clone());
2562 let actual = take(&array, &index, None).unwrap();
2563 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2564 let strings = actual.child(1);
2565 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2566
2567 let actual = strings.iter().collect::<Vec<_>>();
2568 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2569 assert_eq!(expected, actual);
2570 }
2571
2572 #[test]
2573 fn test_take_union_dense() {
2574 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2575 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2576 let ints = vec![10, 20, 30, 40];
2577 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2578
2579 let indices = vec![0, 3, 1, 0, 2, 4];
2580
2581 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2582 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2583 let taken_ints = vec![10, 20, 10, 30];
2584 let taken_strings = vec![Some("a"), None];
2585
2586 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2587 let offsets = <ScalarBuffer<i32>>::from(offsets);
2588 let ints = UInt32Array::from(ints);
2589 let strings = StringArray::from(strings);
2590
2591 let union_fields = [
2592 (
2593 0,
2594 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2595 ),
2596 (
2597 1,
2598 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2599 ),
2600 ]
2601 .into_iter()
2602 .collect();
2603
2604 let array = UnionArray::try_new(
2605 union_fields,
2606 type_ids,
2607 Some(offsets),
2608 vec![Arc::new(ints), Arc::new(strings)],
2609 )
2610 .unwrap();
2611
2612 let index = UInt32Array::from(indices);
2613
2614 let actual = take(&array, &index, None).unwrap();
2615 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2616
2617 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2618 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2619 assert_eq!(
2620 UInt32Array::from(actual.child(0).to_data()),
2621 UInt32Array::from(taken_ints)
2622 );
2623 assert_eq!(
2624 StringArray::from(actual.child(1).to_data()),
2625 StringArray::from(taken_strings)
2626 );
2627 }
2628
2629 #[test]
2630 fn test_take_union_dense_using_builder() {
2631 let mut builder = UnionBuilder::new_dense();
2632
2633 builder.append::<Int32Type>("a", 1).unwrap();
2634 builder.append::<Float64Type>("b", 3.0).unwrap();
2635 builder.append::<Int32Type>("a", 4).unwrap();
2636 builder.append::<Int32Type>("a", 5).unwrap();
2637 builder.append::<Float64Type>("b", 2.0).unwrap();
2638
2639 let union = builder.build().unwrap();
2640
2641 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2642
2643 let mut builder = UnionBuilder::new_dense();
2644
2645 builder.append::<Int32Type>("a", 4).unwrap();
2646 builder.append::<Int32Type>("a", 1).unwrap();
2647 builder.append::<Float64Type>("b", 3.0).unwrap();
2648 builder.append::<Int32Type>("a", 4).unwrap();
2649
2650 let taken = builder.build().unwrap();
2651
2652 assert_eq!(
2653 taken.to_data(),
2654 take(&union, &indices, None).unwrap().to_data()
2655 );
2656 }
2657
2658 #[test]
2659 fn test_take_union_dense_all_match_issue_6206() {
2660 let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2661 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2662
2663 let array = UnionArray::try_new(
2664 fields,
2665 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2666 Some(ScalarBuffer::from_iter(0_i32..5)),
2667 vec![ints],
2668 )
2669 .unwrap();
2670
2671 let indicies = Int64Array::from(vec![0, 2, 4]);
2672 let array = take(&array, &indicies, None).unwrap();
2673 assert_eq!(array.len(), 3);
2674 }
2675
2676 #[test]
2677 fn test_take_bytes_offset_overflow() {
2678 let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2679 let text = ('a'..='z').collect::<String>();
2680 let values = StringArray::from(vec![Some(text.clone())]);
2681 assert!(matches!(
2682 take(&values, &indices, None),
2683 Err(ArrowError::OffsetOverflowError(_))
2684 ));
2685 }
2686}