1use serde_core::de::{self, MapAccess, Visitor};
23use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
24use std::fmt;
25
26use crate::{ArrowError, DataType, Field, extension::ExtensionType};
27
28#[derive(Debug, Clone, PartialEq)]
75pub struct VariableShapeTensor {
76 value_type: DataType,
78
79 dimensions: usize,
81
82 metadata: VariableShapeTensorMetadata,
84}
85
86impl VariableShapeTensor {
87 pub fn try_new(
94 value_type: DataType,
95 dimensions: usize,
96 dimension_names: Option<Vec<String>>,
97 permutations: Option<Vec<usize>>,
98 uniform_shapes: Option<Vec<Option<i32>>>,
99 ) -> Result<Self, ArrowError> {
100 VariableShapeTensorMetadata::try_new(
102 dimensions,
103 dimension_names,
104 permutations,
105 uniform_shapes,
106 )
107 .map(|metadata| Self {
108 value_type,
109 dimensions,
110 metadata,
111 })
112 }
113
114 pub fn value_type(&self) -> &DataType {
116 &self.value_type
117 }
118
119 pub fn dimensions(&self) -> usize {
121 self.dimensions
122 }
123
124 pub fn dimension_names(&self) -> Option<&[String]> {
127 self.metadata.dimension_names()
128 }
129
130 pub fn permutations(&self) -> Option<&[usize]> {
133 self.metadata.permutations()
134 }
135
136 pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
140 self.metadata.uniform_shapes()
141 }
142}
143
144#[derive(Debug, Clone, PartialEq)]
146pub struct VariableShapeTensorMetadata {
147 dim_names: Option<Vec<String>>,
149
150 permutations: Option<Vec<usize>>,
152
153 uniform_shape: Option<Vec<Option<i32>>>,
156}
157
158impl Serialize for VariableShapeTensorMetadata {
159 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
160 where
161 S: Serializer,
162 {
163 use serde_core::ser::SerializeStruct;
164 let mut state = serializer.serialize_struct("VariableShapeTensorMetadata", 3)?;
165 state.serialize_field("dim_names", &self.dim_names)?;
166 state.serialize_field("permutations", &self.permutations)?;
167 state.serialize_field("uniform_shape", &self.uniform_shape)?;
168 state.end()
169 }
170}
171
172#[derive(Debug)]
173enum MetadataField {
174 DimNames,
175 Permutations,
176 UniformShape,
177}
178
179struct MetadataFieldVisitor;
180
181impl<'de> Visitor<'de> for MetadataFieldVisitor {
182 type Value = MetadataField;
183
184 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
185 formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`")
186 }
187
188 fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
189 where
190 E: de::Error,
191 {
192 match value {
193 "dim_names" => Ok(MetadataField::DimNames),
194 "permutations" => Ok(MetadataField::Permutations),
195 "uniform_shape" => Ok(MetadataField::UniformShape),
196 _ => Err(de::Error::unknown_field(
197 value,
198 &["dim_names", "permutations", "uniform_shape"],
199 )),
200 }
201 }
202}
203
204impl<'de> Deserialize<'de> for MetadataField {
205 fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
206 where
207 D: Deserializer<'de>,
208 {
209 deserializer.deserialize_identifier(MetadataFieldVisitor)
210 }
211}
212
213struct VariableShapeTensorMetadataVisitor;
214
215impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor {
216 type Value = VariableShapeTensorMetadata;
217
218 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219 formatter.write_str("struct VariableShapeTensorMetadata")
220 }
221
222 fn visit_seq<V>(self, mut seq: V) -> Result<VariableShapeTensorMetadata, V::Error>
223 where
224 V: de::SeqAccess<'de>,
225 {
226 let dim_names = seq
227 .next_element()?
228 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
229 let permutations = seq
230 .next_element()?
231 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
232 let uniform_shape = seq
233 .next_element()?
234 .ok_or_else(|| de::Error::invalid_length(2, &self))?;
235 Ok(VariableShapeTensorMetadata {
236 dim_names,
237 permutations,
238 uniform_shape,
239 })
240 }
241
242 fn visit_map<V>(self, mut map: V) -> Result<VariableShapeTensorMetadata, V::Error>
243 where
244 V: MapAccess<'de>,
245 {
246 let mut dim_names = None;
247 let mut permutations = None;
248 let mut uniform_shape = None;
249
250 while let Some(key) = map.next_key()? {
251 match key {
252 MetadataField::DimNames => {
253 if dim_names.is_some() {
254 return Err(de::Error::duplicate_field("dim_names"));
255 }
256 dim_names = Some(map.next_value()?);
257 }
258 MetadataField::Permutations => {
259 if permutations.is_some() {
260 return Err(de::Error::duplicate_field("permutations"));
261 }
262 permutations = Some(map.next_value()?);
263 }
264 MetadataField::UniformShape => {
265 if uniform_shape.is_some() {
266 return Err(de::Error::duplicate_field("uniform_shape"));
267 }
268 uniform_shape = Some(map.next_value()?);
269 }
270 }
271 }
272
273 Ok(VariableShapeTensorMetadata {
274 dim_names,
275 permutations,
276 uniform_shape,
277 })
278 }
279}
280
281impl<'de> Deserialize<'de> for VariableShapeTensorMetadata {
282 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283 where
284 D: Deserializer<'de>,
285 {
286 deserializer.deserialize_struct(
287 "VariableShapeTensorMetadata",
288 &["dim_names", "permutations", "uniform_shape"],
289 VariableShapeTensorMetadataVisitor,
290 )
291 }
292}
293
294impl VariableShapeTensorMetadata {
295 pub fn try_new(
302 dimensions: usize,
303 dimension_names: Option<Vec<String>>,
304 permutations: Option<Vec<usize>>,
305 uniform_shapes: Option<Vec<Option<i32>>>,
306 ) -> Result<Self, ArrowError> {
307 let dim_names = dimension_names.map(|dimension_names| {
308 if dimension_names.len() != dimensions {
309 Err(ArrowError::InvalidArgumentError(format!(
310 "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
311 )))
312 } else {
313 Ok(dimension_names)
314 }
315 }).transpose()?;
316
317 let permutations = permutations
318 .map(|permutations| {
319 if permutations.len() != dimensions {
320 Err(ArrowError::InvalidArgumentError(format!(
321 "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
322 permutations.len()
323 )))
324 } else {
325 let mut sorted_permutations = permutations.clone();
326 sorted_permutations.sort_unstable();
327 if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
328 Err(ArrowError::InvalidArgumentError(format!(
329 "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
330 )))
331 } else {
332 Ok(permutations)
333 }
334 }
335 })
336 .transpose()?;
337
338 let uniform_shape = uniform_shapes
339 .map(|uniform_shapes| {
340 if uniform_shapes.len() != dimensions {
341 Err(ArrowError::InvalidArgumentError(format!(
342 "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
343 uniform_shapes.len()
344 )))
345 } else {
346 Ok(uniform_shapes)
347 }
348 })
349 .transpose()?;
350
351 Ok(Self {
352 dim_names,
353 permutations,
354 uniform_shape,
355 })
356 }
357
358 pub fn dimension_names(&self) -> Option<&[String]> {
361 self.dim_names.as_ref().map(AsRef::as_ref)
362 }
363
364 pub fn permutations(&self) -> Option<&[usize]> {
367 self.permutations.as_ref().map(AsRef::as_ref)
368 }
369
370 pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
374 self.uniform_shape.as_ref().map(AsRef::as_ref)
375 }
376}
377
378impl ExtensionType for VariableShapeTensor {
379 const NAME: &'static str = "arrow.variable_shape_tensor";
380
381 type Metadata = VariableShapeTensorMetadata;
382
383 fn metadata(&self) -> &Self::Metadata {
384 &self.metadata
385 }
386
387 fn serialize_metadata(&self) -> Option<String> {
388 Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
389 }
390
391 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
392 metadata.map_or_else(
393 || {
394 Err(ArrowError::InvalidArgumentError(
395 "VariableShapeTensor extension types requires metadata".to_owned(),
396 ))
397 },
398 |value| {
399 serde_json::from_str(value).map_err(|e| {
400 ArrowError::InvalidArgumentError(format!(
401 "VariableShapeTensor metadata deserialization failed: {e}"
402 ))
403 })
404 },
405 )
406 }
407
408 fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
409 let expected = DataType::Struct(
410 [
411 Field::new_list(
412 "data",
413 Field::new_list_field(self.value_type.clone(), false),
414 false,
415 ),
416 Field::new(
417 "shape",
418 DataType::new_fixed_size_list(
419 DataType::Int32,
420 i32::try_from(self.dimensions()).expect("overflow"),
421 false,
422 ),
423 false,
424 ),
425 ]
426 .into_iter()
427 .collect(),
428 );
429 data_type
430 .equals_datatype(&expected)
431 .then_some(())
432 .ok_or_else(|| {
433 ArrowError::InvalidArgumentError(format!(
434 "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
435 ))
436 })
437 }
438
439 fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
440 match data_type {
441 DataType::Struct(fields)
442 if fields.len() == 2
443 && matches!(fields.find("data"), Some((0, _)))
444 && matches!(fields.find("shape"), Some((1, _))) =>
445 {
446 let shape_field = &fields[1];
447 match shape_field.data_type() {
448 DataType::FixedSizeList(_, list_size) => {
449 let dimensions = usize::try_from(*list_size).expect("conversion failed");
450 let metadata = VariableShapeTensorMetadata::try_new(
452 dimensions,
453 metadata.dim_names,
454 metadata.permutations,
455 metadata.uniform_shape,
456 )?;
457 let data_field = &fields[0];
458 match data_field.data_type() {
459 DataType::List(field) => Ok(Self {
460 value_type: field.data_type().clone(),
461 dimensions,
462 metadata,
463 }),
464 data_type => Err(ArrowError::InvalidArgumentError(format!(
465 "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
466 ))),
467 }
468 }
469 data_type => Err(ArrowError::InvalidArgumentError(format!(
470 "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
471 ))),
472 }
473 }
474 data_type => Err(ArrowError::InvalidArgumentError(format!(
475 "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
476 ))),
477 }
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 #[cfg(feature = "canonical_extension_types")]
484 use crate::extension::CanonicalExtensionType;
485 use crate::{
486 Field,
487 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
488 };
489
490 use super::*;
491
492 #[test]
493 fn valid() -> Result<(), ArrowError> {
494 let variable_shape_tensor = VariableShapeTensor::try_new(
495 DataType::Float32,
496 3,
497 Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
498 Some(vec![2, 0, 1]),
499 Some(vec![Some(400), None, Some(3)]),
500 )?;
501 let mut field = Field::new_struct(
502 "",
503 vec![
504 Field::new_list(
505 "data",
506 Field::new_list_field(DataType::Float32, false),
507 false,
508 ),
509 Field::new_fixed_size_list(
510 "shape",
511 Field::new("", DataType::Int32, false),
512 3,
513 false,
514 ),
515 ],
516 false,
517 );
518 field.try_with_extension_type(variable_shape_tensor.clone())?;
519 assert_eq!(
520 field.try_extension_type::<VariableShapeTensor>()?,
521 variable_shape_tensor
522 );
523 #[cfg(feature = "canonical_extension_types")]
524 assert_eq!(
525 field.try_canonical_extension_type()?,
526 CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
527 );
528 Ok(())
529 }
530
531 #[test]
532 #[should_panic(expected = "Field extension type name missing")]
533 fn missing_name() {
534 let field = Field::new_struct(
535 "",
536 vec![
537 Field::new_list(
538 "data",
539 Field::new_list_field(DataType::Float32, false),
540 false,
541 ),
542 Field::new_fixed_size_list(
543 "shape",
544 Field::new("", DataType::Int32, false),
545 3,
546 false,
547 ),
548 ],
549 false,
550 )
551 .with_metadata(
552 [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
553 .into_iter()
554 .collect(),
555 );
556 field.extension_type::<VariableShapeTensor>();
557 }
558
559 #[test]
560 #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
561 fn invalid_type() {
562 let variable_shape_tensor =
563 VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
564 let field = Field::new_struct(
565 "",
566 vec![
567 Field::new_list(
568 "data",
569 Field::new_list_field(DataType::Float32, false),
570 false,
571 ),
572 Field::new_fixed_size_list(
573 "shape",
574 Field::new("", DataType::Int32, false),
575 3,
576 false,
577 ),
578 ],
579 false,
580 );
581 field.with_extension_type(variable_shape_tensor);
582 }
583
584 #[test]
585 #[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
586 fn missing_metadata() {
587 let field = Field::new_struct(
588 "",
589 vec![
590 Field::new_list(
591 "data",
592 Field::new_list_field(DataType::Float32, false),
593 false,
594 ),
595 Field::new_fixed_size_list(
596 "shape",
597 Field::new("", DataType::Int32, false),
598 3,
599 false,
600 ),
601 ],
602 false,
603 )
604 .with_metadata(
605 [(
606 EXTENSION_TYPE_NAME_KEY.to_owned(),
607 VariableShapeTensor::NAME.to_owned(),
608 )]
609 .into_iter()
610 .collect(),
611 );
612 field.extension_type::<VariableShapeTensor>();
613 }
614
615 #[test]
616 #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
617 fn invalid_metadata() {
618 let field = Field::new_struct(
619 "",
620 vec![
621 Field::new_list(
622 "data",
623 Field::new_list_field(DataType::Float32, false),
624 false,
625 ),
626 Field::new_fixed_size_list(
627 "shape",
628 Field::new("", DataType::Int32, false),
629 3,
630 false,
631 ),
632 ],
633 false,
634 )
635 .with_metadata(
636 [
637 (
638 EXTENSION_TYPE_NAME_KEY.to_owned(),
639 VariableShapeTensor::NAME.to_owned(),
640 ),
641 (
642 EXTENSION_TYPE_METADATA_KEY.to_owned(),
643 r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
644 ),
645 ]
646 .into_iter()
647 .collect(),
648 );
649 field.extension_type::<VariableShapeTensor>();
650 }
651
652 #[test]
653 #[should_panic(
654 expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
655 )]
656 fn invalid_metadata_dimension_names() {
657 VariableShapeTensor::try_new(
658 DataType::Float32,
659 3,
660 Some(vec!["a".to_owned(), "b".to_owned()]),
661 None,
662 None,
663 )
664 .unwrap();
665 }
666
667 #[test]
668 #[should_panic(
669 expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
670 )]
671 fn invalid_metadata_permutations_len() {
672 VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
673 }
674
675 #[test]
676 #[should_panic(
677 expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
678 )]
679 fn invalid_metadata_permutations_values() {
680 VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
681 .unwrap();
682 }
683
684 #[test]
685 #[should_panic(
686 expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
687 )]
688 fn invalid_metadata_uniform_shapes() {
689 VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
690 .unwrap();
691 }
692}