parquet/schema/
visitor.rs1use crate::basic::{ConvertedType, Repetition};
21use crate::errors::ParquetError::General;
22use crate::errors::Result;
23use crate::schema::types::{Type, TypePtr};
24
25pub trait TypeVisitor<R, C> {
27 fn visit_primitive(&mut self, primitive_type: TypePtr, context: C) -> Result<R>;
29
30 fn visit_list(&mut self, list_type: TypePtr, context: C) -> Result<R> {
52 match list_type.as_ref() {
53 Type::PrimitiveType { .. } => {
54 panic!("{list_type:?} is a list type and must be a group type")
55 }
56 Type::GroupType {
57 basic_info: _,
58 fields,
59 } if fields.len() == 1 => {
60 let list_item = fields.first().unwrap();
61
62 match list_item.as_ref() {
63 Type::PrimitiveType { .. } => {
64 if list_item.get_basic_info().repetition() == Repetition::REPEATED {
65 self.visit_list_with_item(list_type.clone(), list_item.clone(), context)
66 } else {
67 Err(General(
68 "Primitive element type of list must be repeated.".to_string(),
69 ))
70 }
71 }
72 Type::GroupType {
73 basic_info: _,
74 fields,
75 } => {
76 if fields.len() == 1
77 && list_item.name() != "array"
78 && list_item.name() != format!("{}_tuple", list_type.name())
79 {
80 self.visit_list_with_item(
81 list_type.clone(),
82 fields.first().unwrap().clone(),
83 context,
84 )
85 } else {
86 self.visit_list_with_item(list_type.clone(), list_item.clone(), context)
87 }
88 }
89 }
90 }
91 _ => Err(General(
92 "Group element type of list can only contain one field.".to_string(),
93 )),
94 }
95 }
96
97 fn visit_struct(&mut self, struct_type: TypePtr, context: C) -> Result<R>;
99
100 fn visit_map(&mut self, map_type: TypePtr, context: C) -> Result<R>;
102
103 fn dispatch(&mut self, cur_type: TypePtr, context: C) -> Result<R> {
105 if cur_type.is_primitive() {
106 self.visit_primitive(cur_type, context)
107 } else {
108 match cur_type.get_basic_info().converted_type() {
109 ConvertedType::LIST => self.visit_list(cur_type, context),
110 ConvertedType::MAP | ConvertedType::MAP_KEY_VALUE => {
111 self.visit_map(cur_type, context)
112 }
113 _ => self.visit_struct(cur_type, context),
114 }
115 }
116 }
117
118 fn visit_list_with_item(
120 &mut self,
121 list_type: TypePtr,
122 item_type: TypePtr,
123 context: C,
124 ) -> Result<R>;
125}
126
127#[cfg(test)]
128mod tests {
129 use super::TypeVisitor;
130 use crate::basic::Type as PhysicalType;
131 use crate::errors::Result;
132 use crate::schema::parser::parse_message_type;
133 use crate::schema::types::TypePtr;
134 use std::sync::Arc;
135
136 struct TestVisitorContext {}
137 struct TestVisitor {
138 primitive_visited: bool,
139 struct_visited: bool,
140 list_visited: bool,
141 root_type: TypePtr,
142 }
143
144 impl TypeVisitor<bool, TestVisitorContext> for TestVisitor {
145 fn visit_primitive(
146 &mut self,
147 primitive_type: TypePtr,
148 _context: TestVisitorContext,
149 ) -> Result<bool> {
150 assert_eq!(
151 self.get_field_by_name(primitive_type.name()).as_ref(),
152 primitive_type.as_ref()
153 );
154 self.primitive_visited = true;
155 Ok(true)
156 }
157
158 fn visit_struct(
159 &mut self,
160 struct_type: TypePtr,
161 _context: TestVisitorContext,
162 ) -> Result<bool> {
163 assert_eq!(
164 self.get_field_by_name(struct_type.name()).as_ref(),
165 struct_type.as_ref()
166 );
167 self.struct_visited = true;
168 Ok(true)
169 }
170
171 fn visit_map(&mut self, _map_type: TypePtr, _context: TestVisitorContext) -> Result<bool> {
172 unimplemented!()
173 }
174
175 fn visit_list_with_item(
176 &mut self,
177 list_type: TypePtr,
178 item_type: TypePtr,
179 _context: TestVisitorContext,
180 ) -> Result<bool> {
181 assert_eq!(
182 self.get_field_by_name(list_type.name()).as_ref(),
183 list_type.as_ref()
184 );
185 assert_eq!("element", item_type.name());
186 assert_eq!(PhysicalType::INT32, item_type.get_physical_type());
187 self.list_visited = true;
188 Ok(true)
189 }
190 }
191
192 impl TestVisitor {
193 fn new(root: TypePtr) -> Self {
194 Self {
195 primitive_visited: false,
196 struct_visited: false,
197 list_visited: false,
198 root_type: root,
199 }
200 }
201
202 fn get_field_by_name(&self, name: &str) -> TypePtr {
203 self.root_type
204 .get_fields()
205 .iter()
206 .find(|t| t.name() == name)
207 .cloned()
208 .unwrap()
209 }
210 }
211
212 #[test]
213 fn test_visitor() {
214 let message_type = "
215 message spark_schema {
216 REQUIRED INT32 a;
217 OPTIONAL group inner_schema {
218 REQUIRED INT32 b;
219 REQUIRED DOUBLE c;
220 }
221
222 OPTIONAL group e (LIST) {
223 REPEATED group list {
224 REQUIRED INT32 element;
225 }
226 }
227 ";
228
229 let parquet_type = Arc::new(parse_message_type(message_type).unwrap());
230
231 let mut visitor = TestVisitor::new(parquet_type.clone());
232 for f in parquet_type.get_fields() {
233 let c = TestVisitorContext {};
234 assert!(visitor.dispatch(f.clone(), c).unwrap());
235 }
236
237 assert!(visitor.struct_visited);
238 assert!(visitor.primitive_visited);
239 assert!(visitor.list_visited);
240 }
241}