arrow/util/
data_gen.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities to generate random arrays and batches
19
20use std::sync::Arc;
21
22use rand::{
23    distr::uniform::{SampleRange, SampleUniform},
24    Rng,
25};
26
27use crate::array::*;
28use crate::error::{ArrowError, Result};
29use crate::{
30    buffer::{Buffer, MutableBuffer},
31    datatypes::*,
32};
33
34use super::{bench_util::*, bit_util, test_util::seedable_rng};
35
36/// Create a random [RecordBatch] from a schema
37pub fn create_random_batch(
38    schema: SchemaRef,
39    size: usize,
40    null_density: f32,
41    true_density: f32,
42) -> Result<RecordBatch> {
43    let columns = schema
44        .fields()
45        .iter()
46        .map(|field| create_random_array(field, size, null_density, true_density))
47        .collect::<Result<Vec<ArrayRef>>>()?;
48
49    RecordBatch::try_new_with_options(
50        schema,
51        columns,
52        &RecordBatchOptions::new().with_match_field_names(false),
53    )
54}
55
56/// Create a random [ArrayRef] from a [DataType] with a length,
57/// null density and true density (for [BooleanArray]).
58pub fn create_random_array(
59    field: &Field,
60    size: usize,
61    null_density: f32,
62    true_density: f32,
63) -> Result<ArrayRef> {
64    // Override null density with 0.0 if the array is non-nullable
65    // and a primitive type in case a nested field is nullable
66    let primitive_null_density = match field.is_nullable() {
67        true => null_density,
68        false => 0.0,
69    };
70    use DataType::*;
71    Ok(match field.data_type() {
72        Null => Arc::new(NullArray::new(size)) as ArrayRef,
73        Boolean => Arc::new(create_boolean_array(
74            size,
75            primitive_null_density,
76            true_density,
77        )),
78        Int8 => Arc::new(create_primitive_array::<Int8Type>(
79            size,
80            primitive_null_density,
81        )),
82        Int16 => Arc::new(create_primitive_array::<Int16Type>(
83            size,
84            primitive_null_density,
85        )),
86        Int32 => Arc::new(create_primitive_array::<Int32Type>(
87            size,
88            primitive_null_density,
89        )),
90        Int64 => Arc::new(create_primitive_array::<Int64Type>(
91            size,
92            primitive_null_density,
93        )),
94        UInt8 => Arc::new(create_primitive_array::<UInt8Type>(
95            size,
96            primitive_null_density,
97        )),
98        UInt16 => Arc::new(create_primitive_array::<UInt16Type>(
99            size,
100            primitive_null_density,
101        )),
102        UInt32 => Arc::new(create_primitive_array::<UInt32Type>(
103            size,
104            primitive_null_density,
105        )),
106        UInt64 => Arc::new(create_primitive_array::<UInt64Type>(
107            size,
108            primitive_null_density,
109        )),
110        Float16 => {
111            return Err(ArrowError::NotYetImplemented(
112                "Float16 is not implemented".to_string(),
113            ))
114        }
115        Float32 => Arc::new(create_primitive_array::<Float32Type>(
116            size,
117            primitive_null_density,
118        )),
119        Float64 => Arc::new(create_primitive_array::<Float64Type>(
120            size,
121            primitive_null_density,
122        )),
123        Timestamp(unit, tz) => match unit {
124            TimeUnit::Second => Arc::new(
125                create_random_temporal_array::<TimestampSecondType>(size, primitive_null_density)
126                    .with_timezone_opt(tz.clone()),
127            ),
128            TimeUnit::Millisecond => Arc::new(
129                create_random_temporal_array::<TimestampMillisecondType>(
130                    size,
131                    primitive_null_density,
132                )
133                .with_timezone_opt(tz.clone()),
134            ),
135            TimeUnit::Microsecond => Arc::new(
136                create_random_temporal_array::<TimestampMicrosecondType>(
137                    size,
138                    primitive_null_density,
139                )
140                .with_timezone_opt(tz.clone()),
141            ),
142            TimeUnit::Nanosecond => Arc::new(
143                create_random_temporal_array::<TimestampNanosecondType>(
144                    size,
145                    primitive_null_density,
146                )
147                .with_timezone_opt(tz.clone()),
148            ),
149        },
150        Date32 => Arc::new(create_random_temporal_array::<Date32Type>(
151            size,
152            primitive_null_density,
153        )),
154        Date64 => Arc::new(create_random_temporal_array::<Date64Type>(
155            size,
156            primitive_null_density,
157        )),
158        Time32(unit) => match unit {
159            TimeUnit::Second => Arc::new(create_random_temporal_array::<Time32SecondType>(
160                size,
161                primitive_null_density,
162            )) as ArrayRef,
163            TimeUnit::Millisecond => Arc::new(
164                create_random_temporal_array::<Time32MillisecondType>(size, primitive_null_density),
165            ),
166            _ => {
167                return Err(ArrowError::InvalidArgumentError(format!(
168                    "Unsupported unit {unit:?} for Time32"
169                )))
170            }
171        },
172        Time64(unit) => match unit {
173            TimeUnit::Microsecond => Arc::new(
174                create_random_temporal_array::<Time64MicrosecondType>(size, primitive_null_density),
175            ) as ArrayRef,
176            TimeUnit::Nanosecond => Arc::new(create_random_temporal_array::<Time64NanosecondType>(
177                size,
178                primitive_null_density,
179            )),
180            _ => {
181                return Err(ArrowError::InvalidArgumentError(format!(
182                    "Unsupported unit {unit:?} for Time64"
183                )))
184            }
185        },
186        Utf8 => Arc::new(create_string_array::<i32>(size, primitive_null_density)),
187        LargeUtf8 => Arc::new(create_string_array::<i64>(size, primitive_null_density)),
188        Utf8View => Arc::new(create_string_view_array_with_len(
189            size,
190            primitive_null_density,
191            4,
192            false,
193        )),
194        Binary => Arc::new(create_binary_array::<i32>(size, primitive_null_density)),
195        LargeBinary => Arc::new(create_binary_array::<i64>(size, primitive_null_density)),
196        FixedSizeBinary(len) => Arc::new(create_fsb_array(
197            size,
198            primitive_null_density,
199            *len as usize,
200        )),
201        BinaryView => Arc::new(
202            create_string_view_array_with_len(size, primitive_null_density, 4, false)
203                .to_binary_view(),
204        ),
205        List(_) => create_random_list_array(field, size, null_density, true_density)?,
206        LargeList(_) => create_random_list_array(field, size, null_density, true_density)?,
207        Struct(_) => create_random_struct_array(field, size, null_density, true_density)?,
208        d @ Dictionary(_, value_type) if crate::compute::can_cast_types(value_type, d) => {
209            let f = Field::new(
210                field.name(),
211                value_type.as_ref().clone(),
212                field.is_nullable(),
213            );
214            let v = create_random_array(&f, size, null_density, true_density)?;
215            crate::compute::cast(&v, d)?
216        }
217        Map(_, _) => create_random_map_array(field, size, null_density, true_density)?,
218        other => {
219            return Err(ArrowError::NotYetImplemented(format!(
220                "Generating random arrays not yet implemented for {other:?}"
221            )))
222        }
223    })
224}
225
226#[inline]
227fn create_random_list_array(
228    field: &Field,
229    size: usize,
230    null_density: f32,
231    true_density: f32,
232) -> Result<ArrayRef> {
233    // Override null density with 0.0 if the array is non-nullable
234    let list_null_density = match field.is_nullable() {
235        true => null_density,
236        false => 0.0,
237    };
238    let list_field;
239    let (offsets, child_len) = match field.data_type() {
240        DataType::List(f) => {
241            let (offsets, child_len) = create_random_offsets::<i32>(size, 0, 5);
242            list_field = f;
243            (Buffer::from(offsets.to_byte_slice()), child_len as usize)
244        }
245        DataType::LargeList(f) => {
246            let (offsets, child_len) = create_random_offsets::<i64>(size, 0, 5);
247            list_field = f;
248            (Buffer::from(offsets.to_byte_slice()), child_len as usize)
249        }
250        _ => {
251            return Err(ArrowError::InvalidArgumentError(format!(
252                "Cannot create list array for field {field:?}"
253            )))
254        }
255    };
256
257    // Create list's child data
258    let child_array = create_random_array(list_field, child_len, null_density, true_density)?;
259    let child_data = child_array.to_data();
260    // Create list's null buffers, if it is nullable
261    let null_buffer = match field.is_nullable() {
262        true => Some(create_random_null_buffer(size, list_null_density)),
263        false => None,
264    };
265    let list_data = unsafe {
266        ArrayData::new_unchecked(
267            field.data_type().clone(),
268            size,
269            None,
270            null_buffer,
271            0,
272            vec![offsets],
273            vec![child_data],
274        )
275    };
276    Ok(make_array(list_data))
277}
278
279#[inline]
280fn create_random_struct_array(
281    field: &Field,
282    size: usize,
283    null_density: f32,
284    true_density: f32,
285) -> Result<ArrayRef> {
286    let struct_fields = match field.data_type() {
287        DataType::Struct(fields) => fields,
288        _ => {
289            return Err(ArrowError::InvalidArgumentError(format!(
290                "Cannot create struct array for field {field:?}"
291            )))
292        }
293    };
294
295    let child_arrays = struct_fields
296        .iter()
297        .map(|struct_field| create_random_array(struct_field, size, null_density, true_density))
298        .collect::<Result<Vec<_>>>()?;
299
300    let null_buffer = match field.is_nullable() {
301        true => {
302            let nulls = arrow_buffer::BooleanBuffer::new(
303                create_random_null_buffer(size, null_density),
304                0,
305                size,
306            );
307            Some(nulls.into())
308        }
309        false => None,
310    };
311
312    Ok(Arc::new(StructArray::try_new(
313        struct_fields.clone(),
314        child_arrays,
315        null_buffer,
316    )?))
317}
318
319#[inline]
320fn create_random_map_array(
321    field: &Field,
322    size: usize,
323    null_density: f32,
324    true_density: f32,
325) -> Result<ArrayRef> {
326    // Override null density with 0.0 if the array is non-nullable
327    let map_null_density = match field.is_nullable() {
328        true => null_density,
329        false => 0.0,
330    };
331
332    let entries_field = match field.data_type() {
333        DataType::Map(f, _) => f,
334        _ => {
335            return Err(ArrowError::InvalidArgumentError(format!(
336                "Cannot create map array for field {field:?}"
337            )))
338        }
339    };
340
341    let (offsets, child_len) = create_random_offsets::<i32>(size, 0, 5);
342    let offsets = Buffer::from(offsets.to_byte_slice());
343
344    let entries = create_random_array(
345        entries_field,
346        child_len as usize,
347        null_density,
348        true_density,
349    )?
350    .to_data();
351
352    let null_buffer = match field.is_nullable() {
353        true => Some(create_random_null_buffer(size, map_null_density)),
354        false => None,
355    };
356
357    let map_data = unsafe {
358        ArrayData::new_unchecked(
359            field.data_type().clone(),
360            size,
361            None,
362            null_buffer,
363            0,
364            vec![offsets],
365            vec![entries],
366        )
367    };
368    Ok(make_array(map_data))
369}
370
371/// Generate random offsets for list arrays
372fn create_random_offsets<T: OffsetSizeTrait + SampleUniform>(
373    size: usize,
374    min: T,
375    max: T,
376) -> (Vec<T>, T) {
377    let rng = &mut seedable_rng();
378
379    let mut current_offset = T::zero();
380
381    let mut offsets = Vec::with_capacity(size + 1);
382    offsets.push(current_offset);
383
384    (0..size).for_each(|_| {
385        current_offset += rng.random_range(min..max);
386        offsets.push(current_offset);
387    });
388
389    (offsets, current_offset)
390}
391
392fn create_random_null_buffer(size: usize, null_density: f32) -> Buffer {
393    let mut rng = seedable_rng();
394    let mut mut_buf = MutableBuffer::new_null(size);
395    {
396        let mut_slice = mut_buf.as_slice_mut();
397        (0..size).for_each(|i| {
398            if rng.random::<f32>() >= null_density {
399                bit_util::set_bit(mut_slice, i)
400            }
401        })
402    };
403    mut_buf.into()
404}
405
406/// Useful for testing. The range of values are not likely to be representative of the
407/// actual bounds.
408pub trait RandomTemporalValue: ArrowTemporalType {
409    /// Returns the range of values for `impl`'d type
410    fn value_range() -> impl SampleRange<Self::Native>;
411
412    /// Generate a random value within the range of the type
413    fn gen_range<R: Rng>(rng: &mut R) -> Self::Native
414    where
415        Self::Native: SampleUniform,
416    {
417        rng.random_range(Self::value_range())
418    }
419
420    /// Generate a random value of the type
421    fn random<R: Rng>(rng: &mut R) -> Self::Native
422    where
423        Self::Native: SampleUniform,
424    {
425        Self::gen_range(rng)
426    }
427}
428
429impl RandomTemporalValue for TimestampSecondType {
430    /// Range of values for a timestamp in seconds. The range begins at the start
431    /// of the unix epoch and continues for 100 years.
432    fn value_range() -> impl SampleRange<Self::Native> {
433        0..60 * 60 * 24 * 365 * 100
434    }
435}
436
437impl RandomTemporalValue for TimestampMillisecondType {
438    /// Range of values for a timestamp in milliseconds. The range begins at the start
439    /// of the unix epoch and continues for 100 years.
440    fn value_range() -> impl SampleRange<Self::Native> {
441        0..1_000 * 60 * 60 * 24 * 365 * 100
442    }
443}
444
445impl RandomTemporalValue for TimestampMicrosecondType {
446    /// Range of values for a timestamp in microseconds. The range begins at the start
447    /// of the unix epoch and continues for 100 years.
448    fn value_range() -> impl SampleRange<Self::Native> {
449        0..1_000 * 1_000 * 60 * 60 * 24 * 365 * 100
450    }
451}
452
453impl RandomTemporalValue for TimestampNanosecondType {
454    /// Range of values for a timestamp in nanoseconds. The range begins at the start
455    /// of the unix epoch and continues for 100 years.
456    fn value_range() -> impl SampleRange<Self::Native> {
457        0..1_000 * 1_000 * 1_000 * 60 * 60 * 24 * 365 * 100
458    }
459}
460
461impl RandomTemporalValue for Date32Type {
462    /// Range of values representing the elapsed time since UNIX epoch in days. The
463    /// range begins at the start of the unix epoch and continues for 100 years.
464    fn value_range() -> impl SampleRange<Self::Native> {
465        0..365 * 100
466    }
467}
468
469impl RandomTemporalValue for Date64Type {
470    /// Range of values  representing the elapsed time since UNIX epoch in milliseconds.
471    /// The range begins at the start of the unix epoch and continues for 100 years.
472    fn value_range() -> impl SampleRange<Self::Native> {
473        0..1_000 * 60 * 60 * 24 * 365 * 100
474    }
475}
476
477impl RandomTemporalValue for Time32SecondType {
478    /// Range of values representing the elapsed time since midnight in seconds. The
479    /// range is from 0 to 24 hours.
480    fn value_range() -> impl SampleRange<Self::Native> {
481        0..60 * 60 * 24
482    }
483}
484
485impl RandomTemporalValue for Time32MillisecondType {
486    /// Range of values representing the elapsed time since midnight in milliseconds. The
487    /// range is from 0 to 24 hours.
488    fn value_range() -> impl SampleRange<Self::Native> {
489        0..1_000 * 60 * 60 * 24
490    }
491}
492
493impl RandomTemporalValue for Time64MicrosecondType {
494    /// Range of values representing the elapsed time since midnight in microseconds. The
495    /// range is from 0 to 24 hours.
496    fn value_range() -> impl SampleRange<Self::Native> {
497        0..1_000 * 1_000 * 60 * 60 * 24
498    }
499}
500
501impl RandomTemporalValue for Time64NanosecondType {
502    /// Range of values representing the elapsed time since midnight in nanoseconds. The
503    /// range is from 0 to 24 hours.
504    fn value_range() -> impl SampleRange<Self::Native> {
505        0..1_000 * 1_000 * 1_000 * 60 * 60 * 24
506    }
507}
508
509fn create_random_temporal_array<T>(size: usize, null_density: f32) -> PrimitiveArray<T>
510where
511    T: RandomTemporalValue,
512    <T as ArrowPrimitiveType>::Native: SampleUniform,
513{
514    let mut rng = seedable_rng();
515
516    (0..size)
517        .map(|_| {
518            if rng.random::<f32>() < null_density {
519                None
520            } else {
521                Some(T::random(&mut rng))
522            }
523        })
524        .collect()
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_create_batch() {
533        let size = 32;
534        let fields = vec![
535            Field::new("a", DataType::Int32, true),
536            Field::new(
537                "timestamp_without_timezone",
538                DataType::Timestamp(TimeUnit::Nanosecond, None),
539                true,
540            ),
541            Field::new(
542                "timestamp_with_timezone",
543                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
544                true,
545            ),
546        ];
547        let schema = Schema::new(fields);
548        let schema_ref = Arc::new(schema);
549        let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();
550
551        assert_eq!(batch.schema(), schema_ref);
552        assert_eq!(batch.num_columns(), schema_ref.fields().len());
553        for array in batch.columns() {
554            assert_eq!(array.len(), size);
555        }
556    }
557
558    #[test]
559    fn test_create_batch_non_null() {
560        let size = 32;
561        let fields = vec![
562            Field::new("a", DataType::Int32, false),
563            Field::new(
564                "b",
565                DataType::List(Arc::new(Field::new_list_field(DataType::LargeUtf8, false))),
566                false,
567            ),
568            Field::new("a", DataType::Int32, false),
569        ];
570        let schema = Schema::new(fields);
571        let schema_ref = Arc::new(schema);
572        let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();
573
574        assert_eq!(batch.schema(), schema_ref);
575        assert_eq!(batch.num_columns(), schema_ref.fields().len());
576        for array in batch.columns() {
577            assert_eq!(array.null_count(), 0);
578            assert_eq!(array.logical_null_count(), 0);
579        }
580        // Test that the list's child values are non-null
581        let b_array = batch.column(1);
582        let list_array = b_array.as_list::<i32>();
583        let child_array = list_array.values();
584        assert_eq!(child_array.null_count(), 0);
585        // There should be more values than the list, to show that it's a list
586        assert!(child_array.len() > list_array.len());
587    }
588
589    #[test]
590    fn test_create_struct_array() {
591        let size = 32;
592        let struct_fields = Fields::from(vec![
593            Field::new("b", DataType::Boolean, true),
594            Field::new(
595                "c",
596                DataType::LargeList(Arc::new(Field::new_list_field(
597                    DataType::List(Arc::new(Field::new_list_field(
598                        DataType::FixedSizeBinary(6),
599                        true,
600                    ))),
601                    false,
602                ))),
603                true,
604            ),
605            Field::new(
606                "d",
607                DataType::Struct(Fields::from(vec![
608                    Field::new("d_x", DataType::Int32, true),
609                    Field::new("d_y", DataType::Float32, false),
610                    Field::new("d_z", DataType::Binary, true),
611                ])),
612                true,
613            ),
614        ]);
615        let field = Field::new("struct", DataType::Struct(struct_fields), true);
616        let array = create_random_array(&field, size, 0.2, 0.5).unwrap();
617
618        assert_eq!(array.len(), 32);
619        let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
620        assert_eq!(struct_array.columns().len(), 3);
621
622        // Test that the nested list makes sense,
623        // i.e. its children's values are more than the parent, to show repetition
624        let col_c = struct_array.column_by_name("c").unwrap();
625        let col_c = col_c.as_any().downcast_ref::<LargeListArray>().unwrap();
626        assert_eq!(col_c.len(), size);
627        let col_c_list = col_c.values().as_list::<i32>();
628        assert!(col_c_list.len() > size);
629        // Its values should be FixedSizeBinary(6)
630        let fsb = col_c_list.values();
631        assert_eq!(fsb.data_type(), &DataType::FixedSizeBinary(6));
632        assert!(fsb.len() > col_c_list.len());
633
634        // Test nested struct
635        let col_d = struct_array.column_by_name("d").unwrap();
636        let col_d = col_d.as_any().downcast_ref::<StructArray>().unwrap();
637        let col_d_y = col_d.column_by_name("d_y").unwrap();
638        assert_eq!(col_d_y.data_type(), &DataType::Float32);
639        assert_eq!(col_d_y.null_count(), 0);
640    }
641
642    #[test]
643    fn test_create_list_array_nested_nullability() {
644        let list_field = Field::new_list(
645            "not_null_list",
646            Field::new_list_field(DataType::Boolean, true),
647            false,
648        );
649
650        let list_array = create_random_array(&list_field, 100, 0.95, 0.5).unwrap();
651
652        assert_eq!(list_array.null_count(), 0);
653        assert!(list_array.as_list::<i32>().values().null_count() > 0);
654    }
655
656    #[test]
657    fn test_create_struct_array_nested_nullability() {
658        let struct_child_fields = vec![
659            Field::new("null_int", DataType::Int32, true),
660            Field::new("int", DataType::Int32, false),
661        ];
662        let struct_field = Field::new_struct("not_null_struct", struct_child_fields, false);
663
664        let struct_array = create_random_array(&struct_field, 100, 0.95, 0.5).unwrap();
665
666        assert_eq!(struct_array.null_count(), 0);
667        assert!(
668            struct_array
669                .as_struct()
670                .column_by_name("null_int")
671                .unwrap()
672                .null_count()
673                > 0
674        );
675        assert_eq!(
676            struct_array
677                .as_struct()
678                .column_by_name("int")
679                .unwrap()
680                .null_count(),
681            0
682        );
683    }
684
685    #[test]
686    fn test_create_list_array_nested_struct_nullability() {
687        let struct_child_fields = vec![
688            Field::new("null_int", DataType::Int32, true),
689            Field::new("int", DataType::Int32, false),
690        ];
691        let list_item_field =
692            Field::new_list_field(DataType::Struct(struct_child_fields.into()), true);
693        let list_field = Field::new_list("not_null_list", list_item_field, false);
694
695        let list_array = create_random_array(&list_field, 100, 0.95, 0.5).unwrap();
696
697        assert_eq!(list_array.null_count(), 0);
698        assert!(list_array.as_list::<i32>().values().null_count() > 0);
699        assert!(
700            list_array
701                .as_list::<i32>()
702                .values()
703                .as_struct()
704                .column_by_name("null_int")
705                .unwrap()
706                .null_count()
707                > 0
708        );
709        assert_eq!(
710            list_array
711                .as_list::<i32>()
712                .values()
713                .as_struct()
714                .column_by_name("int")
715                .unwrap()
716                .null_count(),
717            0
718        );
719    }
720
721    #[test]
722    fn test_create_map_array() {
723        let map_field = Field::new_map(
724            "map",
725            "entries",
726            Field::new("key", DataType::Utf8, false),
727            Field::new("value", DataType::Utf8, true),
728            false,
729            false,
730        );
731        let array = create_random_array(&map_field, 100, 0.8, 0.5).unwrap();
732
733        assert_eq!(array.len(), 100);
734        // Map field is not null
735        assert_eq!(array.null_count(), 0);
736        assert_eq!(array.logical_null_count(), 0);
737        // Maps have multiple values like a list, so internal arrays are longer
738        assert!(array.as_map().keys().len() > array.len());
739        assert!(array.as_map().values().len() > array.len());
740        // Keys are not nullable
741        assert_eq!(array.as_map().keys().null_count(), 0);
742        // Values are nullable
743        assert!(array.as_map().values().null_count() > 0);
744
745        assert_eq!(array.as_map().keys().data_type(), &DataType::Utf8);
746        assert_eq!(array.as_map().values().data_type(), &DataType::Utf8);
747    }
748}