arrow/util/
data_gen.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities to generate random arrays and batches
19
20use std::sync::Arc;
21
22use rand::distributions::uniform::SampleRange;
23use rand::{distributions::uniform::SampleUniform, Rng};
24
25use crate::array::*;
26use crate::error::{ArrowError, Result};
27use crate::{
28    buffer::{Buffer, MutableBuffer},
29    datatypes::*,
30};
31
32use super::{bench_util::*, bit_util, test_util::seedable_rng};
33
34/// Create a random [RecordBatch] from a schema
35pub fn create_random_batch(
36    schema: SchemaRef,
37    size: usize,
38    null_density: f32,
39    true_density: f32,
40) -> Result<RecordBatch> {
41    let columns = schema
42        .fields()
43        .iter()
44        .map(|field| create_random_array(field, size, null_density, true_density))
45        .collect::<Result<Vec<ArrayRef>>>()?;
46
47    RecordBatch::try_new_with_options(
48        schema,
49        columns,
50        &RecordBatchOptions::new().with_match_field_names(false),
51    )
52}
53
54/// Create a random [ArrayRef] from a [DataType] with a length,
55/// null density and true density (for [BooleanArray]).
56pub fn create_random_array(
57    field: &Field,
58    size: usize,
59    null_density: f32,
60    true_density: f32,
61) -> Result<ArrayRef> {
62    // Override null density with 0.0 if the array is non-nullable
63    // and a primitive type in case a nested field is nullable
64    let primitive_null_density = match field.is_nullable() {
65        true => null_density,
66        false => 0.0,
67    };
68    use DataType::*;
69    Ok(match field.data_type() {
70        Null => Arc::new(NullArray::new(size)) as ArrayRef,
71        Boolean => Arc::new(create_boolean_array(
72            size,
73            primitive_null_density,
74            true_density,
75        )),
76        Int8 => Arc::new(create_primitive_array::<Int8Type>(
77            size,
78            primitive_null_density,
79        )),
80        Int16 => Arc::new(create_primitive_array::<Int16Type>(
81            size,
82            primitive_null_density,
83        )),
84        Int32 => Arc::new(create_primitive_array::<Int32Type>(
85            size,
86            primitive_null_density,
87        )),
88        Int64 => Arc::new(create_primitive_array::<Int64Type>(
89            size,
90            primitive_null_density,
91        )),
92        UInt8 => Arc::new(create_primitive_array::<UInt8Type>(
93            size,
94            primitive_null_density,
95        )),
96        UInt16 => Arc::new(create_primitive_array::<UInt16Type>(
97            size,
98            primitive_null_density,
99        )),
100        UInt32 => Arc::new(create_primitive_array::<UInt32Type>(
101            size,
102            primitive_null_density,
103        )),
104        UInt64 => Arc::new(create_primitive_array::<UInt64Type>(
105            size,
106            primitive_null_density,
107        )),
108        Float16 => {
109            return Err(ArrowError::NotYetImplemented(
110                "Float16 is not implemented".to_string(),
111            ))
112        }
113        Float32 => Arc::new(create_primitive_array::<Float32Type>(
114            size,
115            primitive_null_density,
116        )),
117        Float64 => Arc::new(create_primitive_array::<Float64Type>(
118            size,
119            primitive_null_density,
120        )),
121        Timestamp(unit, tz) => match unit {
122            TimeUnit::Second => Arc::new(
123                create_random_temporal_array::<TimestampSecondType>(size, primitive_null_density)
124                    .with_timezone_opt(tz.clone()),
125            ),
126            TimeUnit::Millisecond => Arc::new(
127                create_random_temporal_array::<TimestampMillisecondType>(
128                    size,
129                    primitive_null_density,
130                )
131                .with_timezone_opt(tz.clone()),
132            ),
133            TimeUnit::Microsecond => Arc::new(
134                create_random_temporal_array::<TimestampMicrosecondType>(
135                    size,
136                    primitive_null_density,
137                )
138                .with_timezone_opt(tz.clone()),
139            ),
140            TimeUnit::Nanosecond => Arc::new(
141                create_random_temporal_array::<TimestampNanosecondType>(
142                    size,
143                    primitive_null_density,
144                )
145                .with_timezone_opt(tz.clone()),
146            ),
147        },
148        Date32 => Arc::new(create_random_temporal_array::<Date32Type>(
149            size,
150            primitive_null_density,
151        )),
152        Date64 => Arc::new(create_random_temporal_array::<Date64Type>(
153            size,
154            primitive_null_density,
155        )),
156        Time32(unit) => match unit {
157            TimeUnit::Second => Arc::new(create_random_temporal_array::<Time32SecondType>(
158                size,
159                primitive_null_density,
160            )) as ArrayRef,
161            TimeUnit::Millisecond => Arc::new(
162                create_random_temporal_array::<Time32MillisecondType>(size, primitive_null_density),
163            ),
164            _ => {
165                return Err(ArrowError::InvalidArgumentError(format!(
166                    "Unsupported unit {unit:?} for Time32"
167                )))
168            }
169        },
170        Time64(unit) => match unit {
171            TimeUnit::Microsecond => Arc::new(
172                create_random_temporal_array::<Time64MicrosecondType>(size, primitive_null_density),
173            ) as ArrayRef,
174            TimeUnit::Nanosecond => Arc::new(create_random_temporal_array::<Time64NanosecondType>(
175                size,
176                primitive_null_density,
177            )),
178            _ => {
179                return Err(ArrowError::InvalidArgumentError(format!(
180                    "Unsupported unit {unit:?} for Time64"
181                )))
182            }
183        },
184        Utf8 => Arc::new(create_string_array::<i32>(size, primitive_null_density)),
185        LargeUtf8 => Arc::new(create_string_array::<i64>(size, primitive_null_density)),
186        Utf8View => Arc::new(create_string_view_array_with_len(
187            size,
188            primitive_null_density,
189            4,
190            false,
191        )),
192        Binary => Arc::new(create_binary_array::<i32>(size, primitive_null_density)),
193        LargeBinary => Arc::new(create_binary_array::<i64>(size, primitive_null_density)),
194        FixedSizeBinary(len) => Arc::new(create_fsb_array(
195            size,
196            primitive_null_density,
197            *len as usize,
198        )),
199        BinaryView => Arc::new(
200            create_string_view_array_with_len(size, primitive_null_density, 4, false)
201                .to_binary_view(),
202        ),
203        List(_) => create_random_list_array(field, size, null_density, true_density)?,
204        LargeList(_) => create_random_list_array(field, size, null_density, true_density)?,
205        Struct(_) => create_random_struct_array(field, size, null_density, true_density)?,
206        d @ Dictionary(_, value_type) if crate::compute::can_cast_types(value_type, d) => {
207            let f = Field::new(
208                field.name(),
209                value_type.as_ref().clone(),
210                field.is_nullable(),
211            );
212            let v = create_random_array(&f, size, null_density, true_density)?;
213            crate::compute::cast(&v, d)?
214        }
215        Map(_, _) => create_random_map_array(field, size, null_density, true_density)?,
216        other => {
217            return Err(ArrowError::NotYetImplemented(format!(
218                "Generating random arrays not yet implemented for {other:?}"
219            )))
220        }
221    })
222}
223
224#[inline]
225fn create_random_list_array(
226    field: &Field,
227    size: usize,
228    null_density: f32,
229    true_density: f32,
230) -> Result<ArrayRef> {
231    // Override null density with 0.0 if the array is non-nullable
232    let list_null_density = match field.is_nullable() {
233        true => null_density,
234        false => 0.0,
235    };
236    let list_field;
237    let (offsets, child_len) = match field.data_type() {
238        DataType::List(f) => {
239            let (offsets, child_len) = create_random_offsets::<i32>(size, 0, 5);
240            list_field = f;
241            (Buffer::from(offsets.to_byte_slice()), child_len as usize)
242        }
243        DataType::LargeList(f) => {
244            let (offsets, child_len) = create_random_offsets::<i64>(size, 0, 5);
245            list_field = f;
246            (Buffer::from(offsets.to_byte_slice()), child_len as usize)
247        }
248        _ => {
249            return Err(ArrowError::InvalidArgumentError(format!(
250                "Cannot create list array for field {field:?}"
251            )))
252        }
253    };
254
255    // Create list's child data
256    let child_array = create_random_array(list_field, child_len, null_density, true_density)?;
257    let child_data = child_array.to_data();
258    // Create list's null buffers, if it is nullable
259    let null_buffer = match field.is_nullable() {
260        true => Some(create_random_null_buffer(size, list_null_density)),
261        false => None,
262    };
263    let list_data = unsafe {
264        ArrayData::new_unchecked(
265            field.data_type().clone(),
266            size,
267            None,
268            null_buffer,
269            0,
270            vec![offsets],
271            vec![child_data],
272        )
273    };
274    Ok(make_array(list_data))
275}
276
277#[inline]
278fn create_random_struct_array(
279    field: &Field,
280    size: usize,
281    null_density: f32,
282    true_density: f32,
283) -> Result<ArrayRef> {
284    let struct_fields = match field.data_type() {
285        DataType::Struct(fields) => fields,
286        _ => {
287            return Err(ArrowError::InvalidArgumentError(format!(
288                "Cannot create struct array for field {field:?}"
289            )))
290        }
291    };
292
293    let child_arrays = struct_fields
294        .iter()
295        .map(|struct_field| create_random_array(struct_field, size, null_density, true_density))
296        .collect::<Result<Vec<_>>>()?;
297
298    let null_buffer = match field.is_nullable() {
299        true => {
300            let nulls = arrow_buffer::BooleanBuffer::new(
301                create_random_null_buffer(size, null_density),
302                0,
303                size,
304            );
305            Some(nulls.into())
306        }
307        false => None,
308    };
309
310    Ok(Arc::new(StructArray::try_new(
311        struct_fields.clone(),
312        child_arrays,
313        null_buffer,
314    )?))
315}
316
317#[inline]
318fn create_random_map_array(
319    field: &Field,
320    size: usize,
321    null_density: f32,
322    true_density: f32,
323) -> Result<ArrayRef> {
324    // Override null density with 0.0 if the array is non-nullable
325    let map_null_density = match field.is_nullable() {
326        true => null_density,
327        false => 0.0,
328    };
329
330    let entries_field = match field.data_type() {
331        DataType::Map(f, _) => f,
332        _ => {
333            return Err(ArrowError::InvalidArgumentError(format!(
334                "Cannot create map array for field {field:?}"
335            )))
336        }
337    };
338
339    let (offsets, child_len) = create_random_offsets::<i32>(size, 0, 5);
340    let offsets = Buffer::from(offsets.to_byte_slice());
341
342    let entries = create_random_array(
343        entries_field,
344        child_len as usize,
345        null_density,
346        true_density,
347    )?
348    .to_data();
349
350    let null_buffer = match field.is_nullable() {
351        true => Some(create_random_null_buffer(size, map_null_density)),
352        false => None,
353    };
354
355    let map_data = unsafe {
356        ArrayData::new_unchecked(
357            field.data_type().clone(),
358            size,
359            None,
360            null_buffer,
361            0,
362            vec![offsets],
363            vec![entries],
364        )
365    };
366    Ok(make_array(map_data))
367}
368
369/// Generate random offsets for list arrays
370fn create_random_offsets<T: OffsetSizeTrait + SampleUniform>(
371    size: usize,
372    min: T,
373    max: T,
374) -> (Vec<T>, T) {
375    let rng = &mut seedable_rng();
376
377    let mut current_offset = T::zero();
378
379    let mut offsets = Vec::with_capacity(size + 1);
380    offsets.push(current_offset);
381
382    (0..size).for_each(|_| {
383        current_offset += rng.gen_range(min..max);
384        offsets.push(current_offset);
385    });
386
387    (offsets, current_offset)
388}
389
390fn create_random_null_buffer(size: usize, null_density: f32) -> Buffer {
391    let mut rng = seedable_rng();
392    let mut mut_buf = MutableBuffer::new_null(size);
393    {
394        let mut_slice = mut_buf.as_slice_mut();
395        (0..size).for_each(|i| {
396            if rng.gen::<f32>() >= null_density {
397                bit_util::set_bit(mut_slice, i)
398            }
399        })
400    };
401    mut_buf.into()
402}
403
404/// Useful for testing. The range of values are not likely to be representative of the
405/// actual bounds.
406pub trait RandomTemporalValue: ArrowTemporalType {
407    /// Returns the range of values for `impl`'d type
408    fn value_range() -> impl SampleRange<Self::Native>;
409
410    /// Generate a random value within the range of the type
411    fn gen_range<R: Rng>(rng: &mut R) -> Self::Native
412    where
413        Self::Native: SampleUniform,
414    {
415        rng.gen_range(Self::value_range())
416    }
417
418    /// Generate a random value of the type
419    fn random<R: Rng>(rng: &mut R) -> Self::Native
420    where
421        Self::Native: SampleUniform,
422    {
423        Self::gen_range(rng)
424    }
425}
426
427impl RandomTemporalValue for TimestampSecondType {
428    /// Range of values for a timestamp in seconds. The range begins at the start
429    /// of the unix epoch and continues for 100 years.
430    fn value_range() -> impl SampleRange<Self::Native> {
431        0..60 * 60 * 24 * 365 * 100
432    }
433}
434
435impl RandomTemporalValue for TimestampMillisecondType {
436    /// Range of values for a timestamp in milliseconds. The range begins at the start
437    /// of the unix epoch and continues for 100 years.
438    fn value_range() -> impl SampleRange<Self::Native> {
439        0..1_000 * 60 * 60 * 24 * 365 * 100
440    }
441}
442
443impl RandomTemporalValue for TimestampMicrosecondType {
444    /// Range of values for a timestamp in microseconds. The range begins at the start
445    /// of the unix epoch and continues for 100 years.
446    fn value_range() -> impl SampleRange<Self::Native> {
447        0..1_000 * 1_000 * 60 * 60 * 24 * 365 * 100
448    }
449}
450
451impl RandomTemporalValue for TimestampNanosecondType {
452    /// Range of values for a timestamp in nanoseconds. The range begins at the start
453    /// of the unix epoch and continues for 100 years.
454    fn value_range() -> impl SampleRange<Self::Native> {
455        0..1_000 * 1_000 * 1_000 * 60 * 60 * 24 * 365 * 100
456    }
457}
458
459impl RandomTemporalValue for Date32Type {
460    /// Range of values representing the elapsed time since UNIX epoch in days. The
461    /// range begins at the start of the unix epoch and continues for 100 years.
462    fn value_range() -> impl SampleRange<Self::Native> {
463        0..365 * 100
464    }
465}
466
467impl RandomTemporalValue for Date64Type {
468    /// Range of values  representing the elapsed time since UNIX epoch in milliseconds.
469    /// The range begins at the start of the unix epoch and continues for 100 years.
470    fn value_range() -> impl SampleRange<Self::Native> {
471        0..1_000 * 60 * 60 * 24 * 365 * 100
472    }
473}
474
475impl RandomTemporalValue for Time32SecondType {
476    /// Range of values representing the elapsed time since midnight in seconds. The
477    /// range is from 0 to 24 hours.
478    fn value_range() -> impl SampleRange<Self::Native> {
479        0..60 * 60 * 24
480    }
481}
482
483impl RandomTemporalValue for Time32MillisecondType {
484    /// Range of values representing the elapsed time since midnight in milliseconds. The
485    /// range is from 0 to 24 hours.
486    fn value_range() -> impl SampleRange<Self::Native> {
487        0..1_000 * 60 * 60 * 24
488    }
489}
490
491impl RandomTemporalValue for Time64MicrosecondType {
492    /// Range of values representing the elapsed time since midnight in microseconds. The
493    /// range is from 0 to 24 hours.
494    fn value_range() -> impl SampleRange<Self::Native> {
495        0..1_000 * 1_000 * 60 * 60 * 24
496    }
497}
498
499impl RandomTemporalValue for Time64NanosecondType {
500    /// Range of values representing the elapsed time since midnight in nanoseconds. The
501    /// range is from 0 to 24 hours.
502    fn value_range() -> impl SampleRange<Self::Native> {
503        0..1_000 * 1_000 * 1_000 * 60 * 60 * 24
504    }
505}
506
507fn create_random_temporal_array<T>(size: usize, null_density: f32) -> PrimitiveArray<T>
508where
509    T: RandomTemporalValue,
510    <T as ArrowPrimitiveType>::Native: SampleUniform,
511{
512    let mut rng = seedable_rng();
513
514    (0..size)
515        .map(|_| {
516            if rng.gen::<f32>() < null_density {
517                None
518            } else {
519                Some(T::random(&mut rng))
520            }
521        })
522        .collect()
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn test_create_batch() {
531        let size = 32;
532        let fields = vec![
533            Field::new("a", DataType::Int32, true),
534            Field::new(
535                "timestamp_without_timezone",
536                DataType::Timestamp(TimeUnit::Nanosecond, None),
537                true,
538            ),
539            Field::new(
540                "timestamp_with_timezone",
541                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
542                true,
543            ),
544        ];
545        let schema = Schema::new(fields);
546        let schema_ref = Arc::new(schema);
547        let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();
548
549        assert_eq!(batch.schema(), schema_ref);
550        assert_eq!(batch.num_columns(), schema_ref.fields().len());
551        for array in batch.columns() {
552            assert_eq!(array.len(), size);
553        }
554    }
555
556    #[test]
557    fn test_create_batch_non_null() {
558        let size = 32;
559        let fields = vec![
560            Field::new("a", DataType::Int32, false),
561            Field::new(
562                "b",
563                DataType::List(Arc::new(Field::new_list_field(DataType::LargeUtf8, false))),
564                false,
565            ),
566            Field::new("a", DataType::Int32, false),
567        ];
568        let schema = Schema::new(fields);
569        let schema_ref = Arc::new(schema);
570        let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();
571
572        assert_eq!(batch.schema(), schema_ref);
573        assert_eq!(batch.num_columns(), schema_ref.fields().len());
574        for array in batch.columns() {
575            assert_eq!(array.null_count(), 0);
576            assert_eq!(array.logical_null_count(), 0);
577        }
578        // Test that the list's child values are non-null
579        let b_array = batch.column(1);
580        let list_array = b_array.as_list::<i32>();
581        let child_array = list_array.values();
582        assert_eq!(child_array.null_count(), 0);
583        // There should be more values than the list, to show that it's a list
584        assert!(child_array.len() > list_array.len());
585    }
586
587    #[test]
588    fn test_create_struct_array() {
589        let size = 32;
590        let struct_fields = Fields::from(vec![
591            Field::new("b", DataType::Boolean, true),
592            Field::new(
593                "c",
594                DataType::LargeList(Arc::new(Field::new_list_field(
595                    DataType::List(Arc::new(Field::new_list_field(
596                        DataType::FixedSizeBinary(6),
597                        true,
598                    ))),
599                    false,
600                ))),
601                true,
602            ),
603            Field::new(
604                "d",
605                DataType::Struct(Fields::from(vec![
606                    Field::new("d_x", DataType::Int32, true),
607                    Field::new("d_y", DataType::Float32, false),
608                    Field::new("d_z", DataType::Binary, true),
609                ])),
610                true,
611            ),
612        ]);
613        let field = Field::new("struct", DataType::Struct(struct_fields), true);
614        let array = create_random_array(&field, size, 0.2, 0.5).unwrap();
615
616        assert_eq!(array.len(), 32);
617        let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
618        assert_eq!(struct_array.columns().len(), 3);
619
620        // Test that the nested list makes sense,
621        // i.e. its children's values are more than the parent, to show repetition
622        let col_c = struct_array.column_by_name("c").unwrap();
623        let col_c = col_c.as_any().downcast_ref::<LargeListArray>().unwrap();
624        assert_eq!(col_c.len(), size);
625        let col_c_list = col_c.values().as_list::<i32>();
626        assert!(col_c_list.len() > size);
627        // Its values should be FixedSizeBinary(6)
628        let fsb = col_c_list.values();
629        assert_eq!(fsb.data_type(), &DataType::FixedSizeBinary(6));
630        assert!(fsb.len() > col_c_list.len());
631
632        // Test nested struct
633        let col_d = struct_array.column_by_name("d").unwrap();
634        let col_d = col_d.as_any().downcast_ref::<StructArray>().unwrap();
635        let col_d_y = col_d.column_by_name("d_y").unwrap();
636        assert_eq!(col_d_y.data_type(), &DataType::Float32);
637        assert_eq!(col_d_y.null_count(), 0);
638    }
639
640    #[test]
641    fn test_create_list_array_nested_nullability() {
642        let list_field = Field::new_list(
643            "not_null_list",
644            Field::new_list_field(DataType::Boolean, true),
645            false,
646        );
647
648        let list_array = create_random_array(&list_field, 100, 0.95, 0.5).unwrap();
649
650        assert_eq!(list_array.null_count(), 0);
651        assert!(list_array.as_list::<i32>().values().null_count() > 0);
652    }
653
654    #[test]
655    fn test_create_struct_array_nested_nullability() {
656        let struct_child_fields = vec![
657            Field::new("null_int", DataType::Int32, true),
658            Field::new("int", DataType::Int32, false),
659        ];
660        let struct_field = Field::new_struct("not_null_struct", struct_child_fields, false);
661
662        let struct_array = create_random_array(&struct_field, 100, 0.95, 0.5).unwrap();
663
664        assert_eq!(struct_array.null_count(), 0);
665        assert!(
666            struct_array
667                .as_struct()
668                .column_by_name("null_int")
669                .unwrap()
670                .null_count()
671                > 0
672        );
673        assert_eq!(
674            struct_array
675                .as_struct()
676                .column_by_name("int")
677                .unwrap()
678                .null_count(),
679            0
680        );
681    }
682
683    #[test]
684    fn test_create_list_array_nested_struct_nullability() {
685        let struct_child_fields = vec![
686            Field::new("null_int", DataType::Int32, true),
687            Field::new("int", DataType::Int32, false),
688        ];
689        let list_item_field =
690            Field::new_list_field(DataType::Struct(struct_child_fields.into()), true);
691        let list_field = Field::new_list("not_null_list", list_item_field, false);
692
693        let list_array = create_random_array(&list_field, 100, 0.95, 0.5).unwrap();
694
695        assert_eq!(list_array.null_count(), 0);
696        assert!(list_array.as_list::<i32>().values().null_count() > 0);
697        assert!(
698            list_array
699                .as_list::<i32>()
700                .values()
701                .as_struct()
702                .column_by_name("null_int")
703                .unwrap()
704                .null_count()
705                > 0
706        );
707        assert_eq!(
708            list_array
709                .as_list::<i32>()
710                .values()
711                .as_struct()
712                .column_by_name("int")
713                .unwrap()
714                .null_count(),
715            0
716        );
717    }
718
719    #[test]
720    fn test_create_map_array() {
721        let map_field = Field::new_map(
722            "map",
723            "entries",
724            Field::new("key", DataType::Utf8, false),
725            Field::new("value", DataType::Utf8, true),
726            false,
727            false,
728        );
729        let array = create_random_array(&map_field, 100, 0.8, 0.5).unwrap();
730
731        assert_eq!(array.len(), 100);
732        // Map field is not null
733        assert_eq!(array.null_count(), 0);
734        assert_eq!(array.logical_null_count(), 0);
735        // Maps have multiple values like a list, so internal arrays are longer
736        assert!(array.as_map().keys().len() > array.len());
737        assert!(array.as_map().values().len() > array.len());
738        // Keys are not nullable
739        assert_eq!(array.as_map().keys().null_count(), 0);
740        // Values are nullable
741        assert!(array.as_map().values().null_count() > 0);
742
743        assert_eq!(array.as_map().keys().data_type(), &DataType::Utf8);
744        assert_eq!(array.as_map().values().data_type(), &DataType::Utf8);
745    }
746}