1use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::{ArrowError, DataType, Field, FieldRef};
22
23#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[cfg_attr(feature = "serde", serde(transparent))]
60pub struct Fields(Arc<[FieldRef]>);
61
62impl std::fmt::Debug for Fields {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 self.0.as_ref().fmt(f)
65 }
66}
67
68impl Fields {
69 pub fn empty() -> Self {
71 Self(Arc::new([]))
72 }
73
74 pub fn size(&self) -> usize {
76 self.iter()
77 .map(|field| field.size() + std::mem::size_of::<FieldRef>())
78 .sum()
79 }
80
81 pub fn find(&self, name: &str) -> Option<(usize, &FieldRef)> {
83 self.0.iter().enumerate().find(|(_, b)| b.name() == name)
84 }
85
86 pub fn contains(&self, other: &Fields) -> bool {
93 if Arc::ptr_eq(&self.0, &other.0) {
94 return true;
95 }
96 self.len() == other.len()
97 && self
98 .iter()
99 .zip(other.iter())
100 .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
101 }
102
103 pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
140 self.try_filter_leaves(|idx, field| Ok(filter(idx, field)))
141 .unwrap()
142 }
143
144 pub fn try_filter_leaves<F: FnMut(usize, &FieldRef) -> Result<bool, ArrowError>>(
149 &self,
150 mut filter: F,
151 ) -> Result<Self, ArrowError> {
152 fn filter_field<F: FnMut(&FieldRef) -> Result<bool, ArrowError>>(
153 f: &FieldRef,
154 filter: &mut F,
155 ) -> Result<Option<FieldRef>, ArrowError> {
156 use DataType::*;
157
158 let v = match f.data_type() {
159 Dictionary(_, v) => v.as_ref(), RunEndEncoded(_, v) => v.data_type(), d => d,
162 };
163 let d = match v {
164 List(child) => {
165 let fields = filter_field(child, filter)?;
166 if let Some(fields) = fields {
167 List(fields)
168 } else {
169 return Ok(None);
170 }
171 }
172 LargeList(child) => {
173 let fields = filter_field(child, filter)?;
174 if let Some(fields) = fields {
175 LargeList(fields)
176 } else {
177 return Ok(None);
178 }
179 }
180 Map(child, ordered) => {
181 let fields = filter_field(child, filter)?;
182 if let Some(fields) = fields {
183 Map(fields, *ordered)
184 } else {
185 return Ok(None);
186 }
187 }
188 FixedSizeList(child, size) => {
189 let fields = filter_field(child, filter)?;
190 if let Some(fields) = fields {
191 FixedSizeList(fields, *size)
192 } else {
193 return Ok(None);
194 }
195 }
196 Struct(fields) => {
197 let filtered: Result<Vec<_>, _> =
198 fields.iter().map(|f| filter_field(f, filter)).collect();
199 let filtered: Fields = filtered?
200 .iter()
201 .filter_map(|f| f.as_ref().cloned())
202 .collect();
203
204 if filtered.is_empty() {
205 return Ok(None);
206 }
207
208 Struct(filtered)
209 }
210 Union(fields, mode) => {
211 let filtered: Result<Vec<_>, _> = fields
212 .iter()
213 .map(|(id, f)| filter_field(f, filter).map(|f| f.map(|f| (id, f))))
214 .collect();
215 let filtered: UnionFields = filtered?
216 .iter()
217 .filter_map(|f| f.as_ref().cloned())
218 .collect();
219
220 if filtered.is_empty() {
221 return Ok(None);
222 }
223
224 Union(filtered, *mode)
225 }
226 _ => {
227 let filtered = filter(f)?;
228 return Ok(filtered.then(|| f.clone()));
229 }
230 };
231 let d = match f.data_type() {
232 Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
233 RunEndEncoded(v, f) => {
234 RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
235 }
236 _ => d,
237 };
238 Ok(Some(Arc::new(f.as_ref().clone().with_data_type(d))))
239 }
240
241 let mut leaf_idx = 0;
242 let mut filter = |f: &FieldRef| {
243 let t = filter(leaf_idx, f)?;
244 leaf_idx += 1;
245 Ok(t)
246 };
247
248 let filtered: Result<Vec<_>, _> = self
249 .0
250 .iter()
251 .map(|f| filter_field(f, &mut filter))
252 .collect();
253 let filtered = filtered?
254 .iter()
255 .filter_map(|f| f.as_ref().cloned())
256 .collect();
257 Ok(filtered)
258 }
259}
260
261impl Default for Fields {
262 fn default() -> Self {
263 Self::empty()
264 }
265}
266
267impl FromIterator<Field> for Fields {
268 fn from_iter<T: IntoIterator<Item = Field>>(iter: T) -> Self {
269 iter.into_iter().map(Arc::new).collect()
270 }
271}
272
273impl FromIterator<FieldRef> for Fields {
274 fn from_iter<T: IntoIterator<Item = FieldRef>>(iter: T) -> Self {
275 Self(iter.into_iter().collect())
276 }
277}
278
279impl From<Vec<Field>> for Fields {
280 fn from(value: Vec<Field>) -> Self {
281 value.into_iter().collect()
282 }
283}
284
285impl From<Vec<FieldRef>> for Fields {
286 fn from(value: Vec<FieldRef>) -> Self {
287 Self(value.into())
288 }
289}
290
291impl From<&[FieldRef]> for Fields {
292 fn from(value: &[FieldRef]) -> Self {
293 Self(value.into())
294 }
295}
296
297impl<const N: usize> From<[FieldRef; N]> for Fields {
298 fn from(value: [FieldRef; N]) -> Self {
299 Self(Arc::new(value))
300 }
301}
302
303impl Deref for Fields {
304 type Target = [FieldRef];
305
306 fn deref(&self) -> &Self::Target {
307 self.0.as_ref()
308 }
309}
310
311impl<'a> IntoIterator for &'a Fields {
312 type Item = &'a FieldRef;
313 type IntoIter = std::slice::Iter<'a, FieldRef>;
314
315 fn into_iter(self) -> Self::IntoIter {
316 self.0.iter()
317 }
318}
319
320#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
322#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
323#[cfg_attr(feature = "serde", serde(transparent))]
324pub struct UnionFields(Arc<[(i8, FieldRef)]>);
325
326impl std::fmt::Debug for UnionFields {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 self.0.as_ref().fmt(f)
329 }
330}
331
332impl std::ops::Index<usize> for UnionFields {
341 type Output = (i8, FieldRef);
342
343 fn index(&self, index: usize) -> &Self::Output {
344 &self.0[index]
345 }
346}
347
348impl UnionFields {
349 pub fn empty() -> Self {
351 Self(Arc::from([]))
352 }
353
354 pub fn try_new<F, T>(type_ids: T, fields: F) -> Result<Self, ArrowError>
391 where
392 F: IntoIterator,
393 F::Item: Into<FieldRef>,
394 T: IntoIterator<Item = i8>,
395 {
396 let mut type_ids_iter = type_ids.into_iter();
397 let mut fields_iter = fields.into_iter().map(Into::into);
398
399 let mut seen_type_ids = 0u128;
400
401 let mut out = Vec::new();
402
403 loop {
404 match (type_ids_iter.next(), fields_iter.next()) {
405 (None, None) => return Ok(Self(out.into())),
406 (Some(type_id), Some(field)) => {
407 if type_id < 0 {
409 return Err(ArrowError::InvalidArgumentError(format!(
410 "type ids must be non-negative: {type_id}"
411 )));
412 }
413
414 let mask = 1_u128 << type_id;
416 if (seen_type_ids & mask) != 0 {
417 return Err(ArrowError::InvalidArgumentError(format!(
418 "duplicate type id: {type_id}"
419 )));
420 }
421
422 seen_type_ids |= mask;
423
424 out.push((type_id, field));
425 }
426 (None, Some(_)) => {
427 return Err(ArrowError::InvalidArgumentError(
428 "fields iterator has more elements than type_ids iterator".to_string(),
429 ));
430 }
431 (Some(_), None) => {
432 return Err(ArrowError::InvalidArgumentError(
433 "type_ids iterator has more elements than fields iterator".to_string(),
434 ));
435 }
436 }
437 }
438 }
439
440 pub fn from_fields<F>(fields: F) -> Self
468 where
469 F: IntoIterator,
470 F::Item: Into<FieldRef>,
471 {
472 fields
473 .into_iter()
474 .enumerate()
475 .map(|(i, field)| {
476 let id = i8::try_from(i).expect("UnionFields cannot contain more than 128 fields");
477
478 (id, field.into())
479 })
480 .collect()
481 }
482
483 pub fn try_from_fields<F>(fields: F) -> Result<Self, ArrowError>
518 where
519 F: IntoIterator,
520 F::Item: Into<FieldRef>,
521 {
522 let mut out = Vec::with_capacity(i8::MAX as usize + 1);
523
524 for (i, field) in fields.into_iter().enumerate() {
525 let id = i8::try_from(i).map_err(|_| {
526 ArrowError::InvalidArgumentError(
527 "UnionFields cannot contain more than 128 fields".into(),
528 )
529 })?;
530
531 out.push((id, field.into()));
532 }
533
534 Ok(Self(out.into()))
535 }
536
537 #[deprecated(since = "57.0.0", note = "Use `try_new` instead")]
564 pub fn new<F, T>(type_ids: T, fields: F) -> Self
565 where
566 F: IntoIterator,
567 F::Item: Into<FieldRef>,
568 T: IntoIterator<Item = i8>,
569 {
570 let fields = fields.into_iter().map(Into::into);
571 let mut set = 0_u128;
572 type_ids
573 .into_iter()
574 .inspect(|&idx| {
575 let mask = 1_u128 << idx;
576 if (set & mask) != 0 {
577 panic!("duplicate type id: {idx}");
578 } else {
579 set |= mask;
580 }
581 })
582 .zip(fields)
583 .collect()
584 }
585
586 pub fn size(&self) -> usize {
588 self.iter()
589 .map(|(_, field)| field.size() + std::mem::size_of::<(i8, FieldRef)>())
590 .sum()
591 }
592
593 pub fn len(&self) -> usize {
595 self.0.len()
596 }
597
598 pub fn is_empty(&self) -> bool {
600 self.0.is_empty()
601 }
602
603 pub fn iter(&self) -> impl Iterator<Item = (i8, &FieldRef)> + '_ {
605 self.0.iter().map(|(id, f)| (*id, f))
606 }
607
608 pub fn get(&self, index: usize) -> Option<&(i8, FieldRef)> {
630 self.0.get(index)
631 }
632
633 pub fn find_by_type_id(&self, type_id: i8) -> Option<(i8, &FieldRef)> {
636 self.iter().find(|&(i, _)| i == type_id)
637 }
638
639 pub fn find_by_field(&self, field: &Field) -> Option<(i8, &FieldRef)> {
642 self.iter().find(|&(_, f)| f.as_ref() == field)
643 }
644
645 pub(crate) fn try_merge(&mut self, other: &Self) -> Result<(), ArrowError> {
649 let mut output: Vec<_> = self.iter().map(|(id, f)| (id, f.clone())).collect();
651 for (field_type_id, from_field) in other.iter() {
652 let mut is_new_field = true;
653 for (self_type_id, self_field) in output.iter_mut() {
654 if from_field == self_field {
655 if *self_type_id != field_type_id {
658 return Err(ArrowError::SchemaError(format!(
659 "Fail to merge schema field '{}' because the self_type_id = {} does not equal field_type_id = {}",
660 self_field.name(),
661 self_type_id,
662 field_type_id
663 )));
664 }
665
666 is_new_field = false;
667 break;
668 }
669 }
670
671 if is_new_field {
672 output.push((field_type_id, from_field.clone()))
673 }
674 }
675 *self = output.into_iter().collect();
676 Ok(())
677 }
678}
679
680impl FromIterator<(i8, FieldRef)> for UnionFields {
681 fn from_iter<T: IntoIterator<Item = (i8, FieldRef)>>(iter: T) -> Self {
682 Self(iter.into_iter().collect())
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::UnionMode;
690
691 #[test]
692 fn test_filter() {
693 let floats = Fields::from(vec![
694 Field::new("a", DataType::Float32, false),
695 Field::new("b", DataType::Float32, false),
696 ]);
697 let fields = Fields::from(vec![
698 Field::new("a", DataType::Int32, true),
699 Field::new("floats", DataType::Struct(floats.clone()), true),
700 Field::new("b", DataType::Int16, true),
701 Field::new(
702 "c",
703 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
704 false,
705 ),
706 Field::new(
707 "d",
708 DataType::Dictionary(
709 Box::new(DataType::Int32),
710 Box::new(DataType::Struct(floats.clone())),
711 ),
712 false,
713 ),
714 Field::new_list(
715 "e",
716 Field::new("floats", DataType::Struct(floats.clone()), true),
717 true,
718 ),
719 Field::new_fixed_size_list(
720 "f",
721 Field::new_list_field(DataType::Int32, false),
722 3,
723 false,
724 ),
725 Field::new_map(
726 "g",
727 "entries",
728 Field::new("keys", DataType::LargeUtf8, false),
729 Field::new("values", DataType::Int32, true),
730 false,
731 false,
732 ),
733 Field::new(
734 "h",
735 DataType::Union(
736 UnionFields::try_new(
737 vec![1, 3],
738 vec![
739 Field::new("field1", DataType::UInt8, false),
740 Field::new("field3", DataType::Utf8, false),
741 ],
742 )
743 .unwrap(),
744 UnionMode::Dense,
745 ),
746 true,
747 ),
748 Field::new(
749 "i",
750 DataType::RunEndEncoded(
751 Arc::new(Field::new("run_ends", DataType::Int32, false)),
752 Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
753 ),
754 false,
755 ),
756 ]);
757
758 let floats_a = DataType::Struct(vec![floats[0].clone()].into());
759
760 let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
761 assert_eq!(r.len(), 2);
762 assert_eq!(r[0], fields[0]);
763 assert_eq!(r[1].data_type(), &floats_a);
764
765 let r = fields.filter_leaves(|_, f| f.name() == "a");
766 assert_eq!(r.len(), 5);
767 assert_eq!(r[0], fields[0]);
768 assert_eq!(r[1].data_type(), &floats_a);
769 assert_eq!(
770 r[2].data_type(),
771 &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
772 );
773 assert_eq!(
774 r[3].as_ref(),
775 &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
776 );
777 assert_eq!(
778 r[4].as_ref(),
779 &Field::new(
780 "i",
781 DataType::RunEndEncoded(
782 Arc::new(Field::new("run_ends", DataType::Int32, false)),
783 Arc::new(Field::new("values", floats_a.clone(), true)),
784 ),
785 false,
786 )
787 );
788
789 let r = fields.filter_leaves(|_, f| f.name() == "floats");
790 assert_eq!(r.len(), 0);
791
792 let r = fields.filter_leaves(|idx, _| idx == 9);
793 assert_eq!(r.len(), 1);
794 assert_eq!(r[0], fields[6]);
795
796 let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
797 assert_eq!(r.len(), 1);
798 assert_eq!(r[0], fields[7]);
799
800 let union = DataType::Union(
801 UnionFields::try_new(vec![1], vec![Field::new("field1", DataType::UInt8, false)])
802 .unwrap(),
803 UnionMode::Dense,
804 );
805
806 let r = fields.filter_leaves(|idx, _| idx == 12);
807 assert_eq!(r.len(), 1);
808 assert_eq!(r[0].data_type(), &union);
809
810 let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
811 assert_eq!(r.len(), 1);
812 assert_eq!(r[0], fields[9]);
813
814 let r = fields.try_filter_leaves(|_, _| Err(ArrowError::SchemaError("error".to_string())));
816 assert!(r.is_err());
817 }
818
819 #[test]
820 fn test_union_fields_try_new_valid() {
821 let res = UnionFields::try_new(
822 vec![1, 6, 7],
823 vec![
824 Field::new("f1", DataType::UInt8, false),
825 Field::new("f6", DataType::Utf8, false),
826 Field::new("f7", DataType::Int32, true),
827 ],
828 );
829 assert!(res.is_ok());
830 let union_fields = res.unwrap();
831 assert_eq!(union_fields.len(), 3);
832 assert_eq!(
833 union_fields.iter().map(|(id, _)| id).collect::<Vec<_>>(),
834 vec![1, 6, 7]
835 );
836 }
837
838 #[test]
839 fn test_union_fields_try_new_empty() {
840 let res = UnionFields::try_new(Vec::<i8>::new(), Vec::<Field>::new());
841 assert!(res.is_ok());
842 assert!(res.unwrap().is_empty());
843 }
844
845 #[test]
846 fn test_union_fields_try_new_duplicate_type_id() {
847 let res = UnionFields::try_new(
848 vec![1, 1],
849 vec![
850 Field::new("f1", DataType::UInt8, false),
851 Field::new("f2", DataType::Utf8, false),
852 ],
853 );
854 assert!(res.is_err());
855 assert!(
856 res.unwrap_err()
857 .to_string()
858 .contains("duplicate type id: 1")
859 );
860 }
861
862 #[test]
863 fn test_union_fields_try_new_duplicate_field() {
864 let field = Field::new("field", DataType::UInt8, false);
865 let res = UnionFields::try_new(vec![1, 2], vec![field.clone(), field]);
866 assert!(res.is_ok());
867 }
868
869 #[test]
870 fn test_union_fields_try_new_more_type_ids() {
871 let res = UnionFields::try_new(
872 vec![1, 2, 3],
873 vec![
874 Field::new("f1", DataType::UInt8, false),
875 Field::new("f2", DataType::Utf8, false),
876 ],
877 );
878 assert!(res.is_err());
879 assert!(
880 res.unwrap_err()
881 .to_string()
882 .contains("type_ids iterator has more elements")
883 );
884 }
885
886 #[test]
887 fn test_union_fields_try_new_more_fields() {
888 let res = UnionFields::try_new(
889 vec![1, 2],
890 vec![
891 Field::new("f1", DataType::UInt8, false),
892 Field::new("f2", DataType::Utf8, false),
893 Field::new("f3", DataType::Int32, true),
894 ],
895 );
896 assert!(res.is_err());
897 assert!(
898 res.unwrap_err()
899 .to_string()
900 .contains("fields iterator has more elements")
901 );
902 }
903
904 #[test]
905 fn test_union_fields_try_new_negative_type_ids() {
906 let res = UnionFields::try_new(
907 vec![-128, -1, 0, 127],
908 vec![
909 Field::new("field_min", DataType::UInt8, false),
910 Field::new("field_neg", DataType::Utf8, false),
911 Field::new("field_zero", DataType::Int32, true),
912 Field::new("field_max", DataType::Boolean, false),
913 ],
914 );
915 assert!(res.is_err());
916 assert!(
917 res.unwrap_err()
918 .to_string()
919 .contains("type ids must be non-negative")
920 )
921 }
922
923 #[test]
924 fn test_union_fields_try_new_complex_types() {
925 let res = UnionFields::try_new(
926 vec![0, 1, 2],
927 vec![
928 Field::new(
929 "struct_field",
930 DataType::Struct(Fields::from(vec![
931 Field::new("a", DataType::Int32, false),
932 Field::new("b", DataType::Utf8, true),
933 ])),
934 false,
935 ),
936 Field::new_list(
937 "list_field",
938 Field::new("item", DataType::Float64, true),
939 true,
940 ),
941 Field::new(
942 "dict_field",
943 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
944 false,
945 ),
946 ],
947 );
948 assert!(res.is_ok());
949 assert_eq!(res.unwrap().len(), 3);
950 }
951
952 #[test]
953 fn test_union_fields_try_new_single_field() {
954 let res = UnionFields::try_new(
955 vec![42],
956 vec![Field::new("only_field", DataType::Int64, false)],
957 );
958 assert!(res.is_ok());
959 let union_fields = res.unwrap();
960 assert_eq!(union_fields.len(), 1);
961 assert_eq!(union_fields.iter().next().unwrap().0, 42);
962 }
963
964 #[test]
965 fn test_union_fields_try_from_fields_empty() {
966 let res = UnionFields::try_from_fields(Vec::<Field>::new());
967 assert!(res.is_ok());
968 assert!(res.unwrap().is_empty());
969 }
970
971 #[test]
972 fn test_union_fields_try_from_fields_single() {
973 let res = UnionFields::try_from_fields(vec![Field::new("only", DataType::Int64, false)]);
974 assert!(res.is_ok());
975 let union_fields = res.unwrap();
976 assert_eq!(union_fields.len(), 1);
977 assert_eq!(union_fields.iter().next().unwrap().0, 0);
978 }
979
980 #[test]
981 fn test_union_fields_try_from_fields_too_many() {
982 let many_fields: Vec<_> = (0..200)
983 .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
984 .collect();
985 let res = UnionFields::try_from_fields(many_fields);
986 assert!(res.is_err());
987 assert!(
988 res.unwrap_err()
989 .to_string()
990 .contains("UnionFields cannot contain more than 128 fields")
991 );
992 }
993
994 #[test]
995 fn test_union_fields_try_from_fields_max_valid() {
996 let fields: Vec<_> = (0..=i8::MAX)
997 .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
998 .collect();
999 let res = UnionFields::try_from_fields(fields);
1000 assert!(res.is_ok());
1001 let union_fields = res.unwrap();
1002 assert_eq!(union_fields.len(), 128);
1003 assert_eq!(union_fields.iter().map(|(id, _)| id).min().unwrap(), 0);
1004 assert_eq!(union_fields.iter().map(|(id, _)| id).max().unwrap(), 127);
1005 }
1006
1007 #[test]
1008 fn test_union_fields_try_from_fields_over_max() {
1009 let fields: Vec<_> = (0..129)
1011 .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
1012 .collect();
1013 let res = UnionFields::try_from_fields(fields);
1014 assert!(res.is_err());
1015 }
1016
1017 #[test]
1018 fn test_union_fields_try_from_fields_complex_types() {
1019 let res = UnionFields::try_from_fields(vec![
1020 Field::new(
1021 "struct_field",
1022 DataType::Struct(Fields::from(vec![
1023 Field::new("a", DataType::Int32, false),
1024 Field::new("b", DataType::Utf8, true),
1025 ])),
1026 false,
1027 ),
1028 Field::new_list(
1029 "list_field",
1030 Field::new("item", DataType::Float64, true),
1031 true,
1032 ),
1033 Field::new(
1034 "dict_field",
1035 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
1036 false,
1037 ),
1038 ]);
1039 assert!(res.is_ok());
1040 assert_eq!(res.unwrap().len(), 3);
1041 }
1042}