1use crate::array::print_long_array;
19use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch};
20use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer};
21use arrow_data::{ArrayData, ArrayDataBuilder};
22use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields};
23use std::sync::Arc;
24use std::{any::Any, ops::Index};
25
26#[derive(Clone)]
77pub struct StructArray {
78 len: usize,
79 data_type: DataType,
80 nulls: Option<NullBuffer>,
81 fields: Vec<ArrayRef>,
82}
83
84impl StructArray {
85 pub fn new(fields: Fields, arrays: Vec<ArrayRef>, nulls: Option<NullBuffer>) -> Self {
91 Self::try_new(fields, arrays, nulls).unwrap()
92 }
93
94 pub fn try_new(
106 fields: Fields,
107 arrays: Vec<ArrayRef>,
108 nulls: Option<NullBuffer>,
109 ) -> Result<Self, ArrowError> {
110 if fields.len() != arrays.len() {
111 return Err(ArrowError::InvalidArgumentError(format!(
112 "Incorrect number of arrays for StructArray fields, expected {} got {}",
113 fields.len(),
114 arrays.len()
115 )));
116 }
117 let len = arrays.first().map(|x| x.len()).unwrap_or_default();
118
119 if let Some(n) = nulls.as_ref() {
120 if n.len() != len {
121 return Err(ArrowError::InvalidArgumentError(format!(
122 "Incorrect number of nulls for StructArray, expected {len} got {}",
123 n.len(),
124 )));
125 }
126 }
127
128 for (f, a) in fields.iter().zip(&arrays) {
129 if f.data_type() != a.data_type() {
130 return Err(ArrowError::InvalidArgumentError(format!(
131 "Incorrect datatype for StructArray field {:?}, expected {} got {}",
132 f.name(),
133 f.data_type(),
134 a.data_type()
135 )));
136 }
137
138 if a.len() != len {
139 return Err(ArrowError::InvalidArgumentError(format!(
140 "Incorrect array length for StructArray field {:?}, expected {} got {}",
141 f.name(),
142 len,
143 a.len()
144 )));
145 }
146
147 if !f.is_nullable() {
148 if let Some(a) = a.logical_nulls() {
149 if !nulls.as_ref().map(|n| n.contains(&a)).unwrap_or_default() {
150 return Err(ArrowError::InvalidArgumentError(format!(
151 "Found unmasked nulls for non-nullable StructArray field {:?}",
152 f.name()
153 )));
154 }
155 }
156 }
157 }
158
159 Ok(Self {
160 len,
161 data_type: DataType::Struct(fields),
162 nulls: nulls.filter(|n| n.null_count() > 0),
163 fields: arrays,
164 })
165 }
166
167 pub fn new_null(fields: Fields, len: usize) -> Self {
169 let arrays = fields
170 .iter()
171 .map(|f| new_null_array(f.data_type(), len))
172 .collect();
173
174 Self {
175 len,
176 data_type: DataType::Struct(fields),
177 nulls: Some(NullBuffer::new_null(len)),
178 fields: arrays,
179 }
180 }
181
182 pub unsafe fn new_unchecked(
188 fields: Fields,
189 arrays: Vec<ArrayRef>,
190 nulls: Option<NullBuffer>,
191 ) -> Self {
192 if cfg!(feature = "force_validate") {
193 return Self::new(fields, arrays, nulls);
194 }
195
196 let len = arrays.first().map(|x| x.len()).unwrap_or_default();
197 Self {
198 len,
199 data_type: DataType::Struct(fields),
200 nulls,
201 fields: arrays,
202 }
203 }
204
205 pub fn new_empty_fields(len: usize, nulls: Option<NullBuffer>) -> Self {
211 if let Some(n) = &nulls {
212 assert_eq!(len, n.len())
213 }
214 Self {
215 len,
216 data_type: DataType::Struct(Fields::empty()),
217 fields: vec![],
218 nulls,
219 }
220 }
221
222 pub fn into_parts(self) -> (Fields, Vec<ArrayRef>, Option<NullBuffer>) {
224 let f = match self.data_type {
225 DataType::Struct(f) => f,
226 _ => unreachable!(),
227 };
228 (f, self.fields, self.nulls)
229 }
230
231 pub fn column(&self, pos: usize) -> &ArrayRef {
233 &self.fields[pos]
234 }
235
236 pub fn num_columns(&self) -> usize {
238 self.fields.len()
239 }
240
241 pub fn columns(&self) -> &[ArrayRef] {
243 &self.fields
244 }
245
246 pub fn column_names(&self) -> Vec<&str> {
248 match self.data_type() {
249 DataType::Struct(fields) => fields
250 .iter()
251 .map(|f| f.name().as_str())
252 .collect::<Vec<&str>>(),
253 _ => unreachable!("Struct array's data type is not struct!"),
254 }
255 }
256
257 pub fn fields(&self) -> &Fields {
259 match self.data_type() {
260 DataType::Struct(f) => f,
261 _ => unreachable!(),
262 }
263 }
264
265 pub fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> {
271 self.column_names()
272 .iter()
273 .position(|c| c == &column_name)
274 .map(|pos| self.column(pos))
275 }
276
277 pub fn slice(&self, offset: usize, len: usize) -> Self {
279 assert!(
280 offset.saturating_add(len) <= self.len,
281 "the length + offset of the sliced StructArray cannot exceed the existing length"
282 );
283
284 let fields = self.fields.iter().map(|a| a.slice(offset, len)).collect();
285
286 Self {
287 len,
288 data_type: self.data_type.clone(),
289 nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)),
290 fields,
291 }
292 }
293}
294
295impl From<ArrayData> for StructArray {
296 fn from(data: ArrayData) -> Self {
297 let fields = data
298 .child_data()
299 .iter()
300 .map(|cd| make_array(cd.clone()))
301 .collect();
302
303 Self {
304 len: data.len(),
305 data_type: data.data_type().clone(),
306 nulls: data.nulls().cloned(),
307 fields,
308 }
309 }
310}
311
312impl From<StructArray> for ArrayData {
313 fn from(array: StructArray) -> Self {
314 let builder = ArrayDataBuilder::new(array.data_type)
315 .len(array.len)
316 .nulls(array.nulls)
317 .child_data(array.fields.iter().map(|x| x.to_data()).collect());
318
319 unsafe { builder.build_unchecked() }
320 }
321}
322
323impl TryFrom<Vec<(&str, ArrayRef)>> for StructArray {
324 type Error = ArrowError;
325
326 fn try_from(values: Vec<(&str, ArrayRef)>) -> Result<Self, ArrowError> {
328 let (fields, arrays): (Vec<_>, _) = values
329 .into_iter()
330 .map(|(name, array)| {
331 (
332 Field::new(name, array.data_type().clone(), array.is_nullable()),
333 array,
334 )
335 })
336 .unzip();
337
338 StructArray::try_new(fields.into(), arrays, None)
339 }
340}
341
342impl Array for StructArray {
343 fn as_any(&self) -> &dyn Any {
344 self
345 }
346
347 fn to_data(&self) -> ArrayData {
348 self.clone().into()
349 }
350
351 fn into_data(self) -> ArrayData {
352 self.into()
353 }
354
355 fn data_type(&self) -> &DataType {
356 &self.data_type
357 }
358
359 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
360 Arc::new(self.slice(offset, length))
361 }
362
363 fn len(&self) -> usize {
364 self.len
365 }
366
367 fn is_empty(&self) -> bool {
368 self.len == 0
369 }
370
371 fn shrink_to_fit(&mut self) {
372 if let Some(nulls) = &mut self.nulls {
373 nulls.shrink_to_fit();
374 }
375 self.fields.iter_mut().for_each(|n| n.shrink_to_fit());
376 }
377
378 fn offset(&self) -> usize {
379 0
380 }
381
382 fn nulls(&self) -> Option<&NullBuffer> {
383 self.nulls.as_ref()
384 }
385
386 fn logical_null_count(&self) -> usize {
387 self.null_count()
389 }
390
391 fn get_buffer_memory_size(&self) -> usize {
392 let mut size = self.fields.iter().map(|a| a.get_buffer_memory_size()).sum();
393 if let Some(n) = self.nulls.as_ref() {
394 size += n.buffer().capacity();
395 }
396 size
397 }
398
399 fn get_array_memory_size(&self) -> usize {
400 let mut size = self.fields.iter().map(|a| a.get_array_memory_size()).sum();
401 size += std::mem::size_of::<Self>();
402 if let Some(n) = self.nulls.as_ref() {
403 size += n.buffer().capacity();
404 }
405 size
406 }
407}
408
409impl From<Vec<(FieldRef, ArrayRef)>> for StructArray {
410 fn from(v: Vec<(FieldRef, ArrayRef)>) -> Self {
411 let (fields, arrays): (Vec<_>, _) = v.into_iter().unzip();
412 StructArray::new(fields.into(), arrays, None)
413 }
414}
415
416impl std::fmt::Debug for StructArray {
417 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
418 writeln!(f, "StructArray")?;
419 writeln!(f, "-- validity:")?;
420 writeln!(f, "[")?;
421 print_long_array(self, f, |_array, _index, f| write!(f, "valid"))?;
422 writeln!(f, "]\n[")?;
423 for (child_index, name) in self.column_names().iter().enumerate() {
424 let column = self.column(child_index);
425 writeln!(
426 f,
427 "-- child {}: \"{}\" ({:?})",
428 child_index,
429 name,
430 column.data_type()
431 )?;
432 std::fmt::Debug::fmt(column, f)?;
433 writeln!(f)?;
434 }
435 write!(f, "]")
436 }
437}
438
439impl From<(Vec<(FieldRef, ArrayRef)>, Buffer)> for StructArray {
440 fn from(pair: (Vec<(FieldRef, ArrayRef)>, Buffer)) -> Self {
441 let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default();
442 let (fields, arrays): (Vec<_>, Vec<_>) = pair.0.into_iter().unzip();
443 let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len));
444 Self::new(fields.into(), arrays, Some(nulls))
445 }
446}
447
448impl From<RecordBatch> for StructArray {
449 fn from(value: RecordBatch) -> Self {
450 Self {
451 len: value.num_rows(),
452 data_type: DataType::Struct(value.schema().fields().clone()),
453 nulls: None,
454 fields: value.columns().to_vec(),
455 }
456 }
457}
458
459impl Index<&str> for StructArray {
460 type Output = ArrayRef;
461
462 fn index(&self, name: &str) -> &Self::Output {
472 self.column_by_name(name).unwrap()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 use crate::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray};
481 use arrow_buffer::ToByteSlice;
482
483 #[test]
484 fn test_struct_array_builder() {
485 let boolean_array = BooleanArray::from(vec![false, false, true, true]);
486 let int_array = Int64Array::from(vec![42, 28, 19, 31]);
487
488 let fields = vec![
489 Field::new("a", DataType::Boolean, false),
490 Field::new("b", DataType::Int64, false),
491 ];
492 let struct_array_data = ArrayData::builder(DataType::Struct(fields.into()))
493 .len(4)
494 .add_child_data(boolean_array.to_data())
495 .add_child_data(int_array.to_data())
496 .build()
497 .unwrap();
498 let struct_array = StructArray::from(struct_array_data);
499
500 assert_eq!(struct_array.column(0).as_ref(), &boolean_array);
501 assert_eq!(struct_array.column(1).as_ref(), &int_array);
502 }
503
504 #[test]
505 fn test_struct_array_from() {
506 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
507 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
508
509 let struct_array = StructArray::from(vec![
510 (
511 Arc::new(Field::new("b", DataType::Boolean, false)),
512 boolean.clone() as ArrayRef,
513 ),
514 (
515 Arc::new(Field::new("c", DataType::Int32, false)),
516 int.clone() as ArrayRef,
517 ),
518 ]);
519 assert_eq!(struct_array.column(0).as_ref(), boolean.as_ref());
520 assert_eq!(struct_array.column(1).as_ref(), int.as_ref());
521 assert_eq!(4, struct_array.len());
522 assert_eq!(0, struct_array.null_count());
523 assert_eq!(0, struct_array.offset());
524 }
525
526 #[test]
528 fn test_struct_array_index_access() {
529 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
530 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
531
532 let struct_array = StructArray::from(vec![
533 (
534 Arc::new(Field::new("b", DataType::Boolean, false)),
535 boolean.clone() as ArrayRef,
536 ),
537 (
538 Arc::new(Field::new("c", DataType::Int32, false)),
539 int.clone() as ArrayRef,
540 ),
541 ]);
542 assert_eq!(struct_array["b"].as_ref(), boolean.as_ref());
543 assert_eq!(struct_array["c"].as_ref(), int.as_ref());
544 }
545
546 #[test]
548 fn test_struct_array_from_vec() {
549 let strings: ArrayRef = Arc::new(StringArray::from(vec![
550 Some("joe"),
551 None,
552 None,
553 Some("mark"),
554 ]));
555 let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)]));
556
557 let arr =
558 StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]).unwrap();
559
560 let struct_data = arr.into_data();
561 assert_eq!(4, struct_data.len());
562 assert_eq!(0, struct_data.null_count());
563
564 let expected_string_data = ArrayData::builder(DataType::Utf8)
565 .len(4)
566 .null_bit_buffer(Some(Buffer::from(&[9_u8])))
567 .add_buffer(Buffer::from([0, 3, 3, 3, 7].to_byte_slice()))
568 .add_buffer(Buffer::from(b"joemark"))
569 .build()
570 .unwrap();
571
572 let expected_int_data = ArrayData::builder(DataType::Int32)
573 .len(4)
574 .null_bit_buffer(Some(Buffer::from(&[11_u8])))
575 .add_buffer(Buffer::from([1, 2, 0, 4].to_byte_slice()))
576 .build()
577 .unwrap();
578
579 assert_eq!(expected_string_data, struct_data.child_data()[0]);
580 assert_eq!(expected_int_data, struct_data.child_data()[1]);
581 }
582
583 #[test]
584 fn test_struct_array_from_vec_error() {
585 let strings: ArrayRef = Arc::new(StringArray::from(vec![
586 Some("joe"),
587 None,
588 None,
589 ]));
591 let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)]));
592
593 let err = StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())])
594 .unwrap_err()
595 .to_string();
596
597 assert_eq!(
598 err,
599 "Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4"
600 )
601 }
602
603 #[test]
604 #[should_panic(
605 expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
606 )]
607 fn test_struct_array_from_mismatched_types_single() {
608 drop(StructArray::from(vec![(
609 Arc::new(Field::new("b", DataType::Int16, false)),
610 Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc<dyn Array>,
611 )]));
612 }
613
614 #[test]
615 #[should_panic(
616 expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
617 )]
618 fn test_struct_array_from_mismatched_types_multiple() {
619 drop(StructArray::from(vec![
620 (
621 Arc::new(Field::new("b", DataType::Int16, false)),
622 Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc<dyn Array>,
623 ),
624 (
625 Arc::new(Field::new("c", DataType::Utf8, false)),
626 Arc::new(Int32Array::from(vec![42, 28, 19, 31])),
627 ),
628 ]));
629 }
630
631 #[test]
632 fn test_struct_array_slice() {
633 let boolean_data = ArrayData::builder(DataType::Boolean)
634 .len(5)
635 .add_buffer(Buffer::from([0b00010000]))
636 .null_bit_buffer(Some(Buffer::from([0b00010001])))
637 .build()
638 .unwrap();
639 let int_data = ArrayData::builder(DataType::Int32)
640 .len(5)
641 .add_buffer(Buffer::from([0, 28, 42, 0, 0].to_byte_slice()))
642 .null_bit_buffer(Some(Buffer::from([0b00000110])))
643 .build()
644 .unwrap();
645
646 let field_types = vec![
647 Field::new("a", DataType::Boolean, true),
648 Field::new("b", DataType::Int32, true),
649 ];
650 let struct_array_data = ArrayData::builder(DataType::Struct(field_types.into()))
651 .len(5)
652 .add_child_data(boolean_data.clone())
653 .add_child_data(int_data.clone())
654 .null_bit_buffer(Some(Buffer::from([0b00010111])))
655 .build()
656 .unwrap();
657 let struct_array = StructArray::from(struct_array_data);
658
659 assert_eq!(5, struct_array.len());
660 assert_eq!(1, struct_array.null_count());
661 assert!(struct_array.is_valid(0));
662 assert!(struct_array.is_valid(1));
663 assert!(struct_array.is_valid(2));
664 assert!(struct_array.is_null(3));
665 assert!(struct_array.is_valid(4));
666 assert_eq!(boolean_data, struct_array.column(0).to_data());
667 assert_eq!(int_data, struct_array.column(1).to_data());
668
669 let c0 = struct_array.column(0);
670 let c0 = c0.as_any().downcast_ref::<BooleanArray>().unwrap();
671 assert_eq!(5, c0.len());
672 assert_eq!(3, c0.null_count());
673 assert!(c0.is_valid(0));
674 assert!(!c0.value(0));
675 assert!(c0.is_null(1));
676 assert!(c0.is_null(2));
677 assert!(c0.is_null(3));
678 assert!(c0.is_valid(4));
679 assert!(c0.value(4));
680
681 let c1 = struct_array.column(1);
682 let c1 = c1.as_any().downcast_ref::<Int32Array>().unwrap();
683 assert_eq!(5, c1.len());
684 assert_eq!(3, c1.null_count());
685 assert!(c1.is_null(0));
686 assert!(c1.is_valid(1));
687 assert_eq!(28, c1.value(1));
688 assert!(c1.is_valid(2));
689 assert_eq!(42, c1.value(2));
690 assert!(c1.is_null(3));
691 assert!(c1.is_null(4));
692
693 let sliced_array = struct_array.slice(2, 3);
694 let sliced_array = sliced_array.as_any().downcast_ref::<StructArray>().unwrap();
695 assert_eq!(3, sliced_array.len());
696 assert_eq!(1, sliced_array.null_count());
697 assert!(sliced_array.is_valid(0));
698 assert!(sliced_array.is_null(1));
699 assert!(sliced_array.is_valid(2));
700
701 let sliced_c0 = sliced_array.column(0);
702 let sliced_c0 = sliced_c0.as_any().downcast_ref::<BooleanArray>().unwrap();
703 assert_eq!(3, sliced_c0.len());
704 assert!(sliced_c0.is_null(0));
705 assert!(sliced_c0.is_null(1));
706 assert!(sliced_c0.is_valid(2));
707 assert!(sliced_c0.value(2));
708
709 let sliced_c1 = sliced_array.column(1);
710 let sliced_c1 = sliced_c1.as_any().downcast_ref::<Int32Array>().unwrap();
711 assert_eq!(3, sliced_c1.len());
712 assert!(sliced_c1.is_valid(0));
713 assert_eq!(42, sliced_c1.value(0));
714 assert!(sliced_c1.is_null(1));
715 assert!(sliced_c1.is_null(2));
716 }
717
718 #[test]
719 #[should_panic(
720 expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2"
721 )]
722 fn test_invalid_struct_child_array_lengths() {
723 drop(StructArray::from(vec![
724 (
725 Arc::new(Field::new("b", DataType::Float32, false)),
726 Arc::new(Float32Array::from(vec![1.1])) as Arc<dyn Array>,
727 ),
728 (
729 Arc::new(Field::new("c", DataType::Float64, false)),
730 Arc::new(Float64Array::from(vec![2.2, 3.3])),
731 ),
732 ]));
733 }
734
735 #[test]
736 fn test_struct_array_from_empty() {
737 let sa = StructArray::from(vec![]);
738 assert!(sa.is_empty())
739 }
740
741 #[test]
742 #[should_panic(expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"")]
743 fn test_struct_array_from_mismatched_nullability() {
744 drop(StructArray::from(vec![(
745 Arc::new(Field::new("c", DataType::Int32, false)),
746 Arc::new(Int32Array::from(vec![Some(42), None, Some(19)])) as ArrayRef,
747 )]));
748 }
749
750 #[test]
751 fn test_struct_array_fmt_debug() {
752 let arr: StructArray = StructArray::new(
753 vec![Arc::new(Field::new("c", DataType::Int32, true))].into(),
754 vec![Arc::new(Int32Array::from((0..30).collect::<Vec<_>>())) as ArrayRef],
755 Some(NullBuffer::new(BooleanBuffer::from(
756 (0..30).map(|i| i % 2 == 0).collect::<Vec<_>>(),
757 ))),
758 );
759 assert_eq!(format!("{arr:?}"), "StructArray\n-- validity:\n[\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n ...10 elements...,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n]\n[\n-- child 0: \"c\" (Int32)\nPrimitiveArray<Int32>\n[\n 0,\n 1,\n 2,\n 3,\n 4,\n 5,\n 6,\n 7,\n 8,\n 9,\n ...10 elements...,\n 20,\n 21,\n 22,\n 23,\n 24,\n 25,\n 26,\n 27,\n 28,\n 29,\n]\n]")
760 }
761}