arrow/util/
data_gen.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities to generate random arrays and batches
19
20use std::sync::Arc;
21
22use rand::distributions::uniform::SampleRange;
23use rand::{distributions::uniform::SampleUniform, Rng};
24
25use crate::array::*;
26use crate::error::{ArrowError, Result};
27use crate::{
28    buffer::{Buffer, MutableBuffer},
29    datatypes::*,
30};
31
32use super::{bench_util::*, bit_util, test_util::seedable_rng};
33
34/// Create a random [RecordBatch] from a schema
35pub fn create_random_batch(
36    schema: SchemaRef,
37    size: usize,
38    null_density: f32,
39    true_density: f32,
40) -> Result<RecordBatch> {
41    let columns = schema
42        .fields()
43        .iter()
44        .map(|field| create_random_array(field, size, null_density, true_density))
45        .collect::<Result<Vec<ArrayRef>>>()?;
46
47    RecordBatch::try_new_with_options(
48        schema,
49        columns,
50        &RecordBatchOptions::new().with_match_field_names(false),
51    )
52}
53
54/// Create a random [ArrayRef] from a [DataType] with a length,
55/// null density and true density (for [BooleanArray]).
56pub fn create_random_array(
57    field: &Field,
58    size: usize,
59    null_density: f32,
60    true_density: f32,
61) -> Result<ArrayRef> {
62    // Override null density with 0.0 if the array is non-nullable
63    // and a primitive type in case a nested field is nullable
64    let primitive_null_density = match field.is_nullable() {
65        true => null_density,
66        false => 0.0,
67    };
68    use DataType::*;
69    Ok(match field.data_type() {
70        Null => Arc::new(NullArray::new(size)) as ArrayRef,
71        Boolean => Arc::new(create_boolean_array(
72            size,
73            primitive_null_density,
74            true_density,
75        )),
76        Int8 => Arc::new(create_primitive_array::<Int8Type>(
77            size,
78            primitive_null_density,
79        )),
80        Int16 => Arc::new(create_primitive_array::<Int16Type>(
81            size,
82            primitive_null_density,
83        )),
84        Int32 => Arc::new(create_primitive_array::<Int32Type>(
85            size,
86            primitive_null_density,
87        )),
88        Int64 => Arc::new(create_primitive_array::<Int64Type>(
89            size,
90            primitive_null_density,
91        )),
92        UInt8 => Arc::new(create_primitive_array::<UInt8Type>(
93            size,
94            primitive_null_density,
95        )),
96        UInt16 => Arc::new(create_primitive_array::<UInt16Type>(
97            size,
98            primitive_null_density,
99        )),
100        UInt32 => Arc::new(create_primitive_array::<UInt32Type>(
101            size,
102            primitive_null_density,
103        )),
104        UInt64 => Arc::new(create_primitive_array::<UInt64Type>(
105            size,
106            primitive_null_density,
107        )),
108        Float16 => {
109            return Err(ArrowError::NotYetImplemented(
110                "Float16 is not implemented".to_string(),
111            ))
112        }
113        Float32 => Arc::new(create_primitive_array::<Float32Type>(
114            size,
115            primitive_null_density,
116        )),
117        Float64 => Arc::new(create_primitive_array::<Float64Type>(
118            size,
119            primitive_null_density,
120        )),
121        Timestamp(unit, _) => {
122            match unit {
123                TimeUnit::Second => Arc::new(create_random_temporal_array::<TimestampSecondType>(
124                    size,
125                    primitive_null_density,
126                )),
127                TimeUnit::Millisecond => Arc::new(create_random_temporal_array::<
128                    TimestampMillisecondType,
129                >(size, primitive_null_density)),
130                TimeUnit::Microsecond => Arc::new(create_random_temporal_array::<
131                    TimestampMicrosecondType,
132                >(size, primitive_null_density)),
133                TimeUnit::Nanosecond => Arc::new(create_random_temporal_array::<
134                    TimestampNanosecondType,
135                >(size, primitive_null_density)),
136            }
137        }
138        Date32 => Arc::new(create_random_temporal_array::<Date32Type>(
139            size,
140            primitive_null_density,
141        )),
142        Date64 => Arc::new(create_random_temporal_array::<Date64Type>(
143            size,
144            primitive_null_density,
145        )),
146        Time32(unit) => match unit {
147            TimeUnit::Second => Arc::new(create_random_temporal_array::<Time32SecondType>(
148                size,
149                primitive_null_density,
150            )) as ArrayRef,
151            TimeUnit::Millisecond => Arc::new(
152                create_random_temporal_array::<Time32MillisecondType>(size, primitive_null_density),
153            ),
154            _ => {
155                return Err(ArrowError::InvalidArgumentError(format!(
156                    "Unsupported unit {unit:?} for Time32"
157                )))
158            }
159        },
160        Time64(unit) => match unit {
161            TimeUnit::Microsecond => Arc::new(
162                create_random_temporal_array::<Time64MicrosecondType>(size, primitive_null_density),
163            ) as ArrayRef,
164            TimeUnit::Nanosecond => Arc::new(create_random_temporal_array::<Time64NanosecondType>(
165                size,
166                primitive_null_density,
167            )),
168            _ => {
169                return Err(ArrowError::InvalidArgumentError(format!(
170                    "Unsupported unit {unit:?} for Time64"
171                )))
172            }
173        },
174        Utf8 => Arc::new(create_string_array::<i32>(size, primitive_null_density)),
175        LargeUtf8 => Arc::new(create_string_array::<i64>(size, primitive_null_density)),
176        Utf8View => Arc::new(create_string_view_array_with_len(
177            size,
178            primitive_null_density,
179            4,
180            false,
181        )),
182        Binary => Arc::new(create_binary_array::<i32>(size, primitive_null_density)),
183        LargeBinary => Arc::new(create_binary_array::<i64>(size, primitive_null_density)),
184        FixedSizeBinary(len) => Arc::new(create_fsb_array(
185            size,
186            primitive_null_density,
187            *len as usize,
188        )),
189        BinaryView => Arc::new(
190            create_string_view_array_with_len(size, primitive_null_density, 4, false)
191                .to_binary_view(),
192        ),
193        List(_) => create_random_list_array(field, size, null_density, true_density)?,
194        LargeList(_) => create_random_list_array(field, size, null_density, true_density)?,
195        Struct(_) => create_random_struct_array(field, size, null_density, true_density)?,
196        d @ Dictionary(_, value_type) if crate::compute::can_cast_types(value_type, d) => {
197            let f = Field::new(
198                field.name(),
199                value_type.as_ref().clone(),
200                field.is_nullable(),
201            );
202            let v = create_random_array(&f, size, null_density, true_density)?;
203            crate::compute::cast(&v, d)?
204        }
205        Map(_, _) => create_random_map_array(field, size, null_density, true_density)?,
206        other => {
207            return Err(ArrowError::NotYetImplemented(format!(
208                "Generating random arrays not yet implemented for {other:?}"
209            )))
210        }
211    })
212}
213
214#[inline]
215fn create_random_list_array(
216    field: &Field,
217    size: usize,
218    null_density: f32,
219    true_density: f32,
220) -> Result<ArrayRef> {
221    // Override null density with 0.0 if the array is non-nullable
222    let list_null_density = match field.is_nullable() {
223        true => null_density,
224        false => 0.0,
225    };
226    let list_field;
227    let (offsets, child_len) = match field.data_type() {
228        DataType::List(f) => {
229            let (offsets, child_len) = create_random_offsets::<i32>(size, 0, 5);
230            list_field = f;
231            (Buffer::from(offsets.to_byte_slice()), child_len as usize)
232        }
233        DataType::LargeList(f) => {
234            let (offsets, child_len) = create_random_offsets::<i64>(size, 0, 5);
235            list_field = f;
236            (Buffer::from(offsets.to_byte_slice()), child_len as usize)
237        }
238        _ => {
239            return Err(ArrowError::InvalidArgumentError(format!(
240                "Cannot create list array for field {field:?}"
241            )))
242        }
243    };
244
245    // Create list's child data
246    let child_array = create_random_array(list_field, child_len, null_density, true_density)?;
247    let child_data = child_array.to_data();
248    // Create list's null buffers, if it is nullable
249    let null_buffer = match field.is_nullable() {
250        true => Some(create_random_null_buffer(size, list_null_density)),
251        false => None,
252    };
253    let list_data = unsafe {
254        ArrayData::new_unchecked(
255            field.data_type().clone(),
256            size,
257            None,
258            null_buffer,
259            0,
260            vec![offsets],
261            vec![child_data],
262        )
263    };
264    Ok(make_array(list_data))
265}
266
267#[inline]
268fn create_random_struct_array(
269    field: &Field,
270    size: usize,
271    null_density: f32,
272    true_density: f32,
273) -> Result<ArrayRef> {
274    let struct_fields = match field.data_type() {
275        DataType::Struct(fields) => fields,
276        _ => {
277            return Err(ArrowError::InvalidArgumentError(format!(
278                "Cannot create struct array for field {field:?}"
279            )))
280        }
281    };
282
283    let child_arrays = struct_fields
284        .iter()
285        .map(|struct_field| create_random_array(struct_field, size, null_density, true_density))
286        .collect::<Result<Vec<_>>>()?;
287
288    let null_buffer = match field.is_nullable() {
289        true => {
290            let nulls = arrow_buffer::BooleanBuffer::new(
291                create_random_null_buffer(size, null_density),
292                0,
293                size,
294            );
295            Some(nulls.into())
296        }
297        false => None,
298    };
299
300    Ok(Arc::new(StructArray::try_new(
301        struct_fields.clone(),
302        child_arrays,
303        null_buffer,
304    )?))
305}
306
307#[inline]
308fn create_random_map_array(
309    field: &Field,
310    size: usize,
311    null_density: f32,
312    true_density: f32,
313) -> Result<ArrayRef> {
314    // Override null density with 0.0 if the array is non-nullable
315    let map_null_density = match field.is_nullable() {
316        true => null_density,
317        false => 0.0,
318    };
319
320    let entries_field = match field.data_type() {
321        DataType::Map(f, _) => f,
322        _ => {
323            return Err(ArrowError::InvalidArgumentError(format!(
324                "Cannot create map array for field {field:?}"
325            )))
326        }
327    };
328
329    let (offsets, child_len) = create_random_offsets::<i32>(size, 0, 5);
330    let offsets = Buffer::from(offsets.to_byte_slice());
331
332    let entries = create_random_array(
333        entries_field,
334        child_len as usize,
335        null_density,
336        true_density,
337    )?
338    .to_data();
339
340    let null_buffer = match field.is_nullable() {
341        true => Some(create_random_null_buffer(size, map_null_density)),
342        false => None,
343    };
344
345    let map_data = unsafe {
346        ArrayData::new_unchecked(
347            field.data_type().clone(),
348            size,
349            None,
350            null_buffer,
351            0,
352            vec![offsets],
353            vec![entries],
354        )
355    };
356    Ok(make_array(map_data))
357}
358
359/// Generate random offsets for list arrays
360fn create_random_offsets<T: OffsetSizeTrait + SampleUniform>(
361    size: usize,
362    min: T,
363    max: T,
364) -> (Vec<T>, T) {
365    let rng = &mut seedable_rng();
366
367    let mut current_offset = T::zero();
368
369    let mut offsets = Vec::with_capacity(size + 1);
370    offsets.push(current_offset);
371
372    (0..size).for_each(|_| {
373        current_offset += rng.gen_range(min..max);
374        offsets.push(current_offset);
375    });
376
377    (offsets, current_offset)
378}
379
380fn create_random_null_buffer(size: usize, null_density: f32) -> Buffer {
381    let mut rng = seedable_rng();
382    let mut mut_buf = MutableBuffer::new_null(size);
383    {
384        let mut_slice = mut_buf.as_slice_mut();
385        (0..size).for_each(|i| {
386            if rng.gen::<f32>() >= null_density {
387                bit_util::set_bit(mut_slice, i)
388            }
389        })
390    };
391    mut_buf.into()
392}
393
394/// Useful for testing. The range of values are not likely to be representative of the
395/// actual bounds.
396pub trait RandomTemporalValue: ArrowTemporalType {
397    /// Returns the range of values for `impl`'d type
398    fn value_range() -> impl SampleRange<Self::Native>;
399
400    /// Generate a random value within the range of the type
401    fn gen_range<R: Rng>(rng: &mut R) -> Self::Native
402    where
403        Self::Native: SampleUniform,
404    {
405        rng.gen_range(Self::value_range())
406    }
407
408    /// Generate a random value of the type
409    fn random<R: Rng>(rng: &mut R) -> Self::Native
410    where
411        Self::Native: SampleUniform,
412    {
413        Self::gen_range(rng)
414    }
415}
416
417impl RandomTemporalValue for TimestampSecondType {
418    /// Range of values for a timestamp in seconds. The range begins at the start
419    /// of the unix epoch and continues for 100 years.
420    fn value_range() -> impl SampleRange<Self::Native> {
421        0..60 * 60 * 24 * 365 * 100
422    }
423}
424
425impl RandomTemporalValue for TimestampMillisecondType {
426    /// Range of values for a timestamp in milliseconds. The range begins at the start
427    /// of the unix epoch and continues for 100 years.
428    fn value_range() -> impl SampleRange<Self::Native> {
429        0..1_000 * 60 * 60 * 24 * 365 * 100
430    }
431}
432
433impl RandomTemporalValue for TimestampMicrosecondType {
434    /// Range of values for a timestamp in microseconds. The range begins at the start
435    /// of the unix epoch and continues for 100 years.
436    fn value_range() -> impl SampleRange<Self::Native> {
437        0..1_000 * 1_000 * 60 * 60 * 24 * 365 * 100
438    }
439}
440
441impl RandomTemporalValue for TimestampNanosecondType {
442    /// Range of values for a timestamp in nanoseconds. The range begins at the start
443    /// of the unix epoch and continues for 100 years.
444    fn value_range() -> impl SampleRange<Self::Native> {
445        0..1_000 * 1_000 * 1_000 * 60 * 60 * 24 * 365 * 100
446    }
447}
448
449impl RandomTemporalValue for Date32Type {
450    /// Range of values representing the elapsed time since UNIX epoch in days. The
451    /// range begins at the start of the unix epoch and continues for 100 years.
452    fn value_range() -> impl SampleRange<Self::Native> {
453        0..365 * 100
454    }
455}
456
457impl RandomTemporalValue for Date64Type {
458    /// Range of values  representing the elapsed time since UNIX epoch in milliseconds.
459    /// The range begins at the start of the unix epoch and continues for 100 years.
460    fn value_range() -> impl SampleRange<Self::Native> {
461        0..1_000 * 60 * 60 * 24 * 365 * 100
462    }
463}
464
465impl RandomTemporalValue for Time32SecondType {
466    /// Range of values representing the elapsed time since midnight in seconds. The
467    /// range is from 0 to 24 hours.
468    fn value_range() -> impl SampleRange<Self::Native> {
469        0..60 * 60 * 24
470    }
471}
472
473impl RandomTemporalValue for Time32MillisecondType {
474    /// Range of values representing the elapsed time since midnight in milliseconds. The
475    /// range is from 0 to 24 hours.
476    fn value_range() -> impl SampleRange<Self::Native> {
477        0..1_000 * 60 * 60 * 24
478    }
479}
480
481impl RandomTemporalValue for Time64MicrosecondType {
482    /// Range of values representing the elapsed time since midnight in microseconds. The
483    /// range is from 0 to 24 hours.
484    fn value_range() -> impl SampleRange<Self::Native> {
485        0..1_000 * 1_000 * 60 * 60 * 24
486    }
487}
488
489impl RandomTemporalValue for Time64NanosecondType {
490    /// Range of values representing the elapsed time since midnight in nanoseconds. The
491    /// range is from 0 to 24 hours.
492    fn value_range() -> impl SampleRange<Self::Native> {
493        0..1_000 * 1_000 * 1_000 * 60 * 60 * 24
494    }
495}
496
497fn create_random_temporal_array<T>(size: usize, null_density: f32) -> PrimitiveArray<T>
498where
499    T: RandomTemporalValue,
500    <T as ArrowPrimitiveType>::Native: SampleUniform,
501{
502    let mut rng = seedable_rng();
503
504    (0..size)
505        .map(|_| {
506            if rng.gen::<f32>() < null_density {
507                None
508            } else {
509                Some(T::random(&mut rng))
510            }
511        })
512        .collect()
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_create_batch() {
521        let size = 32;
522        let fields = vec![Field::new("a", DataType::Int32, true)];
523        let schema = Schema::new(fields);
524        let schema_ref = Arc::new(schema);
525        let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();
526
527        assert_eq!(batch.schema(), schema_ref);
528        assert_eq!(batch.num_columns(), schema_ref.fields().len());
529        for array in batch.columns() {
530            assert_eq!(array.len(), size);
531        }
532    }
533
534    #[test]
535    fn test_create_batch_non_null() {
536        let size = 32;
537        let fields = vec![
538            Field::new("a", DataType::Int32, false),
539            Field::new(
540                "b",
541                DataType::List(Arc::new(Field::new_list_field(DataType::LargeUtf8, false))),
542                false,
543            ),
544            Field::new("a", DataType::Int32, false),
545        ];
546        let schema = Schema::new(fields);
547        let schema_ref = Arc::new(schema);
548        let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();
549
550        assert_eq!(batch.schema(), schema_ref);
551        assert_eq!(batch.num_columns(), schema_ref.fields().len());
552        for array in batch.columns() {
553            assert_eq!(array.null_count(), 0);
554            assert_eq!(array.logical_null_count(), 0);
555        }
556        // Test that the list's child values are non-null
557        let b_array = batch.column(1);
558        let list_array = b_array.as_list::<i32>();
559        let child_array = list_array.values();
560        assert_eq!(child_array.null_count(), 0);
561        // There should be more values than the list, to show that it's a list
562        assert!(child_array.len() > list_array.len());
563    }
564
565    #[test]
566    fn test_create_struct_array() {
567        let size = 32;
568        let struct_fields = Fields::from(vec![
569            Field::new("b", DataType::Boolean, true),
570            Field::new(
571                "c",
572                DataType::LargeList(Arc::new(Field::new_list_field(
573                    DataType::List(Arc::new(Field::new_list_field(
574                        DataType::FixedSizeBinary(6),
575                        true,
576                    ))),
577                    false,
578                ))),
579                true,
580            ),
581            Field::new(
582                "d",
583                DataType::Struct(Fields::from(vec![
584                    Field::new("d_x", DataType::Int32, true),
585                    Field::new("d_y", DataType::Float32, false),
586                    Field::new("d_z", DataType::Binary, true),
587                ])),
588                true,
589            ),
590        ]);
591        let field = Field::new("struct", DataType::Struct(struct_fields), true);
592        let array = create_random_array(&field, size, 0.2, 0.5).unwrap();
593
594        assert_eq!(array.len(), 32);
595        let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
596        assert_eq!(struct_array.columns().len(), 3);
597
598        // Test that the nested list makes sense,
599        // i.e. its children's values are more than the parent, to show repetition
600        let col_c = struct_array.column_by_name("c").unwrap();
601        let col_c = col_c.as_any().downcast_ref::<LargeListArray>().unwrap();
602        assert_eq!(col_c.len(), size);
603        let col_c_list = col_c.values().as_list::<i32>();
604        assert!(col_c_list.len() > size);
605        // Its values should be FixedSizeBinary(6)
606        let fsb = col_c_list.values();
607        assert_eq!(fsb.data_type(), &DataType::FixedSizeBinary(6));
608        assert!(fsb.len() > col_c_list.len());
609
610        // Test nested struct
611        let col_d = struct_array.column_by_name("d").unwrap();
612        let col_d = col_d.as_any().downcast_ref::<StructArray>().unwrap();
613        let col_d_y = col_d.column_by_name("d_y").unwrap();
614        assert_eq!(col_d_y.data_type(), &DataType::Float32);
615        assert_eq!(col_d_y.null_count(), 0);
616    }
617
618    #[test]
619    fn test_create_list_array_nested_nullability() {
620        let list_field = Field::new_list(
621            "not_null_list",
622            Field::new_list_field(DataType::Boolean, true),
623            false,
624        );
625
626        let list_array = create_random_array(&list_field, 100, 0.95, 0.5).unwrap();
627
628        assert_eq!(list_array.null_count(), 0);
629        assert!(list_array.as_list::<i32>().values().null_count() > 0);
630    }
631
632    #[test]
633    fn test_create_struct_array_nested_nullability() {
634        let struct_child_fields = vec![
635            Field::new("null_int", DataType::Int32, true),
636            Field::new("int", DataType::Int32, false),
637        ];
638        let struct_field = Field::new_struct("not_null_struct", struct_child_fields, false);
639
640        let struct_array = create_random_array(&struct_field, 100, 0.95, 0.5).unwrap();
641
642        assert_eq!(struct_array.null_count(), 0);
643        assert!(
644            struct_array
645                .as_struct()
646                .column_by_name("null_int")
647                .unwrap()
648                .null_count()
649                > 0
650        );
651        assert_eq!(
652            struct_array
653                .as_struct()
654                .column_by_name("int")
655                .unwrap()
656                .null_count(),
657            0
658        );
659    }
660
661    #[test]
662    fn test_create_list_array_nested_struct_nullability() {
663        let struct_child_fields = vec![
664            Field::new("null_int", DataType::Int32, true),
665            Field::new("int", DataType::Int32, false),
666        ];
667        let list_item_field =
668            Field::new_list_field(DataType::Struct(struct_child_fields.into()), true);
669        let list_field = Field::new_list("not_null_list", list_item_field, false);
670
671        let list_array = create_random_array(&list_field, 100, 0.95, 0.5).unwrap();
672
673        assert_eq!(list_array.null_count(), 0);
674        assert!(list_array.as_list::<i32>().values().null_count() > 0);
675        assert!(
676            list_array
677                .as_list::<i32>()
678                .values()
679                .as_struct()
680                .column_by_name("null_int")
681                .unwrap()
682                .null_count()
683                > 0
684        );
685        assert_eq!(
686            list_array
687                .as_list::<i32>()
688                .values()
689                .as_struct()
690                .column_by_name("int")
691                .unwrap()
692                .null_count(),
693            0
694        );
695    }
696
697    #[test]
698    fn test_create_map_array() {
699        let map_field = Field::new_map(
700            "map",
701            "entries",
702            Field::new("key", DataType::Utf8, false),
703            Field::new("value", DataType::Utf8, true),
704            false,
705            false,
706        );
707        let array = create_random_array(&map_field, 100, 0.8, 0.5).unwrap();
708
709        assert_eq!(array.len(), 100);
710        // Map field is not null
711        assert_eq!(array.null_count(), 0);
712        assert_eq!(array.logical_null_count(), 0);
713        // Maps have multiple values like a list, so internal arrays are longer
714        assert!(array.as_map().keys().len() > array.len());
715        assert!(array.as_map().values().len() > array.len());
716        // Keys are not nullable
717        assert_eq!(array.as_map().keys().null_count(), 0);
718        // Values are nullable
719        assert!(array.as_map().values().null_count() > 0);
720
721        assert_eq!(array.as_map().keys().data_type(), &DataType::Utf8);
722        assert_eq!(array.as_map().values().data_type(), &DataType::Utf8);
723    }
724}