1use std::marker::PhantomData;
22use std::mem;
23
24use crate::buffer::Buffer;
25use crate::datatypes::*;
26
27use crate::error::{ArrowError, Result};
28
29fn compute_row_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
31 let mut remaining_bytes = mem::size_of::<T::Native>();
32
33 for i in shape {
34 if let Some(val) = remaining_bytes.checked_mul(*i) {
35 remaining_bytes = val;
36 } else {
37 return Err(ArrowError::ComputeError(
38 "overflow occurred when computing row major strides.".to_string(),
39 ));
40 }
41 }
42
43 let mut strides = Vec::<usize>::new();
44 for i in shape {
45 remaining_bytes /= *i;
46 strides.push(remaining_bytes);
47 }
48
49 Ok(strides)
50}
51
52fn compute_column_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
54 let mut remaining_bytes = mem::size_of::<T::Native>();
55 let mut strides = Vec::<usize>::new();
56
57 for i in shape {
58 strides.push(remaining_bytes);
59
60 if let Some(val) = remaining_bytes.checked_mul(*i) {
61 remaining_bytes = val;
62 } else {
63 return Err(ArrowError::ComputeError(
64 "overflow occurred when computing column major strides.".to_string(),
65 ));
66 }
67 }
68
69 Ok(strides)
70}
71
72#[derive(Debug)]
74pub struct Tensor<'a, T: ArrowPrimitiveType> {
75 data_type: DataType,
76 buffer: Buffer,
77 shape: Option<Vec<usize>>,
78 strides: Option<Vec<usize>>,
79 names: Option<Vec<&'a str>>,
80 _marker: PhantomData<T>,
81}
82
83pub type BooleanTensor<'a> = Tensor<'a, BooleanType>;
85pub type Date32Tensor<'a> = Tensor<'a, Date32Type>;
87pub type Date64Tensor<'a> = Tensor<'a, Date64Type>;
89pub type Decimal32Tensor<'a> = Tensor<'a, Decimal32Type>;
91pub type Decimal64Tensor<'a> = Tensor<'a, Decimal64Type>;
93pub type Decimal128Tensor<'a> = Tensor<'a, Decimal128Type>;
95pub type Decimal256Tensor<'a> = Tensor<'a, Decimal256Type>;
97pub type DurationMicrosecondTensor<'a> = Tensor<'a, DurationMicrosecondType>;
99pub type DurationMillisecondTensor<'a> = Tensor<'a, DurationMillisecondType>;
101pub type DurationNanosecondTensor<'a> = Tensor<'a, DurationNanosecondType>;
103pub type DurationSecondTensor<'a> = Tensor<'a, DurationSecondType>;
105pub type Float16Tensor<'a> = Tensor<'a, Float16Type>;
107pub type Float32Tensor<'a> = Tensor<'a, Float32Type>;
109pub type Float64Tensor<'a> = Tensor<'a, Float64Type>;
111pub type Int8Tensor<'a> = Tensor<'a, Int8Type>;
113pub type Int16Tensor<'a> = Tensor<'a, Int16Type>;
115pub type Int32Tensor<'a> = Tensor<'a, Int32Type>;
117pub type Int64Tensor<'a> = Tensor<'a, Int64Type>;
119pub type IntervalDayTimeTensor<'a> = Tensor<'a, IntervalDayTimeType>;
121pub type IntervalMonthDayNanoTensor<'a> = Tensor<'a, IntervalMonthDayNanoType>;
123pub type IntervalYearMonthTensor<'a> = Tensor<'a, IntervalYearMonthType>;
125pub type Time32MillisecondTensor<'a> = Tensor<'a, Time32MillisecondType>;
127pub type Time32SecondTensor<'a> = Tensor<'a, Time32SecondType>;
129pub type Time64MicrosecondTensor<'a> = Tensor<'a, Time64MicrosecondType>;
131pub type Time64NanosecondTensor<'a> = Tensor<'a, Time64NanosecondType>;
133pub type TimestampMicrosecondTensor<'a> = Tensor<'a, TimestampMicrosecondType>;
135pub type TimestampMillisecondTensor<'a> = Tensor<'a, TimestampMillisecondType>;
137pub type TimestampNanosecondTensor<'a> = Tensor<'a, TimestampNanosecondType>;
139pub type TimestampSecondTensor<'a> = Tensor<'a, TimestampSecondType>;
141pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>;
143pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>;
145pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>;
147pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>;
149
150impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> {
151 pub fn try_new(
153 buffer: Buffer,
154 shape: Option<Vec<usize>>,
155 strides: Option<Vec<usize>>,
156 names: Option<Vec<&'a str>>,
157 ) -> Result<Self> {
158 match shape {
159 None => {
160 if buffer.len() != mem::size_of::<T::Native>() {
161 return Err(ArrowError::InvalidArgumentError(
162 "underlying buffer should only contain a single tensor element".to_string(),
163 ));
164 }
165
166 if strides.is_some() {
167 return Err(ArrowError::InvalidArgumentError(
168 "expected None strides for tensor with no shape".to_string(),
169 ));
170 }
171
172 if names.is_some() {
173 return Err(ArrowError::InvalidArgumentError(
174 "expected None names for tensor with no shape".to_string(),
175 ));
176 }
177 }
178
179 Some(ref s) => {
180 if let Some(ref st) = strides {
181 if st.len() != s.len() {
182 return Err(ArrowError::InvalidArgumentError(
183 "shape and stride dimensions differ".to_string(),
184 ));
185 }
186 }
187
188 if let Some(ref n) = names {
189 if n.len() != s.len() {
190 return Err(ArrowError::InvalidArgumentError(
191 "number of dimensions and number of dimension names differ".to_string(),
192 ));
193 }
194 }
195
196 let total_elements: usize = s.iter().product();
197 if total_elements != (buffer.len() / mem::size_of::<T::Native>()) {
198 return Err(ArrowError::InvalidArgumentError(
199 "number of elements in buffer does not match dimensions".to_string(),
200 ));
201 }
202 }
203 };
204
205 let tensor_strides = {
208 if let Some(st) = strides {
209 if let Some(ref s) = shape {
210 if compute_row_major_strides::<T>(s)? == st
211 || compute_column_major_strides::<T>(s)? == st
212 {
213 Some(st)
214 } else {
215 return Err(ArrowError::InvalidArgumentError(
216 "the input stride does not match the selected shape".to_string(),
217 ));
218 }
219 } else {
220 Some(st)
221 }
222 } else if let Some(ref s) = shape {
223 Some(compute_row_major_strides::<T>(s)?)
224 } else {
225 None
226 }
227 };
228
229 Ok(Self {
230 data_type: T::DATA_TYPE,
231 buffer,
232 shape,
233 strides: tensor_strides,
234 names,
235 _marker: PhantomData,
236 })
237 }
238
239 pub fn new_row_major(
241 buffer: Buffer,
242 shape: Option<Vec<usize>>,
243 names: Option<Vec<&'a str>>,
244 ) -> Result<Self> {
245 if let Some(ref s) = shape {
246 let strides = Some(compute_row_major_strides::<T>(s)?);
247
248 Self::try_new(buffer, shape, strides, names)
249 } else {
250 Err(ArrowError::InvalidArgumentError(
251 "shape required to create row major tensor".to_string(),
252 ))
253 }
254 }
255
256 pub fn new_column_major(
258 buffer: Buffer,
259 shape: Option<Vec<usize>>,
260 names: Option<Vec<&'a str>>,
261 ) -> Result<Self> {
262 if let Some(ref s) = shape {
263 let strides = Some(compute_column_major_strides::<T>(s)?);
264
265 Self::try_new(buffer, shape, strides, names)
266 } else {
267 Err(ArrowError::InvalidArgumentError(
268 "shape required to create column major tensor".to_string(),
269 ))
270 }
271 }
272
273 pub fn data_type(&self) -> &DataType {
275 &self.data_type
276 }
277
278 pub fn shape(&self) -> Option<&Vec<usize>> {
280 self.shape.as_ref()
281 }
282
283 pub fn data(&self) -> &Buffer {
285 &self.buffer
286 }
287
288 pub fn strides(&self) -> Option<&Vec<usize>> {
290 self.strides.as_ref()
291 }
292
293 pub fn names(&self) -> Option<&Vec<&'a str>> {
295 self.names.as_ref()
296 }
297
298 pub fn ndim(&self) -> usize {
300 match &self.shape {
301 None => 0,
302 Some(v) => v.len(),
303 }
304 }
305
306 pub fn dim_name(&self, i: usize) -> Option<&'a str> {
308 self.names.as_ref().map(|names| names[i])
309 }
310
311 pub fn size(&self) -> usize {
313 match self.shape {
314 None => 0,
315 Some(ref s) => s.iter().product(),
316 }
317 }
318
319 pub fn is_contiguous(&self) -> Result<bool> {
321 Ok(self.is_row_major()? || self.is_column_major()?)
322 }
323
324 pub fn is_row_major(&self) -> Result<bool> {
326 match self.shape {
327 None => Ok(false),
328 Some(ref s) => Ok(Some(compute_row_major_strides::<T>(s)?) == self.strides),
329 }
330 }
331
332 pub fn is_column_major(&self) -> Result<bool> {
334 match self.shape {
335 None => Ok(false),
336 Some(ref s) => Ok(Some(compute_column_major_strides::<T>(s)?) == self.strides),
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 use crate::array::*;
346
347 #[test]
348 fn test_compute_row_major_strides() {
349 assert_eq!(
350 vec![48_usize, 8],
351 compute_row_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
352 );
353 assert_eq!(
354 vec![24_usize, 4],
355 compute_row_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
356 );
357 assert_eq!(
358 vec![6_usize, 1],
359 compute_row_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
360 );
361 }
362
363 #[test]
364 fn test_compute_column_major_strides() {
365 assert_eq!(
366 vec![8_usize, 32],
367 compute_column_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
368 );
369 assert_eq!(
370 vec![4_usize, 16],
371 compute_column_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
372 );
373 assert_eq!(
374 vec![1_usize, 4],
375 compute_column_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
376 );
377 }
378
379 #[test]
380 fn test_zero_dim() {
381 let buf = Buffer::from(&[1]);
382 let tensor = UInt8Tensor::try_new(buf, None, None, None).unwrap();
383 assert_eq!(0, tensor.size());
384 assert_eq!(None, tensor.shape());
385 assert_eq!(None, tensor.names());
386 assert_eq!(0, tensor.ndim());
387 assert!(!tensor.is_row_major().unwrap());
388 assert!(!tensor.is_column_major().unwrap());
389 assert!(!tensor.is_contiguous().unwrap());
390
391 let buf = Buffer::from(&[1, 2, 2, 2]);
392 let tensor = Int32Tensor::try_new(buf, None, None, None).unwrap();
393 assert_eq!(0, tensor.size());
394 assert_eq!(None, tensor.shape());
395 assert_eq!(None, tensor.names());
396 assert_eq!(0, tensor.ndim());
397 assert!(!tensor.is_row_major().unwrap());
398 assert!(!tensor.is_column_major().unwrap());
399 assert!(!tensor.is_contiguous().unwrap());
400 }
401
402 #[test]
403 fn test_tensor() {
404 let mut builder = Int32BufferBuilder::new(16);
405 for i in 0..16 {
406 builder.append(i);
407 }
408 let buf = builder.finish();
409 let tensor = Int32Tensor::try_new(buf, Some(vec![2, 8]), None, None).unwrap();
410 assert_eq!(16, tensor.size());
411 assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
412 assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
413 assert_eq!(2, tensor.ndim());
414 assert_eq!(None, tensor.names());
415 }
416
417 #[test]
418 fn test_new_row_major() {
419 let mut builder = Int32BufferBuilder::new(16);
420 for i in 0..16 {
421 builder.append(i);
422 }
423 let buf = builder.finish();
424 let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None).unwrap();
425 assert_eq!(16, tensor.size());
426 assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
427 assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
428 assert_eq!(None, tensor.names());
429 assert_eq!(2, tensor.ndim());
430 assert!(tensor.is_row_major().unwrap());
431 assert!(!tensor.is_column_major().unwrap());
432 assert!(tensor.is_contiguous().unwrap());
433 }
434
435 #[test]
436 fn test_new_column_major() {
437 let mut builder = Int32BufferBuilder::new(16);
438 for i in 0..16 {
439 builder.append(i);
440 }
441 let buf = builder.finish();
442 let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None).unwrap();
443 assert_eq!(16, tensor.size());
444 assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
445 assert_eq!(Some(vec![4_usize, 8]).as_ref(), tensor.strides());
446 assert_eq!(None, tensor.names());
447 assert_eq!(2, tensor.ndim());
448 assert!(!tensor.is_row_major().unwrap());
449 assert!(tensor.is_column_major().unwrap());
450 assert!(tensor.is_contiguous().unwrap());
451 }
452
453 #[test]
454 fn test_with_names() {
455 let mut builder = Int64BufferBuilder::new(8);
456 for i in 0..8 {
457 builder.append(i);
458 }
459 let buf = builder.finish();
460 let names = vec!["Dim 1", "Dim 2"];
461 let tensor = Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)).unwrap();
462 assert_eq!(8, tensor.size());
463 assert_eq!(Some(vec![2_usize, 4]).as_ref(), tensor.shape());
464 assert_eq!(Some(vec![8_usize, 16]).as_ref(), tensor.strides());
465 assert_eq!("Dim 1", tensor.dim_name(0).unwrap());
466 assert_eq!("Dim 2", tensor.dim_name(1).unwrap());
467 assert_eq!(2, tensor.ndim());
468 assert!(!tensor.is_row_major().unwrap());
469 assert!(tensor.is_column_major().unwrap());
470 assert!(tensor.is_contiguous().unwrap());
471 }
472
473 #[test]
474 fn test_inconsistent_strides() {
475 let mut builder = Int32BufferBuilder::new(16);
476 for i in 0..16 {
477 builder.append(i);
478 }
479 let buf = builder.finish();
480
481 let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None);
482
483 if result.is_ok() {
484 panic!("shape and stride dimensions are different")
485 }
486 }
487
488 #[test]
489 fn test_inconsistent_names() {
490 let mut builder = Int32BufferBuilder::new(16);
491 for i in 0..16 {
492 builder.append(i);
493 }
494 let buf = builder.finish();
495
496 let result = Int32Tensor::try_new(
497 buf,
498 Some(vec![2, 8]),
499 Some(vec![4, 8]),
500 Some(vec!["1", "2", "3"]),
501 );
502
503 if result.is_ok() {
504 panic!("dimensions and names have different shape")
505 }
506 }
507
508 #[test]
509 fn test_incorrect_shape() {
510 let mut builder = Int32BufferBuilder::new(16);
511 for i in 0..16 {
512 builder.append(i);
513 }
514 let buf = builder.finish();
515
516 let result = Int32Tensor::try_new(buf, Some(vec![2, 6]), None, None);
517
518 if result.is_ok() {
519 panic!("number of elements does not match for the shape")
520 }
521 }
522
523 #[test]
524 fn test_incorrect_stride() {
525 let mut builder = Int32BufferBuilder::new(16);
526 for i in 0..16 {
527 builder.append(i);
528 }
529 let buf = builder.finish();
530
531 let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![30, 4]), None);
532
533 if result.is_ok() {
534 panic!("the input stride does not match the selected shape")
535 }
536 }
537}