1use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27 bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer,
28 ScalarBuffer,
29};
30use arrow_data::ArrayDataBuilder;
31use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33use num::{One, Zero};
34
35pub fn take(
81 values: &dyn Array,
82 indices: &dyn Array,
83 options: Option<TakeOptions>,
84) -> Result<ArrayRef, ArrowError> {
85 let options = options.unwrap_or_default();
86 downcast_integer_array!(
87 indices => {
88 if options.check_bounds {
89 check_bounds(values.len(), indices)?;
90 }
91 let indices = indices.to_indices();
92 take_impl(values, &indices)
93 },
94 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
95 )
96}
97
98pub fn take_arrays(
147 arrays: &[ArrayRef],
148 indices: &dyn Array,
149 options: Option<TakeOptions>,
150) -> Result<Vec<ArrayRef>, ArrowError> {
151 arrays
152 .iter()
153 .map(|array| take(array.as_ref(), indices, options.clone()))
154 .collect()
155}
156
157fn check_bounds<T: ArrowPrimitiveType>(
159 len: usize,
160 indices: &PrimitiveArray<T>,
161) -> Result<(), ArrowError> {
162 if indices.null_count() > 0 {
163 indices.iter().flatten().try_for_each(|index| {
164 let ix = index
165 .to_usize()
166 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
167 if ix >= len {
168 return Err(ArrowError::ComputeError(format!(
169 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
170 )));
171 }
172 Ok(())
173 })
174 } else {
175 indices.values().iter().try_for_each(|index| {
176 let ix = index
177 .to_usize()
178 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
179 if ix >= len {
180 return Err(ArrowError::ComputeError(format!(
181 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
182 )));
183 }
184 Ok(())
185 })
186 }
187}
188
189#[inline(never)]
190fn take_impl<IndexType: ArrowPrimitiveType>(
191 values: &dyn Array,
192 indices: &PrimitiveArray<IndexType>,
193) -> Result<ArrayRef, ArrowError> {
194 downcast_primitive_array! {
195 values => Ok(Arc::new(take_primitive(values, indices)?)),
196 DataType::Boolean => {
197 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
198 Ok(Arc::new(take_boolean(values, indices)))
199 }
200 DataType::Utf8 => {
201 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
202 }
203 DataType::LargeUtf8 => {
204 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
205 }
206 DataType::Utf8View => {
207 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
208 }
209 DataType::List(_) => {
210 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
211 }
212 DataType::LargeList(_) => {
213 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
214 }
215 DataType::FixedSizeList(_, length) => {
216 let values = values
217 .as_any()
218 .downcast_ref::<FixedSizeListArray>()
219 .unwrap();
220 Ok(Arc::new(take_fixed_size_list(
221 values,
222 indices,
223 *length as u32,
224 )?))
225 }
226 DataType::Map(_, _) => {
227 let list_arr = ListArray::from(values.as_map().clone());
228 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
229 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
230 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
231 }
232 DataType::Struct(fields) => {
233 let array: &StructArray = values.as_struct();
234 let arrays = array
235 .columns()
236 .iter()
237 .map(|a| take_impl(a.as_ref(), indices))
238 .collect::<Result<Vec<ArrayRef>, _>>()?;
239 let fields: Vec<(FieldRef, ArrayRef)> =
240 fields.iter().cloned().zip(arrays).collect();
241
242 let is_valid: Buffer = indices
244 .iter()
245 .map(|index| {
246 if let Some(index) = index {
247 array.is_valid(index.to_usize().unwrap())
248 } else {
249 false
250 }
251 })
252 .collect();
253
254 if fields.is_empty() {
255 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
256 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
257 } else {
258 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
259 }
260 }
261 DataType::Dictionary(_, _) => downcast_dictionary_array! {
262 values => Ok(Arc::new(take_dict(values, indices)?)),
263 t => unimplemented!("Take not supported for dictionary type {:?}", t)
264 }
265 DataType::RunEndEncoded(_, _) => downcast_run_array! {
266 values => Ok(Arc::new(take_run(values, indices)?)),
267 t => unimplemented!("Take not supported for run type {:?}", t)
268 }
269 DataType::Binary => {
270 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
271 }
272 DataType::LargeBinary => {
273 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
274 }
275 DataType::BinaryView => {
276 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
277 }
278 DataType::FixedSizeBinary(size) => {
279 let values = values
280 .as_any()
281 .downcast_ref::<FixedSizeBinaryArray>()
282 .unwrap();
283 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
284 }
285 DataType::Null => {
286 if values.len() >= indices.len() {
288 Ok(values.slice(0, indices.len()))
291 } else {
292 Ok(new_null_array(&DataType::Null, indices.len()))
294 }
295 }
296 DataType::Union(fields, UnionMode::Sparse) => {
297 let mut children = Vec::with_capacity(fields.len());
298 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
299 let type_ids = take_native(values.type_ids(), indices);
300 for (type_id, _field) in fields.iter() {
301 let values = values.child(type_id);
302 let values = take_impl(values, indices)?;
303 children.push(values);
304 }
305 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
306 Ok(Arc::new(array))
307 }
308 DataType::Union(fields, UnionMode::Dense) => {
309 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
310
311 let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
312 let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
313
314 let children = fields.iter()
315 .map(|(field_type_id, _)| {
316 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
317
318 let indices = crate::filter::filter(&offsets, &mask)?;
319
320 let values = values.child(field_type_id);
321
322 take_impl(values, indices.as_primitive::<Int32Type>())
323 })
324 .collect::<Result<_, _>>()?;
325
326 let mut child_offsets = [0; 128];
327
328 let offsets = type_ids.values()
329 .iter()
330 .map(|&i| {
331 let offset = child_offsets[i as usize];
332
333 child_offsets[i as usize] += 1;
334
335 offset
336 })
337 .collect();
338
339 let (_, type_ids, _) = type_ids.into_parts();
340
341 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
342
343 Ok(Arc::new(array))
344 }
345 t => unimplemented!("Take not supported for data type {:?}", t)
346 }
347}
348
349#[derive(Clone, Debug, Default)]
351pub struct TakeOptions {
352 pub check_bounds: bool,
356}
357
358#[inline(always)]
359fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
360 index
361 .to_usize()
362 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
363}
364
365fn take_primitive<T, I>(
375 values: &PrimitiveArray<T>,
376 indices: &PrimitiveArray<I>,
377) -> Result<PrimitiveArray<T>, ArrowError>
378where
379 T: ArrowPrimitiveType,
380 I: ArrowPrimitiveType,
381{
382 let values_buf = take_native(values.values(), indices);
383 let nulls = take_nulls(values.nulls(), indices);
384 Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
385}
386
387#[inline(never)]
388fn take_nulls<I: ArrowPrimitiveType>(
389 values: Option<&NullBuffer>,
390 indices: &PrimitiveArray<I>,
391) -> Option<NullBuffer> {
392 match values.filter(|n| n.null_count() > 0) {
393 Some(n) => {
394 let buffer = take_bits(n.inner(), indices);
395 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
396 }
397 None => indices.nulls().cloned(),
398 }
399}
400
401#[inline(never)]
402fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
403 values: &[T],
404 indices: &PrimitiveArray<I>,
405) -> ScalarBuffer<T> {
406 match indices.nulls().filter(|n| n.null_count() > 0) {
407 Some(n) => indices
408 .values()
409 .iter()
410 .enumerate()
411 .map(|(idx, index)| match values.get(index.as_usize()) {
412 Some(v) => *v,
413 None => match n.is_null(idx) {
414 true => T::default(),
415 false => panic!("Out-of-bounds index {index:?}"),
416 },
417 })
418 .collect(),
419 None => indices
420 .values()
421 .iter()
422 .map(|index| values[index.as_usize()])
423 .collect(),
424 }
425}
426
427#[inline(never)]
428fn take_bits<I: ArrowPrimitiveType>(
429 values: &BooleanBuffer,
430 indices: &PrimitiveArray<I>,
431) -> BooleanBuffer {
432 let len = indices.len();
433
434 match indices.nulls().filter(|n| n.null_count() > 0) {
435 Some(nulls) => {
436 let mut output_buffer = MutableBuffer::new_null(len);
437 let output_slice = output_buffer.as_slice_mut();
438 nulls.valid_indices().for_each(|idx| {
439 if values.value(indices.value(idx).as_usize()) {
440 bit_util::set_bit(output_slice, idx);
441 }
442 });
443 BooleanBuffer::new(output_buffer.into(), 0, len)
444 }
445 None => {
446 BooleanBuffer::collect_bool(len, |idx: usize| {
447 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
449 })
450 }
451 }
452}
453
454fn take_boolean<IndexType: ArrowPrimitiveType>(
456 values: &BooleanArray,
457 indices: &PrimitiveArray<IndexType>,
458) -> BooleanArray {
459 let val_buf = take_bits(values.values(), indices);
460 let null_buf = take_nulls(values.nulls(), indices);
461 BooleanArray::new(val_buf, null_buf)
462}
463
464fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
466 array: &GenericByteArray<T>,
467 indices: &PrimitiveArray<IndexType>,
468) -> Result<GenericByteArray<T>, ArrowError> {
469 let mut offsets = Vec::with_capacity(indices.len() + 1);
470 offsets.push(T::Offset::default());
471
472 let input_offsets = array.value_offsets();
473 let mut capacity = 0;
474 let nulls = take_nulls(array.nulls(), indices);
475
476 let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 {
477 offsets.extend(indices.values().iter().map(|index| {
478 let index = index.as_usize();
479 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
480 T::Offset::from_usize(capacity).expect("overflow")
481 }));
482 let mut values = Vec::with_capacity(capacity);
483
484 for index in indices.values() {
485 values.extend_from_slice(array.value(index.as_usize()).as_ref());
486 }
487 (offsets, values)
488 } else if indices.null_count() == 0 {
489 offsets.extend(indices.values().iter().map(|index| {
490 let index = index.as_usize();
491 if array.is_valid(index) {
492 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
493 }
494 T::Offset::from_usize(capacity).expect("overflow")
495 }));
496 let mut values = Vec::with_capacity(capacity);
497
498 for index in indices.values() {
499 let index = index.as_usize();
500 if array.is_valid(index) {
501 values.extend_from_slice(array.value(index).as_ref());
502 }
503 }
504 (offsets, values)
505 } else if array.null_count() == 0 {
506 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
507 let index = index.as_usize();
508 if indices.is_valid(i) {
509 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
510 }
511 T::Offset::from_usize(capacity).expect("overflow")
512 }));
513 let mut values = Vec::with_capacity(capacity);
514
515 for (i, index) in indices.values().iter().enumerate() {
516 if indices.is_valid(i) {
517 values.extend_from_slice(array.value(index.as_usize()).as_ref());
518 }
519 }
520 (offsets, values)
521 } else {
522 let nulls = nulls.as_ref().unwrap();
523 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
524 let index = index.as_usize();
525 if nulls.is_valid(i) {
526 capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
527 }
528 T::Offset::from_usize(capacity).expect("overflow")
529 }));
530 let mut values = Vec::with_capacity(capacity);
531
532 for (i, index) in indices.values().iter().enumerate() {
533 let index = index.as_usize();
536 if nulls.is_valid(i) {
537 values.extend_from_slice(array.value(index).as_ref());
538 }
539 }
540 (offsets, values)
541 };
542
543 T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
544 "Offset overflow for {}BinaryArray: {}",
545 T::Offset::PREFIX,
546 values.len()
547 )))?;
548
549 let array = unsafe {
550 let offsets = OffsetBuffer::new_unchecked(offsets.into());
551 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
552 };
553
554 Ok(array)
555}
556
557fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
559 array: &GenericByteViewArray<T>,
560 indices: &PrimitiveArray<IndexType>,
561) -> Result<GenericByteViewArray<T>, ArrowError> {
562 let new_views = take_native(array.views(), indices);
563 let new_nulls = take_nulls(array.nulls(), indices);
564 Ok(unsafe {
566 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
567 })
568}
569
570fn take_list<IndexType, OffsetType>(
576 values: &GenericListArray<OffsetType::Native>,
577 indices: &PrimitiveArray<IndexType>,
578) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
579where
580 IndexType: ArrowPrimitiveType,
581 OffsetType: ArrowPrimitiveType,
582 OffsetType::Native: OffsetSizeTrait,
583 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
584{
585 let (list_indices, offsets, null_buf) =
588 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
589
590 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
591 let value_offsets = Buffer::from_vec(offsets);
592 let list_data = ArrayDataBuilder::new(values.data_type().clone())
594 .len(indices.len())
595 .null_bit_buffer(Some(null_buf.into()))
596 .offset(0)
597 .add_child_data(taken.into_data())
598 .add_buffer(value_offsets);
599
600 let list_data = unsafe { list_data.build_unchecked() };
601
602 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
603}
604
605fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
611 values: &FixedSizeListArray,
612 indices: &PrimitiveArray<IndexType>,
613 length: <UInt32Type as ArrowPrimitiveType>::Native,
614) -> Result<FixedSizeListArray, ArrowError> {
615 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
616 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
617
618 let num_bytes = bit_util::ceil(indices.len(), 8);
620 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
621 let null_slice = null_buf.as_slice_mut();
622
623 for i in 0..indices.len() {
624 let index = indices
625 .value(i)
626 .to_usize()
627 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
628 if !indices.is_valid(i) || values.is_null(index) {
629 bit_util::unset_bit(null_slice, i);
630 }
631 }
632
633 let list_data = ArrayDataBuilder::new(values.data_type().clone())
634 .len(indices.len())
635 .null_bit_buffer(Some(null_buf.into()))
636 .offset(0)
637 .add_child_data(taken.into_data());
638
639 let list_data = unsafe { list_data.build_unchecked() };
640
641 Ok(FixedSizeListArray::from(list_data))
642}
643
644fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
645 values: &FixedSizeBinaryArray,
646 indices: &PrimitiveArray<IndexType>,
647 size: i32,
648) -> Result<FixedSizeBinaryArray, ArrowError> {
649 let nulls = values.nulls();
650 let array_iter = indices
651 .values()
652 .iter()
653 .map(|idx| {
654 let idx = maybe_usize::<IndexType::Native>(*idx)?;
655 if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
656 Ok(Some(values.value(idx)))
657 } else {
658 Ok(None)
659 }
660 })
661 .collect::<Result<Vec<_>, ArrowError>>()?
662 .into_iter();
663
664 FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
665}
666
667fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
672 values: &DictionaryArray<T>,
673 indices: &PrimitiveArray<I>,
674) -> Result<DictionaryArray<T>, ArrowError> {
675 let new_keys = take_primitive(values.keys(), indices)?;
676 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
677}
678
679fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
688 run_array: &RunArray<T>,
689 logical_indices: &PrimitiveArray<I>,
690) -> Result<RunArray<T>, ArrowError> {
691 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
693
694 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
698 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
699 let mut new_physical_len = 1;
700 for ix in 1..physical_indices.len() {
701 if physical_indices[ix] != physical_indices[ix - 1] {
702 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
703 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
704 new_physical_len += 1;
705 }
706 }
707 take_value_indices
708 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
709 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
710 let new_run_ends = unsafe {
711 ArrayDataBuilder::new(T::DATA_TYPE)
714 .len(new_physical_len)
715 .null_count(0)
716 .add_buffer(new_run_ends_builder.finish())
717 .build_unchecked()
718 };
719
720 let take_value_indices: PrimitiveArray<I> = unsafe {
721 ArrayDataBuilder::new(I::DATA_TYPE)
724 .len(new_physical_len)
725 .null_count(0)
726 .add_buffer(take_value_indices.finish())
727 .build_unchecked()
728 .into()
729 };
730
731 let new_values = take(run_array.values(), &take_value_indices, None)?;
732
733 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
734 .len(physical_indices.len())
735 .add_child_data(new_run_ends)
736 .add_child_data(new_values.into_data());
737 let array_data = unsafe {
738 builder.build_unchecked()
741 };
742 Ok(array_data.into())
743}
744
745#[allow(clippy::type_complexity)]
751fn take_value_indices_from_list<IndexType, OffsetType>(
752 list: &GenericListArray<OffsetType::Native>,
753 indices: &PrimitiveArray<IndexType>,
754) -> Result<
755 (
756 PrimitiveArray<OffsetType>,
757 Vec<OffsetType::Native>,
758 MutableBuffer,
759 ),
760 ArrowError,
761>
762where
763 IndexType: ArrowPrimitiveType,
764 OffsetType: ArrowPrimitiveType,
765 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
766 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
767{
768 let offsets: &[OffsetType::Native] = list.value_offsets();
770
771 let mut new_offsets = Vec::with_capacity(indices.len());
772 let mut values = Vec::new();
773 let mut current_offset = OffsetType::Native::zero();
774 new_offsets.push(OffsetType::Native::zero());
776
777 let num_bytes = bit_util::ceil(indices.len(), 8);
779 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
780 let null_slice = null_buf.as_slice_mut();
781
782 for i in 0..indices.len() {
784 if indices.is_valid(i) {
785 let ix = indices
786 .value(i)
787 .to_usize()
788 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
789 let start = offsets[ix];
790 let end = offsets[ix + 1];
791 current_offset += end - start;
792 new_offsets.push(current_offset);
793
794 let mut curr = start;
795
796 while curr < end {
798 values.push(curr);
799 curr += One::one();
800 }
801 if !list.is_valid(ix) {
802 bit_util::unset_bit(null_slice, i);
803 }
804 } else {
805 bit_util::unset_bit(null_slice, i);
806 new_offsets.push(current_offset);
807 }
808 }
809
810 Ok((
811 PrimitiveArray::<OffsetType>::from(values),
812 new_offsets,
813 null_buf,
814 ))
815}
816
817fn take_value_indices_from_fixed_size_list<IndexType>(
819 list: &FixedSizeListArray,
820 indices: &PrimitiveArray<IndexType>,
821 length: <UInt32Type as ArrowPrimitiveType>::Native,
822) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
823where
824 IndexType: ArrowPrimitiveType,
825{
826 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
827
828 for i in 0..indices.len() {
829 if indices.is_valid(i) {
830 let index = indices
831 .value(i)
832 .to_usize()
833 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
834 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
835
836 unsafe {
838 values.append_trusted_len_iter(start..start + length);
839 }
840 } else {
841 values.append_nulls(length as usize);
842 }
843 }
844
845 Ok(values.finish())
846}
847
848trait ToIndices {
851 type T: ArrowPrimitiveType;
852
853 fn to_indices(&self) -> PrimitiveArray<Self::T>;
854}
855
856macro_rules! to_indices_reinterpret {
857 ($t:ty, $o:ty) => {
858 impl ToIndices for PrimitiveArray<$t> {
859 type T = $o;
860
861 fn to_indices(&self) -> PrimitiveArray<$o> {
862 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
863 PrimitiveArray::new(cast, self.nulls().cloned())
864 }
865 }
866 };
867}
868
869macro_rules! to_indices_identity {
870 ($t:ty) => {
871 impl ToIndices for PrimitiveArray<$t> {
872 type T = $t;
873
874 fn to_indices(&self) -> PrimitiveArray<$t> {
875 self.clone()
876 }
877 }
878 };
879}
880
881macro_rules! to_indices_widening {
882 ($t:ty, $o:ty) => {
883 impl ToIndices for PrimitiveArray<$t> {
884 type T = UInt32Type;
885
886 fn to_indices(&self) -> PrimitiveArray<$o> {
887 let cast = self.values().iter().copied().map(|x| x as _).collect();
888 PrimitiveArray::new(cast, self.nulls().cloned())
889 }
890 }
891 };
892}
893
894to_indices_widening!(UInt8Type, UInt32Type);
895to_indices_widening!(Int8Type, UInt32Type);
896
897to_indices_widening!(UInt16Type, UInt32Type);
898to_indices_widening!(Int16Type, UInt32Type);
899
900to_indices_identity!(UInt32Type);
901to_indices_reinterpret!(Int32Type, UInt32Type);
902
903to_indices_identity!(UInt64Type);
904to_indices_reinterpret!(Int64Type, UInt64Type);
905
906pub fn take_record_batch(
946 record_batch: &RecordBatch,
947 indices: &dyn Array,
948) -> Result<RecordBatch, ArrowError> {
949 let columns = record_batch
950 .columns()
951 .iter()
952 .map(|c| take(c, indices, None))
953 .collect::<Result<Vec<_>, _>>()?;
954 RecordBatch::try_new(record_batch.schema(), columns)
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960 use arrow_array::builder::*;
961 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
962 use arrow_data::ArrayData;
963 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
964
965 fn test_take_decimal_arrays(
966 data: Vec<Option<i128>>,
967 index: &UInt32Array,
968 options: Option<TakeOptions>,
969 expected_data: Vec<Option<i128>>,
970 precision: &u8,
971 scale: &i8,
972 ) -> Result<(), ArrowError> {
973 let output = data
974 .into_iter()
975 .collect::<Decimal128Array>()
976 .with_precision_and_scale(*precision, *scale)
977 .unwrap();
978
979 let expected = expected_data
980 .into_iter()
981 .collect::<Decimal128Array>()
982 .with_precision_and_scale(*precision, *scale)
983 .unwrap();
984
985 let expected = Arc::new(expected) as ArrayRef;
986 let output = take(&output, index, options).unwrap();
987 assert_eq!(&output, &expected);
988 Ok(())
989 }
990
991 fn test_take_boolean_arrays(
992 data: Vec<Option<bool>>,
993 index: &UInt32Array,
994 options: Option<TakeOptions>,
995 expected_data: Vec<Option<bool>>,
996 ) {
997 let output = BooleanArray::from(data);
998 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
999 let output = take(&output, index, options).unwrap();
1000 assert_eq!(&output, &expected)
1001 }
1002
1003 fn test_take_primitive_arrays<T>(
1004 data: Vec<Option<T::Native>>,
1005 index: &UInt32Array,
1006 options: Option<TakeOptions>,
1007 expected_data: Vec<Option<T::Native>>,
1008 ) -> Result<(), ArrowError>
1009 where
1010 T: ArrowPrimitiveType,
1011 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1012 {
1013 let output = PrimitiveArray::<T>::from(data);
1014 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1015 let output = take(&output, index, options)?;
1016 assert_eq!(&output, &expected);
1017 Ok(())
1018 }
1019
1020 fn test_take_primitive_arrays_non_null<T>(
1021 data: Vec<T::Native>,
1022 index: &UInt32Array,
1023 options: Option<TakeOptions>,
1024 expected_data: Vec<Option<T::Native>>,
1025 ) -> Result<(), ArrowError>
1026 where
1027 T: ArrowPrimitiveType,
1028 PrimitiveArray<T>: From<Vec<T::Native>>,
1029 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1030 {
1031 let output = PrimitiveArray::<T>::from(data);
1032 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1033 let output = take(&output, index, options)?;
1034 assert_eq!(&output, &expected);
1035 Ok(())
1036 }
1037
1038 fn test_take_impl_primitive_arrays<T, I>(
1039 data: Vec<Option<T::Native>>,
1040 index: &PrimitiveArray<I>,
1041 options: Option<TakeOptions>,
1042 expected_data: Vec<Option<T::Native>>,
1043 ) where
1044 T: ArrowPrimitiveType,
1045 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1046 I: ArrowPrimitiveType,
1047 {
1048 let output = PrimitiveArray::<T>::from(data);
1049 let expected = PrimitiveArray::<T>::from(expected_data);
1050 let output = take(&output, index, options).unwrap();
1051 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1052 assert_eq!(output, &expected)
1053 }
1054
1055 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1057 let mut struct_builder = StructBuilder::new(
1058 Fields::from(vec![
1059 Field::new("a", DataType::Boolean, true),
1060 Field::new("b", DataType::Int32, true),
1061 ]),
1062 vec![
1063 Box::new(BooleanBuilder::with_capacity(values.len())),
1064 Box::new(Int32Builder::with_capacity(values.len())),
1065 ],
1066 );
1067
1068 for value in values {
1069 struct_builder
1070 .field_builder::<BooleanBuilder>(0)
1071 .unwrap()
1072 .append_option(value.and_then(|v| v.0));
1073 struct_builder
1074 .field_builder::<Int32Builder>(1)
1075 .unwrap()
1076 .append_option(value.and_then(|v| v.1));
1077 struct_builder.append(value.is_some());
1078 }
1079 struct_builder.finish()
1080 }
1081
1082 #[test]
1083 fn test_take_decimal128_non_null_indices() {
1084 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1085 let precision: u8 = 10;
1086 let scale: i8 = 5;
1087 test_take_decimal_arrays(
1088 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1089 &index,
1090 None,
1091 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1092 &precision,
1093 &scale,
1094 )
1095 .unwrap();
1096 }
1097
1098 #[test]
1099 fn test_take_decimal128() {
1100 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1101 let precision: u8 = 10;
1102 let scale: i8 = 5;
1103 test_take_decimal_arrays(
1104 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1105 &index,
1106 None,
1107 vec![Some(3), None, Some(1), Some(3), Some(2)],
1108 &precision,
1109 &scale,
1110 )
1111 .unwrap();
1112 }
1113
1114 #[test]
1115 fn test_take_primitive_non_null_indices() {
1116 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1117 test_take_primitive_arrays::<Int8Type>(
1118 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1119 &index,
1120 None,
1121 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1122 )
1123 .unwrap();
1124 }
1125
1126 #[test]
1127 fn test_take_primitive_non_null_values() {
1128 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1129 test_take_primitive_arrays::<Int8Type>(
1130 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1131 &index,
1132 None,
1133 vec![Some(3), None, Some(1), Some(3), Some(2)],
1134 )
1135 .unwrap();
1136 }
1137
1138 #[test]
1139 fn test_take_primitive_non_null() {
1140 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1141 test_take_primitive_arrays::<Int8Type>(
1142 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1143 &index,
1144 None,
1145 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1146 )
1147 .unwrap();
1148 }
1149
1150 #[test]
1151 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1152 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1153 let index = index.slice(2, 4);
1154 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1155
1156 assert_eq!(
1157 index,
1158 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1159 );
1160
1161 test_take_primitive_arrays_non_null::<Int64Type>(
1162 vec![0, 10, 20, 30, 40, 50],
1163 index,
1164 None,
1165 vec![Some(20), Some(30), None, None],
1166 )
1167 .unwrap();
1168 }
1169
1170 #[test]
1171 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1172 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1173 let index = index.slice(2, 4);
1174 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1175
1176 assert_eq!(
1177 index,
1178 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1179 );
1180
1181 test_take_primitive_arrays::<Int64Type>(
1182 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1183 index,
1184 None,
1185 vec![Some(20), Some(30), None, None],
1186 )
1187 .unwrap();
1188 }
1189
1190 #[test]
1191 fn test_take_primitive() {
1192 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1193
1194 test_take_primitive_arrays::<Int8Type>(
1196 vec![Some(0), None, Some(2), Some(3), None],
1197 &index,
1198 None,
1199 vec![Some(3), None, None, Some(3), Some(2)],
1200 )
1201 .unwrap();
1202
1203 test_take_primitive_arrays::<Int16Type>(
1205 vec![Some(0), None, Some(2), Some(3), None],
1206 &index,
1207 None,
1208 vec![Some(3), None, None, Some(3), Some(2)],
1209 )
1210 .unwrap();
1211
1212 test_take_primitive_arrays::<Int32Type>(
1214 vec![Some(0), None, Some(2), Some(3), None],
1215 &index,
1216 None,
1217 vec![Some(3), None, None, Some(3), Some(2)],
1218 )
1219 .unwrap();
1220
1221 test_take_primitive_arrays::<Int64Type>(
1223 vec![Some(0), None, Some(2), Some(3), None],
1224 &index,
1225 None,
1226 vec![Some(3), None, None, Some(3), Some(2)],
1227 )
1228 .unwrap();
1229
1230 test_take_primitive_arrays::<UInt8Type>(
1232 vec![Some(0), None, Some(2), Some(3), None],
1233 &index,
1234 None,
1235 vec![Some(3), None, None, Some(3), Some(2)],
1236 )
1237 .unwrap();
1238
1239 test_take_primitive_arrays::<UInt16Type>(
1241 vec![Some(0), None, Some(2), Some(3), None],
1242 &index,
1243 None,
1244 vec![Some(3), None, None, Some(3), Some(2)],
1245 )
1246 .unwrap();
1247
1248 test_take_primitive_arrays::<UInt32Type>(
1250 vec![Some(0), None, Some(2), Some(3), None],
1251 &index,
1252 None,
1253 vec![Some(3), None, None, Some(3), Some(2)],
1254 )
1255 .unwrap();
1256
1257 test_take_primitive_arrays::<Int64Type>(
1259 vec![Some(0), None, Some(2), Some(-15), None],
1260 &index,
1261 None,
1262 vec![Some(-15), None, None, Some(-15), Some(2)],
1263 )
1264 .unwrap();
1265
1266 test_take_primitive_arrays::<IntervalYearMonthType>(
1268 vec![Some(0), None, Some(2), Some(-15), None],
1269 &index,
1270 None,
1271 vec![Some(-15), None, None, Some(-15), Some(2)],
1272 )
1273 .unwrap();
1274
1275 let v1 = IntervalDayTime::new(0, 0);
1277 let v2 = IntervalDayTime::new(2, 0);
1278 let v3 = IntervalDayTime::new(-15, 0);
1279 test_take_primitive_arrays::<IntervalDayTimeType>(
1280 vec![Some(v1), None, Some(v2), Some(v3), None],
1281 &index,
1282 None,
1283 vec![Some(v3), None, None, Some(v3), Some(v2)],
1284 )
1285 .unwrap();
1286
1287 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1289 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1290 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1291 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1292 vec![Some(v1), None, Some(v2), Some(v3), None],
1293 &index,
1294 None,
1295 vec![Some(v3), None, None, Some(v3), Some(v2)],
1296 )
1297 .unwrap();
1298
1299 test_take_primitive_arrays::<DurationSecondType>(
1301 vec![Some(0), None, Some(2), Some(-15), None],
1302 &index,
1303 None,
1304 vec![Some(-15), None, None, Some(-15), Some(2)],
1305 )
1306 .unwrap();
1307
1308 test_take_primitive_arrays::<DurationMillisecondType>(
1310 vec![Some(0), None, Some(2), Some(-15), None],
1311 &index,
1312 None,
1313 vec![Some(-15), None, None, Some(-15), Some(2)],
1314 )
1315 .unwrap();
1316
1317 test_take_primitive_arrays::<DurationMicrosecondType>(
1319 vec![Some(0), None, Some(2), Some(-15), None],
1320 &index,
1321 None,
1322 vec![Some(-15), None, None, Some(-15), Some(2)],
1323 )
1324 .unwrap();
1325
1326 test_take_primitive_arrays::<DurationNanosecondType>(
1328 vec![Some(0), None, Some(2), Some(-15), None],
1329 &index,
1330 None,
1331 vec![Some(-15), None, None, Some(-15), Some(2)],
1332 )
1333 .unwrap();
1334
1335 test_take_primitive_arrays::<Float32Type>(
1337 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1338 &index,
1339 None,
1340 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1341 )
1342 .unwrap();
1343
1344 test_take_primitive_arrays::<Float64Type>(
1346 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1347 &index,
1348 None,
1349 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1350 )
1351 .unwrap();
1352 }
1353
1354 #[test]
1355 fn test_take_preserve_timezone() {
1356 let index = Int64Array::from(vec![Some(0), None]);
1357
1358 let input = TimestampNanosecondArray::from(vec![
1359 1_639_715_368_000_000_000,
1360 1_639_715_368_000_000_000,
1361 ])
1362 .with_timezone("UTC".to_string());
1363 let result = take(&input, &index, None).unwrap();
1364 match result.data_type() {
1365 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1366 assert_eq!(tz.clone(), Some("UTC".into()))
1367 }
1368 _ => panic!(),
1369 }
1370 }
1371
1372 #[test]
1373 fn test_take_impl_primitive_with_int64_indices() {
1374 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1375
1376 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1378 vec![Some(0), None, Some(2), Some(3), None],
1379 &index,
1380 None,
1381 vec![Some(3), None, None, Some(3), Some(2)],
1382 );
1383
1384 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1386 vec![Some(0), None, Some(2), Some(-15), None],
1387 &index,
1388 None,
1389 vec![Some(-15), None, None, Some(-15), Some(2)],
1390 );
1391
1392 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1394 vec![Some(0), None, Some(2), Some(3), None],
1395 &index,
1396 None,
1397 vec![Some(3), None, None, Some(3), Some(2)],
1398 );
1399
1400 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1402 vec![Some(0), None, Some(2), Some(-15), None],
1403 &index,
1404 None,
1405 vec![Some(-15), None, None, Some(-15), Some(2)],
1406 );
1407
1408 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1410 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1411 &index,
1412 None,
1413 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1414 );
1415 }
1416
1417 #[test]
1418 fn test_take_impl_primitive_with_uint8_indices() {
1419 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1420
1421 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1423 vec![Some(0), None, Some(2), Some(3), None],
1424 &index,
1425 None,
1426 vec![Some(3), None, None, Some(3), Some(2)],
1427 );
1428
1429 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1431 vec![Some(0), None, Some(2), Some(-15), None],
1432 &index,
1433 None,
1434 vec![Some(-15), None, None, Some(-15), Some(2)],
1435 );
1436
1437 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1439 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1440 &index,
1441 None,
1442 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1443 );
1444 }
1445
1446 #[test]
1447 fn test_take_bool() {
1448 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1449 test_take_boolean_arrays(
1451 vec![Some(false), None, Some(true), Some(false), None],
1452 &index,
1453 None,
1454 vec![Some(false), None, None, Some(false), Some(true)],
1455 );
1456 }
1457
1458 #[test]
1459 fn test_take_bool_nullable_index() {
1460 let index_data = ArrayData::try_new(
1462 DataType::UInt32,
1463 6,
1464 Some(Buffer::from_iter(vec![
1465 false, true, false, true, false, true,
1466 ])),
1467 0,
1468 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1469 vec![],
1470 )
1471 .unwrap();
1472 let index = UInt32Array::from(index_data);
1473 test_take_boolean_arrays(
1474 vec![Some(true), None, Some(false)],
1475 &index,
1476 None,
1477 vec![None, Some(true), None, None, None, Some(false)],
1478 );
1479 }
1480
1481 #[test]
1482 fn test_take_bool_nullable_index_nonnull_values() {
1483 let index_data = ArrayData::try_new(
1485 DataType::UInt32,
1486 6,
1487 Some(Buffer::from_iter(vec![
1488 false, true, false, true, false, true,
1489 ])),
1490 0,
1491 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1492 vec![],
1493 )
1494 .unwrap();
1495 let index = UInt32Array::from(index_data);
1496 test_take_boolean_arrays(
1497 vec![Some(true), Some(true), Some(false)],
1498 &index,
1499 None,
1500 vec![None, Some(true), None, Some(true), None, Some(false)],
1501 );
1502 }
1503
1504 #[test]
1505 fn test_take_bool_with_offset() {
1506 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1507 let index = index.slice(2, 4);
1508 let index = index
1509 .as_any()
1510 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1511 .unwrap();
1512
1513 test_take_boolean_arrays(
1515 vec![Some(false), None, Some(true), Some(false), None],
1516 index,
1517 None,
1518 vec![None, Some(false), Some(true), None],
1519 );
1520 }
1521
1522 fn _test_take_string<'a, K>()
1523 where
1524 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1525 {
1526 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1527
1528 let array = K::from(vec![
1529 Some("one"),
1530 None,
1531 Some("three"),
1532 Some("four"),
1533 Some("five"),
1534 ]);
1535 let actual = take(&array, &index, None).unwrap();
1536 assert_eq!(actual.len(), index.len());
1537
1538 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1539
1540 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1541
1542 assert_eq!(actual, &expected);
1543 }
1544
1545 #[test]
1546 fn test_take_string() {
1547 _test_take_string::<StringArray>()
1548 }
1549
1550 #[test]
1551 fn test_take_large_string() {
1552 _test_take_string::<LargeStringArray>()
1553 }
1554
1555 #[test]
1556 fn test_take_slice_string() {
1557 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1558 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1559 let indices_slice = indices.slice(1, 4);
1560 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1561 let result = take(&strings, &indices_slice, None).unwrap();
1562 assert_eq!(result.as_ref(), &expected);
1563 }
1564
1565 fn _test_byte_view<T>()
1566 where
1567 T: ByteViewType,
1568 str: AsRef<T::Native>,
1569 T::Native: PartialEq,
1570 {
1571 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1572 let array = {
1573 let mut builder = GenericByteViewBuilder::<T>::new();
1575 builder.append_value("hello");
1576 builder.append_value("world");
1577 builder.append_null();
1578 builder.append_value("large payload over 12 bytes");
1579 builder.append_value("lulu");
1580 builder.finish()
1581 };
1582
1583 let actual = take(&array, &index, None).unwrap();
1584
1585 assert_eq!(actual.len(), index.len());
1586
1587 let expected = {
1588 let mut builder = GenericByteViewBuilder::<T>::new();
1590 builder.append_value("large payload over 12 bytes");
1591 builder.append_null();
1592 builder.append_value("world");
1593 builder.append_value("large payload over 12 bytes");
1594 builder.append_value("lulu");
1595 builder.append_null();
1596 builder.finish()
1597 };
1598
1599 assert_eq!(actual.as_ref(), &expected);
1600 }
1601
1602 #[test]
1603 fn test_take_string_view() {
1604 _test_byte_view::<StringViewType>()
1605 }
1606
1607 #[test]
1608 fn test_take_binary_view() {
1609 _test_byte_view::<BinaryViewType>()
1610 }
1611
1612 macro_rules! test_take_list {
1613 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1614 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1616 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1618 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1619 let list_data_type =
1621 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1622 let list_data = ArrayData::builder(list_data_type.clone())
1623 .len(4)
1624 .add_buffer(value_offsets)
1625 .add_child_data(value_data)
1626 .build()
1627 .unwrap();
1628 let list_array = $list_array_type::from(list_data);
1629
1630 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1632
1633 let a = take(&list_array, &index, None).unwrap();
1634 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1635
1636 let expected_data = Int32Array::from(vec![
1639 Some(2),
1640 Some(3),
1641 Some(-1),
1642 Some(-2),
1643 Some(-1),
1644 Some(0),
1645 Some(0),
1646 Some(0),
1647 ])
1648 .into_data();
1649 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1651 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1652 let expected_list_data = ArrayData::builder(list_data_type)
1654 .len(5)
1655 .nulls(index.nulls().cloned())
1657 .add_buffer(expected_offsets)
1658 .add_child_data(expected_data)
1659 .build()
1660 .unwrap();
1661 let expected_list_array = $list_array_type::from(expected_list_data);
1662
1663 assert_eq!(a, &expected_list_array);
1664 }};
1665 }
1666
1667 macro_rules! test_take_list_with_value_nulls {
1668 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1669 let value_data = Int32Array::from(vec![
1671 Some(0),
1672 None,
1673 Some(0),
1674 Some(-1),
1675 Some(-2),
1676 Some(3),
1677 None,
1678 Some(5),
1679 None,
1680 ])
1681 .into_data();
1682 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1684 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1685 let list_data_type =
1687 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1688 let list_data = ArrayData::builder(list_data_type.clone())
1689 .len(4)
1690 .add_buffer(value_offsets)
1691 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1692 .add_child_data(value_data)
1693 .build()
1694 .unwrap();
1695 let list_array = $list_array_type::from(list_data);
1696
1697 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1699
1700 let a = take(&list_array, &index, None).unwrap();
1701 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1702
1703 let expected_data = Int32Array::from(vec![
1706 None,
1707 Some(-1),
1708 Some(-2),
1709 Some(3),
1710 Some(5),
1711 None,
1712 Some(0),
1713 None,
1714 Some(0),
1715 ])
1716 .into_data();
1717 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1719 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1720 let expected_list_data = ArrayData::builder(list_data_type)
1722 .len(5)
1723 .nulls(index.nulls().cloned())
1725 .add_buffer(expected_offsets)
1726 .add_child_data(expected_data)
1727 .build()
1728 .unwrap();
1729 let expected_list_array = $list_array_type::from(expected_list_data);
1730
1731 assert_eq!(a, &expected_list_array);
1732 }};
1733 }
1734
1735 macro_rules! test_take_list_with_nulls {
1736 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1737 let value_data = Int32Array::from(vec![
1739 Some(0),
1740 None,
1741 Some(0),
1742 Some(-1),
1743 Some(-2),
1744 Some(3),
1745 Some(5),
1746 None,
1747 ])
1748 .into_data();
1749 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1751 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1752 let list_data_type =
1754 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1755 let list_data = ArrayData::builder(list_data_type.clone())
1756 .len(4)
1757 .add_buffer(value_offsets)
1758 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1759 .add_child_data(value_data)
1760 .build()
1761 .unwrap();
1762 let list_array = $list_array_type::from(list_data);
1763
1764 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1766
1767 let a = take(&list_array, &index, None).unwrap();
1768 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1769
1770 let expected_data = Int32Array::from(vec![
1773 Some(-1),
1774 Some(-2),
1775 Some(3),
1776 Some(5),
1777 None,
1778 Some(0),
1779 None,
1780 Some(0),
1781 ])
1782 .into_data();
1783 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1785 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1786 let mut null_bits: [u8; 1] = [0; 1];
1788 bit_util::set_bit(&mut null_bits, 2);
1789 bit_util::set_bit(&mut null_bits, 3);
1790 bit_util::set_bit(&mut null_bits, 4);
1791 let expected_list_data = ArrayData::builder(list_data_type)
1792 .len(5)
1793 .null_bit_buffer(Some(Buffer::from(null_bits)))
1795 .add_buffer(expected_offsets)
1796 .add_child_data(expected_data)
1797 .build()
1798 .unwrap();
1799 let expected_list_array = $list_array_type::from(expected_list_data);
1800
1801 assert_eq!(a, &expected_list_array);
1802 }};
1803 }
1804
1805 fn do_take_fixed_size_list_test<T>(
1806 length: <Int32Type as ArrowPrimitiveType>::Native,
1807 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1808 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1809 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1810 ) where
1811 T: ArrowPrimitiveType,
1812 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1813 {
1814 let indices = UInt32Array::from(indices);
1815
1816 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1817
1818 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1819
1820 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1821
1822 assert_eq!(&output, &expected)
1823 }
1824
1825 #[test]
1826 fn test_take_list() {
1827 test_take_list!(i32, List, ListArray);
1828 }
1829
1830 #[test]
1831 fn test_take_large_list() {
1832 test_take_list!(i64, LargeList, LargeListArray);
1833 }
1834
1835 #[test]
1836 fn test_take_list_with_value_nulls() {
1837 test_take_list_with_value_nulls!(i32, List, ListArray);
1838 }
1839
1840 #[test]
1841 fn test_take_large_list_with_value_nulls() {
1842 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1843 }
1844
1845 #[test]
1846 fn test_test_take_list_with_nulls() {
1847 test_take_list_with_nulls!(i32, List, ListArray);
1848 }
1849
1850 #[test]
1851 fn test_test_take_large_list_with_nulls() {
1852 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1853 }
1854
1855 #[test]
1856 fn test_take_fixed_size_list() {
1857 do_take_fixed_size_list_test::<Int32Type>(
1858 3,
1859 vec![
1860 Some(vec![None, Some(1), Some(2)]),
1861 Some(vec![Some(3), Some(4), None]),
1862 Some(vec![Some(6), Some(7), Some(8)]),
1863 ],
1864 vec![2, 1, 0],
1865 vec![
1866 Some(vec![Some(6), Some(7), Some(8)]),
1867 Some(vec![Some(3), Some(4), None]),
1868 Some(vec![None, Some(1), Some(2)]),
1869 ],
1870 );
1871
1872 do_take_fixed_size_list_test::<UInt8Type>(
1873 1,
1874 vec![
1875 Some(vec![Some(1)]),
1876 Some(vec![Some(2)]),
1877 Some(vec![Some(3)]),
1878 Some(vec![Some(4)]),
1879 Some(vec![Some(5)]),
1880 Some(vec![Some(6)]),
1881 Some(vec![Some(7)]),
1882 Some(vec![Some(8)]),
1883 ],
1884 vec![2, 7, 0],
1885 vec![
1886 Some(vec![Some(3)]),
1887 Some(vec![Some(8)]),
1888 Some(vec![Some(1)]),
1889 ],
1890 );
1891
1892 do_take_fixed_size_list_test::<UInt64Type>(
1893 3,
1894 vec![
1895 Some(vec![Some(10), Some(11), Some(12)]),
1896 Some(vec![Some(13), Some(14), Some(15)]),
1897 None,
1898 Some(vec![Some(16), Some(17), Some(18)]),
1899 ],
1900 vec![3, 2, 1, 2, 0],
1901 vec![
1902 Some(vec![Some(16), Some(17), Some(18)]),
1903 None,
1904 Some(vec![Some(13), Some(14), Some(15)]),
1905 None,
1906 Some(vec![Some(10), Some(11), Some(12)]),
1907 ],
1908 );
1909 }
1910
1911 #[test]
1912 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1913 fn test_take_list_out_of_bounds() {
1914 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1916 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1918 let list_data_type =
1920 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1921 let list_data = ArrayData::builder(list_data_type)
1922 .len(3)
1923 .add_buffer(value_offsets)
1924 .add_child_data(value_data)
1925 .build()
1926 .unwrap();
1927 let list_array = ListArray::from(list_data);
1928
1929 let index = UInt32Array::from(vec![1000]);
1930
1931 take(&list_array, &index, None).unwrap();
1934 }
1935
1936 #[test]
1937 fn test_take_map() {
1938 let values = Int32Array::from(vec![1, 2, 3, 4]);
1939 let array =
1940 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1941 .unwrap();
1942
1943 let index = UInt32Array::from(vec![0]);
1944
1945 let result = take(&array, &index, None).unwrap();
1946 let expected: ArrayRef = Arc::new(
1947 MapArray::new_from_strings(
1948 vec!["a", "b", "c"].into_iter(),
1949 &values.slice(0, 3),
1950 &[0, 3],
1951 )
1952 .unwrap(),
1953 );
1954 assert_eq!(&expected, &result);
1955 }
1956
1957 #[test]
1958 fn test_take_struct() {
1959 let array = create_test_struct(vec![
1960 Some((Some(true), Some(42))),
1961 Some((Some(false), Some(28))),
1962 Some((Some(false), Some(19))),
1963 Some((Some(true), Some(31))),
1964 None,
1965 ]);
1966
1967 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1968 let actual = take(&array, &index, None).unwrap();
1969 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1970 assert_eq!(index.len(), actual.len());
1971 assert_eq!(1, actual.null_count());
1972
1973 let expected = create_test_struct(vec![
1974 Some((Some(true), Some(42))),
1975 Some((Some(true), Some(31))),
1976 Some((Some(false), Some(28))),
1977 Some((Some(true), Some(42))),
1978 Some((Some(false), Some(19))),
1979 None,
1980 ]);
1981
1982 assert_eq!(&expected, actual);
1983
1984 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
1985 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
1986 let index = UInt32Array::from(vec![0, 2, 1, 4]);
1987 let actual = take(&empty_struct_arr, &index, None).unwrap();
1988
1989 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
1990 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
1991 assert_eq!(&expected_struct_arr, actual.as_struct());
1992 }
1993
1994 #[test]
1995 fn test_take_struct_with_null_indices() {
1996 let array = create_test_struct(vec![
1997 Some((Some(true), Some(42))),
1998 Some((Some(false), Some(28))),
1999 Some((Some(false), Some(19))),
2000 Some((Some(true), Some(31))),
2001 None,
2002 ]);
2003
2004 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2005 let actual = take(&array, &index, None).unwrap();
2006 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2007 assert_eq!(index.len(), actual.len());
2008 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2011 None,
2012 Some((Some(true), Some(31))),
2013 Some((Some(false), Some(28))),
2014 None,
2015 Some((Some(true), Some(42))),
2016 None,
2017 ]);
2018
2019 assert_eq!(&expected, actual);
2020 }
2021
2022 #[test]
2023 fn test_take_out_of_bounds() {
2024 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2025 let take_opt = TakeOptions { check_bounds: true };
2026
2027 let result = test_take_primitive_arrays::<Int64Type>(
2029 vec![Some(0), None, Some(2), Some(3), None],
2030 &index,
2031 Some(take_opt),
2032 vec![None],
2033 );
2034 assert!(result.is_err());
2035 }
2036
2037 #[test]
2038 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2039 fn test_take_out_of_bounds_panic() {
2040 let index = UInt32Array::from(vec![Some(1000)]);
2041
2042 test_take_primitive_arrays::<Int64Type>(
2043 vec![Some(0), Some(1), Some(2), Some(3)],
2044 &index,
2045 None,
2046 vec![None],
2047 )
2048 .unwrap();
2049 }
2050
2051 #[test]
2052 fn test_null_array_smaller_than_indices() {
2053 let values = NullArray::new(2);
2054 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2055
2056 let result = take(&values, &indices, None).unwrap();
2057 let expected: ArrayRef = Arc::new(NullArray::new(3));
2058 assert_eq!(&result, &expected);
2059 }
2060
2061 #[test]
2062 fn test_null_array_larger_than_indices() {
2063 let values = NullArray::new(5);
2064 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2065
2066 let result = take(&values, &indices, None).unwrap();
2067 let expected: ArrayRef = Arc::new(NullArray::new(3));
2068 assert_eq!(&result, &expected);
2069 }
2070
2071 #[test]
2072 fn test_null_array_indices_out_of_bounds() {
2073 let values = NullArray::new(5);
2074 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2075
2076 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2077 assert_eq!(
2078 result.unwrap_err().to_string(),
2079 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2080 );
2081 }
2082
2083 #[test]
2084 fn test_take_dict() {
2085 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2086
2087 dict_builder.append("foo").unwrap();
2088 dict_builder.append("bar").unwrap();
2089 dict_builder.append("").unwrap();
2090 dict_builder.append_null();
2091 dict_builder.append("foo").unwrap();
2092 dict_builder.append("bar").unwrap();
2093 dict_builder.append("bar").unwrap();
2094 dict_builder.append("foo").unwrap();
2095
2096 let array = dict_builder.finish();
2097 let dict_values = array.values().clone();
2098 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2099
2100 let indices = UInt32Array::from(vec![
2101 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2109
2110 let result = take(&array, &indices, None).unwrap();
2111 let result = result
2112 .as_any()
2113 .downcast_ref::<DictionaryArray<Int16Type>>()
2114 .unwrap();
2115
2116 let result_values: StringArray = result.values().to_data().into();
2117
2118 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2120 assert_eq!(&expected_values, dict_values);
2121 assert_eq!(&expected_values, &result_values);
2122
2123 let expected_keys = Int16Array::from(vec![
2124 Some(0),
2125 Some(0),
2126 None,
2127 Some(1),
2128 Some(1),
2129 Some(2),
2130 None,
2131 ]);
2132 assert_eq!(result.keys(), &expected_keys);
2133 }
2134
2135 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2136 where
2137 S: OffsetSizeTrait + 'static,
2138 T: ArrowPrimitiveType,
2139 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2140 {
2141 GenericListArray::from_iter_primitive::<T, _, _>(
2142 data.iter()
2143 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2144 )
2145 }
2146
2147 #[test]
2148 fn test_take_value_index_from_list() {
2149 let list = build_generic_list::<i32, Int32Type>(vec![
2150 Some(vec![0, 1]),
2151 Some(vec![2, 3, 4]),
2152 Some(vec![5, 6, 7, 8, 9]),
2153 ]);
2154 let indices = UInt32Array::from(vec![2, 0]);
2155
2156 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2157
2158 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2159 assert_eq!(offsets, vec![0, 5, 7]);
2160 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2161 }
2162
2163 #[test]
2164 fn test_take_value_index_from_large_list() {
2165 let list = build_generic_list::<i64, Int32Type>(vec![
2166 Some(vec![0, 1]),
2167 Some(vec![2, 3, 4]),
2168 Some(vec![5, 6, 7, 8, 9]),
2169 ]);
2170 let indices = UInt32Array::from(vec![2, 0]);
2171
2172 let (indexed, offsets, null_buf) =
2173 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2174
2175 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2176 assert_eq!(offsets, vec![0, 5, 7]);
2177 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2178 }
2179
2180 #[test]
2181 fn test_take_runs() {
2182 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2183
2184 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2185 builder.extend(logical_array.into_iter().map(Some));
2186 let run_array = builder.finish();
2187
2188 let take_indices: PrimitiveArray<Int32Type> =
2189 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2190
2191 let take_out = take_run(&run_array, &take_indices).unwrap();
2192
2193 assert_eq!(take_out.len(), 7);
2194 assert_eq!(take_out.run_ends().len(), 7);
2195 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2196
2197 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2198 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2199 }
2200
2201 #[test]
2202 fn test_take_value_index_from_fixed_list() {
2203 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2204 vec![
2205 Some(vec![Some(1), Some(2), None]),
2206 Some(vec![Some(4), None, Some(6)]),
2207 None,
2208 Some(vec![None, Some(8), Some(9)]),
2209 ],
2210 3,
2211 );
2212
2213 let indices = UInt32Array::from(vec![2, 1, 0]);
2214 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2215
2216 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2217
2218 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2219 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2220
2221 assert_eq!(
2222 indexed,
2223 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2224 );
2225 }
2226
2227 #[test]
2228 fn test_take_null_indices() {
2229 let indices = Int32Array::new(
2231 vec![1, 2, 400, 400].into(),
2232 Some(NullBuffer::from(vec![true, true, false, false])),
2233 );
2234 let values = Int32Array::from(vec![1, 23, 4, 5]);
2235 let r = take(&values, &indices, None).unwrap();
2236 let values = r
2237 .as_primitive::<Int32Type>()
2238 .into_iter()
2239 .collect::<Vec<_>>();
2240 assert_eq!(&values, &[Some(23), Some(4), None, None])
2241 }
2242
2243 #[test]
2244 fn test_take_fixed_size_list_null_indices() {
2245 let indices = Int32Array::from_iter([Some(0), None]);
2246 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2247 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2248 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2249
2250 let r = take(&values, &indices, None).unwrap();
2251 let values = r
2252 .as_fixed_size_list()
2253 .values()
2254 .as_primitive::<Int32Type>()
2255 .into_iter()
2256 .collect::<Vec<_>>();
2257 assert_eq!(values, &[Some(0), Some(1), None, None])
2258 }
2259
2260 #[test]
2261 fn test_take_bytes_null_indices() {
2262 let indices = Int32Array::new(
2263 vec![0, 1, 400, 400].into(),
2264 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2265 );
2266 let values = StringArray::from(vec![Some("foo"), None]);
2267 let r = take(&values, &indices, None).unwrap();
2268 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2269 assert_eq!(&values, &[Some("foo"), None, None, None])
2270 }
2271
2272 #[test]
2273 fn test_take_union_sparse() {
2274 let structs = create_test_struct(vec![
2275 Some((Some(true), Some(42))),
2276 Some((Some(false), Some(28))),
2277 Some((Some(false), Some(19))),
2278 Some((Some(true), Some(31))),
2279 None,
2280 ]);
2281 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2282 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2283
2284 let union_fields = [
2285 (
2286 0,
2287 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2288 ),
2289 (
2290 1,
2291 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2292 ),
2293 ]
2294 .into_iter()
2295 .collect();
2296 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2297 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2298
2299 let indices = vec![0, 3, 1, 0, 2, 4];
2300 let index = UInt32Array::from(indices.clone());
2301 let actual = take(&array, &index, None).unwrap();
2302 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2303 let strings = actual.child(1);
2304 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2305
2306 let actual = strings.iter().collect::<Vec<_>>();
2307 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2308 assert_eq!(expected, actual);
2309 }
2310
2311 #[test]
2312 fn test_take_union_dense() {
2313 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2314 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2315 let ints = vec![10, 20, 30, 40];
2316 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2317
2318 let indices = vec![0, 3, 1, 0, 2, 4];
2319
2320 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2321 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2322 let taken_ints = vec![10, 20, 10, 30];
2323 let taken_strings = vec![Some("a"), None];
2324
2325 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2326 let offsets = <ScalarBuffer<i32>>::from(offsets);
2327 let ints = UInt32Array::from(ints);
2328 let strings = StringArray::from(strings);
2329
2330 let union_fields = [
2331 (
2332 0,
2333 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2334 ),
2335 (
2336 1,
2337 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2338 ),
2339 ]
2340 .into_iter()
2341 .collect();
2342
2343 let array = UnionArray::try_new(
2344 union_fields,
2345 type_ids,
2346 Some(offsets),
2347 vec![Arc::new(ints), Arc::new(strings)],
2348 )
2349 .unwrap();
2350
2351 let index = UInt32Array::from(indices);
2352
2353 let actual = take(&array, &index, None).unwrap();
2354 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2355
2356 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2357 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2358 assert_eq!(
2359 UInt32Array::from(actual.child(0).to_data()),
2360 UInt32Array::from(taken_ints)
2361 );
2362 assert_eq!(
2363 StringArray::from(actual.child(1).to_data()),
2364 StringArray::from(taken_strings)
2365 );
2366 }
2367
2368 #[test]
2369 fn test_take_union_dense_using_builder() {
2370 let mut builder = UnionBuilder::new_dense();
2371
2372 builder.append::<Int32Type>("a", 1).unwrap();
2373 builder.append::<Float64Type>("b", 3.0).unwrap();
2374 builder.append::<Int32Type>("a", 4).unwrap();
2375 builder.append::<Int32Type>("a", 5).unwrap();
2376 builder.append::<Float64Type>("b", 2.0).unwrap();
2377
2378 let union = builder.build().unwrap();
2379
2380 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2381
2382 let mut builder = UnionBuilder::new_dense();
2383
2384 builder.append::<Int32Type>("a", 4).unwrap();
2385 builder.append::<Int32Type>("a", 1).unwrap();
2386 builder.append::<Float64Type>("b", 3.0).unwrap();
2387 builder.append::<Int32Type>("a", 4).unwrap();
2388
2389 let taken = builder.build().unwrap();
2390
2391 assert_eq!(
2392 taken.to_data(),
2393 take(&union, &indices, None).unwrap().to_data()
2394 );
2395 }
2396
2397 #[test]
2398 fn test_take_union_dense_all_match_issue_6206() {
2399 let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2400 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2401
2402 let array = UnionArray::try_new(
2403 fields,
2404 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2405 Some(ScalarBuffer::from_iter(0_i32..5)),
2406 vec![ints],
2407 )
2408 .unwrap();
2409
2410 let indicies = Int64Array::from(vec![0, 2, 4]);
2411 let array = take(&array, &indicies, None).unwrap();
2412 assert_eq!(array.len(), 3);
2413 }
2414}