arrow/
tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow Tensor Type, defined in
19//! [`format/Tensor.fbs`](https://github.com/apache/arrow/blob/master/format/Tensor.fbs).
20
21use std::marker::PhantomData;
22use std::mem;
23
24use crate::buffer::Buffer;
25use crate::datatypes::*;
26
27use crate::error::{ArrowError, Result};
28
29/// Computes the strides required assuming a row major memory layout
30fn compute_row_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
31    let mut remaining_bytes = mem::size_of::<T::Native>();
32
33    for i in shape {
34        if let Some(val) = remaining_bytes.checked_mul(*i) {
35            remaining_bytes = val;
36        } else {
37            return Err(ArrowError::ComputeError(
38                "overflow occurred when computing row major strides.".to_string(),
39            ));
40        }
41    }
42
43    let mut strides = Vec::<usize>::new();
44    for i in shape {
45        remaining_bytes /= *i;
46        strides.push(remaining_bytes);
47    }
48
49    Ok(strides)
50}
51
52/// Computes the strides required assuming a column major memory layout
53fn compute_column_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
54    let mut remaining_bytes = mem::size_of::<T::Native>();
55    let mut strides = Vec::<usize>::new();
56
57    for i in shape {
58        strides.push(remaining_bytes);
59
60        if let Some(val) = remaining_bytes.checked_mul(*i) {
61            remaining_bytes = val;
62        } else {
63            return Err(ArrowError::ComputeError(
64                "overflow occurred when computing column major strides.".to_string(),
65            ));
66        }
67    }
68
69    Ok(strides)
70}
71
72/// Tensor of primitive types
73#[derive(Debug)]
74pub struct Tensor<'a, T: ArrowPrimitiveType> {
75    data_type: DataType,
76    buffer: Buffer,
77    shape: Option<Vec<usize>>,
78    strides: Option<Vec<usize>>,
79    names: Option<Vec<&'a str>>,
80    _marker: PhantomData<T>,
81}
82
83/// [Tensor] of type [BooleanType]
84pub type BooleanTensor<'a> = Tensor<'a, BooleanType>;
85/// [Tensor] of type [Int8Type]
86pub type Date32Tensor<'a> = Tensor<'a, Date32Type>;
87/// [Tensor] of type [Int16Type]
88pub type Date64Tensor<'a> = Tensor<'a, Date64Type>;
89/// [Tensor] of type [Decimal128Type]
90pub type Decimal128Tensor<'a> = Tensor<'a, Decimal128Type>;
91/// [Tensor] of type [Decimal256Type]
92pub type Decimal256Tensor<'a> = Tensor<'a, Decimal256Type>;
93/// [Tensor] of type [DurationMicrosecondType]
94pub type DurationMicrosecondTensor<'a> = Tensor<'a, DurationMicrosecondType>;
95/// [Tensor] of type [DurationMillisecondType]
96pub type DurationMillisecondTensor<'a> = Tensor<'a, DurationMillisecondType>;
97/// [Tensor] of type [DurationNanosecondType]
98pub type DurationNanosecondTensor<'a> = Tensor<'a, DurationNanosecondType>;
99/// [Tensor] of type [DurationSecondType]
100pub type DurationSecondTensor<'a> = Tensor<'a, DurationSecondType>;
101/// [Tensor] of type [Float16Type]
102pub type Float16Tensor<'a> = Tensor<'a, Float16Type>;
103/// [Tensor] of type [Float32Type]
104pub type Float32Tensor<'a> = Tensor<'a, Float32Type>;
105/// [Tensor] of type [Float64Type]
106pub type Float64Tensor<'a> = Tensor<'a, Float64Type>;
107/// [Tensor] of type [Int8Type]
108pub type Int8Tensor<'a> = Tensor<'a, Int8Type>;
109/// [Tensor] of type [Int16Type]
110pub type Int16Tensor<'a> = Tensor<'a, Int16Type>;
111/// [Tensor] of type [Int32Type]
112pub type Int32Tensor<'a> = Tensor<'a, Int32Type>;
113/// [Tensor] of type [Int64Type]
114pub type Int64Tensor<'a> = Tensor<'a, Int64Type>;
115/// [Tensor] of type [IntervalDayTimeType]
116pub type IntervalDayTimeTensor<'a> = Tensor<'a, IntervalDayTimeType>;
117/// [Tensor] of type [IntervalMonthDayNanoType]
118pub type IntervalMonthDayNanoTensor<'a> = Tensor<'a, IntervalMonthDayNanoType>;
119/// [Tensor] of type [IntervalYearMonthType]
120pub type IntervalYearMonthTensor<'a> = Tensor<'a, IntervalYearMonthType>;
121/// [Tensor] of type [Time32MillisecondType]
122pub type Time32MillisecondTensor<'a> = Tensor<'a, Time32MillisecondType>;
123/// [Tensor] of type [Time32SecondType]
124pub type Time32SecondTensor<'a> = Tensor<'a, Time32SecondType>;
125/// [Tensor] of type [Time64MicrosecondType]
126pub type Time64MicrosecondTensor<'a> = Tensor<'a, Time64MicrosecondType>;
127/// [Tensor] of type [Time64NanosecondType]
128pub type Time64NanosecondTensor<'a> = Tensor<'a, Time64NanosecondType>;
129/// [Tensor] of type [TimestampMicrosecondType]
130pub type TimestampMicrosecondTensor<'a> = Tensor<'a, TimestampMicrosecondType>;
131/// [Tensor] of type [TimestampMillisecondType]
132pub type TimestampMillisecondTensor<'a> = Tensor<'a, TimestampMillisecondType>;
133/// [Tensor] of type [TimestampNanosecondType]
134pub type TimestampNanosecondTensor<'a> = Tensor<'a, TimestampNanosecondType>;
135/// [Tensor] of type [TimestampSecondType]
136pub type TimestampSecondTensor<'a> = Tensor<'a, TimestampSecondType>;
137/// [Tensor] of type [UInt8Type]
138pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>;
139/// [Tensor] of type [UInt16Type]
140pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>;
141/// [Tensor] of type [UInt32Type]
142pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>;
143/// [Tensor] of type [UInt64Type]
144pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>;
145
146impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> {
147    /// Creates a new `Tensor`
148    pub fn try_new(
149        buffer: Buffer,
150        shape: Option<Vec<usize>>,
151        strides: Option<Vec<usize>>,
152        names: Option<Vec<&'a str>>,
153    ) -> Result<Self> {
154        match shape {
155            None => {
156                if buffer.len() != mem::size_of::<T::Native>() {
157                    return Err(ArrowError::InvalidArgumentError(
158                        "underlying buffer should only contain a single tensor element".to_string(),
159                    ));
160                }
161
162                if strides.is_some() {
163                    return Err(ArrowError::InvalidArgumentError(
164                        "expected None strides for tensor with no shape".to_string(),
165                    ));
166                }
167
168                if names.is_some() {
169                    return Err(ArrowError::InvalidArgumentError(
170                        "expected None names for tensor with no shape".to_string(),
171                    ));
172                }
173            }
174
175            Some(ref s) => {
176                if let Some(ref st) = strides {
177                    if st.len() != s.len() {
178                        return Err(ArrowError::InvalidArgumentError(
179                            "shape and stride dimensions differ".to_string(),
180                        ));
181                    }
182                }
183
184                if let Some(ref n) = names {
185                    if n.len() != s.len() {
186                        return Err(ArrowError::InvalidArgumentError(
187                            "number of dimensions and number of dimension names differ".to_string(),
188                        ));
189                    }
190                }
191
192                let total_elements: usize = s.iter().product();
193                if total_elements != (buffer.len() / mem::size_of::<T::Native>()) {
194                    return Err(ArrowError::InvalidArgumentError(
195                        "number of elements in buffer does not match dimensions".to_string(),
196                    ));
197                }
198            }
199        };
200
201        // Checking that the tensor strides used for construction are correct
202        // otherwise a row major stride is calculated and used as value for the tensor
203        let tensor_strides = {
204            if let Some(st) = strides {
205                if let Some(ref s) = shape {
206                    if compute_row_major_strides::<T>(s)? == st
207                        || compute_column_major_strides::<T>(s)? == st
208                    {
209                        Some(st)
210                    } else {
211                        return Err(ArrowError::InvalidArgumentError(
212                            "the input stride does not match the selected shape".to_string(),
213                        ));
214                    }
215                } else {
216                    Some(st)
217                }
218            } else if let Some(ref s) = shape {
219                Some(compute_row_major_strides::<T>(s)?)
220            } else {
221                None
222            }
223        };
224
225        Ok(Self {
226            data_type: T::DATA_TYPE,
227            buffer,
228            shape,
229            strides: tensor_strides,
230            names,
231            _marker: PhantomData,
232        })
233    }
234
235    /// Creates a new Tensor using row major memory layout
236    pub fn new_row_major(
237        buffer: Buffer,
238        shape: Option<Vec<usize>>,
239        names: Option<Vec<&'a str>>,
240    ) -> Result<Self> {
241        if let Some(ref s) = shape {
242            let strides = Some(compute_row_major_strides::<T>(s)?);
243
244            Self::try_new(buffer, shape, strides, names)
245        } else {
246            Err(ArrowError::InvalidArgumentError(
247                "shape required to create row major tensor".to_string(),
248            ))
249        }
250    }
251
252    /// Creates a new Tensor using column major memory layout
253    pub fn new_column_major(
254        buffer: Buffer,
255        shape: Option<Vec<usize>>,
256        names: Option<Vec<&'a str>>,
257    ) -> Result<Self> {
258        if let Some(ref s) = shape {
259            let strides = Some(compute_column_major_strides::<T>(s)?);
260
261            Self::try_new(buffer, shape, strides, names)
262        } else {
263            Err(ArrowError::InvalidArgumentError(
264                "shape required to create column major tensor".to_string(),
265            ))
266        }
267    }
268
269    /// The data type of the `Tensor`
270    pub fn data_type(&self) -> &DataType {
271        &self.data_type
272    }
273
274    /// The sizes of the dimensions
275    pub fn shape(&self) -> Option<&Vec<usize>> {
276        self.shape.as_ref()
277    }
278
279    /// Returns a reference to the underlying `Buffer`
280    pub fn data(&self) -> &Buffer {
281        &self.buffer
282    }
283
284    /// The number of bytes between elements in each dimension
285    pub fn strides(&self) -> Option<&Vec<usize>> {
286        self.strides.as_ref()
287    }
288
289    /// The names of the dimensions
290    pub fn names(&self) -> Option<&Vec<&'a str>> {
291        self.names.as_ref()
292    }
293
294    /// The number of dimensions
295    pub fn ndim(&self) -> usize {
296        match &self.shape {
297            None => 0,
298            Some(v) => v.len(),
299        }
300    }
301
302    /// The name of dimension i
303    pub fn dim_name(&self, i: usize) -> Option<&'a str> {
304        self.names.as_ref().map(|names| names[i])
305    }
306
307    /// The total number of elements in the `Tensor`
308    pub fn size(&self) -> usize {
309        match self.shape {
310            None => 0,
311            Some(ref s) => s.iter().product(),
312        }
313    }
314
315    /// Indicates if the data is laid out contiguously in memory
316    pub fn is_contiguous(&self) -> Result<bool> {
317        Ok(self.is_row_major()? || self.is_column_major()?)
318    }
319
320    /// Indicates if the memory layout row major
321    pub fn is_row_major(&self) -> Result<bool> {
322        match self.shape {
323            None => Ok(false),
324            Some(ref s) => Ok(Some(compute_row_major_strides::<T>(s)?) == self.strides),
325        }
326    }
327
328    /// Indicates if the memory layout column major
329    pub fn is_column_major(&self) -> Result<bool> {
330        match self.shape {
331            None => Ok(false),
332            Some(ref s) => Ok(Some(compute_column_major_strides::<T>(s)?) == self.strides),
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    use crate::array::*;
342
343    #[test]
344    fn test_compute_row_major_strides() {
345        assert_eq!(
346            vec![48_usize, 8],
347            compute_row_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
348        );
349        assert_eq!(
350            vec![24_usize, 4],
351            compute_row_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
352        );
353        assert_eq!(
354            vec![6_usize, 1],
355            compute_row_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
356        );
357    }
358
359    #[test]
360    fn test_compute_column_major_strides() {
361        assert_eq!(
362            vec![8_usize, 32],
363            compute_column_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
364        );
365        assert_eq!(
366            vec![4_usize, 16],
367            compute_column_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
368        );
369        assert_eq!(
370            vec![1_usize, 4],
371            compute_column_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
372        );
373    }
374
375    #[test]
376    fn test_zero_dim() {
377        let buf = Buffer::from(&[1]);
378        let tensor = UInt8Tensor::try_new(buf, None, None, None).unwrap();
379        assert_eq!(0, tensor.size());
380        assert_eq!(None, tensor.shape());
381        assert_eq!(None, tensor.names());
382        assert_eq!(0, tensor.ndim());
383        assert!(!tensor.is_row_major().unwrap());
384        assert!(!tensor.is_column_major().unwrap());
385        assert!(!tensor.is_contiguous().unwrap());
386
387        let buf = Buffer::from(&[1, 2, 2, 2]);
388        let tensor = Int32Tensor::try_new(buf, None, None, None).unwrap();
389        assert_eq!(0, tensor.size());
390        assert_eq!(None, tensor.shape());
391        assert_eq!(None, tensor.names());
392        assert_eq!(0, tensor.ndim());
393        assert!(!tensor.is_row_major().unwrap());
394        assert!(!tensor.is_column_major().unwrap());
395        assert!(!tensor.is_contiguous().unwrap());
396    }
397
398    #[test]
399    fn test_tensor() {
400        let mut builder = Int32BufferBuilder::new(16);
401        for i in 0..16 {
402            builder.append(i);
403        }
404        let buf = builder.finish();
405        let tensor = Int32Tensor::try_new(buf, Some(vec![2, 8]), None, None).unwrap();
406        assert_eq!(16, tensor.size());
407        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
408        assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
409        assert_eq!(2, tensor.ndim());
410        assert_eq!(None, tensor.names());
411    }
412
413    #[test]
414    fn test_new_row_major() {
415        let mut builder = Int32BufferBuilder::new(16);
416        for i in 0..16 {
417            builder.append(i);
418        }
419        let buf = builder.finish();
420        let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None).unwrap();
421        assert_eq!(16, tensor.size());
422        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
423        assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
424        assert_eq!(None, tensor.names());
425        assert_eq!(2, tensor.ndim());
426        assert!(tensor.is_row_major().unwrap());
427        assert!(!tensor.is_column_major().unwrap());
428        assert!(tensor.is_contiguous().unwrap());
429    }
430
431    #[test]
432    fn test_new_column_major() {
433        let mut builder = Int32BufferBuilder::new(16);
434        for i in 0..16 {
435            builder.append(i);
436        }
437        let buf = builder.finish();
438        let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None).unwrap();
439        assert_eq!(16, tensor.size());
440        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
441        assert_eq!(Some(vec![4_usize, 8]).as_ref(), tensor.strides());
442        assert_eq!(None, tensor.names());
443        assert_eq!(2, tensor.ndim());
444        assert!(!tensor.is_row_major().unwrap());
445        assert!(tensor.is_column_major().unwrap());
446        assert!(tensor.is_contiguous().unwrap());
447    }
448
449    #[test]
450    fn test_with_names() {
451        let mut builder = Int64BufferBuilder::new(8);
452        for i in 0..8 {
453            builder.append(i);
454        }
455        let buf = builder.finish();
456        let names = vec!["Dim 1", "Dim 2"];
457        let tensor = Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)).unwrap();
458        assert_eq!(8, tensor.size());
459        assert_eq!(Some(vec![2_usize, 4]).as_ref(), tensor.shape());
460        assert_eq!(Some(vec![8_usize, 16]).as_ref(), tensor.strides());
461        assert_eq!("Dim 1", tensor.dim_name(0).unwrap());
462        assert_eq!("Dim 2", tensor.dim_name(1).unwrap());
463        assert_eq!(2, tensor.ndim());
464        assert!(!tensor.is_row_major().unwrap());
465        assert!(tensor.is_column_major().unwrap());
466        assert!(tensor.is_contiguous().unwrap());
467    }
468
469    #[test]
470    fn test_inconsistent_strides() {
471        let mut builder = Int32BufferBuilder::new(16);
472        for i in 0..16 {
473            builder.append(i);
474        }
475        let buf = builder.finish();
476
477        let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None);
478
479        if result.is_ok() {
480            panic!("shape and stride dimensions are different")
481        }
482    }
483
484    #[test]
485    fn test_inconsistent_names() {
486        let mut builder = Int32BufferBuilder::new(16);
487        for i in 0..16 {
488            builder.append(i);
489        }
490        let buf = builder.finish();
491
492        let result = Int32Tensor::try_new(
493            buf,
494            Some(vec![2, 8]),
495            Some(vec![4, 8]),
496            Some(vec!["1", "2", "3"]),
497        );
498
499        if result.is_ok() {
500            panic!("dimensions and names have different shape")
501        }
502    }
503
504    #[test]
505    fn test_incorrect_shape() {
506        let mut builder = Int32BufferBuilder::new(16);
507        for i in 0..16 {
508            builder.append(i);
509        }
510        let buf = builder.finish();
511
512        let result = Int32Tensor::try_new(buf, Some(vec![2, 6]), None, None);
513
514        if result.is_ok() {
515            panic!("number of elements does not match for the shape")
516        }
517    }
518
519    #[test]
520    fn test_incorrect_stride() {
521        let mut builder = Int32BufferBuilder::new(16);
522        for i in 0..16 {
523            builder.append(i);
524        }
525        let buf = builder.finish();
526
527        let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![30, 4]), None);
528
529        if result.is_ok() {
530            panic!("the input stride does not match the selected shape")
531        }
532    }
533}