arrow/
tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow Tensor Type, defined in
19//! [`format/Tensor.fbs`](https://github.com/apache/arrow/blob/master/format/Tensor.fbs).
20
21use std::marker::PhantomData;
22use std::mem;
23
24use crate::buffer::Buffer;
25use crate::datatypes::*;
26
27use crate::error::{ArrowError, Result};
28
29/// Computes the strides required assuming a row major memory layout
30fn compute_row_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
31    let mut remaining_bytes = mem::size_of::<T::Native>();
32
33    for i in shape {
34        if let Some(val) = remaining_bytes.checked_mul(*i) {
35            remaining_bytes = val;
36        } else {
37            return Err(ArrowError::ComputeError(
38                "overflow occurred when computing row major strides.".to_string(),
39            ));
40        }
41    }
42
43    let mut strides = Vec::<usize>::new();
44    for i in shape {
45        remaining_bytes /= *i;
46        strides.push(remaining_bytes);
47    }
48
49    Ok(strides)
50}
51
52/// Computes the strides required assuming a column major memory layout
53fn compute_column_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
54    let mut remaining_bytes = mem::size_of::<T::Native>();
55    let mut strides = Vec::<usize>::new();
56
57    for i in shape {
58        strides.push(remaining_bytes);
59
60        if let Some(val) = remaining_bytes.checked_mul(*i) {
61            remaining_bytes = val;
62        } else {
63            return Err(ArrowError::ComputeError(
64                "overflow occurred when computing column major strides.".to_string(),
65            ));
66        }
67    }
68
69    Ok(strides)
70}
71
72/// Tensor of primitive types
73#[derive(Debug)]
74pub struct Tensor<'a, T: ArrowPrimitiveType> {
75    data_type: DataType,
76    buffer: Buffer,
77    shape: Option<Vec<usize>>,
78    strides: Option<Vec<usize>>,
79    names: Option<Vec<&'a str>>,
80    _marker: PhantomData<T>,
81}
82
83/// [Tensor] of type [BooleanType]
84pub type BooleanTensor<'a> = Tensor<'a, BooleanType>;
85/// [Tensor] of type [Int8Type]
86pub type Date32Tensor<'a> = Tensor<'a, Date32Type>;
87/// [Tensor] of type [Int16Type]
88pub type Date64Tensor<'a> = Tensor<'a, Date64Type>;
89/// [Tensor] of type [Decimal32Type]
90pub type Decimal32Tensor<'a> = Tensor<'a, Decimal32Type>;
91/// [Tensor] of type [Decimal64Type]
92pub type Decimal64Tensor<'a> = Tensor<'a, Decimal64Type>;
93/// [Tensor] of type [Decimal128Type]
94pub type Decimal128Tensor<'a> = Tensor<'a, Decimal128Type>;
95/// [Tensor] of type [Decimal256Type]
96pub type Decimal256Tensor<'a> = Tensor<'a, Decimal256Type>;
97/// [Tensor] of type [DurationMicrosecondType]
98pub type DurationMicrosecondTensor<'a> = Tensor<'a, DurationMicrosecondType>;
99/// [Tensor] of type [DurationMillisecondType]
100pub type DurationMillisecondTensor<'a> = Tensor<'a, DurationMillisecondType>;
101/// [Tensor] of type [DurationNanosecondType]
102pub type DurationNanosecondTensor<'a> = Tensor<'a, DurationNanosecondType>;
103/// [Tensor] of type [DurationSecondType]
104pub type DurationSecondTensor<'a> = Tensor<'a, DurationSecondType>;
105/// [Tensor] of type [Float16Type]
106pub type Float16Tensor<'a> = Tensor<'a, Float16Type>;
107/// [Tensor] of type [Float32Type]
108pub type Float32Tensor<'a> = Tensor<'a, Float32Type>;
109/// [Tensor] of type [Float64Type]
110pub type Float64Tensor<'a> = Tensor<'a, Float64Type>;
111/// [Tensor] of type [Int8Type]
112pub type Int8Tensor<'a> = Tensor<'a, Int8Type>;
113/// [Tensor] of type [Int16Type]
114pub type Int16Tensor<'a> = Tensor<'a, Int16Type>;
115/// [Tensor] of type [Int32Type]
116pub type Int32Tensor<'a> = Tensor<'a, Int32Type>;
117/// [Tensor] of type [Int64Type]
118pub type Int64Tensor<'a> = Tensor<'a, Int64Type>;
119/// [Tensor] of type [IntervalDayTimeType]
120pub type IntervalDayTimeTensor<'a> = Tensor<'a, IntervalDayTimeType>;
121/// [Tensor] of type [IntervalMonthDayNanoType]
122pub type IntervalMonthDayNanoTensor<'a> = Tensor<'a, IntervalMonthDayNanoType>;
123/// [Tensor] of type [IntervalYearMonthType]
124pub type IntervalYearMonthTensor<'a> = Tensor<'a, IntervalYearMonthType>;
125/// [Tensor] of type [Time32MillisecondType]
126pub type Time32MillisecondTensor<'a> = Tensor<'a, Time32MillisecondType>;
127/// [Tensor] of type [Time32SecondType]
128pub type Time32SecondTensor<'a> = Tensor<'a, Time32SecondType>;
129/// [Tensor] of type [Time64MicrosecondType]
130pub type Time64MicrosecondTensor<'a> = Tensor<'a, Time64MicrosecondType>;
131/// [Tensor] of type [Time64NanosecondType]
132pub type Time64NanosecondTensor<'a> = Tensor<'a, Time64NanosecondType>;
133/// [Tensor] of type [TimestampMicrosecondType]
134pub type TimestampMicrosecondTensor<'a> = Tensor<'a, TimestampMicrosecondType>;
135/// [Tensor] of type [TimestampMillisecondType]
136pub type TimestampMillisecondTensor<'a> = Tensor<'a, TimestampMillisecondType>;
137/// [Tensor] of type [TimestampNanosecondType]
138pub type TimestampNanosecondTensor<'a> = Tensor<'a, TimestampNanosecondType>;
139/// [Tensor] of type [TimestampSecondType]
140pub type TimestampSecondTensor<'a> = Tensor<'a, TimestampSecondType>;
141/// [Tensor] of type [UInt8Type]
142pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>;
143/// [Tensor] of type [UInt16Type]
144pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>;
145/// [Tensor] of type [UInt32Type]
146pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>;
147/// [Tensor] of type [UInt64Type]
148pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>;
149
150impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> {
151    /// Creates a new `Tensor`
152    pub fn try_new(
153        buffer: Buffer,
154        shape: Option<Vec<usize>>,
155        strides: Option<Vec<usize>>,
156        names: Option<Vec<&'a str>>,
157    ) -> Result<Self> {
158        match shape {
159            None => {
160                if buffer.len() != mem::size_of::<T::Native>() {
161                    return Err(ArrowError::InvalidArgumentError(
162                        "underlying buffer should only contain a single tensor element".to_string(),
163                    ));
164                }
165
166                if strides.is_some() {
167                    return Err(ArrowError::InvalidArgumentError(
168                        "expected None strides for tensor with no shape".to_string(),
169                    ));
170                }
171
172                if names.is_some() {
173                    return Err(ArrowError::InvalidArgumentError(
174                        "expected None names for tensor with no shape".to_string(),
175                    ));
176                }
177            }
178
179            Some(ref s) => {
180                if let Some(ref st) = strides {
181                    if st.len() != s.len() {
182                        return Err(ArrowError::InvalidArgumentError(
183                            "shape and stride dimensions differ".to_string(),
184                        ));
185                    }
186                }
187
188                if let Some(ref n) = names {
189                    if n.len() != s.len() {
190                        return Err(ArrowError::InvalidArgumentError(
191                            "number of dimensions and number of dimension names differ".to_string(),
192                        ));
193                    }
194                }
195
196                let total_elements: usize = s.iter().product();
197                if total_elements != (buffer.len() / mem::size_of::<T::Native>()) {
198                    return Err(ArrowError::InvalidArgumentError(
199                        "number of elements in buffer does not match dimensions".to_string(),
200                    ));
201                }
202            }
203        };
204
205        // Checking that the tensor strides used for construction are correct
206        // otherwise a row major stride is calculated and used as value for the tensor
207        let tensor_strides = {
208            if let Some(st) = strides {
209                if let Some(ref s) = shape {
210                    if compute_row_major_strides::<T>(s)? == st
211                        || compute_column_major_strides::<T>(s)? == st
212                    {
213                        Some(st)
214                    } else {
215                        return Err(ArrowError::InvalidArgumentError(
216                            "the input stride does not match the selected shape".to_string(),
217                        ));
218                    }
219                } else {
220                    Some(st)
221                }
222            } else if let Some(ref s) = shape {
223                Some(compute_row_major_strides::<T>(s)?)
224            } else {
225                None
226            }
227        };
228
229        Ok(Self {
230            data_type: T::DATA_TYPE,
231            buffer,
232            shape,
233            strides: tensor_strides,
234            names,
235            _marker: PhantomData,
236        })
237    }
238
239    /// Creates a new Tensor using row major memory layout
240    pub fn new_row_major(
241        buffer: Buffer,
242        shape: Option<Vec<usize>>,
243        names: Option<Vec<&'a str>>,
244    ) -> Result<Self> {
245        if let Some(ref s) = shape {
246            let strides = Some(compute_row_major_strides::<T>(s)?);
247
248            Self::try_new(buffer, shape, strides, names)
249        } else {
250            Err(ArrowError::InvalidArgumentError(
251                "shape required to create row major tensor".to_string(),
252            ))
253        }
254    }
255
256    /// Creates a new Tensor using column major memory layout
257    pub fn new_column_major(
258        buffer: Buffer,
259        shape: Option<Vec<usize>>,
260        names: Option<Vec<&'a str>>,
261    ) -> Result<Self> {
262        if let Some(ref s) = shape {
263            let strides = Some(compute_column_major_strides::<T>(s)?);
264
265            Self::try_new(buffer, shape, strides, names)
266        } else {
267            Err(ArrowError::InvalidArgumentError(
268                "shape required to create column major tensor".to_string(),
269            ))
270        }
271    }
272
273    /// The data type of the `Tensor`
274    pub fn data_type(&self) -> &DataType {
275        &self.data_type
276    }
277
278    /// The sizes of the dimensions
279    pub fn shape(&self) -> Option<&Vec<usize>> {
280        self.shape.as_ref()
281    }
282
283    /// Returns a reference to the underlying `Buffer`
284    pub fn data(&self) -> &Buffer {
285        &self.buffer
286    }
287
288    /// The number of bytes between elements in each dimension
289    pub fn strides(&self) -> Option<&Vec<usize>> {
290        self.strides.as_ref()
291    }
292
293    /// The names of the dimensions
294    pub fn names(&self) -> Option<&Vec<&'a str>> {
295        self.names.as_ref()
296    }
297
298    /// The number of dimensions
299    pub fn ndim(&self) -> usize {
300        match &self.shape {
301            None => 0,
302            Some(v) => v.len(),
303        }
304    }
305
306    /// The name of dimension i
307    pub fn dim_name(&self, i: usize) -> Option<&'a str> {
308        self.names.as_ref().map(|names| names[i])
309    }
310
311    /// The total number of elements in the `Tensor`
312    pub fn size(&self) -> usize {
313        match self.shape {
314            None => 0,
315            Some(ref s) => s.iter().product(),
316        }
317    }
318
319    /// Indicates if the data is laid out contiguously in memory
320    pub fn is_contiguous(&self) -> Result<bool> {
321        Ok(self.is_row_major()? || self.is_column_major()?)
322    }
323
324    /// Indicates if the memory layout row major
325    pub fn is_row_major(&self) -> Result<bool> {
326        match self.shape {
327            None => Ok(false),
328            Some(ref s) => Ok(Some(compute_row_major_strides::<T>(s)?) == self.strides),
329        }
330    }
331
332    /// Indicates if the memory layout column major
333    pub fn is_column_major(&self) -> Result<bool> {
334        match self.shape {
335            None => Ok(false),
336            Some(ref s) => Ok(Some(compute_column_major_strides::<T>(s)?) == self.strides),
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    use crate::array::*;
346
347    #[test]
348    fn test_compute_row_major_strides() {
349        assert_eq!(
350            vec![48_usize, 8],
351            compute_row_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
352        );
353        assert_eq!(
354            vec![24_usize, 4],
355            compute_row_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
356        );
357        assert_eq!(
358            vec![6_usize, 1],
359            compute_row_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
360        );
361    }
362
363    #[test]
364    fn test_compute_column_major_strides() {
365        assert_eq!(
366            vec![8_usize, 32],
367            compute_column_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
368        );
369        assert_eq!(
370            vec![4_usize, 16],
371            compute_column_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
372        );
373        assert_eq!(
374            vec![1_usize, 4],
375            compute_column_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
376        );
377    }
378
379    #[test]
380    fn test_zero_dim() {
381        let buf = Buffer::from(&[1]);
382        let tensor = UInt8Tensor::try_new(buf, None, None, None).unwrap();
383        assert_eq!(0, tensor.size());
384        assert_eq!(None, tensor.shape());
385        assert_eq!(None, tensor.names());
386        assert_eq!(0, tensor.ndim());
387        assert!(!tensor.is_row_major().unwrap());
388        assert!(!tensor.is_column_major().unwrap());
389        assert!(!tensor.is_contiguous().unwrap());
390
391        let buf = Buffer::from(&[1, 2, 2, 2]);
392        let tensor = Int32Tensor::try_new(buf, None, None, None).unwrap();
393        assert_eq!(0, tensor.size());
394        assert_eq!(None, tensor.shape());
395        assert_eq!(None, tensor.names());
396        assert_eq!(0, tensor.ndim());
397        assert!(!tensor.is_row_major().unwrap());
398        assert!(!tensor.is_column_major().unwrap());
399        assert!(!tensor.is_contiguous().unwrap());
400    }
401
402    #[test]
403    fn test_tensor() {
404        let mut builder = Int32BufferBuilder::new(16);
405        for i in 0..16 {
406            builder.append(i);
407        }
408        let buf = builder.finish();
409        let tensor = Int32Tensor::try_new(buf, Some(vec![2, 8]), None, None).unwrap();
410        assert_eq!(16, tensor.size());
411        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
412        assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
413        assert_eq!(2, tensor.ndim());
414        assert_eq!(None, tensor.names());
415    }
416
417    #[test]
418    fn test_new_row_major() {
419        let mut builder = Int32BufferBuilder::new(16);
420        for i in 0..16 {
421            builder.append(i);
422        }
423        let buf = builder.finish();
424        let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None).unwrap();
425        assert_eq!(16, tensor.size());
426        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
427        assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
428        assert_eq!(None, tensor.names());
429        assert_eq!(2, tensor.ndim());
430        assert!(tensor.is_row_major().unwrap());
431        assert!(!tensor.is_column_major().unwrap());
432        assert!(tensor.is_contiguous().unwrap());
433    }
434
435    #[test]
436    fn test_new_column_major() {
437        let mut builder = Int32BufferBuilder::new(16);
438        for i in 0..16 {
439            builder.append(i);
440        }
441        let buf = builder.finish();
442        let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None).unwrap();
443        assert_eq!(16, tensor.size());
444        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
445        assert_eq!(Some(vec![4_usize, 8]).as_ref(), tensor.strides());
446        assert_eq!(None, tensor.names());
447        assert_eq!(2, tensor.ndim());
448        assert!(!tensor.is_row_major().unwrap());
449        assert!(tensor.is_column_major().unwrap());
450        assert!(tensor.is_contiguous().unwrap());
451    }
452
453    #[test]
454    fn test_with_names() {
455        let mut builder = Int64BufferBuilder::new(8);
456        for i in 0..8 {
457            builder.append(i);
458        }
459        let buf = builder.finish();
460        let names = vec!["Dim 1", "Dim 2"];
461        let tensor = Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)).unwrap();
462        assert_eq!(8, tensor.size());
463        assert_eq!(Some(vec![2_usize, 4]).as_ref(), tensor.shape());
464        assert_eq!(Some(vec![8_usize, 16]).as_ref(), tensor.strides());
465        assert_eq!("Dim 1", tensor.dim_name(0).unwrap());
466        assert_eq!("Dim 2", tensor.dim_name(1).unwrap());
467        assert_eq!(2, tensor.ndim());
468        assert!(!tensor.is_row_major().unwrap());
469        assert!(tensor.is_column_major().unwrap());
470        assert!(tensor.is_contiguous().unwrap());
471    }
472
473    #[test]
474    fn test_inconsistent_strides() {
475        let mut builder = Int32BufferBuilder::new(16);
476        for i in 0..16 {
477            builder.append(i);
478        }
479        let buf = builder.finish();
480
481        let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None);
482
483        if result.is_ok() {
484            panic!("shape and stride dimensions are different")
485        }
486    }
487
488    #[test]
489    fn test_inconsistent_names() {
490        let mut builder = Int32BufferBuilder::new(16);
491        for i in 0..16 {
492            builder.append(i);
493        }
494        let buf = builder.finish();
495
496        let result = Int32Tensor::try_new(
497            buf,
498            Some(vec![2, 8]),
499            Some(vec![4, 8]),
500            Some(vec!["1", "2", "3"]),
501        );
502
503        if result.is_ok() {
504            panic!("dimensions and names have different shape")
505        }
506    }
507
508    #[test]
509    fn test_incorrect_shape() {
510        let mut builder = Int32BufferBuilder::new(16);
511        for i in 0..16 {
512            builder.append(i);
513        }
514        let buf = builder.finish();
515
516        let result = Int32Tensor::try_new(buf, Some(vec![2, 6]), None, None);
517
518        if result.is_ok() {
519            panic!("number of elements does not match for the shape")
520        }
521    }
522
523    #[test]
524    fn test_incorrect_stride() {
525        let mut builder = Int32BufferBuilder::new(16);
526        for i in 0..16 {
527            builder.append(i);
528        }
529        let buf = builder.finish();
530
531        let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![30, 4]), None);
532
533        if result.is_ok() {
534            panic!("the input stride does not match the selected shape")
535        }
536    }
537}