parquet_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This crate provides a procedural macro to derive
19//! implementations of a RecordWriter and RecordReader
20
21#![doc(
22    html_logo_url = "https://raw.githubusercontent.com/apache/parquet-format/25f05e73d8cd7f5c83532ce51cb4f4de8ba5f2a2/logo/parquet-logos_1.svg",
23    html_favicon_url = "https://raw.githubusercontent.com/apache/parquet-format/25f05e73d8cd7f5c83532ce51cb4f4de8ba5f2a2/logo/parquet-logos_1.svg"
24)]
25#![cfg_attr(docsrs, feature(doc_auto_cfg))]
26#![warn(missing_docs)]
27#![recursion_limit = "128"]
28
29extern crate proc_macro;
30extern crate proc_macro2;
31extern crate syn;
32#[macro_use]
33extern crate quote;
34
35extern crate parquet;
36
37use ::syn::{parse_macro_input, Data, DataStruct, DeriveInput};
38
39mod parquet_field;
40
41/// Derive flat, simple RecordWriter implementations.
42///
43/// Works by parsing a struct tagged with `#[derive(ParquetRecordWriter)]` and emitting
44/// the correct writing code for each field of the struct. Column writers
45/// are generated in the order they are defined.
46///
47/// It is up to the programmer to keep the order of the struct
48/// fields lined up with the schema.
49///
50/// Example:
51///
52/// ```rust
53/// use parquet::file::properties::WriterProperties;
54/// use parquet::file::writer::SerializedFileWriter;
55/// use parquet::record::RecordWriter;
56/// use parquet_derive::ParquetRecordWriter;
57/// use std::fs::File;
58/// use std::sync::Arc;
59///
60/// // For reader
61/// use parquet::file::reader::{FileReader, SerializedFileReader};
62/// use parquet::record::RecordReader;
63/// use parquet_derive::ParquetRecordReader;
64///
65/// #[derive(Debug, ParquetRecordWriter, ParquetRecordReader)]
66/// struct ACompleteRecord {
67///     pub a_bool: bool,
68///     pub a_string: String,
69/// }
70///
71/// fn write_some_records() {
72///     let samples = vec![
73///         ACompleteRecord {
74///             a_bool: true,
75///             a_string: "I'm true".into(),
76///         },
77///         ACompleteRecord {
78///             a_bool: false,
79///             a_string: "I'm false".into(),
80///         },
81///     ];
82///
83///     let schema = samples.as_slice().schema().unwrap();
84///
85///     let props = Arc::new(WriterProperties::builder().build());
86///
87///     let file = File::create("example.parquet").unwrap();
88///
89///     let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
90///
91///     let mut row_group = writer.next_row_group().unwrap();
92///
93///     samples
94///         .as_slice()
95///         .write_to_row_group(&mut row_group)
96///         .unwrap();
97///
98///     row_group.close().unwrap();
99///
100///     writer.close().unwrap();
101/// }
102///
103/// fn read_some_records() -> Vec<ACompleteRecord> {
104///     let mut samples: Vec<ACompleteRecord> = Vec::new();
105///     let file = File::open("example.parquet").unwrap();
106///
107///     let reader = SerializedFileReader::new(file).unwrap();
108///     let mut row_group = reader.get_row_group(0).unwrap();
109///     samples.read_from_row_group(&mut *row_group, 2).unwrap();
110///
111///     samples
112/// }
113///
114/// pub fn main() {
115///     write_some_records();
116///
117///     let records = read_some_records();
118///
119///     std::fs::remove_file("example.parquet").unwrap();
120///
121///     assert_eq!(
122///         format!("{:?}", records),
123///         "[ACompleteRecord { a_bool: true, a_string: \"I'm true\" }, ACompleteRecord { a_bool: false, a_string: \"I'm false\" }]"
124///     );
125/// }
126/// ```
127///
128#[proc_macro_derive(ParquetRecordWriter)]
129pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
130    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
131    let fields = match input.data {
132        Data::Struct(DataStruct { fields, .. }) => fields,
133        Data::Enum(_) => unimplemented!("Enum currently is not supported"),
134        Data::Union(_) => unimplemented!("Union currently is not supported"),
135    };
136
137    let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect();
138
139    let writer_snippets: Vec<proc_macro2::TokenStream> =
140        field_infos.iter().map(|x| x.writer_snippet()).collect();
141
142    let derived_for = input.ident;
143    let generics = input.generics;
144
145    let field_types: Vec<proc_macro2::TokenStream> =
146        field_infos.iter().map(|x| x.parquet_type()).collect();
147
148    (quote! {
149    impl #generics ::parquet::record::RecordWriter<#derived_for #generics> for &[#derived_for #generics] {
150      fn write_to_row_group<W: ::std::io::Write + Send>(
151        &self,
152        row_group_writer: &mut ::parquet::file::writer::SerializedRowGroupWriter<'_, W>
153      ) -> ::std::result::Result<(), ::parquet::errors::ParquetError> {
154        use ::parquet::column::writer::ColumnWriter;
155
156        let mut row_group_writer = row_group_writer;
157        let records = &self; // Used by all the writer snippets to be more clear
158
159        #(
160          {
161              let mut some_column_writer = row_group_writer.next_column().unwrap();
162              if let Some(mut column_writer) = some_column_writer {
163                  #writer_snippets
164                  column_writer.close()?;
165              } else {
166                  return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
167              }
168          }
169        );*
170
171        Ok(())
172      }
173
174      fn schema(&self) -> ::std::result::Result<::parquet::schema::types::TypePtr, ::parquet::errors::ParquetError> {
175        use ::parquet::schema::types::Type as ParquetType;
176        use ::parquet::schema::types::TypePtr;
177        use ::parquet::basic::LogicalType;
178
179        let mut fields: ::std::vec::Vec<TypePtr> = ::std::vec::Vec::new();
180        #(
181          #field_types
182        );*;
183        let group = ParquetType::group_type_builder("rust_schema")
184          .with_fields(fields)
185          .build()?;
186        Ok(group.into())
187      }
188    }
189  }).into()
190}
191
192/// Derive flat, simple RecordReader implementations.
193///
194/// Works by parsing a struct tagged with `#[derive(ParquetRecordReader)]` and emitting
195/// the correct writing code for each field of the struct. Column readers
196/// are generated by matching names in the schema to the names in the struct.
197///
198/// It is up to the programmer to ensure the names in the struct
199/// fields line up with the schema.
200///
201/// Example:
202///
203/// ```rust
204/// use parquet::record::RecordReader;
205/// use parquet::file::{serialized_reader::SerializedFileReader, reader::FileReader};
206/// use parquet_derive::{ParquetRecordReader};
207/// use std::fs::File;
208///
209/// #[derive(ParquetRecordReader)]
210/// struct ACompleteRecord {
211///     pub a_bool: bool,
212///     pub a_string: String,
213/// }
214///
215/// pub fn read_some_records() -> Vec<ACompleteRecord> {
216///   let mut samples: Vec<ACompleteRecord> = Vec::new();
217///   let file = File::open("some_file.parquet").unwrap();
218///
219///   let reader = SerializedFileReader::new(file).unwrap();
220///   let mut row_group = reader.get_row_group(0).unwrap();
221///   samples.read_from_row_group(&mut *row_group, 1).unwrap();
222///   samples
223/// }
224/// ```
225///
226#[proc_macro_derive(ParquetRecordReader)]
227pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
228    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
229    let fields = match input.data {
230        Data::Struct(DataStruct { fields, .. }) => fields,
231        Data::Enum(_) => unimplemented!("Enum currently is not supported"),
232        Data::Union(_) => unimplemented!("Union currently is not supported"),
233    };
234
235    let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect();
236    let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
237    let reader_snippets: Vec<proc_macro2::TokenStream> =
238        field_infos.iter().map(|x| x.reader_snippet()).collect();
239
240    let derived_for = input.ident;
241    let generics = input.generics;
242
243    (quote! {
244
245    impl #generics ::parquet::record::RecordReader<#derived_for #generics> for Vec<#derived_for #generics> {
246      fn read_from_row_group(
247        &mut self,
248        row_group_reader: &mut dyn ::parquet::file::reader::RowGroupReader,
249        num_records: usize,
250      ) -> ::std::result::Result<(), ::parquet::errors::ParquetError> {
251        use ::parquet::column::reader::ColumnReader;
252
253        let mut row_group_reader = row_group_reader;
254
255        // key: parquet file column name, value: column index
256        let mut name_to_index = std::collections::HashMap::new();
257        for (idx, col) in row_group_reader.metadata().schema_descr().columns().iter().enumerate() {
258            name_to_index.insert(col.name().to_string(), idx);
259        }
260
261        for _ in 0..num_records {
262          self.push(#derived_for {
263            #(
264              #field_names: Default::default()
265            ),*
266          })
267        }
268
269        let records = self; // Used by all the reader snippets to be more clear
270
271        #(
272          {
273              let idx: usize = match name_to_index.get(stringify!(#field_names)) {
274                Some(&col_idx) => col_idx,
275                None => {
276                  let error_msg = format!("column name '{}' is not found in parquet file!", stringify!(#field_names));
277                  return Err(::parquet::errors::ParquetError::General(error_msg));
278                }
279              };
280              if let Ok(mut column_reader) = row_group_reader.get_column_reader(idx) {
281                  #reader_snippets
282              } else {
283                  return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
284              }
285          }
286        );*
287
288        Ok(())
289      }
290    }
291  }).into()
292}