1use serde_core::de::{self, MapAccess, Visitor};
23use serde_core::ser::SerializeStruct;
24use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
25use std::fmt;
26
27use crate::{ArrowError, DataType, extension::ExtensionType};
28
29#[derive(Debug, Clone, PartialEq)]
76pub struct FixedShapeTensor {
77 value_type: DataType,
79
80 metadata: FixedShapeTensorMetadata,
82}
83
84impl FixedShapeTensor {
85 pub fn try_new(
92 value_type: DataType,
93 shape: impl IntoIterator<Item = usize>,
94 dimension_names: Option<Vec<String>>,
95 permutations: Option<Vec<usize>>,
96 ) -> Result<Self, ArrowError> {
97 FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| {
99 Self {
100 value_type,
101 metadata,
102 }
103 })
104 }
105
106 pub fn value_type(&self) -> &DataType {
108 &self.value_type
109 }
110
111 pub fn list_size(&self) -> usize {
113 self.metadata.list_size()
114 }
115
116 pub fn dimensions(&self) -> usize {
118 self.metadata.dimensions()
119 }
120
121 pub fn dimension_names(&self) -> Option<&[String]> {
124 self.metadata.dimension_names()
125 }
126
127 pub fn permutations(&self) -> Option<&[usize]> {
130 self.metadata.permutations()
131 }
132}
133
134#[derive(Debug, Clone, PartialEq)]
136pub struct FixedShapeTensorMetadata {
137 shape: Vec<usize>,
139
140 dim_names: Option<Vec<String>>,
142
143 permutations: Option<Vec<usize>>,
145}
146
147impl Serialize for FixedShapeTensorMetadata {
148 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149 where
150 S: Serializer,
151 {
152 let mut state = serializer.serialize_struct("FixedShapeTensorMetadata", 3)?;
153 state.serialize_field("shape", &self.shape)?;
154 state.serialize_field("dim_names", &self.dim_names)?;
155 state.serialize_field("permutations", &self.permutations)?;
156 state.end()
157 }
158}
159
160#[derive(Debug)]
161enum MetadataField {
162 Shape,
163 DimNames,
164 Permutations,
165}
166
167struct MetadataFieldVisitor;
168
169impl<'de> Visitor<'de> for MetadataFieldVisitor {
170 type Value = MetadataField;
171
172 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
173 formatter.write_str("`shape`, `dim_names`, or `permutations`")
174 }
175
176 fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
177 where
178 E: de::Error,
179 {
180 match value {
181 "shape" => Ok(MetadataField::Shape),
182 "dim_names" => Ok(MetadataField::DimNames),
183 "permutations" => Ok(MetadataField::Permutations),
184 _ => Err(de::Error::unknown_field(
185 value,
186 &["shape", "dim_names", "permutations"],
187 )),
188 }
189 }
190}
191
192impl<'de> Deserialize<'de> for MetadataField {
193 fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
194 where
195 D: Deserializer<'de>,
196 {
197 deserializer.deserialize_identifier(MetadataFieldVisitor)
198 }
199}
200
201struct FixedShapeTensorMetadataVisitor;
202
203impl<'de> Visitor<'de> for FixedShapeTensorMetadataVisitor {
204 type Value = FixedShapeTensorMetadata;
205
206 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
207 formatter.write_str("struct FixedShapeTensorMetadata")
208 }
209
210 fn visit_seq<V>(self, mut seq: V) -> Result<FixedShapeTensorMetadata, V::Error>
211 where
212 V: de::SeqAccess<'de>,
213 {
214 let shape = seq
215 .next_element()?
216 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
217 let dim_names = seq
218 .next_element()?
219 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
220 let permutations = seq
221 .next_element()?
222 .ok_or_else(|| de::Error::invalid_length(2, &self))?;
223 Ok(FixedShapeTensorMetadata {
224 shape,
225 dim_names,
226 permutations,
227 })
228 }
229
230 fn visit_map<V>(self, mut map: V) -> Result<FixedShapeTensorMetadata, V::Error>
231 where
232 V: MapAccess<'de>,
233 {
234 let mut shape = None;
235 let mut dim_names = None;
236 let mut permutations = None;
237
238 while let Some(key) = map.next_key()? {
239 match key {
240 MetadataField::Shape => {
241 if shape.is_some() {
242 return Err(de::Error::duplicate_field("shape"));
243 }
244 shape = Some(map.next_value()?);
245 }
246 MetadataField::DimNames => {
247 if dim_names.is_some() {
248 return Err(de::Error::duplicate_field("dim_names"));
249 }
250 dim_names = Some(map.next_value()?);
251 }
252 MetadataField::Permutations => {
253 if permutations.is_some() {
254 return Err(de::Error::duplicate_field("permutations"));
255 }
256 permutations = Some(map.next_value()?);
257 }
258 }
259 }
260
261 let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
262
263 Ok(FixedShapeTensorMetadata {
264 shape,
265 dim_names,
266 permutations,
267 })
268 }
269}
270
271impl<'de> Deserialize<'de> for FixedShapeTensorMetadata {
272 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273 where
274 D: Deserializer<'de>,
275 {
276 deserializer.deserialize_struct(
277 "FixedShapeTensorMetadata",
278 &["shape", "dim_names", "permutations"],
279 FixedShapeTensorMetadataVisitor,
280 )
281 }
282}
283
284impl FixedShapeTensorMetadata {
285 pub fn try_new(
292 shape: impl IntoIterator<Item = usize>,
293 dimension_names: Option<Vec<String>>,
294 permutations: Option<Vec<usize>>,
295 ) -> Result<Self, ArrowError> {
296 let shape = shape.into_iter().collect::<Vec<_>>();
297 let dimensions = shape.len();
298
299 let dim_names = dimension_names.map(|dimension_names| {
300 if dimension_names.len() != dimensions {
301 Err(ArrowError::InvalidArgumentError(format!(
302 "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
303 )))
304 } else {
305 Ok(dimension_names)
306 }
307 }).transpose()?;
308
309 let permutations = permutations
310 .map(|permutations| {
311 if permutations.len() != dimensions {
312 Err(ArrowError::InvalidArgumentError(format!(
313 "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}",
314 permutations.len()
315 )))
316 } else {
317 let mut sorted_permutations = permutations.clone();
318 sorted_permutations.sort_unstable();
319 if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
320 Err(ArrowError::InvalidArgumentError(format!(
321 "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
322 )))
323 } else {
324 Ok(permutations)
325 }
326 }
327 })
328 .transpose()?;
329
330 Ok(Self {
331 shape,
332 dim_names,
333 permutations,
334 })
335 }
336
337 pub fn list_size(&self) -> usize {
339 self.shape.iter().product()
340 }
341
342 pub fn dimensions(&self) -> usize {
344 self.shape.len()
345 }
346
347 pub fn dimension_names(&self) -> Option<&[String]> {
350 self.dim_names.as_ref().map(AsRef::as_ref)
351 }
352
353 pub fn permutations(&self) -> Option<&[usize]> {
356 self.permutations.as_ref().map(AsRef::as_ref)
357 }
358}
359
360impl ExtensionType for FixedShapeTensor {
361 const NAME: &'static str = "arrow.fixed_shape_tensor";
362
363 type Metadata = FixedShapeTensorMetadata;
364
365 fn metadata(&self) -> &Self::Metadata {
366 &self.metadata
367 }
368
369 fn serialize_metadata(&self) -> Option<String> {
370 Some(serde_json::to_string(&self.metadata).expect("metadata serialization"))
371 }
372
373 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
374 metadata.map_or_else(
375 || {
376 Err(ArrowError::InvalidArgumentError(
377 "FixedShapeTensor extension types requires metadata".to_owned(),
378 ))
379 },
380 |value| {
381 serde_json::from_str(value).map_err(|e| {
382 ArrowError::InvalidArgumentError(format!(
383 "FixedShapeTensor metadata deserialization failed: {e}"
384 ))
385 })
386 },
387 )
388 }
389
390 fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
391 let expected = DataType::new_fixed_size_list(
392 self.value_type.clone(),
393 i32::try_from(self.list_size()).expect("overflow"),
394 false,
395 );
396 data_type
397 .equals_datatype(&expected)
398 .then_some(())
399 .ok_or_else(|| {
400 ArrowError::InvalidArgumentError(format!(
401 "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}"
402 ))
403 })
404 }
405
406 fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
407 match data_type {
408 DataType::FixedSizeList(field, list_size) if !field.is_nullable() => {
409 let metadata = FixedShapeTensorMetadata::try_new(
411 metadata.shape,
412 metadata.dim_names,
413 metadata.permutations,
414 )?;
415 let expected_size = i32::try_from(metadata.list_size()).expect("overflow");
417 if *list_size != expected_size {
418 Err(ArrowError::InvalidArgumentError(format!(
419 "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)"
420 )))
421 } else {
422 Ok(Self {
423 value_type: field.data_type().clone(),
424 metadata,
425 })
426 }
427 }
428 data_type => Err(ArrowError::InvalidArgumentError(format!(
429 "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}"
430 ))),
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 #[cfg(feature = "canonical_extension_types")]
438 use crate::extension::CanonicalExtensionType;
439 use crate::{
440 Field,
441 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
442 };
443
444 use super::*;
445
446 #[test]
447 fn valid() -> Result<(), ArrowError> {
448 let fixed_shape_tensor = FixedShapeTensor::try_new(
449 DataType::Float32,
450 [100, 200, 500],
451 Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
452 Some(vec![2, 0, 1]),
453 )?;
454 let mut field = Field::new_fixed_size_list(
455 "",
456 Field::new("", DataType::Float32, false),
457 i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
458 false,
459 );
460 field.try_with_extension_type(fixed_shape_tensor.clone())?;
461 assert_eq!(
462 field.try_extension_type::<FixedShapeTensor>()?,
463 fixed_shape_tensor
464 );
465 #[cfg(feature = "canonical_extension_types")]
466 assert_eq!(
467 field.try_canonical_extension_type()?,
468 CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor)
469 );
470 Ok(())
471 }
472
473 #[test]
474 #[should_panic(expected = "Field extension type name missing")]
475 fn missing_name() {
476 let field =
477 Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
478 .with_metadata(
479 [(
480 EXTENSION_TYPE_METADATA_KEY.to_owned(),
481 r#"{ "shape": [100, 200, 500], }"#.to_owned(),
482 )]
483 .into_iter()
484 .collect(),
485 );
486 field.extension_type::<FixedShapeTensor>();
487 }
488
489 #[test]
490 #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")]
491 fn invalid_type() {
492 let fixed_shape_tensor =
493 FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap();
494 let field = Field::new_fixed_size_list(
495 "",
496 Field::new("", DataType::Float32, false),
497 i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
498 false,
499 );
500 field.with_extension_type(fixed_shape_tensor);
501 }
502
503 #[test]
504 #[should_panic(expected = "FixedShapeTensor extension types requires metadata")]
505 fn missing_metadata() {
506 let field =
507 Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
508 .with_metadata(
509 [(
510 EXTENSION_TYPE_NAME_KEY.to_owned(),
511 FixedShapeTensor::NAME.to_owned(),
512 )]
513 .into_iter()
514 .collect(),
515 );
516 field.extension_type::<FixedShapeTensor>();
517 }
518
519 #[test]
520 #[should_panic(expected = "FixedShapeTensor metadata deserialization failed: \
521 unknown field `not-shape`, expected one of `shape`, `dim_names`, `permutations`")]
522 fn invalid_metadata() {
523 let fixed_shape_tensor =
524 FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap();
525 let field = Field::new_fixed_size_list(
526 "",
527 Field::new("", DataType::Float32, false),
528 i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
529 false,
530 )
531 .with_metadata(
532 [
533 (
534 EXTENSION_TYPE_NAME_KEY.to_owned(),
535 FixedShapeTensor::NAME.to_owned(),
536 ),
537 (
538 EXTENSION_TYPE_METADATA_KEY.to_owned(),
539 r#"{ "not-shape": [] }"#.to_owned(),
540 ),
541 ]
542 .into_iter()
543 .collect(),
544 );
545 field.extension_type::<FixedShapeTensor>();
546 }
547
548 #[test]
549 #[should_panic(
550 expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2"
551 )]
552 fn invalid_metadata_dimension_names() {
553 FixedShapeTensor::try_new(
554 DataType::Float32,
555 [100, 200, 500],
556 Some(vec!["a".to_owned(), "b".to_owned()]),
557 None,
558 )
559 .unwrap();
560 }
561
562 #[test]
563 #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")]
564 fn invalid_metadata_permutations_len() {
565 FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0]))
566 .unwrap();
567 }
568
569 #[test]
570 #[should_panic(
571 expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
572 )]
573 fn invalid_metadata_permutations_values() {
574 FixedShapeTensor::try_new(
575 DataType::Float32,
576 [100, 200, 500],
577 None,
578 Some(vec![4, 3, 2]),
579 )
580 .unwrap();
581 }
582}