arrow_schema/extension/canonical/
variable_shape_tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! VariableShapeTensor
19//!
20//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
21
22use serde_core::de::{self, MapAccess, Visitor};
23use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
24use std::fmt;
25
26use crate::{ArrowError, DataType, Field, extension::ExtensionType};
27
28/// The extension type for `VariableShapeTensor`.
29///
30/// Extension name: `arrow.variable_shape_tensor`.
31///
32/// The storage type of the extension is: StructArray where struct is composed
33/// of data and shape fields describing a single tensor per row:
34/// - `data` is a List holding tensor elements (each list element is a single
35///   tensor). The List’s value type is the value type of the tensor, such as
36///   an integer or floating-point type.
37/// - `shape` is a `FixedSizeList<int32>[ndim]` of the tensor shape where the
38///   size of the list `ndim` is equal to the number of dimensions of the
39///   tensor.
40///
41/// Extension type parameters:
42/// `value_type`: the Arrow data type of individual tensor elements.
43///
44/// Optional parameters describing the logical layout:
45/// - `dim_names`: explicit names to tensor dimensions as an array. The length
46///   of it should be equal to the shape length and equal to the number of
47///   dimensions.
48///   `dim_names` can be used if the dimensions have well-known names and they
49///   map to the physical layout (row-major).
50/// - `permutation`: indices of the desired ordering of the original
51///   dimensions, defined as an array.
52///   The indices contain a permutation of the values `[0, 1, .., N-1]` where
53///   `N` is the number of dimensions. The permutation indicates which
54///   dimension of the logical layout corresponds to which dimension of the
55///   physical tensor (the i-th dimension of the logical view corresponds to
56///   the dimension with number `permutations[i]` of the physical tensor).
57///   Permutation can be useful in case the logical order of the tensor is a
58///   permutation of the physical order (row-major).
59///   When logical and physical layout are equal, the permutation will always
60///   be (`[0, 1, .., N-1]`) and can therefore be left out.
61/// - `uniform_shape`: sizes of individual tensor’s dimensions which are
62///   guaranteed to stay constant in uniform dimensions and can vary in non-
63///   uniform dimensions. This holds over all tensors in the array. Sizes in
64///   uniform dimensions are represented with int32 values, while sizes of the
65///   non-uniform dimensions are not known in advance and are represented with
66///   null. If `uniform_shape` is not provided it is assumed that all
67///   dimensions are non-uniform. An array containing a tensor with shape (2,
68///   3, 4) and whose first and last dimensions are uniform would have
69///   `uniform_shape` (2, null, 4). This allows for interpreting the tensor
70///   correctly without accounting for uniform dimensions while still
71///   permitting optional optimizations that take advantage of the uniformity.
72///
73/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
74#[derive(Debug, Clone, PartialEq)]
75pub struct VariableShapeTensor {
76    /// The data type of individual tensor elements.
77    value_type: DataType,
78
79    /// The number of dimensions of the tensor.
80    dimensions: usize,
81
82    /// The metadata of this extension type.
83    metadata: VariableShapeTensorMetadata,
84}
85
86impl VariableShapeTensor {
87    /// Returns a new variable shape tensor extension type.
88    ///
89    /// # Error
90    ///
91    /// Return an error if the provided dimension names, permutations or
92    /// uniform shapes are invalid.
93    pub fn try_new(
94        value_type: DataType,
95        dimensions: usize,
96        dimension_names: Option<Vec<String>>,
97        permutations: Option<Vec<usize>>,
98        uniform_shapes: Option<Vec<Option<i32>>>,
99    ) -> Result<Self, ArrowError> {
100        // TODO: are all data types are suitable as value type?
101        VariableShapeTensorMetadata::try_new(
102            dimensions,
103            dimension_names,
104            permutations,
105            uniform_shapes,
106        )
107        .map(|metadata| Self {
108            value_type,
109            dimensions,
110            metadata,
111        })
112    }
113
114    /// Returns the value type of the individual tensor elements.
115    pub fn value_type(&self) -> &DataType {
116        &self.value_type
117    }
118
119    /// Returns the number of dimensions  in this variable shape tensor.
120    pub fn dimensions(&self) -> usize {
121        self.dimensions
122    }
123
124    /// Returns the names of the dimensions in this variable shape tensor, if
125    /// set.
126    pub fn dimension_names(&self) -> Option<&[String]> {
127        self.metadata.dimension_names()
128    }
129
130    /// Returns the indices of the desired ordering of the original
131    /// dimensions, if set.
132    pub fn permutations(&self) -> Option<&[usize]> {
133        self.metadata.permutations()
134    }
135
136    /// Returns sizes of individual tensor’s dimensions which are guaranteed
137    /// to stay constant in uniform dimensions and can vary in non-uniform
138    /// dimensions.
139    pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
140        self.metadata.uniform_shapes()
141    }
142}
143
144/// Extension type metadata for [`VariableShapeTensor`].
145#[derive(Debug, Clone, PartialEq)]
146pub struct VariableShapeTensorMetadata {
147    /// Explicit names to tensor dimensions.
148    dim_names: Option<Vec<String>>,
149
150    /// Indices of the desired ordering of the original dimensions.
151    permutations: Option<Vec<usize>>,
152
153    /// Sizes of individual tensor's dimensions which are guaranteed to stay
154    /// constant in uniform dimensions and can vary in non-uniform dimensions.
155    uniform_shape: Option<Vec<Option<i32>>>,
156}
157
158impl Serialize for VariableShapeTensorMetadata {
159    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
160    where
161        S: Serializer,
162    {
163        use serde_core::ser::SerializeStruct;
164        let mut state = serializer.serialize_struct("VariableShapeTensorMetadata", 3)?;
165        state.serialize_field("dim_names", &self.dim_names)?;
166        state.serialize_field("permutations", &self.permutations)?;
167        state.serialize_field("uniform_shape", &self.uniform_shape)?;
168        state.end()
169    }
170}
171
172#[derive(Debug)]
173enum MetadataField {
174    DimNames,
175    Permutations,
176    UniformShape,
177}
178
179struct MetadataFieldVisitor;
180
181impl<'de> Visitor<'de> for MetadataFieldVisitor {
182    type Value = MetadataField;
183
184    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
185        formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`")
186    }
187
188    fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
189    where
190        E: de::Error,
191    {
192        match value {
193            "dim_names" => Ok(MetadataField::DimNames),
194            "permutations" => Ok(MetadataField::Permutations),
195            "uniform_shape" => Ok(MetadataField::UniformShape),
196            _ => Err(de::Error::unknown_field(
197                value,
198                &["dim_names", "permutations", "uniform_shape"],
199            )),
200        }
201    }
202}
203
204impl<'de> Deserialize<'de> for MetadataField {
205    fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
206    where
207        D: Deserializer<'de>,
208    {
209        deserializer.deserialize_identifier(MetadataFieldVisitor)
210    }
211}
212
213struct VariableShapeTensorMetadataVisitor;
214
215impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor {
216    type Value = VariableShapeTensorMetadata;
217
218    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219        formatter.write_str("struct VariableShapeTensorMetadata")
220    }
221
222    fn visit_seq<V>(self, mut seq: V) -> Result<VariableShapeTensorMetadata, V::Error>
223    where
224        V: de::SeqAccess<'de>,
225    {
226        let dim_names = seq
227            .next_element()?
228            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
229        let permutations = seq
230            .next_element()?
231            .ok_or_else(|| de::Error::invalid_length(1, &self))?;
232        let uniform_shape = seq
233            .next_element()?
234            .ok_or_else(|| de::Error::invalid_length(2, &self))?;
235        Ok(VariableShapeTensorMetadata {
236            dim_names,
237            permutations,
238            uniform_shape,
239        })
240    }
241
242    fn visit_map<V>(self, mut map: V) -> Result<VariableShapeTensorMetadata, V::Error>
243    where
244        V: MapAccess<'de>,
245    {
246        let mut dim_names = None;
247        let mut permutations = None;
248        let mut uniform_shape = None;
249
250        while let Some(key) = map.next_key()? {
251            match key {
252                MetadataField::DimNames => {
253                    if dim_names.is_some() {
254                        return Err(de::Error::duplicate_field("dim_names"));
255                    }
256                    dim_names = Some(map.next_value()?);
257                }
258                MetadataField::Permutations => {
259                    if permutations.is_some() {
260                        return Err(de::Error::duplicate_field("permutations"));
261                    }
262                    permutations = Some(map.next_value()?);
263                }
264                MetadataField::UniformShape => {
265                    if uniform_shape.is_some() {
266                        return Err(de::Error::duplicate_field("uniform_shape"));
267                    }
268                    uniform_shape = Some(map.next_value()?);
269                }
270            }
271        }
272
273        Ok(VariableShapeTensorMetadata {
274            dim_names,
275            permutations,
276            uniform_shape,
277        })
278    }
279}
280
281impl<'de> Deserialize<'de> for VariableShapeTensorMetadata {
282    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283    where
284        D: Deserializer<'de>,
285    {
286        deserializer.deserialize_struct(
287            "VariableShapeTensorMetadata",
288            &["dim_names", "permutations", "uniform_shape"],
289            VariableShapeTensorMetadataVisitor,
290        )
291    }
292}
293
294impl VariableShapeTensorMetadata {
295    /// Returns metadata for a variable shape tensor extension type.
296    ///
297    /// # Error
298    ///
299    /// Return an error if the provided dimension names, permutations or
300    /// uniform shapes are invalid.
301    pub fn try_new(
302        dimensions: usize,
303        dimension_names: Option<Vec<String>>,
304        permutations: Option<Vec<usize>>,
305        uniform_shapes: Option<Vec<Option<i32>>>,
306    ) -> Result<Self, ArrowError> {
307        let dim_names = dimension_names.map(|dimension_names| {
308            if dimension_names.len() != dimensions {
309                Err(ArrowError::InvalidArgumentError(format!(
310                    "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
311                )))
312            } else {
313                Ok(dimension_names)
314            }
315        }).transpose()?;
316
317        let permutations = permutations
318            .map(|permutations| {
319                if permutations.len() != dimensions {
320                    Err(ArrowError::InvalidArgumentError(format!(
321                        "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
322                        permutations.len()
323                    )))
324                } else {
325                    let mut sorted_permutations = permutations.clone();
326                    sorted_permutations.sort_unstable();
327                    if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
328                        Err(ArrowError::InvalidArgumentError(format!(
329                            "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
330                        )))
331                    } else {
332                        Ok(permutations)
333                    }
334                }
335            })
336            .transpose()?;
337
338        let uniform_shape = uniform_shapes
339            .map(|uniform_shapes| {
340                if uniform_shapes.len() != dimensions {
341                    Err(ArrowError::InvalidArgumentError(format!(
342                        "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
343                        uniform_shapes.len()
344                    )))
345                } else {
346                    Ok(uniform_shapes)
347                }
348            })
349            .transpose()?;
350
351        Ok(Self {
352            dim_names,
353            permutations,
354            uniform_shape,
355        })
356    }
357
358    /// Returns the names of the dimensions in this variable shape tensor, if
359    /// set.
360    pub fn dimension_names(&self) -> Option<&[String]> {
361        self.dim_names.as_ref().map(AsRef::as_ref)
362    }
363
364    /// Returns the indices of the desired ordering of the original dimensions,
365    /// if set.
366    pub fn permutations(&self) -> Option<&[usize]> {
367        self.permutations.as_ref().map(AsRef::as_ref)
368    }
369
370    /// Returns sizes of individual tensor’s dimensions which are guaranteed
371    /// to stay constant in uniform dimensions and can vary in non-uniform
372    /// dimensions.
373    pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
374        self.uniform_shape.as_ref().map(AsRef::as_ref)
375    }
376}
377
378impl ExtensionType for VariableShapeTensor {
379    const NAME: &'static str = "arrow.variable_shape_tensor";
380
381    type Metadata = VariableShapeTensorMetadata;
382
383    fn metadata(&self) -> &Self::Metadata {
384        &self.metadata
385    }
386
387    fn serialize_metadata(&self) -> Option<String> {
388        Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
389    }
390
391    fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
392        metadata.map_or_else(
393            || {
394                Err(ArrowError::InvalidArgumentError(
395                    "VariableShapeTensor extension types requires metadata".to_owned(),
396                ))
397            },
398            |value| {
399                serde_json::from_str(value).map_err(|e| {
400                    ArrowError::InvalidArgumentError(format!(
401                        "VariableShapeTensor metadata deserialization failed: {e}"
402                    ))
403                })
404            },
405        )
406    }
407
408    fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
409        let expected = DataType::Struct(
410            [
411                Field::new_list(
412                    "data",
413                    Field::new_list_field(self.value_type.clone(), false),
414                    false,
415                ),
416                Field::new(
417                    "shape",
418                    DataType::new_fixed_size_list(
419                        DataType::Int32,
420                        i32::try_from(self.dimensions()).expect("overflow"),
421                        false,
422                    ),
423                    false,
424                ),
425            ]
426            .into_iter()
427            .collect(),
428        );
429        data_type
430            .equals_datatype(&expected)
431            .then_some(())
432            .ok_or_else(|| {
433                ArrowError::InvalidArgumentError(format!(
434                    "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
435                ))
436            })
437    }
438
439    fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
440        match data_type {
441            DataType::Struct(fields)
442                if fields.len() == 2
443                    && matches!(fields.find("data"), Some((0, _)))
444                    && matches!(fields.find("shape"), Some((1, _))) =>
445            {
446                let shape_field = &fields[1];
447                match shape_field.data_type() {
448                    DataType::FixedSizeList(_, list_size) => {
449                        let dimensions = usize::try_from(*list_size).expect("conversion failed");
450                        // Make sure the metadata is valid.
451                        let metadata = VariableShapeTensorMetadata::try_new(
452                            dimensions,
453                            metadata.dim_names,
454                            metadata.permutations,
455                            metadata.uniform_shape,
456                        )?;
457                        let data_field = &fields[0];
458                        match data_field.data_type() {
459                            DataType::List(field) => Ok(Self {
460                                value_type: field.data_type().clone(),
461                                dimensions,
462                                metadata,
463                            }),
464                            data_type => Err(ArrowError::InvalidArgumentError(format!(
465                                "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
466                            ))),
467                        }
468                    }
469                    data_type => Err(ArrowError::InvalidArgumentError(format!(
470                        "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
471                    ))),
472                }
473            }
474            data_type => Err(ArrowError::InvalidArgumentError(format!(
475                "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
476            ))),
477        }
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    #[cfg(feature = "canonical_extension_types")]
484    use crate::extension::CanonicalExtensionType;
485    use crate::{
486        Field,
487        extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
488    };
489
490    use super::*;
491
492    #[test]
493    fn valid() -> Result<(), ArrowError> {
494        let variable_shape_tensor = VariableShapeTensor::try_new(
495            DataType::Float32,
496            3,
497            Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
498            Some(vec![2, 0, 1]),
499            Some(vec![Some(400), None, Some(3)]),
500        )?;
501        let mut field = Field::new_struct(
502            "",
503            vec![
504                Field::new_list(
505                    "data",
506                    Field::new_list_field(DataType::Float32, false),
507                    false,
508                ),
509                Field::new_fixed_size_list(
510                    "shape",
511                    Field::new("", DataType::Int32, false),
512                    3,
513                    false,
514                ),
515            ],
516            false,
517        );
518        field.try_with_extension_type(variable_shape_tensor.clone())?;
519        assert_eq!(
520            field.try_extension_type::<VariableShapeTensor>()?,
521            variable_shape_tensor
522        );
523        #[cfg(feature = "canonical_extension_types")]
524        assert_eq!(
525            field.try_canonical_extension_type()?,
526            CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
527        );
528        Ok(())
529    }
530
531    #[test]
532    #[should_panic(expected = "Field extension type name missing")]
533    fn missing_name() {
534        let field = Field::new_struct(
535            "",
536            vec![
537                Field::new_list(
538                    "data",
539                    Field::new_list_field(DataType::Float32, false),
540                    false,
541                ),
542                Field::new_fixed_size_list(
543                    "shape",
544                    Field::new("", DataType::Int32, false),
545                    3,
546                    false,
547                ),
548            ],
549            false,
550        )
551        .with_metadata(
552            [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
553                .into_iter()
554                .collect(),
555        );
556        field.extension_type::<VariableShapeTensor>();
557    }
558
559    #[test]
560    #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
561    fn invalid_type() {
562        let variable_shape_tensor =
563            VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
564        let field = Field::new_struct(
565            "",
566            vec![
567                Field::new_list(
568                    "data",
569                    Field::new_list_field(DataType::Float32, false),
570                    false,
571                ),
572                Field::new_fixed_size_list(
573                    "shape",
574                    Field::new("", DataType::Int32, false),
575                    3,
576                    false,
577                ),
578            ],
579            false,
580        );
581        field.with_extension_type(variable_shape_tensor);
582    }
583
584    #[test]
585    #[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
586    fn missing_metadata() {
587        let field = Field::new_struct(
588            "",
589            vec![
590                Field::new_list(
591                    "data",
592                    Field::new_list_field(DataType::Float32, false),
593                    false,
594                ),
595                Field::new_fixed_size_list(
596                    "shape",
597                    Field::new("", DataType::Int32, false),
598                    3,
599                    false,
600                ),
601            ],
602            false,
603        )
604        .with_metadata(
605            [(
606                EXTENSION_TYPE_NAME_KEY.to_owned(),
607                VariableShapeTensor::NAME.to_owned(),
608            )]
609            .into_iter()
610            .collect(),
611        );
612        field.extension_type::<VariableShapeTensor>();
613    }
614
615    #[test]
616    #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
617    fn invalid_metadata() {
618        let field = Field::new_struct(
619            "",
620            vec![
621                Field::new_list(
622                    "data",
623                    Field::new_list_field(DataType::Float32, false),
624                    false,
625                ),
626                Field::new_fixed_size_list(
627                    "shape",
628                    Field::new("", DataType::Int32, false),
629                    3,
630                    false,
631                ),
632            ],
633            false,
634        )
635        .with_metadata(
636            [
637                (
638                    EXTENSION_TYPE_NAME_KEY.to_owned(),
639                    VariableShapeTensor::NAME.to_owned(),
640                ),
641                (
642                    EXTENSION_TYPE_METADATA_KEY.to_owned(),
643                    r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
644                ),
645            ]
646            .into_iter()
647            .collect(),
648        );
649        field.extension_type::<VariableShapeTensor>();
650    }
651
652    #[test]
653    #[should_panic(
654        expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
655    )]
656    fn invalid_metadata_dimension_names() {
657        VariableShapeTensor::try_new(
658            DataType::Float32,
659            3,
660            Some(vec!["a".to_owned(), "b".to_owned()]),
661            None,
662            None,
663        )
664        .unwrap();
665    }
666
667    #[test]
668    #[should_panic(
669        expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
670    )]
671    fn invalid_metadata_permutations_len() {
672        VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
673    }
674
675    #[test]
676    #[should_panic(
677        expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
678    )]
679    fn invalid_metadata_permutations_values() {
680        VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
681            .unwrap();
682    }
683
684    #[test]
685    #[should_panic(
686        expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
687    )]
688    fn invalid_metadata_uniform_shapes() {
689        VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
690            .unwrap();
691    }
692}