parquet_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This crate provides a procedural macro to derive
19//! implementations of a RecordWriter and RecordReader
20
21#![doc(
22    html_logo_url = "https://raw.githubusercontent.com/apache/parquet-format/25f05e73d8cd7f5c83532ce51cb4f4de8ba5f2a2/logo/parquet-logos_1.svg",
23    html_favicon_url = "https://raw.githubusercontent.com/apache/parquet-format/25f05e73d8cd7f5c83532ce51cb4f4de8ba5f2a2/logo/parquet-logos_1.svg"
24)]
25#![cfg_attr(docsrs, feature(doc_auto_cfg))]
26#![warn(missing_docs)]
27#![recursion_limit = "128"]
28
29extern crate proc_macro;
30extern crate proc_macro2;
31extern crate syn;
32#[macro_use]
33extern crate quote;
34
35extern crate parquet;
36
37use ::syn::{parse_macro_input, Data, DataStruct, DeriveInput};
38
39mod parquet_field;
40
41/// Derive flat, simple RecordWriter implementations.
42///
43/// Works by parsing a struct tagged with `#[derive(ParquetRecordWriter)]` and emitting
44/// the correct writing code for each field of the struct. Column writers
45/// are generated in the order they are defined.
46///
47/// It is up to the programmer to keep the order of the struct
48/// fields lined up with the schema.
49///
50/// Example:
51///
52/// ```no_run
53/// use parquet_derive::ParquetRecordWriter;
54/// use std::io::{self, Write};
55/// use parquet::file::properties::WriterProperties;
56/// use parquet::file::writer::SerializedFileWriter;
57/// use parquet::record::RecordWriter;
58/// use std::fs::File;
59///
60/// use std::sync::Arc;
61///
62/// #[derive(ParquetRecordWriter)]
63/// struct ACompleteRecord<'a> {
64///   pub a_bool: bool,
65///   pub a_str: &'a str,
66/// }
67///
68/// pub fn write_some_records() {
69///   let samples = vec![
70///     ACompleteRecord {
71///       a_bool: true,
72///       a_str: "I'm true"
73///     },
74///     ACompleteRecord {
75///       a_bool: false,
76///       a_str: "I'm false"
77///     }
78///   ];
79///  let file = File::open("some_file.parquet").unwrap();
80///
81///  let schema = samples.as_slice().schema().unwrap();
82///
83///  let mut writer = SerializedFileWriter::new(file, schema, Default::default()).unwrap();
84///
85///  let mut row_group = writer.next_row_group().unwrap();
86///  samples.as_slice().write_to_row_group(&mut row_group).unwrap();
87///  row_group.close().unwrap();
88///  writer.close().unwrap();
89/// }
90/// ```
91///
92#[proc_macro_derive(ParquetRecordWriter)]
93pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
94    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
95    let fields = match input.data {
96        Data::Struct(DataStruct { fields, .. }) => fields,
97        Data::Enum(_) => unimplemented!("Enum currently is not supported"),
98        Data::Union(_) => unimplemented!("Union currently is not supported"),
99    };
100
101    let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect();
102
103    let writer_snippets: Vec<proc_macro2::TokenStream> =
104        field_infos.iter().map(|x| x.writer_snippet()).collect();
105
106    let derived_for = input.ident;
107    let generics = input.generics;
108
109    let field_types: Vec<proc_macro2::TokenStream> =
110        field_infos.iter().map(|x| x.parquet_type()).collect();
111
112    (quote! {
113    impl #generics ::parquet::record::RecordWriter<#derived_for #generics> for &[#derived_for #generics] {
114      fn write_to_row_group<W: ::std::io::Write + Send>(
115        &self,
116        row_group_writer: &mut ::parquet::file::writer::SerializedRowGroupWriter<'_, W>
117      ) -> Result<(), ::parquet::errors::ParquetError> {
118        use ::parquet::column::writer::ColumnWriter;
119
120        let mut row_group_writer = row_group_writer;
121        let records = &self; // Used by all the writer snippets to be more clear
122
123        #(
124          {
125              let mut some_column_writer = row_group_writer.next_column().unwrap();
126              if let Some(mut column_writer) = some_column_writer {
127                  #writer_snippets
128                  column_writer.close()?;
129              } else {
130                  return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
131              }
132          }
133        );*
134
135        Ok(())
136      }
137
138      fn schema(&self) -> Result<::parquet::schema::types::TypePtr, ::parquet::errors::ParquetError> {
139        use ::parquet::schema::types::Type as ParquetType;
140        use ::parquet::schema::types::TypePtr;
141        use ::parquet::basic::LogicalType;
142
143        let mut fields: ::std::vec::Vec<TypePtr> = ::std::vec::Vec::new();
144        #(
145          #field_types
146        );*;
147        let group = ParquetType::group_type_builder("rust_schema")
148          .with_fields(fields)
149          .build()?;
150        Ok(group.into())
151      }
152    }
153  }).into()
154}
155
156/// Derive flat, simple RecordReader implementations.
157///
158/// Works by parsing a struct tagged with `#[derive(ParquetRecordReader)]` and emitting
159/// the correct writing code for each field of the struct. Column readers
160/// are generated by matching names in the schema to the names in the struct.
161///
162/// It is up to the programmer to ensure the names in the struct
163/// fields line up with the schema.
164///
165/// Example:
166///
167/// ```no_run
168/// use parquet::record::RecordReader;
169/// use parquet::file::{serialized_reader::SerializedFileReader, reader::FileReader};
170/// use parquet_derive::{ParquetRecordReader};
171/// use std::fs::File;
172///
173/// #[derive(ParquetRecordReader)]
174/// struct ACompleteRecord {
175///     pub a_bool: bool,
176///     pub a_string: String,
177/// }
178///
179/// pub fn read_some_records() -> Vec<ACompleteRecord> {
180///   let mut samples: Vec<ACompleteRecord> = Vec::new();
181///   let file = File::open("some_file.parquet").unwrap();
182///
183///   let reader = SerializedFileReader::new(file).unwrap();
184///   let mut row_group = reader.get_row_group(0).unwrap();
185///   samples.read_from_row_group(&mut *row_group, 1).unwrap();
186///   samples
187/// }
188/// ```
189///
190#[proc_macro_derive(ParquetRecordReader)]
191pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
192    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
193    let fields = match input.data {
194        Data::Struct(DataStruct { fields, .. }) => fields,
195        Data::Enum(_) => unimplemented!("Enum currently is not supported"),
196        Data::Union(_) => unimplemented!("Union currently is not supported"),
197    };
198
199    let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect();
200    let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
201    let reader_snippets: Vec<proc_macro2::TokenStream> =
202        field_infos.iter().map(|x| x.reader_snippet()).collect();
203
204    let derived_for = input.ident;
205    let generics = input.generics;
206
207    (quote! {
208
209    impl #generics ::parquet::record::RecordReader<#derived_for #generics> for Vec<#derived_for #generics> {
210      fn read_from_row_group(
211        &mut self,
212        row_group_reader: &mut dyn ::parquet::file::reader::RowGroupReader,
213        num_records: usize,
214      ) -> Result<(), ::parquet::errors::ParquetError> {
215        use ::parquet::column::reader::ColumnReader;
216
217        let mut row_group_reader = row_group_reader;
218
219        // key: parquet file column name, value: column index
220        let mut name_to_index = std::collections::HashMap::new();
221        for (idx, col) in row_group_reader.metadata().schema_descr().columns().iter().enumerate() {
222            name_to_index.insert(col.name().to_string(), idx);
223        }
224
225        for _ in 0..num_records {
226          self.push(#derived_for {
227            #(
228              #field_names: Default::default()
229            ),*
230          })
231        }
232
233        let records = self; // Used by all the reader snippets to be more clear
234
235        #(
236          {
237              let idx: usize = match name_to_index.get(stringify!(#field_names)) {
238                Some(&col_idx) => col_idx,
239                None => {
240                  let error_msg = format!("column name '{}' is not found in parquet file!", stringify!(#field_names));
241                  return Err(::parquet::errors::ParquetError::General(error_msg));
242                }
243              };
244              if let Ok(mut column_reader) = row_group_reader.get_column_reader(idx) {
245                  #reader_snippets
246              } else {
247                  return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
248              }
249          }
250        );*
251
252        Ok(())
253      }
254    }
255  }).into()
256}