1use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27 bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer,
28 ScalarBuffer,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num::{One, Zero};
34
35pub fn take(
87 values: &dyn Array,
88 indices: &dyn Array,
89 options: Option<TakeOptions>,
90) -> Result<ArrayRef, ArrowError> {
91 let options = options.unwrap_or_default();
92 downcast_integer_array!(
93 indices => {
94 if options.check_bounds {
95 check_bounds(values.len(), indices)?;
96 }
97 let indices = indices.to_indices();
98 take_impl(values, &indices)
99 },
100 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
101 )
102}
103
104pub fn take_arrays(
153 arrays: &[ArrayRef],
154 indices: &dyn Array,
155 options: Option<TakeOptions>,
156) -> Result<Vec<ArrayRef>, ArrowError> {
157 arrays
158 .iter()
159 .map(|array| take(array.as_ref(), indices, options.clone()))
160 .collect()
161}
162
163fn check_bounds<T: ArrowPrimitiveType>(
165 len: usize,
166 indices: &PrimitiveArray<T>,
167) -> Result<(), ArrowError> {
168 if indices.null_count() > 0 {
169 indices.iter().flatten().try_for_each(|index| {
170 let ix = index
171 .to_usize()
172 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
173 if ix >= len {
174 return Err(ArrowError::ComputeError(format!(
175 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
176 )));
177 }
178 Ok(())
179 })
180 } else {
181 indices.values().iter().try_for_each(|index| {
182 let ix = index
183 .to_usize()
184 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
185 if ix >= len {
186 return Err(ArrowError::ComputeError(format!(
187 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
188 )));
189 }
190 Ok(())
191 })
192 }
193}
194
195#[inline(never)]
196fn take_impl<IndexType: ArrowPrimitiveType>(
197 values: &dyn Array,
198 indices: &PrimitiveArray<IndexType>,
199) -> Result<ArrayRef, ArrowError> {
200 downcast_primitive_array! {
201 values => Ok(Arc::new(take_primitive(values, indices)?)),
202 DataType::Boolean => {
203 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
204 Ok(Arc::new(take_boolean(values, indices)))
205 }
206 DataType::Utf8 => {
207 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
208 }
209 DataType::LargeUtf8 => {
210 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
211 }
212 DataType::Utf8View => {
213 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
214 }
215 DataType::List(_) => {
216 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
217 }
218 DataType::LargeList(_) => {
219 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
220 }
221 DataType::FixedSizeList(_, length) => {
222 let values = values
223 .as_any()
224 .downcast_ref::<FixedSizeListArray>()
225 .unwrap();
226 Ok(Arc::new(take_fixed_size_list(
227 values,
228 indices,
229 *length as u32,
230 )?))
231 }
232 DataType::Map(_, _) => {
233 let list_arr = ListArray::from(values.as_map().clone());
234 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
235 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
236 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
237 }
238 DataType::Struct(fields) => {
239 let array: &StructArray = values.as_struct();
240 let arrays = array
241 .columns()
242 .iter()
243 .map(|a| take_impl(a.as_ref(), indices))
244 .collect::<Result<Vec<ArrayRef>, _>>()?;
245 let fields: Vec<(FieldRef, ArrayRef)> =
246 fields.iter().cloned().zip(arrays).collect();
247
248 let is_valid: Buffer = indices
250 .iter()
251 .map(|index| {
252 if let Some(index) = index {
253 array.is_valid(index.to_usize().unwrap())
254 } else {
255 false
256 }
257 })
258 .collect();
259
260 if fields.is_empty() {
261 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
262 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
263 } else {
264 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
265 }
266 }
267 DataType::Dictionary(_, _) => downcast_dictionary_array! {
268 values => Ok(Arc::new(take_dict(values, indices)?)),
269 t => unimplemented!("Take not supported for dictionary type {:?}", t)
270 }
271 DataType::RunEndEncoded(_, _) => downcast_run_array! {
272 values => Ok(Arc::new(take_run(values, indices)?)),
273 t => unimplemented!("Take not supported for run type {:?}", t)
274 }
275 DataType::Binary => {
276 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
277 }
278 DataType::LargeBinary => {
279 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
280 }
281 DataType::BinaryView => {
282 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
283 }
284 DataType::FixedSizeBinary(size) => {
285 let values = values
286 .as_any()
287 .downcast_ref::<FixedSizeBinaryArray>()
288 .unwrap();
289 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
290 }
291 DataType::Null => {
292 if values.len() >= indices.len() {
294 Ok(values.slice(0, indices.len()))
297 } else {
298 Ok(new_null_array(&DataType::Null, indices.len()))
300 }
301 }
302 DataType::Union(fields, UnionMode::Sparse) => {
303 let mut children = Vec::with_capacity(fields.len());
304 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
305 let type_ids = take_native(values.type_ids(), indices);
306 for (type_id, _field) in fields.iter() {
307 let values = values.child(type_id);
308 let values = take_impl(values, indices)?;
309 children.push(values);
310 }
311 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
312 Ok(Arc::new(array))
313 }
314 DataType::Union(fields, UnionMode::Dense) => {
315 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
316
317 let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
318 let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
319
320 let children = fields.iter()
321 .map(|(field_type_id, _)| {
322 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
323
324 let indices = crate::filter::filter(&offsets, &mask)?;
325
326 let values = values.child(field_type_id);
327
328 take_impl(values, indices.as_primitive::<Int32Type>())
329 })
330 .collect::<Result<_, _>>()?;
331
332 let mut child_offsets = [0; 128];
333
334 let offsets = type_ids.values()
335 .iter()
336 .map(|&i| {
337 let offset = child_offsets[i as usize];
338
339 child_offsets[i as usize] += 1;
340
341 offset
342 })
343 .collect();
344
345 let (_, type_ids, _) = type_ids.into_parts();
346
347 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
348
349 Ok(Arc::new(array))
350 }
351 t => unimplemented!("Take not supported for data type {:?}", t)
352 }
353}
354
355#[derive(Clone, Debug, Default)]
357pub struct TakeOptions {
358 pub check_bounds: bool,
362}
363
364#[inline(always)]
365fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
366 index
367 .to_usize()
368 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
369}
370
371fn take_primitive<T, I>(
381 values: &PrimitiveArray<T>,
382 indices: &PrimitiveArray<I>,
383) -> Result<PrimitiveArray<T>, ArrowError>
384where
385 T: ArrowPrimitiveType,
386 I: ArrowPrimitiveType,
387{
388 let values_buf = take_native(values.values(), indices);
389 let nulls = take_nulls(values.nulls(), indices);
390 Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
391}
392
393#[inline(never)]
394fn take_nulls<I: ArrowPrimitiveType>(
395 values: Option<&NullBuffer>,
396 indices: &PrimitiveArray<I>,
397) -> Option<NullBuffer> {
398 match values.filter(|n| n.null_count() > 0) {
399 Some(n) => {
400 let buffer = take_bits(n.inner(), indices);
401 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
402 }
403 None => indices.nulls().cloned(),
404 }
405}
406
407#[inline(never)]
408fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
409 values: &[T],
410 indices: &PrimitiveArray<I>,
411) -> ScalarBuffer<T> {
412 match indices.nulls().filter(|n| n.null_count() > 0) {
413 Some(n) => indices
414 .values()
415 .iter()
416 .enumerate()
417 .map(|(idx, index)| match values.get(index.as_usize()) {
418 Some(v) => *v,
419 None => match n.is_null(idx) {
420 true => T::default(),
421 false => panic!("Out-of-bounds index {index:?}"),
422 },
423 })
424 .collect(),
425 None => indices
426 .values()
427 .iter()
428 .map(|index| values[index.as_usize()])
429 .collect(),
430 }
431}
432
433#[inline(never)]
434fn take_bits<I: ArrowPrimitiveType>(
435 values: &BooleanBuffer,
436 indices: &PrimitiveArray<I>,
437) -> BooleanBuffer {
438 let len = indices.len();
439
440 match indices.nulls().filter(|n| n.null_count() > 0) {
441 Some(nulls) => {
442 let mut output_buffer = MutableBuffer::new_null(len);
443 let output_slice = output_buffer.as_slice_mut();
444 nulls.valid_indices().for_each(|idx| {
445 if values.value(indices.value(idx).as_usize()) {
446 bit_util::set_bit(output_slice, idx);
447 }
448 });
449 BooleanBuffer::new(output_buffer.into(), 0, len)
450 }
451 None => {
452 BooleanBuffer::collect_bool(len, |idx: usize| {
453 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
455 })
456 }
457 }
458}
459
460fn take_boolean<IndexType: ArrowPrimitiveType>(
462 values: &BooleanArray,
463 indices: &PrimitiveArray<IndexType>,
464) -> BooleanArray {
465 let val_buf = take_bits(values.values(), indices);
466 let null_buf = take_nulls(values.nulls(), indices);
467 BooleanArray::new(val_buf, null_buf)
468}
469
470fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
472 array: &GenericByteArray<T>,
473 indices: &PrimitiveArray<IndexType>,
474) -> Result<GenericByteArray<T>, ArrowError> {
475 let mut offsets = Vec::with_capacity(indices.len() + 1);
476 offsets.push(T::Offset::default());
477
478 let input_offsets = array.value_offsets();
479 let mut capacity = 0;
480 let nulls = take_nulls(array.nulls(), indices);
481
482 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
483 offsets.extend(indices.values().iter().map(|index| {
484 let index = index.as_usize();
485 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
486 T::Offset::from_usize(capacity).expect("overflow")
487 }));
488 let mut values = Vec::with_capacity(capacity);
489
490 for index in indices.values() {
491 values.extend_from_slice(array.value(index.as_usize()).as_ref());
492 }
493 (offsets, values)
494 } else if indices.null_count() == 0 {
495 offsets.extend(indices.values().iter().map(|index| {
496 let index = index.as_usize();
497 if array.is_valid(index) {
498 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
499 }
500 T::Offset::from_usize(capacity).expect("overflow")
501 }));
502 let mut values = Vec::with_capacity(capacity);
503
504 for index in indices.values() {
505 let index = index.as_usize();
506 if array.is_valid(index) {
507 values.extend_from_slice(array.value(index).as_ref());
508 }
509 }
510 (offsets, values)
511 } else if array.null_count() == 0 {
512 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
513 let index = index.as_usize();
514 if indices.is_valid(i) {
515 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
516 }
517 T::Offset::from_usize(capacity).expect("overflow")
518 }));
519 let mut values = Vec::with_capacity(capacity);
520
521 for (i, index) in indices.values().iter().enumerate() {
522 if indices.is_valid(i) {
523 values.extend_from_slice(array.value(index.as_usize()).as_ref());
524 }
525 }
526 (offsets, values)
527 } else {
528 let nulls = nulls.as_ref().unwrap();
529 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
530 let index = index.as_usize();
531 if nulls.is_valid(i) {
532 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
533 }
534 T::Offset::from_usize(capacity).expect("overflow")
535 }));
536 let mut values = Vec::with_capacity(capacity);
537
538 for (i, index) in indices.values().iter().enumerate() {
539 let index = index.as_usize();
542 if nulls.is_valid(i) {
543 values.extend_from_slice(array.value(index).as_ref());
544 }
545 }
546 (offsets, values)
547 };
548
549 T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
550 "Offset overflow for {}BinaryArray: {}",
551 T::Offset::PREFIX,
552 values.len()
553 )))?;
554
555 let array = unsafe {
556 let offsets = OffsetBuffer::new_unchecked(offsets.into());
557 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
558 };
559
560 Ok(array)
561}
562
563fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
565 array: &GenericByteViewArray<T>,
566 indices: &PrimitiveArray<IndexType>,
567) -> Result<GenericByteViewArray<T>, ArrowError> {
568 let new_views = take_native(array.views(), indices);
569 let new_nulls = take_nulls(array.nulls(), indices);
570 Ok(unsafe {
572 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
573 })
574}
575
576fn take_list<IndexType, OffsetType>(
582 values: &GenericListArray<OffsetType::Native>,
583 indices: &PrimitiveArray<IndexType>,
584) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
585where
586 IndexType: ArrowPrimitiveType,
587 OffsetType: ArrowPrimitiveType,
588 OffsetType::Native: OffsetSizeTrait,
589 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
590{
591 let (list_indices, offsets, null_buf) =
594 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
595
596 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
597 let value_offsets = Buffer::from_vec(offsets);
598 let list_data = ArrayDataBuilder::new(values.data_type().clone())
600 .len(indices.len())
601 .null_bit_buffer(Some(null_buf.into()))
602 .offset(0)
603 .add_child_data(taken.into_data())
604 .add_buffer(value_offsets);
605
606 let list_data = unsafe { list_data.build_unchecked() };
607
608 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
609}
610
611fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
617 values: &FixedSizeListArray,
618 indices: &PrimitiveArray<IndexType>,
619 length: <UInt32Type as ArrowPrimitiveType>::Native,
620) -> Result<FixedSizeListArray, ArrowError> {
621 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
622 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
623
624 let num_bytes = bit_util::ceil(indices.len(), 8);
626 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
627 let null_slice = null_buf.as_slice_mut();
628
629 for i in 0..indices.len() {
630 let index = indices
631 .value(i)
632 .to_usize()
633 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
634 if !indices.is_valid(i) || values.is_null(index) {
635 bit_util::unset_bit(null_slice, i);
636 }
637 }
638
639 let list_data = ArrayDataBuilder::new(values.data_type().clone())
640 .len(indices.len())
641 .null_bit_buffer(Some(null_buf.into()))
642 .offset(0)
643 .add_child_data(taken.into_data());
644
645 let list_data = unsafe { list_data.build_unchecked() };
646
647 Ok(FixedSizeListArray::from(list_data))
648}
649
650fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
651 values: &FixedSizeBinaryArray,
652 indices: &PrimitiveArray<IndexType>,
653 size: i32,
654) -> Result<FixedSizeBinaryArray, ArrowError> {
655 let nulls = values.nulls();
656 let array_iter = indices
657 .values()
658 .iter()
659 .map(|idx| {
660 let idx = maybe_usize::<IndexType::Native>(*idx)?;
661 if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
662 Ok(Some(values.value(idx)))
663 } else {
664 Ok(None)
665 }
666 })
667 .collect::<Result<Vec<_>, ArrowError>>()?
668 .into_iter();
669
670 FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
671}
672
673fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
678 values: &DictionaryArray<T>,
679 indices: &PrimitiveArray<I>,
680) -> Result<DictionaryArray<T>, ArrowError> {
681 let new_keys = take_primitive(values.keys(), indices)?;
682 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
683}
684
685fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
694 run_array: &RunArray<T>,
695 logical_indices: &PrimitiveArray<I>,
696) -> Result<RunArray<T>, ArrowError> {
697 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
699
700 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
704 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
705 let mut new_physical_len = 1;
706 for ix in 1..physical_indices.len() {
707 if physical_indices[ix] != physical_indices[ix - 1] {
708 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
709 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
710 new_physical_len += 1;
711 }
712 }
713 take_value_indices
714 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
715 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
716 let new_run_ends = unsafe {
717 ArrayDataBuilder::new(T::DATA_TYPE)
720 .len(new_physical_len)
721 .null_count(0)
722 .add_buffer(new_run_ends_builder.finish())
723 .build_unchecked()
724 };
725
726 let take_value_indices: PrimitiveArray<I> = unsafe {
727 ArrayDataBuilder::new(I::DATA_TYPE)
730 .len(new_physical_len)
731 .null_count(0)
732 .add_buffer(take_value_indices.finish())
733 .build_unchecked()
734 .into()
735 };
736
737 let new_values = take(run_array.values(), &take_value_indices, None)?;
738
739 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
740 .len(physical_indices.len())
741 .add_child_data(new_run_ends)
742 .add_child_data(new_values.into_data());
743 let array_data = unsafe {
744 builder.build_unchecked()
747 };
748 Ok(array_data.into())
749}
750
751#[allow(clippy::type_complexity)]
757fn take_value_indices_from_list<IndexType, OffsetType>(
758 list: &GenericListArray<OffsetType::Native>,
759 indices: &PrimitiveArray<IndexType>,
760) -> Result<
761 (
762 PrimitiveArray<OffsetType>,
763 Vec<OffsetType::Native>,
764 MutableBuffer,
765 ),
766 ArrowError,
767>
768where
769 IndexType: ArrowPrimitiveType,
770 OffsetType: ArrowPrimitiveType,
771 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
772 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
773{
774 let offsets: &[OffsetType::Native] = list.value_offsets();
776
777 let mut new_offsets = Vec::with_capacity(indices.len());
778 let mut values = Vec::new();
779 let mut current_offset = OffsetType::Native::zero();
780 new_offsets.push(OffsetType::Native::zero());
782
783 let num_bytes = bit_util::ceil(indices.len(), 8);
785 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
786 let null_slice = null_buf.as_slice_mut();
787
788 for i in 0..indices.len() {
790 if indices.is_valid(i) {
791 let ix = indices
792 .value(i)
793 .to_usize()
794 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
795 let start = offsets[ix];
796 let end = offsets[ix + 1];
797 current_offset += end - start;
798 new_offsets.push(current_offset);
799
800 let mut curr = start;
801
802 while curr < end {
804 values.push(curr);
805 curr += One::one();
806 }
807 if !list.is_valid(ix) {
808 bit_util::unset_bit(null_slice, i);
809 }
810 } else {
811 bit_util::unset_bit(null_slice, i);
812 new_offsets.push(current_offset);
813 }
814 }
815
816 Ok((
817 PrimitiveArray::<OffsetType>::from(values),
818 new_offsets,
819 null_buf,
820 ))
821}
822
823fn take_value_indices_from_fixed_size_list<IndexType>(
825 list: &FixedSizeListArray,
826 indices: &PrimitiveArray<IndexType>,
827 length: <UInt32Type as ArrowPrimitiveType>::Native,
828) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
829where
830 IndexType: ArrowPrimitiveType,
831{
832 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
833
834 for i in 0..indices.len() {
835 if indices.is_valid(i) {
836 let index = indices
837 .value(i)
838 .to_usize()
839 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
840 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
841
842 unsafe {
844 values.append_trusted_len_iter(start..start + length);
845 }
846 } else {
847 values.append_nulls(length as usize);
848 }
849 }
850
851 Ok(values.finish())
852}
853
854trait ToIndices {
857 type T: ArrowPrimitiveType;
858
859 fn to_indices(&self) -> PrimitiveArray<Self::T>;
860}
861
862macro_rules! to_indices_reinterpret {
863 ($t:ty, $o:ty) => {
864 impl ToIndices for PrimitiveArray<$t> {
865 type T = $o;
866
867 fn to_indices(&self) -> PrimitiveArray<$o> {
868 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
869 PrimitiveArray::new(cast, self.nulls().cloned())
870 }
871 }
872 };
873}
874
875macro_rules! to_indices_identity {
876 ($t:ty) => {
877 impl ToIndices for PrimitiveArray<$t> {
878 type T = $t;
879
880 fn to_indices(&self) -> PrimitiveArray<$t> {
881 self.clone()
882 }
883 }
884 };
885}
886
887macro_rules! to_indices_widening {
888 ($t:ty, $o:ty) => {
889 impl ToIndices for PrimitiveArray<$t> {
890 type T = UInt32Type;
891
892 fn to_indices(&self) -> PrimitiveArray<$o> {
893 let cast = self.values().iter().copied().map(|x| x as _).collect();
894 PrimitiveArray::new(cast, self.nulls().cloned())
895 }
896 }
897 };
898}
899
900to_indices_widening!(UInt8Type, UInt32Type);
901to_indices_widening!(Int8Type, UInt32Type);
902
903to_indices_widening!(UInt16Type, UInt32Type);
904to_indices_widening!(Int16Type, UInt32Type);
905
906to_indices_identity!(UInt32Type);
907to_indices_reinterpret!(Int32Type, UInt32Type);
908
909to_indices_identity!(UInt64Type);
910to_indices_reinterpret!(Int64Type, UInt64Type);
911
912pub fn take_record_batch(
952 record_batch: &RecordBatch,
953 indices: &dyn Array,
954) -> Result<RecordBatch, ArrowError> {
955 let columns = record_batch
956 .columns()
957 .iter()
958 .map(|c| take(c, indices, None))
959 .collect::<Result<Vec<_>, _>>()?;
960 RecordBatch::try_new(record_batch.schema(), columns)
961}
962
963#[cfg(test)]
964mod tests {
965 use super::*;
966 use arrow_array::builder::*;
967 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
968 use arrow_data::ArrayData;
969 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
970
971 fn test_take_decimal_arrays(
972 data: Vec<Option<i128>>,
973 index: &UInt32Array,
974 options: Option<TakeOptions>,
975 expected_data: Vec<Option<i128>>,
976 precision: &u8,
977 scale: &i8,
978 ) -> Result<(), ArrowError> {
979 let output = data
980 .into_iter()
981 .collect::<Decimal128Array>()
982 .with_precision_and_scale(*precision, *scale)
983 .unwrap();
984
985 let expected = expected_data
986 .into_iter()
987 .collect::<Decimal128Array>()
988 .with_precision_and_scale(*precision, *scale)
989 .unwrap();
990
991 let expected = Arc::new(expected) as ArrayRef;
992 let output = take(&output, index, options).unwrap();
993 assert_eq!(&output, &expected);
994 Ok(())
995 }
996
997 fn test_take_boolean_arrays(
998 data: Vec<Option<bool>>,
999 index: &UInt32Array,
1000 options: Option<TakeOptions>,
1001 expected_data: Vec<Option<bool>>,
1002 ) {
1003 let output = BooleanArray::from(data);
1004 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1005 let output = take(&output, index, options).unwrap();
1006 assert_eq!(&output, &expected)
1007 }
1008
1009 fn test_take_primitive_arrays<T>(
1010 data: Vec<Option<T::Native>>,
1011 index: &UInt32Array,
1012 options: Option<TakeOptions>,
1013 expected_data: Vec<Option<T::Native>>,
1014 ) -> Result<(), ArrowError>
1015 where
1016 T: ArrowPrimitiveType,
1017 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1018 {
1019 let output = PrimitiveArray::<T>::from(data);
1020 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1021 let output = take(&output, index, options)?;
1022 assert_eq!(&output, &expected);
1023 Ok(())
1024 }
1025
1026 fn test_take_primitive_arrays_non_null<T>(
1027 data: Vec<T::Native>,
1028 index: &UInt32Array,
1029 options: Option<TakeOptions>,
1030 expected_data: Vec<Option<T::Native>>,
1031 ) -> Result<(), ArrowError>
1032 where
1033 T: ArrowPrimitiveType,
1034 PrimitiveArray<T>: From<Vec<T::Native>>,
1035 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1036 {
1037 let output = PrimitiveArray::<T>::from(data);
1038 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1039 let output = take(&output, index, options)?;
1040 assert_eq!(&output, &expected);
1041 Ok(())
1042 }
1043
1044 fn test_take_impl_primitive_arrays<T, I>(
1045 data: Vec<Option<T::Native>>,
1046 index: &PrimitiveArray<I>,
1047 options: Option<TakeOptions>,
1048 expected_data: Vec<Option<T::Native>>,
1049 ) where
1050 T: ArrowPrimitiveType,
1051 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1052 I: ArrowPrimitiveType,
1053 {
1054 let output = PrimitiveArray::<T>::from(data);
1055 let expected = PrimitiveArray::<T>::from(expected_data);
1056 let output = take(&output, index, options).unwrap();
1057 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1058 assert_eq!(output, &expected)
1059 }
1060
1061 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1063 let mut struct_builder = StructBuilder::new(
1064 Fields::from(vec![
1065 Field::new("a", DataType::Boolean, true),
1066 Field::new("b", DataType::Int32, true),
1067 ]),
1068 vec![
1069 Box::new(BooleanBuilder::with_capacity(values.len())),
1070 Box::new(Int32Builder::with_capacity(values.len())),
1071 ],
1072 );
1073
1074 for value in values {
1075 struct_builder
1076 .field_builder::<BooleanBuilder>(0)
1077 .unwrap()
1078 .append_option(value.and_then(|v| v.0));
1079 struct_builder
1080 .field_builder::<Int32Builder>(1)
1081 .unwrap()
1082 .append_option(value.and_then(|v| v.1));
1083 struct_builder.append(value.is_some());
1084 }
1085 struct_builder.finish()
1086 }
1087
1088 #[test]
1089 fn test_take_decimal128_non_null_indices() {
1090 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1091 let precision: u8 = 10;
1092 let scale: i8 = 5;
1093 test_take_decimal_arrays(
1094 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1095 &index,
1096 None,
1097 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1098 &precision,
1099 &scale,
1100 )
1101 .unwrap();
1102 }
1103
1104 #[test]
1105 fn test_take_decimal128() {
1106 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1107 let precision: u8 = 10;
1108 let scale: i8 = 5;
1109 test_take_decimal_arrays(
1110 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1111 &index,
1112 None,
1113 vec![Some(3), None, Some(1), Some(3), Some(2)],
1114 &precision,
1115 &scale,
1116 )
1117 .unwrap();
1118 }
1119
1120 #[test]
1121 fn test_take_primitive_non_null_indices() {
1122 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1123 test_take_primitive_arrays::<Int8Type>(
1124 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1125 &index,
1126 None,
1127 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1128 )
1129 .unwrap();
1130 }
1131
1132 #[test]
1133 fn test_take_primitive_non_null_values() {
1134 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1135 test_take_primitive_arrays::<Int8Type>(
1136 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1137 &index,
1138 None,
1139 vec![Some(3), None, Some(1), Some(3), Some(2)],
1140 )
1141 .unwrap();
1142 }
1143
1144 #[test]
1145 fn test_take_primitive_non_null() {
1146 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1147 test_take_primitive_arrays::<Int8Type>(
1148 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1149 &index,
1150 None,
1151 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1152 )
1153 .unwrap();
1154 }
1155
1156 #[test]
1157 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1158 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1159 let index = index.slice(2, 4);
1160 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1161
1162 assert_eq!(
1163 index,
1164 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1165 );
1166
1167 test_take_primitive_arrays_non_null::<Int64Type>(
1168 vec![0, 10, 20, 30, 40, 50],
1169 index,
1170 None,
1171 vec![Some(20), Some(30), None, None],
1172 )
1173 .unwrap();
1174 }
1175
1176 #[test]
1177 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1178 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1179 let index = index.slice(2, 4);
1180 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1181
1182 assert_eq!(
1183 index,
1184 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1185 );
1186
1187 test_take_primitive_arrays::<Int64Type>(
1188 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1189 index,
1190 None,
1191 vec![Some(20), Some(30), None, None],
1192 )
1193 .unwrap();
1194 }
1195
1196 #[test]
1197 fn test_take_primitive() {
1198 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1199
1200 test_take_primitive_arrays::<Int8Type>(
1202 vec![Some(0), None, Some(2), Some(3), None],
1203 &index,
1204 None,
1205 vec![Some(3), None, None, Some(3), Some(2)],
1206 )
1207 .unwrap();
1208
1209 test_take_primitive_arrays::<Int16Type>(
1211 vec![Some(0), None, Some(2), Some(3), None],
1212 &index,
1213 None,
1214 vec![Some(3), None, None, Some(3), Some(2)],
1215 )
1216 .unwrap();
1217
1218 test_take_primitive_arrays::<Int32Type>(
1220 vec![Some(0), None, Some(2), Some(3), None],
1221 &index,
1222 None,
1223 vec![Some(3), None, None, Some(3), Some(2)],
1224 )
1225 .unwrap();
1226
1227 test_take_primitive_arrays::<Int64Type>(
1229 vec![Some(0), None, Some(2), Some(3), None],
1230 &index,
1231 None,
1232 vec![Some(3), None, None, Some(3), Some(2)],
1233 )
1234 .unwrap();
1235
1236 test_take_primitive_arrays::<UInt8Type>(
1238 vec![Some(0), None, Some(2), Some(3), None],
1239 &index,
1240 None,
1241 vec![Some(3), None, None, Some(3), Some(2)],
1242 )
1243 .unwrap();
1244
1245 test_take_primitive_arrays::<UInt16Type>(
1247 vec![Some(0), None, Some(2), Some(3), None],
1248 &index,
1249 None,
1250 vec![Some(3), None, None, Some(3), Some(2)],
1251 )
1252 .unwrap();
1253
1254 test_take_primitive_arrays::<UInt32Type>(
1256 vec![Some(0), None, Some(2), Some(3), None],
1257 &index,
1258 None,
1259 vec![Some(3), None, None, Some(3), Some(2)],
1260 )
1261 .unwrap();
1262
1263 test_take_primitive_arrays::<Int64Type>(
1265 vec![Some(0), None, Some(2), Some(-15), None],
1266 &index,
1267 None,
1268 vec![Some(-15), None, None, Some(-15), Some(2)],
1269 )
1270 .unwrap();
1271
1272 test_take_primitive_arrays::<IntervalYearMonthType>(
1274 vec![Some(0), None, Some(2), Some(-15), None],
1275 &index,
1276 None,
1277 vec![Some(-15), None, None, Some(-15), Some(2)],
1278 )
1279 .unwrap();
1280
1281 let v1 = IntervalDayTime::new(0, 0);
1283 let v2 = IntervalDayTime::new(2, 0);
1284 let v3 = IntervalDayTime::new(-15, 0);
1285 test_take_primitive_arrays::<IntervalDayTimeType>(
1286 vec![Some(v1), None, Some(v2), Some(v3), None],
1287 &index,
1288 None,
1289 vec![Some(v3), None, None, Some(v3), Some(v2)],
1290 )
1291 .unwrap();
1292
1293 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1295 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1296 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1297 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1298 vec![Some(v1), None, Some(v2), Some(v3), None],
1299 &index,
1300 None,
1301 vec![Some(v3), None, None, Some(v3), Some(v2)],
1302 )
1303 .unwrap();
1304
1305 test_take_primitive_arrays::<DurationSecondType>(
1307 vec![Some(0), None, Some(2), Some(-15), None],
1308 &index,
1309 None,
1310 vec![Some(-15), None, None, Some(-15), Some(2)],
1311 )
1312 .unwrap();
1313
1314 test_take_primitive_arrays::<DurationMillisecondType>(
1316 vec![Some(0), None, Some(2), Some(-15), None],
1317 &index,
1318 None,
1319 vec![Some(-15), None, None, Some(-15), Some(2)],
1320 )
1321 .unwrap();
1322
1323 test_take_primitive_arrays::<DurationMicrosecondType>(
1325 vec![Some(0), None, Some(2), Some(-15), None],
1326 &index,
1327 None,
1328 vec![Some(-15), None, None, Some(-15), Some(2)],
1329 )
1330 .unwrap();
1331
1332 test_take_primitive_arrays::<DurationNanosecondType>(
1334 vec![Some(0), None, Some(2), Some(-15), None],
1335 &index,
1336 None,
1337 vec![Some(-15), None, None, Some(-15), Some(2)],
1338 )
1339 .unwrap();
1340
1341 test_take_primitive_arrays::<Float32Type>(
1343 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1344 &index,
1345 None,
1346 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1347 )
1348 .unwrap();
1349
1350 test_take_primitive_arrays::<Float64Type>(
1352 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1353 &index,
1354 None,
1355 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1356 )
1357 .unwrap();
1358 }
1359
1360 #[test]
1361 fn test_take_preserve_timezone() {
1362 let index = Int64Array::from(vec![Some(0), None]);
1363
1364 let input = TimestampNanosecondArray::from(vec![
1365 1_639_715_368_000_000_000,
1366 1_639_715_368_000_000_000,
1367 ])
1368 .with_timezone("UTC".to_string());
1369 let result = take(&input, &index, None).unwrap();
1370 match result.data_type() {
1371 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1372 assert_eq!(tz.clone(), Some("UTC".into()))
1373 }
1374 _ => panic!(),
1375 }
1376 }
1377
1378 #[test]
1379 fn test_take_impl_primitive_with_int64_indices() {
1380 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1381
1382 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1384 vec![Some(0), None, Some(2), Some(3), None],
1385 &index,
1386 None,
1387 vec![Some(3), None, None, Some(3), Some(2)],
1388 );
1389
1390 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1392 vec![Some(0), None, Some(2), Some(-15), None],
1393 &index,
1394 None,
1395 vec![Some(-15), None, None, Some(-15), Some(2)],
1396 );
1397
1398 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1400 vec![Some(0), None, Some(2), Some(3), None],
1401 &index,
1402 None,
1403 vec![Some(3), None, None, Some(3), Some(2)],
1404 );
1405
1406 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1408 vec![Some(0), None, Some(2), Some(-15), None],
1409 &index,
1410 None,
1411 vec![Some(-15), None, None, Some(-15), Some(2)],
1412 );
1413
1414 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1416 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1417 &index,
1418 None,
1419 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1420 );
1421 }
1422
1423 #[test]
1424 fn test_take_impl_primitive_with_uint8_indices() {
1425 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1426
1427 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1429 vec![Some(0), None, Some(2), Some(3), None],
1430 &index,
1431 None,
1432 vec![Some(3), None, None, Some(3), Some(2)],
1433 );
1434
1435 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1437 vec![Some(0), None, Some(2), Some(-15), None],
1438 &index,
1439 None,
1440 vec![Some(-15), None, None, Some(-15), Some(2)],
1441 );
1442
1443 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1445 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1446 &index,
1447 None,
1448 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1449 );
1450 }
1451
1452 #[test]
1453 fn test_take_bool() {
1454 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1455 test_take_boolean_arrays(
1457 vec![Some(false), None, Some(true), Some(false), None],
1458 &index,
1459 None,
1460 vec![Some(false), None, None, Some(false), Some(true)],
1461 );
1462 }
1463
1464 #[test]
1465 fn test_take_bool_nullable_index() {
1466 let index_data = ArrayData::try_new(
1468 DataType::UInt32,
1469 6,
1470 Some(Buffer::from_iter(vec![
1471 false, true, false, true, false, true,
1472 ])),
1473 0,
1474 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1475 vec![],
1476 )
1477 .unwrap();
1478 let index = UInt32Array::from(index_data);
1479 test_take_boolean_arrays(
1480 vec![Some(true), None, Some(false)],
1481 &index,
1482 None,
1483 vec![None, Some(true), None, None, None, Some(false)],
1484 );
1485 }
1486
1487 #[test]
1488 fn test_take_bool_nullable_index_nonnull_values() {
1489 let index_data = ArrayData::try_new(
1491 DataType::UInt32,
1492 6,
1493 Some(Buffer::from_iter(vec![
1494 false, true, false, true, false, true,
1495 ])),
1496 0,
1497 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1498 vec![],
1499 )
1500 .unwrap();
1501 let index = UInt32Array::from(index_data);
1502 test_take_boolean_arrays(
1503 vec![Some(true), Some(true), Some(false)],
1504 &index,
1505 None,
1506 vec![None, Some(true), None, Some(true), None, Some(false)],
1507 );
1508 }
1509
1510 #[test]
1511 fn test_take_bool_with_offset() {
1512 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1513 let index = index.slice(2, 4);
1514 let index = index
1515 .as_any()
1516 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1517 .unwrap();
1518
1519 test_take_boolean_arrays(
1521 vec![Some(false), None, Some(true), Some(false), None],
1522 index,
1523 None,
1524 vec![None, Some(false), Some(true), None],
1525 );
1526 }
1527
1528 fn _test_take_string<'a, K>()
1529 where
1530 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1531 {
1532 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1533
1534 let array = K::from(vec![
1535 Some("one"),
1536 None,
1537 Some("three"),
1538 Some("four"),
1539 Some("five"),
1540 ]);
1541 let actual = take(&array, &index, None).unwrap();
1542 assert_eq!(actual.len(), index.len());
1543
1544 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1545
1546 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1547
1548 assert_eq!(actual, &expected);
1549 }
1550
1551 #[test]
1552 fn test_take_string() {
1553 _test_take_string::<StringArray>()
1554 }
1555
1556 #[test]
1557 fn test_take_large_string() {
1558 _test_take_string::<LargeStringArray>()
1559 }
1560
1561 #[test]
1562 fn test_take_slice_string() {
1563 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1564 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1565 let indices_slice = indices.slice(1, 4);
1566 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1567 let result = take(&strings, &indices_slice, None).unwrap();
1568 assert_eq!(result.as_ref(), &expected);
1569 }
1570
1571 fn _test_byte_view<T>()
1572 where
1573 T: ByteViewType,
1574 str: AsRef<T::Native>,
1575 T::Native: PartialEq,
1576 {
1577 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1578 let array = {
1579 let mut builder = GenericByteViewBuilder::<T>::new();
1581 builder.append_value("hello");
1582 builder.append_value("world");
1583 builder.append_null();
1584 builder.append_value("large payload over 12 bytes");
1585 builder.append_value("lulu");
1586 builder.finish()
1587 };
1588
1589 let actual = take(&array, &index, None).unwrap();
1590
1591 assert_eq!(actual.len(), index.len());
1592
1593 let expected = {
1594 let mut builder = GenericByteViewBuilder::<T>::new();
1596 builder.append_value("large payload over 12 bytes");
1597 builder.append_null();
1598 builder.append_value("world");
1599 builder.append_value("large payload over 12 bytes");
1600 builder.append_value("lulu");
1601 builder.append_null();
1602 builder.finish()
1603 };
1604
1605 assert_eq!(actual.as_ref(), &expected);
1606 }
1607
1608 #[test]
1609 fn test_take_string_view() {
1610 _test_byte_view::<StringViewType>()
1611 }
1612
1613 #[test]
1614 fn test_take_binary_view() {
1615 _test_byte_view::<BinaryViewType>()
1616 }
1617
1618 macro_rules! test_take_list {
1619 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1620 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1622 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1624 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1625 let list_data_type =
1627 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1628 let list_data = ArrayData::builder(list_data_type.clone())
1629 .len(4)
1630 .add_buffer(value_offsets)
1631 .add_child_data(value_data)
1632 .build()
1633 .unwrap();
1634 let list_array = $list_array_type::from(list_data);
1635
1636 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1638
1639 let a = take(&list_array, &index, None).unwrap();
1640 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1641
1642 let expected_data = Int32Array::from(vec![
1645 Some(2),
1646 Some(3),
1647 Some(-1),
1648 Some(-2),
1649 Some(-1),
1650 Some(0),
1651 Some(0),
1652 Some(0),
1653 ])
1654 .into_data();
1655 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1657 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1658 let expected_list_data = ArrayData::builder(list_data_type)
1660 .len(5)
1661 .nulls(index.nulls().cloned())
1663 .add_buffer(expected_offsets)
1664 .add_child_data(expected_data)
1665 .build()
1666 .unwrap();
1667 let expected_list_array = $list_array_type::from(expected_list_data);
1668
1669 assert_eq!(a, &expected_list_array);
1670 }};
1671 }
1672
1673 macro_rules! test_take_list_with_value_nulls {
1674 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1675 let value_data = Int32Array::from(vec![
1677 Some(0),
1678 None,
1679 Some(0),
1680 Some(-1),
1681 Some(-2),
1682 Some(3),
1683 None,
1684 Some(5),
1685 None,
1686 ])
1687 .into_data();
1688 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1690 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1691 let list_data_type =
1693 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1694 let list_data = ArrayData::builder(list_data_type.clone())
1695 .len(4)
1696 .add_buffer(value_offsets)
1697 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1698 .add_child_data(value_data)
1699 .build()
1700 .unwrap();
1701 let list_array = $list_array_type::from(list_data);
1702
1703 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1705
1706 let a = take(&list_array, &index, None).unwrap();
1707 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1708
1709 let expected_data = Int32Array::from(vec![
1712 None,
1713 Some(-1),
1714 Some(-2),
1715 Some(3),
1716 Some(5),
1717 None,
1718 Some(0),
1719 None,
1720 Some(0),
1721 ])
1722 .into_data();
1723 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1725 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1726 let expected_list_data = ArrayData::builder(list_data_type)
1728 .len(5)
1729 .nulls(index.nulls().cloned())
1731 .add_buffer(expected_offsets)
1732 .add_child_data(expected_data)
1733 .build()
1734 .unwrap();
1735 let expected_list_array = $list_array_type::from(expected_list_data);
1736
1737 assert_eq!(a, &expected_list_array);
1738 }};
1739 }
1740
1741 macro_rules! test_take_list_with_nulls {
1742 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1743 let value_data = Int32Array::from(vec![
1745 Some(0),
1746 None,
1747 Some(0),
1748 Some(-1),
1749 Some(-2),
1750 Some(3),
1751 Some(5),
1752 None,
1753 ])
1754 .into_data();
1755 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1757 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1758 let list_data_type =
1760 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1761 let list_data = ArrayData::builder(list_data_type.clone())
1762 .len(4)
1763 .add_buffer(value_offsets)
1764 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1765 .add_child_data(value_data)
1766 .build()
1767 .unwrap();
1768 let list_array = $list_array_type::from(list_data);
1769
1770 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1772
1773 let a = take(&list_array, &index, None).unwrap();
1774 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1775
1776 let expected_data = Int32Array::from(vec![
1779 Some(-1),
1780 Some(-2),
1781 Some(3),
1782 Some(5),
1783 None,
1784 Some(0),
1785 None,
1786 Some(0),
1787 ])
1788 .into_data();
1789 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1791 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1792 let mut null_bits: [u8; 1] = [0; 1];
1794 bit_util::set_bit(&mut null_bits, 2);
1795 bit_util::set_bit(&mut null_bits, 3);
1796 bit_util::set_bit(&mut null_bits, 4);
1797 let expected_list_data = ArrayData::builder(list_data_type)
1798 .len(5)
1799 .null_bit_buffer(Some(Buffer::from(null_bits)))
1801 .add_buffer(expected_offsets)
1802 .add_child_data(expected_data)
1803 .build()
1804 .unwrap();
1805 let expected_list_array = $list_array_type::from(expected_list_data);
1806
1807 assert_eq!(a, &expected_list_array);
1808 }};
1809 }
1810
1811 fn do_take_fixed_size_list_test<T>(
1812 length: <Int32Type as ArrowPrimitiveType>::Native,
1813 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1814 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1815 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1816 ) where
1817 T: ArrowPrimitiveType,
1818 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1819 {
1820 let indices = UInt32Array::from(indices);
1821
1822 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1823
1824 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1825
1826 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1827
1828 assert_eq!(&output, &expected)
1829 }
1830
1831 #[test]
1832 fn test_take_list() {
1833 test_take_list!(i32, List, ListArray);
1834 }
1835
1836 #[test]
1837 fn test_take_large_list() {
1838 test_take_list!(i64, LargeList, LargeListArray);
1839 }
1840
1841 #[test]
1842 fn test_take_list_with_value_nulls() {
1843 test_take_list_with_value_nulls!(i32, List, ListArray);
1844 }
1845
1846 #[test]
1847 fn test_take_large_list_with_value_nulls() {
1848 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1849 }
1850
1851 #[test]
1852 fn test_test_take_list_with_nulls() {
1853 test_take_list_with_nulls!(i32, List, ListArray);
1854 }
1855
1856 #[test]
1857 fn test_test_take_large_list_with_nulls() {
1858 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1859 }
1860
1861 #[test]
1862 fn test_take_fixed_size_list() {
1863 do_take_fixed_size_list_test::<Int32Type>(
1864 3,
1865 vec![
1866 Some(vec![None, Some(1), Some(2)]),
1867 Some(vec![Some(3), Some(4), None]),
1868 Some(vec![Some(6), Some(7), Some(8)]),
1869 ],
1870 vec![2, 1, 0],
1871 vec![
1872 Some(vec![Some(6), Some(7), Some(8)]),
1873 Some(vec![Some(3), Some(4), None]),
1874 Some(vec![None, Some(1), Some(2)]),
1875 ],
1876 );
1877
1878 do_take_fixed_size_list_test::<UInt8Type>(
1879 1,
1880 vec![
1881 Some(vec![Some(1)]),
1882 Some(vec![Some(2)]),
1883 Some(vec![Some(3)]),
1884 Some(vec![Some(4)]),
1885 Some(vec![Some(5)]),
1886 Some(vec![Some(6)]),
1887 Some(vec![Some(7)]),
1888 Some(vec![Some(8)]),
1889 ],
1890 vec![2, 7, 0],
1891 vec![
1892 Some(vec![Some(3)]),
1893 Some(vec![Some(8)]),
1894 Some(vec![Some(1)]),
1895 ],
1896 );
1897
1898 do_take_fixed_size_list_test::<UInt64Type>(
1899 3,
1900 vec![
1901 Some(vec![Some(10), Some(11), Some(12)]),
1902 Some(vec![Some(13), Some(14), Some(15)]),
1903 None,
1904 Some(vec![Some(16), Some(17), Some(18)]),
1905 ],
1906 vec![3, 2, 1, 2, 0],
1907 vec![
1908 Some(vec![Some(16), Some(17), Some(18)]),
1909 None,
1910 Some(vec![Some(13), Some(14), Some(15)]),
1911 None,
1912 Some(vec![Some(10), Some(11), Some(12)]),
1913 ],
1914 );
1915 }
1916
1917 #[test]
1918 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1919 fn test_take_list_out_of_bounds() {
1920 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1922 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1924 let list_data_type =
1926 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1927 let list_data = ArrayData::builder(list_data_type)
1928 .len(3)
1929 .add_buffer(value_offsets)
1930 .add_child_data(value_data)
1931 .build()
1932 .unwrap();
1933 let list_array = ListArray::from(list_data);
1934
1935 let index = UInt32Array::from(vec![1000]);
1936
1937 take(&list_array, &index, None).unwrap();
1940 }
1941
1942 #[test]
1943 fn test_take_map() {
1944 let values = Int32Array::from(vec![1, 2, 3, 4]);
1945 let array =
1946 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1947 .unwrap();
1948
1949 let index = UInt32Array::from(vec![0]);
1950
1951 let result = take(&array, &index, None).unwrap();
1952 let expected: ArrayRef = Arc::new(
1953 MapArray::new_from_strings(
1954 vec!["a", "b", "c"].into_iter(),
1955 &values.slice(0, 3),
1956 &[0, 3],
1957 )
1958 .unwrap(),
1959 );
1960 assert_eq!(&expected, &result);
1961 }
1962
1963 #[test]
1964 fn test_take_struct() {
1965 let array = create_test_struct(vec![
1966 Some((Some(true), Some(42))),
1967 Some((Some(false), Some(28))),
1968 Some((Some(false), Some(19))),
1969 Some((Some(true), Some(31))),
1970 None,
1971 ]);
1972
1973 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1974 let actual = take(&array, &index, None).unwrap();
1975 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1976 assert_eq!(index.len(), actual.len());
1977 assert_eq!(1, actual.null_count());
1978
1979 let expected = create_test_struct(vec![
1980 Some((Some(true), Some(42))),
1981 Some((Some(true), Some(31))),
1982 Some((Some(false), Some(28))),
1983 Some((Some(true), Some(42))),
1984 Some((Some(false), Some(19))),
1985 None,
1986 ]);
1987
1988 assert_eq!(&expected, actual);
1989
1990 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
1991 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
1992 let index = UInt32Array::from(vec![0, 2, 1, 4]);
1993 let actual = take(&empty_struct_arr, &index, None).unwrap();
1994
1995 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
1996 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
1997 assert_eq!(&expected_struct_arr, actual.as_struct());
1998 }
1999
2000 #[test]
2001 fn test_take_struct_with_null_indices() {
2002 let array = create_test_struct(vec![
2003 Some((Some(true), Some(42))),
2004 Some((Some(false), Some(28))),
2005 Some((Some(false), Some(19))),
2006 Some((Some(true), Some(31))),
2007 None,
2008 ]);
2009
2010 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2011 let actual = take(&array, &index, None).unwrap();
2012 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2013 assert_eq!(index.len(), actual.len());
2014 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2017 None,
2018 Some((Some(true), Some(31))),
2019 Some((Some(false), Some(28))),
2020 None,
2021 Some((Some(true), Some(42))),
2022 None,
2023 ]);
2024
2025 assert_eq!(&expected, actual);
2026 }
2027
2028 #[test]
2029 fn test_take_out_of_bounds() {
2030 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2031 let take_opt = TakeOptions { check_bounds: true };
2032
2033 let result = test_take_primitive_arrays::<Int64Type>(
2035 vec![Some(0), None, Some(2), Some(3), None],
2036 &index,
2037 Some(take_opt),
2038 vec![None],
2039 );
2040 assert!(result.is_err());
2041 }
2042
2043 #[test]
2044 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2045 fn test_take_out_of_bounds_panic() {
2046 let index = UInt32Array::from(vec![Some(1000)]);
2047
2048 test_take_primitive_arrays::<Int64Type>(
2049 vec![Some(0), Some(1), Some(2), Some(3)],
2050 &index,
2051 None,
2052 vec![None],
2053 )
2054 .unwrap();
2055 }
2056
2057 #[test]
2058 fn test_null_array_smaller_than_indices() {
2059 let values = NullArray::new(2);
2060 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2061
2062 let result = take(&values, &indices, None).unwrap();
2063 let expected: ArrayRef = Arc::new(NullArray::new(3));
2064 assert_eq!(&result, &expected);
2065 }
2066
2067 #[test]
2068 fn test_null_array_larger_than_indices() {
2069 let values = NullArray::new(5);
2070 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2071
2072 let result = take(&values, &indices, None).unwrap();
2073 let expected: ArrayRef = Arc::new(NullArray::new(3));
2074 assert_eq!(&result, &expected);
2075 }
2076
2077 #[test]
2078 fn test_null_array_indices_out_of_bounds() {
2079 let values = NullArray::new(5);
2080 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2081
2082 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2083 assert_eq!(
2084 result.unwrap_err().to_string(),
2085 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2086 );
2087 }
2088
2089 #[test]
2090 fn test_take_dict() {
2091 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2092
2093 dict_builder.append("foo").unwrap();
2094 dict_builder.append("bar").unwrap();
2095 dict_builder.append("").unwrap();
2096 dict_builder.append_null();
2097 dict_builder.append("foo").unwrap();
2098 dict_builder.append("bar").unwrap();
2099 dict_builder.append("bar").unwrap();
2100 dict_builder.append("foo").unwrap();
2101
2102 let array = dict_builder.finish();
2103 let dict_values = array.values().clone();
2104 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2105
2106 let indices = UInt32Array::from(vec![
2107 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2115
2116 let result = take(&array, &indices, None).unwrap();
2117 let result = result
2118 .as_any()
2119 .downcast_ref::<DictionaryArray<Int16Type>>()
2120 .unwrap();
2121
2122 let result_values: StringArray = result.values().to_data().into();
2123
2124 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2126 assert_eq!(&expected_values, dict_values);
2127 assert_eq!(&expected_values, &result_values);
2128
2129 let expected_keys = Int16Array::from(vec![
2130 Some(0),
2131 Some(0),
2132 None,
2133 Some(1),
2134 Some(1),
2135 Some(2),
2136 None,
2137 ]);
2138 assert_eq!(result.keys(), &expected_keys);
2139 }
2140
2141 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2142 where
2143 S: OffsetSizeTrait + 'static,
2144 T: ArrowPrimitiveType,
2145 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2146 {
2147 GenericListArray::from_iter_primitive::<T, _, _>(
2148 data.iter()
2149 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2150 )
2151 }
2152
2153 #[test]
2154 fn test_take_value_index_from_list() {
2155 let list = build_generic_list::<i32, Int32Type>(vec![
2156 Some(vec![0, 1]),
2157 Some(vec![2, 3, 4]),
2158 Some(vec![5, 6, 7, 8, 9]),
2159 ]);
2160 let indices = UInt32Array::from(vec![2, 0]);
2161
2162 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2163
2164 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2165 assert_eq!(offsets, vec![0, 5, 7]);
2166 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2167 }
2168
2169 #[test]
2170 fn test_take_value_index_from_large_list() {
2171 let list = build_generic_list::<i64, Int32Type>(vec![
2172 Some(vec![0, 1]),
2173 Some(vec![2, 3, 4]),
2174 Some(vec![5, 6, 7, 8, 9]),
2175 ]);
2176 let indices = UInt32Array::from(vec![2, 0]);
2177
2178 let (indexed, offsets, null_buf) =
2179 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2180
2181 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2182 assert_eq!(offsets, vec![0, 5, 7]);
2183 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2184 }
2185
2186 #[test]
2187 fn test_take_runs() {
2188 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2189
2190 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2191 builder.extend(logical_array.into_iter().map(Some));
2192 let run_array = builder.finish();
2193
2194 let take_indices: PrimitiveArray<Int32Type> =
2195 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2196
2197 let take_out = take_run(&run_array, &take_indices).unwrap();
2198
2199 assert_eq!(take_out.len(), 7);
2200 assert_eq!(take_out.run_ends().len(), 7);
2201 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2202
2203 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2204 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2205 }
2206
2207 #[test]
2208 fn test_take_value_index_from_fixed_list() {
2209 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2210 vec![
2211 Some(vec![Some(1), Some(2), None]),
2212 Some(vec![Some(4), None, Some(6)]),
2213 None,
2214 Some(vec![None, Some(8), Some(9)]),
2215 ],
2216 3,
2217 );
2218
2219 let indices = UInt32Array::from(vec![2, 1, 0]);
2220 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2221
2222 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2223
2224 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2225 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2226
2227 assert_eq!(
2228 indexed,
2229 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2230 );
2231 }
2232
2233 #[test]
2234 fn test_take_null_indices() {
2235 let indices = Int32Array::new(
2237 vec![1, 2, 400, 400].into(),
2238 Some(NullBuffer::from(vec![true, true, false, false])),
2239 );
2240 let values = Int32Array::from(vec![1, 23, 4, 5]);
2241 let r = take(&values, &indices, None).unwrap();
2242 let values = r
2243 .as_primitive::<Int32Type>()
2244 .into_iter()
2245 .collect::<Vec<_>>();
2246 assert_eq!(&values, &[Some(23), Some(4), None, None])
2247 }
2248
2249 #[test]
2250 fn test_take_fixed_size_list_null_indices() {
2251 let indices = Int32Array::from_iter([Some(0), None]);
2252 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2253 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2254 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2255
2256 let r = take(&values, &indices, None).unwrap();
2257 let values = r
2258 .as_fixed_size_list()
2259 .values()
2260 .as_primitive::<Int32Type>()
2261 .into_iter()
2262 .collect::<Vec<_>>();
2263 assert_eq!(values, &[Some(0), Some(1), None, None])
2264 }
2265
2266 #[test]
2267 fn test_take_bytes_null_indices() {
2268 let indices = Int32Array::new(
2269 vec![0, 1, 400, 400].into(),
2270 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2271 );
2272 let values = StringArray::from(vec![Some("foo"), None]);
2273 let r = take(&values, &indices, None).unwrap();
2274 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2275 assert_eq!(&values, &[Some("foo"), None, None, None])
2276 }
2277
2278 #[test]
2279 fn test_take_union_sparse() {
2280 let structs = create_test_struct(vec![
2281 Some((Some(true), Some(42))),
2282 Some((Some(false), Some(28))),
2283 Some((Some(false), Some(19))),
2284 Some((Some(true), Some(31))),
2285 None,
2286 ]);
2287 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2288 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2289
2290 let union_fields = [
2291 (
2292 0,
2293 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2294 ),
2295 (
2296 1,
2297 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2298 ),
2299 ]
2300 .into_iter()
2301 .collect();
2302 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2303 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2304
2305 let indices = vec![0, 3, 1, 0, 2, 4];
2306 let index = UInt32Array::from(indices.clone());
2307 let actual = take(&array, &index, None).unwrap();
2308 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2309 let strings = actual.child(1);
2310 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2311
2312 let actual = strings.iter().collect::<Vec<_>>();
2313 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2314 assert_eq!(expected, actual);
2315 }
2316
2317 #[test]
2318 fn test_take_union_dense() {
2319 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2320 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2321 let ints = vec![10, 20, 30, 40];
2322 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2323
2324 let indices = vec![0, 3, 1, 0, 2, 4];
2325
2326 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2327 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2328 let taken_ints = vec![10, 20, 10, 30];
2329 let taken_strings = vec![Some("a"), None];
2330
2331 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2332 let offsets = <ScalarBuffer<i32>>::from(offsets);
2333 let ints = UInt32Array::from(ints);
2334 let strings = StringArray::from(strings);
2335
2336 let union_fields = [
2337 (
2338 0,
2339 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2340 ),
2341 (
2342 1,
2343 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2344 ),
2345 ]
2346 .into_iter()
2347 .collect();
2348
2349 let array = UnionArray::try_new(
2350 union_fields,
2351 type_ids,
2352 Some(offsets),
2353 vec![Arc::new(ints), Arc::new(strings)],
2354 )
2355 .unwrap();
2356
2357 let index = UInt32Array::from(indices);
2358
2359 let actual = take(&array, &index, None).unwrap();
2360 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2361
2362 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2363 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2364 assert_eq!(
2365 UInt32Array::from(actual.child(0).to_data()),
2366 UInt32Array::from(taken_ints)
2367 );
2368 assert_eq!(
2369 StringArray::from(actual.child(1).to_data()),
2370 StringArray::from(taken_strings)
2371 );
2372 }
2373
2374 #[test]
2375 fn test_take_union_dense_using_builder() {
2376 let mut builder = UnionBuilder::new_dense();
2377
2378 builder.append::<Int32Type>("a", 1).unwrap();
2379 builder.append::<Float64Type>("b", 3.0).unwrap();
2380 builder.append::<Int32Type>("a", 4).unwrap();
2381 builder.append::<Int32Type>("a", 5).unwrap();
2382 builder.append::<Float64Type>("b", 2.0).unwrap();
2383
2384 let union = builder.build().unwrap();
2385
2386 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2387
2388 let mut builder = UnionBuilder::new_dense();
2389
2390 builder.append::<Int32Type>("a", 4).unwrap();
2391 builder.append::<Int32Type>("a", 1).unwrap();
2392 builder.append::<Float64Type>("b", 3.0).unwrap();
2393 builder.append::<Int32Type>("a", 4).unwrap();
2394
2395 let taken = builder.build().unwrap();
2396
2397 assert_eq!(
2398 taken.to_data(),
2399 take(&union, &indices, None).unwrap().to_data()
2400 );
2401 }
2402
2403 #[test]
2404 fn test_take_union_dense_all_match_issue_6206() {
2405 let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2406 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2407
2408 let array = UnionArray::try_new(
2409 fields,
2410 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2411 Some(ScalarBuffer::from_iter(0_i32..5)),
2412 vec![ints],
2413 )
2414 .unwrap();
2415
2416 let indicies = Int64Array::from(vec![0, 2, 4]);
2417 let array = take(&array, &indicies, None).unwrap();
2418 assert_eq!(array.len(), 3);
2419 }
2420}