1use std::marker::PhantomData;
22use std::mem;
23
24use crate::buffer::Buffer;
25use crate::datatypes::*;
26
27use crate::error::{ArrowError, Result};
28
29fn compute_row_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
31 let mut remaining_bytes = mem::size_of::<T::Native>();
32
33 for i in shape {
34 if let Some(val) = remaining_bytes.checked_mul(*i) {
35 remaining_bytes = val;
36 } else {
37 return Err(ArrowError::ComputeError(
38 "overflow occurred when computing row major strides.".to_string(),
39 ));
40 }
41 }
42
43 let mut strides = Vec::<usize>::new();
44 for i in shape {
45 remaining_bytes /= *i;
46 strides.push(remaining_bytes);
47 }
48
49 Ok(strides)
50}
51
52fn compute_column_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
54 let mut remaining_bytes = mem::size_of::<T::Native>();
55 let mut strides = Vec::<usize>::new();
56
57 for i in shape {
58 strides.push(remaining_bytes);
59
60 if let Some(val) = remaining_bytes.checked_mul(*i) {
61 remaining_bytes = val;
62 } else {
63 return Err(ArrowError::ComputeError(
64 "overflow occurred when computing column major strides.".to_string(),
65 ));
66 }
67 }
68
69 Ok(strides)
70}
71
72#[derive(Debug)]
74pub struct Tensor<'a, T: ArrowPrimitiveType> {
75 data_type: DataType,
76 buffer: Buffer,
77 shape: Option<Vec<usize>>,
78 strides: Option<Vec<usize>>,
79 names: Option<Vec<&'a str>>,
80 _marker: PhantomData<T>,
81}
82
83pub type BooleanTensor<'a> = Tensor<'a, BooleanType>;
85pub type Date32Tensor<'a> = Tensor<'a, Date32Type>;
87pub type Date64Tensor<'a> = Tensor<'a, Date64Type>;
89pub type Decimal128Tensor<'a> = Tensor<'a, Decimal128Type>;
91pub type Decimal256Tensor<'a> = Tensor<'a, Decimal256Type>;
93pub type DurationMicrosecondTensor<'a> = Tensor<'a, DurationMicrosecondType>;
95pub type DurationMillisecondTensor<'a> = Tensor<'a, DurationMillisecondType>;
97pub type DurationNanosecondTensor<'a> = Tensor<'a, DurationNanosecondType>;
99pub type DurationSecondTensor<'a> = Tensor<'a, DurationSecondType>;
101pub type Float16Tensor<'a> = Tensor<'a, Float16Type>;
103pub type Float32Tensor<'a> = Tensor<'a, Float32Type>;
105pub type Float64Tensor<'a> = Tensor<'a, Float64Type>;
107pub type Int8Tensor<'a> = Tensor<'a, Int8Type>;
109pub type Int16Tensor<'a> = Tensor<'a, Int16Type>;
111pub type Int32Tensor<'a> = Tensor<'a, Int32Type>;
113pub type Int64Tensor<'a> = Tensor<'a, Int64Type>;
115pub type IntervalDayTimeTensor<'a> = Tensor<'a, IntervalDayTimeType>;
117pub type IntervalMonthDayNanoTensor<'a> = Tensor<'a, IntervalMonthDayNanoType>;
119pub type IntervalYearMonthTensor<'a> = Tensor<'a, IntervalYearMonthType>;
121pub type Time32MillisecondTensor<'a> = Tensor<'a, Time32MillisecondType>;
123pub type Time32SecondTensor<'a> = Tensor<'a, Time32SecondType>;
125pub type Time64MicrosecondTensor<'a> = Tensor<'a, Time64MicrosecondType>;
127pub type Time64NanosecondTensor<'a> = Tensor<'a, Time64NanosecondType>;
129pub type TimestampMicrosecondTensor<'a> = Tensor<'a, TimestampMicrosecondType>;
131pub type TimestampMillisecondTensor<'a> = Tensor<'a, TimestampMillisecondType>;
133pub type TimestampNanosecondTensor<'a> = Tensor<'a, TimestampNanosecondType>;
135pub type TimestampSecondTensor<'a> = Tensor<'a, TimestampSecondType>;
137pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>;
139pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>;
141pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>;
143pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>;
145
146impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> {
147 pub fn try_new(
149 buffer: Buffer,
150 shape: Option<Vec<usize>>,
151 strides: Option<Vec<usize>>,
152 names: Option<Vec<&'a str>>,
153 ) -> Result<Self> {
154 match shape {
155 None => {
156 if buffer.len() != mem::size_of::<T::Native>() {
157 return Err(ArrowError::InvalidArgumentError(
158 "underlying buffer should only contain a single tensor element".to_string(),
159 ));
160 }
161
162 if strides.is_some() {
163 return Err(ArrowError::InvalidArgumentError(
164 "expected None strides for tensor with no shape".to_string(),
165 ));
166 }
167
168 if names.is_some() {
169 return Err(ArrowError::InvalidArgumentError(
170 "expected None names for tensor with no shape".to_string(),
171 ));
172 }
173 }
174
175 Some(ref s) => {
176 if let Some(ref st) = strides {
177 if st.len() != s.len() {
178 return Err(ArrowError::InvalidArgumentError(
179 "shape and stride dimensions differ".to_string(),
180 ));
181 }
182 }
183
184 if let Some(ref n) = names {
185 if n.len() != s.len() {
186 return Err(ArrowError::InvalidArgumentError(
187 "number of dimensions and number of dimension names differ".to_string(),
188 ));
189 }
190 }
191
192 let total_elements: usize = s.iter().product();
193 if total_elements != (buffer.len() / mem::size_of::<T::Native>()) {
194 return Err(ArrowError::InvalidArgumentError(
195 "number of elements in buffer does not match dimensions".to_string(),
196 ));
197 }
198 }
199 };
200
201 let tensor_strides = {
204 if let Some(st) = strides {
205 if let Some(ref s) = shape {
206 if compute_row_major_strides::<T>(s)? == st
207 || compute_column_major_strides::<T>(s)? == st
208 {
209 Some(st)
210 } else {
211 return Err(ArrowError::InvalidArgumentError(
212 "the input stride does not match the selected shape".to_string(),
213 ));
214 }
215 } else {
216 Some(st)
217 }
218 } else if let Some(ref s) = shape {
219 Some(compute_row_major_strides::<T>(s)?)
220 } else {
221 None
222 }
223 };
224
225 Ok(Self {
226 data_type: T::DATA_TYPE,
227 buffer,
228 shape,
229 strides: tensor_strides,
230 names,
231 _marker: PhantomData,
232 })
233 }
234
235 pub fn new_row_major(
237 buffer: Buffer,
238 shape: Option<Vec<usize>>,
239 names: Option<Vec<&'a str>>,
240 ) -> Result<Self> {
241 if let Some(ref s) = shape {
242 let strides = Some(compute_row_major_strides::<T>(s)?);
243
244 Self::try_new(buffer, shape, strides, names)
245 } else {
246 Err(ArrowError::InvalidArgumentError(
247 "shape required to create row major tensor".to_string(),
248 ))
249 }
250 }
251
252 pub fn new_column_major(
254 buffer: Buffer,
255 shape: Option<Vec<usize>>,
256 names: Option<Vec<&'a str>>,
257 ) -> Result<Self> {
258 if let Some(ref s) = shape {
259 let strides = Some(compute_column_major_strides::<T>(s)?);
260
261 Self::try_new(buffer, shape, strides, names)
262 } else {
263 Err(ArrowError::InvalidArgumentError(
264 "shape required to create column major tensor".to_string(),
265 ))
266 }
267 }
268
269 pub fn data_type(&self) -> &DataType {
271 &self.data_type
272 }
273
274 pub fn shape(&self) -> Option<&Vec<usize>> {
276 self.shape.as_ref()
277 }
278
279 pub fn data(&self) -> &Buffer {
281 &self.buffer
282 }
283
284 pub fn strides(&self) -> Option<&Vec<usize>> {
286 self.strides.as_ref()
287 }
288
289 pub fn names(&self) -> Option<&Vec<&'a str>> {
291 self.names.as_ref()
292 }
293
294 pub fn ndim(&self) -> usize {
296 match &self.shape {
297 None => 0,
298 Some(v) => v.len(),
299 }
300 }
301
302 pub fn dim_name(&self, i: usize) -> Option<&'a str> {
304 self.names.as_ref().map(|names| names[i])
305 }
306
307 pub fn size(&self) -> usize {
309 match self.shape {
310 None => 0,
311 Some(ref s) => s.iter().product(),
312 }
313 }
314
315 pub fn is_contiguous(&self) -> Result<bool> {
317 Ok(self.is_row_major()? || self.is_column_major()?)
318 }
319
320 pub fn is_row_major(&self) -> Result<bool> {
322 match self.shape {
323 None => Ok(false),
324 Some(ref s) => Ok(Some(compute_row_major_strides::<T>(s)?) == self.strides),
325 }
326 }
327
328 pub fn is_column_major(&self) -> Result<bool> {
330 match self.shape {
331 None => Ok(false),
332 Some(ref s) => Ok(Some(compute_column_major_strides::<T>(s)?) == self.strides),
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 use crate::array::*;
342
343 #[test]
344 fn test_compute_row_major_strides() {
345 assert_eq!(
346 vec![48_usize, 8],
347 compute_row_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
348 );
349 assert_eq!(
350 vec![24_usize, 4],
351 compute_row_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
352 );
353 assert_eq!(
354 vec![6_usize, 1],
355 compute_row_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
356 );
357 }
358
359 #[test]
360 fn test_compute_column_major_strides() {
361 assert_eq!(
362 vec![8_usize, 32],
363 compute_column_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
364 );
365 assert_eq!(
366 vec![4_usize, 16],
367 compute_column_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
368 );
369 assert_eq!(
370 vec![1_usize, 4],
371 compute_column_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
372 );
373 }
374
375 #[test]
376 fn test_zero_dim() {
377 let buf = Buffer::from(&[1]);
378 let tensor = UInt8Tensor::try_new(buf, None, None, None).unwrap();
379 assert_eq!(0, tensor.size());
380 assert_eq!(None, tensor.shape());
381 assert_eq!(None, tensor.names());
382 assert_eq!(0, tensor.ndim());
383 assert!(!tensor.is_row_major().unwrap());
384 assert!(!tensor.is_column_major().unwrap());
385 assert!(!tensor.is_contiguous().unwrap());
386
387 let buf = Buffer::from(&[1, 2, 2, 2]);
388 let tensor = Int32Tensor::try_new(buf, None, None, None).unwrap();
389 assert_eq!(0, tensor.size());
390 assert_eq!(None, tensor.shape());
391 assert_eq!(None, tensor.names());
392 assert_eq!(0, tensor.ndim());
393 assert!(!tensor.is_row_major().unwrap());
394 assert!(!tensor.is_column_major().unwrap());
395 assert!(!tensor.is_contiguous().unwrap());
396 }
397
398 #[test]
399 fn test_tensor() {
400 let mut builder = Int32BufferBuilder::new(16);
401 for i in 0..16 {
402 builder.append(i);
403 }
404 let buf = builder.finish();
405 let tensor = Int32Tensor::try_new(buf, Some(vec![2, 8]), None, None).unwrap();
406 assert_eq!(16, tensor.size());
407 assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
408 assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
409 assert_eq!(2, tensor.ndim());
410 assert_eq!(None, tensor.names());
411 }
412
413 #[test]
414 fn test_new_row_major() {
415 let mut builder = Int32BufferBuilder::new(16);
416 for i in 0..16 {
417 builder.append(i);
418 }
419 let buf = builder.finish();
420 let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None).unwrap();
421 assert_eq!(16, tensor.size());
422 assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
423 assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
424 assert_eq!(None, tensor.names());
425 assert_eq!(2, tensor.ndim());
426 assert!(tensor.is_row_major().unwrap());
427 assert!(!tensor.is_column_major().unwrap());
428 assert!(tensor.is_contiguous().unwrap());
429 }
430
431 #[test]
432 fn test_new_column_major() {
433 let mut builder = Int32BufferBuilder::new(16);
434 for i in 0..16 {
435 builder.append(i);
436 }
437 let buf = builder.finish();
438 let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None).unwrap();
439 assert_eq!(16, tensor.size());
440 assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
441 assert_eq!(Some(vec![4_usize, 8]).as_ref(), tensor.strides());
442 assert_eq!(None, tensor.names());
443 assert_eq!(2, tensor.ndim());
444 assert!(!tensor.is_row_major().unwrap());
445 assert!(tensor.is_column_major().unwrap());
446 assert!(tensor.is_contiguous().unwrap());
447 }
448
449 #[test]
450 fn test_with_names() {
451 let mut builder = Int64BufferBuilder::new(8);
452 for i in 0..8 {
453 builder.append(i);
454 }
455 let buf = builder.finish();
456 let names = vec!["Dim 1", "Dim 2"];
457 let tensor = Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)).unwrap();
458 assert_eq!(8, tensor.size());
459 assert_eq!(Some(vec![2_usize, 4]).as_ref(), tensor.shape());
460 assert_eq!(Some(vec![8_usize, 16]).as_ref(), tensor.strides());
461 assert_eq!("Dim 1", tensor.dim_name(0).unwrap());
462 assert_eq!("Dim 2", tensor.dim_name(1).unwrap());
463 assert_eq!(2, tensor.ndim());
464 assert!(!tensor.is_row_major().unwrap());
465 assert!(tensor.is_column_major().unwrap());
466 assert!(tensor.is_contiguous().unwrap());
467 }
468
469 #[test]
470 fn test_inconsistent_strides() {
471 let mut builder = Int32BufferBuilder::new(16);
472 for i in 0..16 {
473 builder.append(i);
474 }
475 let buf = builder.finish();
476
477 let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None);
478
479 if result.is_ok() {
480 panic!("shape and stride dimensions are different")
481 }
482 }
483
484 #[test]
485 fn test_inconsistent_names() {
486 let mut builder = Int32BufferBuilder::new(16);
487 for i in 0..16 {
488 builder.append(i);
489 }
490 let buf = builder.finish();
491
492 let result = Int32Tensor::try_new(
493 buf,
494 Some(vec![2, 8]),
495 Some(vec![4, 8]),
496 Some(vec!["1", "2", "3"]),
497 );
498
499 if result.is_ok() {
500 panic!("dimensions and names have different shape")
501 }
502 }
503
504 #[test]
505 fn test_incorrect_shape() {
506 let mut builder = Int32BufferBuilder::new(16);
507 for i in 0..16 {
508 builder.append(i);
509 }
510 let buf = builder.finish();
511
512 let result = Int32Tensor::try_new(buf, Some(vec![2, 6]), None, None);
513
514 if result.is_ok() {
515 panic!("number of elements does not match for the shape")
516 }
517 }
518
519 #[test]
520 fn test_incorrect_stride() {
521 let mut builder = Int32BufferBuilder::new(16);
522 for i in 0..16 {
523 builder.append(i);
524 }
525 let buf = builder.finish();
526
527 let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![30, 4]), None);
528
529 if result.is_ok() {
530 panic!("the input stride does not match the selected shape")
531 }
532 }
533}