parquet_derive/lib.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This crate provides a procedural macro to derive
19//! implementations of a RecordWriter and RecordReader
20
21#![doc(
22 html_logo_url = "https://raw.githubusercontent.com/apache/parquet-format/25f05e73d8cd7f5c83532ce51cb4f4de8ba5f2a2/logo/parquet-logos_1.svg",
23 html_favicon_url = "https://raw.githubusercontent.com/apache/parquet-format/25f05e73d8cd7f5c83532ce51cb4f4de8ba5f2a2/logo/parquet-logos_1.svg"
24)]
25#![cfg_attr(docsrs, feature(doc_auto_cfg))]
26#![warn(missing_docs)]
27#![recursion_limit = "128"]
28
29extern crate proc_macro;
30extern crate proc_macro2;
31extern crate syn;
32#[macro_use]
33extern crate quote;
34
35extern crate parquet;
36
37use ::syn::{parse_macro_input, Data, DataStruct, DeriveInput};
38
39mod parquet_field;
40
41/// Derive flat, simple RecordWriter implementations.
42///
43/// Works by parsing a struct tagged with `#[derive(ParquetRecordWriter)]` and emitting
44/// the correct writing code for each field of the struct. Column writers
45/// are generated in the order they are defined.
46///
47/// It is up to the programmer to keep the order of the struct
48/// fields lined up with the schema.
49///
50/// Example:
51///
52/// ```rust
53/// use parquet::file::properties::WriterProperties;
54/// use parquet::file::writer::SerializedFileWriter;
55/// use parquet::record::RecordWriter;
56/// use parquet_derive::ParquetRecordWriter;
57/// use std::fs::File;
58/// use std::sync::Arc;
59///
60/// // For reader
61/// use parquet::file::reader::{FileReader, SerializedFileReader};
62/// use parquet::record::RecordReader;
63/// use parquet_derive::ParquetRecordReader;
64///
65/// #[derive(Debug, ParquetRecordWriter, ParquetRecordReader)]
66/// struct ACompleteRecord {
67/// pub a_bool: bool,
68/// pub a_string: String,
69/// }
70///
71/// fn write_some_records() {
72/// let samples = vec![
73/// ACompleteRecord {
74/// a_bool: true,
75/// a_string: "I'm true".into(),
76/// },
77/// ACompleteRecord {
78/// a_bool: false,
79/// a_string: "I'm false".into(),
80/// },
81/// ];
82///
83/// let schema = samples.as_slice().schema().unwrap();
84///
85/// let props = Arc::new(WriterProperties::builder().build());
86///
87/// let file = File::create("example.parquet").unwrap();
88///
89/// let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
90///
91/// let mut row_group = writer.next_row_group().unwrap();
92///
93/// samples
94/// .as_slice()
95/// .write_to_row_group(&mut row_group)
96/// .unwrap();
97///
98/// row_group.close().unwrap();
99///
100/// writer.close().unwrap();
101/// }
102///
103/// fn read_some_records() -> Vec<ACompleteRecord> {
104/// let mut samples: Vec<ACompleteRecord> = Vec::new();
105/// let file = File::open("example.parquet").unwrap();
106///
107/// let reader = SerializedFileReader::new(file).unwrap();
108/// let mut row_group = reader.get_row_group(0).unwrap();
109/// samples.read_from_row_group(&mut *row_group, 2).unwrap();
110///
111/// samples
112/// }
113///
114/// pub fn main() {
115/// write_some_records();
116///
117/// let records = read_some_records();
118///
119/// std::fs::remove_file("example.parquet").unwrap();
120///
121/// assert_eq!(
122/// format!("{:?}", records),
123/// "[ACompleteRecord { a_bool: true, a_string: \"I'm true\" }, ACompleteRecord { a_bool: false, a_string: \"I'm false\" }]"
124/// );
125/// }
126/// ```
127///
128#[proc_macro_derive(ParquetRecordWriter)]
129pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
130 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
131 let fields = match input.data {
132 Data::Struct(DataStruct { fields, .. }) => fields,
133 Data::Enum(_) => unimplemented!("Enum currently is not supported"),
134 Data::Union(_) => unimplemented!("Union currently is not supported"),
135 };
136
137 let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect();
138
139 let writer_snippets: Vec<proc_macro2::TokenStream> =
140 field_infos.iter().map(|x| x.writer_snippet()).collect();
141
142 let derived_for = input.ident;
143 let generics = input.generics;
144
145 let field_types: Vec<proc_macro2::TokenStream> =
146 field_infos.iter().map(|x| x.parquet_type()).collect();
147
148 (quote! {
149 impl #generics ::parquet::record::RecordWriter<#derived_for #generics> for &[#derived_for #generics] {
150 fn write_to_row_group<W: ::std::io::Write + Send>(
151 &self,
152 row_group_writer: &mut ::parquet::file::writer::SerializedRowGroupWriter<'_, W>
153 ) -> ::std::result::Result<(), ::parquet::errors::ParquetError> {
154 use ::parquet::column::writer::ColumnWriter;
155
156 let mut row_group_writer = row_group_writer;
157 let records = &self; // Used by all the writer snippets to be more clear
158
159 #(
160 {
161 let mut some_column_writer = row_group_writer.next_column().unwrap();
162 if let Some(mut column_writer) = some_column_writer {
163 #writer_snippets
164 column_writer.close()?;
165 } else {
166 return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
167 }
168 }
169 );*
170
171 Ok(())
172 }
173
174 fn schema(&self) -> ::std::result::Result<::parquet::schema::types::TypePtr, ::parquet::errors::ParquetError> {
175 use ::parquet::schema::types::Type as ParquetType;
176 use ::parquet::schema::types::TypePtr;
177 use ::parquet::basic::LogicalType;
178
179 let mut fields: ::std::vec::Vec<TypePtr> = ::std::vec::Vec::new();
180 #(
181 #field_types
182 );*;
183 let group = ParquetType::group_type_builder("rust_schema")
184 .with_fields(fields)
185 .build()?;
186 Ok(group.into())
187 }
188 }
189 }).into()
190}
191
192/// Derive flat, simple RecordReader implementations.
193///
194/// Works by parsing a struct tagged with `#[derive(ParquetRecordReader)]` and emitting
195/// the correct writing code for each field of the struct. Column readers
196/// are generated by matching names in the schema to the names in the struct.
197///
198/// It is up to the programmer to ensure the names in the struct
199/// fields line up with the schema.
200///
201/// Example:
202///
203/// ```rust
204/// use parquet::record::RecordReader;
205/// use parquet::file::{serialized_reader::SerializedFileReader, reader::FileReader};
206/// use parquet_derive::{ParquetRecordReader};
207/// use std::fs::File;
208///
209/// #[derive(ParquetRecordReader)]
210/// struct ACompleteRecord {
211/// pub a_bool: bool,
212/// pub a_string: String,
213/// }
214///
215/// pub fn read_some_records() -> Vec<ACompleteRecord> {
216/// let mut samples: Vec<ACompleteRecord> = Vec::new();
217/// let file = File::open("some_file.parquet").unwrap();
218///
219/// let reader = SerializedFileReader::new(file).unwrap();
220/// let mut row_group = reader.get_row_group(0).unwrap();
221/// samples.read_from_row_group(&mut *row_group, 1).unwrap();
222/// samples
223/// }
224/// ```
225///
226#[proc_macro_derive(ParquetRecordReader)]
227pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
228 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
229 let fields = match input.data {
230 Data::Struct(DataStruct { fields, .. }) => fields,
231 Data::Enum(_) => unimplemented!("Enum currently is not supported"),
232 Data::Union(_) => unimplemented!("Union currently is not supported"),
233 };
234
235 let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect();
236 let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
237 let reader_snippets: Vec<proc_macro2::TokenStream> =
238 field_infos.iter().map(|x| x.reader_snippet()).collect();
239
240 let derived_for = input.ident;
241 let generics = input.generics;
242
243 (quote! {
244
245 impl #generics ::parquet::record::RecordReader<#derived_for #generics> for Vec<#derived_for #generics> {
246 fn read_from_row_group(
247 &mut self,
248 row_group_reader: &mut dyn ::parquet::file::reader::RowGroupReader,
249 num_records: usize,
250 ) -> ::std::result::Result<(), ::parquet::errors::ParquetError> {
251 use ::parquet::column::reader::ColumnReader;
252
253 let mut row_group_reader = row_group_reader;
254
255 // key: parquet file column name, value: column index
256 let mut name_to_index = std::collections::HashMap::new();
257 for (idx, col) in row_group_reader.metadata().schema_descr().columns().iter().enumerate() {
258 name_to_index.insert(col.name().to_string(), idx);
259 }
260
261 for _ in 0..num_records {
262 self.push(#derived_for {
263 #(
264 #field_names: Default::default()
265 ),*
266 })
267 }
268
269 let records = self; // Used by all the reader snippets to be more clear
270
271 #(
272 {
273 let idx: usize = match name_to_index.get(stringify!(#field_names)) {
274 Some(&col_idx) => col_idx,
275 None => {
276 let error_msg = format!("column name '{}' is not found in parquet file!", stringify!(#field_names));
277 return Err(::parquet::errors::ParquetError::General(error_msg));
278 }
279 };
280 if let Ok(mut column_reader) = row_group_reader.get_column_reader(idx) {
281 #reader_snippets
282 } else {
283 return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
284 }
285 }
286 );*
287
288 Ok(())
289 }
290 }
291 }).into()
292}