arrow_schema/extension/canonical/
fixed_shape_tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! FixedShapeTensor
19//!
20//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
21
22use serde_core::de::{self, MapAccess, Visitor};
23use serde_core::ser::SerializeStruct;
24use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
25use std::fmt;
26
27use crate::{ArrowError, DataType, extension::ExtensionType};
28
29/// The extension type for fixed shape tensor.
30///
31/// Extension name: `arrow.fixed_shape_tensor`.
32///
33/// The storage type of the extension: `FixedSizeList` where:
34/// - `value_type` is the data type of individual tensor elements.
35/// - `list_size` is the product of all the elements in tensor shape.
36///
37/// Extension type parameters:
38/// - `value_type`: the Arrow data type of individual tensor elements.
39/// - `shape`: the physical shape of the contained tensors as an array.
40///
41/// Optional parameters describing the logical layout:
42/// - `dim_names`: explicit names to tensor dimensions as an array. The
43///   length of it should be equal to the shape length and equal to the
44///   number of dimensions.
45///   `dim_names` can be used if the dimensions have
46///   well-known names and they map to the physical layout (row-major).
47/// - `permutation`: indices of the desired ordering of the original
48///   dimensions, defined as an array.
49///   The indices contain a permutation of the values `[0, 1, .., N-1]`
50///   where `N` is the number of dimensions. The permutation indicates
51///   which dimension of the logical layout corresponds to which dimension
52///   of the physical tensor (the i-th dimension of the logical view
53///   corresponds to the dimension with number `permutations[i]` of the
54///   physical tensor).
55///   Permutation can be useful in case the logical order of the tensor is
56///   a permutation of the physical order (row-major).
57///   When logical and physical layout are equal, the permutation will
58///   always be `([0, 1, .., N-1])` and can therefore be left out.
59///
60/// Description of the serialization:
61/// The metadata must be a valid JSON object including shape of the
62/// contained tensors as an array with key `shape` plus optional
63/// dimension names with keys `dim_names` and ordering of the
64/// dimensions with key `permutation`.
65/// Example: `{ "shape": [2, 5]}`
66/// Example with `dim_names` metadata for NCHW ordered data:
67/// `{ "shape": [100, 200, 500], "dim_names": ["C", "H", "W"]}`
68/// Example of permuted 3-dimensional tensor:
69/// `{ "shape": [100, 200, 500], "permutation": [2, 0, 1]}`
70///
71/// This is the physical layout shape and the shape of the logical layout
72/// would in this case be `[500, 100, 200]`.
73///
74/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
75#[derive(Debug, Clone, PartialEq)]
76pub struct FixedShapeTensor {
77    /// The data type of individual tensor elements.
78    value_type: DataType,
79
80    /// The metadata of this extension type.
81    metadata: FixedShapeTensorMetadata,
82}
83
84impl FixedShapeTensor {
85    /// Returns a new fixed shape tensor extension type.
86    ///
87    /// # Error
88    ///
89    /// Return an error if the provided dimension names or permutations are
90    /// invalid.
91    pub fn try_new(
92        value_type: DataType,
93        shape: impl IntoIterator<Item = usize>,
94        dimension_names: Option<Vec<String>>,
95        permutations: Option<Vec<usize>>,
96    ) -> Result<Self, ArrowError> {
97        // TODO: are all data types are suitable as value type?
98        FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| {
99            Self {
100                value_type,
101                metadata,
102            }
103        })
104    }
105
106    /// Returns the value type of the individual tensor elements.
107    pub fn value_type(&self) -> &DataType {
108        &self.value_type
109    }
110
111    /// Returns the product of all the elements in tensor shape.
112    pub fn list_size(&self) -> usize {
113        self.metadata.list_size()
114    }
115
116    /// Returns the number of dimensions in this fixed shape tensor.
117    pub fn dimensions(&self) -> usize {
118        self.metadata.dimensions()
119    }
120
121    /// Returns the names of the dimensions in this fixed shape tensor, if
122    /// set.
123    pub fn dimension_names(&self) -> Option<&[String]> {
124        self.metadata.dimension_names()
125    }
126
127    /// Returns the indices of the desired ordering of the original
128    /// dimensions, if set.
129    pub fn permutations(&self) -> Option<&[usize]> {
130        self.metadata.permutations()
131    }
132}
133
134/// Extension type metadata for [`FixedShapeTensor`].
135#[derive(Debug, Clone, PartialEq)]
136pub struct FixedShapeTensorMetadata {
137    /// The physical shape of the contained tensors.
138    shape: Vec<usize>,
139
140    /// Explicit names to tensor dimensions.
141    dim_names: Option<Vec<String>>,
142
143    /// Indices of the desired ordering of the original dimensions.
144    permutations: Option<Vec<usize>>,
145}
146
147impl Serialize for FixedShapeTensorMetadata {
148    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149    where
150        S: Serializer,
151    {
152        let mut state = serializer.serialize_struct("FixedShapeTensorMetadata", 3)?;
153        state.serialize_field("shape", &self.shape)?;
154        state.serialize_field("dim_names", &self.dim_names)?;
155        state.serialize_field("permutations", &self.permutations)?;
156        state.end()
157    }
158}
159
160#[derive(Debug)]
161enum MetadataField {
162    Shape,
163    DimNames,
164    Permutations,
165}
166
167struct MetadataFieldVisitor;
168
169impl<'de> Visitor<'de> for MetadataFieldVisitor {
170    type Value = MetadataField;
171
172    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
173        formatter.write_str("`shape`, `dim_names`, or `permutations`")
174    }
175
176    fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
177    where
178        E: de::Error,
179    {
180        match value {
181            "shape" => Ok(MetadataField::Shape),
182            "dim_names" => Ok(MetadataField::DimNames),
183            "permutations" => Ok(MetadataField::Permutations),
184            _ => Err(de::Error::unknown_field(
185                value,
186                &["shape", "dim_names", "permutations"],
187            )),
188        }
189    }
190}
191
192impl<'de> Deserialize<'de> for MetadataField {
193    fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
194    where
195        D: Deserializer<'de>,
196    {
197        deserializer.deserialize_identifier(MetadataFieldVisitor)
198    }
199}
200
201struct FixedShapeTensorMetadataVisitor;
202
203impl<'de> Visitor<'de> for FixedShapeTensorMetadataVisitor {
204    type Value = FixedShapeTensorMetadata;
205
206    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
207        formatter.write_str("struct FixedShapeTensorMetadata")
208    }
209
210    fn visit_seq<V>(self, mut seq: V) -> Result<FixedShapeTensorMetadata, V::Error>
211    where
212        V: de::SeqAccess<'de>,
213    {
214        let shape = seq
215            .next_element()?
216            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
217        let dim_names = seq
218            .next_element()?
219            .ok_or_else(|| de::Error::invalid_length(1, &self))?;
220        let permutations = seq
221            .next_element()?
222            .ok_or_else(|| de::Error::invalid_length(2, &self))?;
223        Ok(FixedShapeTensorMetadata {
224            shape,
225            dim_names,
226            permutations,
227        })
228    }
229
230    fn visit_map<V>(self, mut map: V) -> Result<FixedShapeTensorMetadata, V::Error>
231    where
232        V: MapAccess<'de>,
233    {
234        let mut shape = None;
235        let mut dim_names = None;
236        let mut permutations = None;
237
238        while let Some(key) = map.next_key()? {
239            match key {
240                MetadataField::Shape => {
241                    if shape.is_some() {
242                        return Err(de::Error::duplicate_field("shape"));
243                    }
244                    shape = Some(map.next_value()?);
245                }
246                MetadataField::DimNames => {
247                    if dim_names.is_some() {
248                        return Err(de::Error::duplicate_field("dim_names"));
249                    }
250                    dim_names = Some(map.next_value()?);
251                }
252                MetadataField::Permutations => {
253                    if permutations.is_some() {
254                        return Err(de::Error::duplicate_field("permutations"));
255                    }
256                    permutations = Some(map.next_value()?);
257                }
258            }
259        }
260
261        let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
262
263        Ok(FixedShapeTensorMetadata {
264            shape,
265            dim_names,
266            permutations,
267        })
268    }
269}
270
271impl<'de> Deserialize<'de> for FixedShapeTensorMetadata {
272    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273    where
274        D: Deserializer<'de>,
275    {
276        deserializer.deserialize_struct(
277            "FixedShapeTensorMetadata",
278            &["shape", "dim_names", "permutations"],
279            FixedShapeTensorMetadataVisitor,
280        )
281    }
282}
283
284impl FixedShapeTensorMetadata {
285    /// Returns metadata for a fixed shape tensor extension type.
286    ///
287    /// # Error
288    ///
289    /// Return an error if the provided dimension names or permutations are
290    /// invalid.
291    pub fn try_new(
292        shape: impl IntoIterator<Item = usize>,
293        dimension_names: Option<Vec<String>>,
294        permutations: Option<Vec<usize>>,
295    ) -> Result<Self, ArrowError> {
296        let shape = shape.into_iter().collect::<Vec<_>>();
297        let dimensions = shape.len();
298
299        let dim_names = dimension_names.map(|dimension_names| {
300            if dimension_names.len() != dimensions {
301                Err(ArrowError::InvalidArgumentError(format!(
302                    "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
303                )))
304            } else {
305                Ok(dimension_names)
306            }
307        }).transpose()?;
308
309        let permutations = permutations
310            .map(|permutations| {
311                if permutations.len() != dimensions {
312                    Err(ArrowError::InvalidArgumentError(format!(
313                        "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}",
314                        permutations.len()
315                    )))
316                } else {
317                    let mut sorted_permutations = permutations.clone();
318                    sorted_permutations.sort_unstable();
319                    if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
320                        Err(ArrowError::InvalidArgumentError(format!(
321                            "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
322                        )))
323                    } else {
324                        Ok(permutations)
325                    }
326                }
327            })
328            .transpose()?;
329
330        Ok(Self {
331            shape,
332            dim_names,
333            permutations,
334        })
335    }
336
337    /// Returns the product of all the elements in tensor shape.
338    pub fn list_size(&self) -> usize {
339        self.shape.iter().product()
340    }
341
342    /// Returns the number of dimensions in this fixed shape tensor.
343    pub fn dimensions(&self) -> usize {
344        self.shape.len()
345    }
346
347    /// Returns the names of the dimensions in this fixed shape tensor, if
348    /// set.
349    pub fn dimension_names(&self) -> Option<&[String]> {
350        self.dim_names.as_ref().map(AsRef::as_ref)
351    }
352
353    /// Returns the indices of the desired ordering of the original
354    /// dimensions, if set.
355    pub fn permutations(&self) -> Option<&[usize]> {
356        self.permutations.as_ref().map(AsRef::as_ref)
357    }
358}
359
360impl ExtensionType for FixedShapeTensor {
361    const NAME: &'static str = "arrow.fixed_shape_tensor";
362
363    type Metadata = FixedShapeTensorMetadata;
364
365    fn metadata(&self) -> &Self::Metadata {
366        &self.metadata
367    }
368
369    fn serialize_metadata(&self) -> Option<String> {
370        Some(serde_json::to_string(&self.metadata).expect("metadata serialization"))
371    }
372
373    fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
374        metadata.map_or_else(
375            || {
376                Err(ArrowError::InvalidArgumentError(
377                    "FixedShapeTensor extension types requires metadata".to_owned(),
378                ))
379            },
380            |value| {
381                serde_json::from_str(value).map_err(|e| {
382                    ArrowError::InvalidArgumentError(format!(
383                        "FixedShapeTensor metadata deserialization failed: {e}"
384                    ))
385                })
386            },
387        )
388    }
389
390    fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
391        let expected = DataType::new_fixed_size_list(
392            self.value_type.clone(),
393            i32::try_from(self.list_size()).expect("overflow"),
394            false,
395        );
396        data_type
397            .equals_datatype(&expected)
398            .then_some(())
399            .ok_or_else(|| {
400                ArrowError::InvalidArgumentError(format!(
401                    "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}"
402                ))
403            })
404    }
405
406    fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
407        match data_type {
408            DataType::FixedSizeList(field, list_size) if !field.is_nullable() => {
409                // Make sure the metadata is valid.
410                let metadata = FixedShapeTensorMetadata::try_new(
411                    metadata.shape,
412                    metadata.dim_names,
413                    metadata.permutations,
414                )?;
415                // Make sure it is compatible with this data type.
416                let expected_size = i32::try_from(metadata.list_size()).expect("overflow");
417                if *list_size != expected_size {
418                    Err(ArrowError::InvalidArgumentError(format!(
419                        "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)"
420                    )))
421                } else {
422                    Ok(Self {
423                        value_type: field.data_type().clone(),
424                        metadata,
425                    })
426                }
427            }
428            data_type => Err(ArrowError::InvalidArgumentError(format!(
429                "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}"
430            ))),
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    #[cfg(feature = "canonical_extension_types")]
438    use crate::extension::CanonicalExtensionType;
439    use crate::{
440        Field,
441        extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
442    };
443
444    use super::*;
445
446    #[test]
447    fn valid() -> Result<(), ArrowError> {
448        let fixed_shape_tensor = FixedShapeTensor::try_new(
449            DataType::Float32,
450            [100, 200, 500],
451            Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
452            Some(vec![2, 0, 1]),
453        )?;
454        let mut field = Field::new_fixed_size_list(
455            "",
456            Field::new("", DataType::Float32, false),
457            i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
458            false,
459        );
460        field.try_with_extension_type(fixed_shape_tensor.clone())?;
461        assert_eq!(
462            field.try_extension_type::<FixedShapeTensor>()?,
463            fixed_shape_tensor
464        );
465        #[cfg(feature = "canonical_extension_types")]
466        assert_eq!(
467            field.try_canonical_extension_type()?,
468            CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor)
469        );
470        Ok(())
471    }
472
473    #[test]
474    #[should_panic(expected = "Field extension type name missing")]
475    fn missing_name() {
476        let field =
477            Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
478                .with_metadata(
479                    [(
480                        EXTENSION_TYPE_METADATA_KEY.to_owned(),
481                        r#"{ "shape": [100, 200, 500], }"#.to_owned(),
482                    )]
483                    .into_iter()
484                    .collect(),
485                );
486        field.extension_type::<FixedShapeTensor>();
487    }
488
489    #[test]
490    #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")]
491    fn invalid_type() {
492        let fixed_shape_tensor =
493            FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap();
494        let field = Field::new_fixed_size_list(
495            "",
496            Field::new("", DataType::Float32, false),
497            i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
498            false,
499        );
500        field.with_extension_type(fixed_shape_tensor);
501    }
502
503    #[test]
504    #[should_panic(expected = "FixedShapeTensor extension types requires metadata")]
505    fn missing_metadata() {
506        let field =
507            Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
508                .with_metadata(
509                    [(
510                        EXTENSION_TYPE_NAME_KEY.to_owned(),
511                        FixedShapeTensor::NAME.to_owned(),
512                    )]
513                    .into_iter()
514                    .collect(),
515                );
516        field.extension_type::<FixedShapeTensor>();
517    }
518
519    #[test]
520    #[should_panic(expected = "FixedShapeTensor metadata deserialization failed: \
521        unknown field `not-shape`, expected one of `shape`, `dim_names`, `permutations`")]
522    fn invalid_metadata() {
523        let fixed_shape_tensor =
524            FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap();
525        let field = Field::new_fixed_size_list(
526            "",
527            Field::new("", DataType::Float32, false),
528            i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
529            false,
530        )
531        .with_metadata(
532            [
533                (
534                    EXTENSION_TYPE_NAME_KEY.to_owned(),
535                    FixedShapeTensor::NAME.to_owned(),
536                ),
537                (
538                    EXTENSION_TYPE_METADATA_KEY.to_owned(),
539                    r#"{ "not-shape": [] }"#.to_owned(),
540                ),
541            ]
542            .into_iter()
543            .collect(),
544        );
545        field.extension_type::<FixedShapeTensor>();
546    }
547
548    #[test]
549    #[should_panic(
550        expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2"
551    )]
552    fn invalid_metadata_dimension_names() {
553        FixedShapeTensor::try_new(
554            DataType::Float32,
555            [100, 200, 500],
556            Some(vec!["a".to_owned(), "b".to_owned()]),
557            None,
558        )
559        .unwrap();
560    }
561
562    #[test]
563    #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")]
564    fn invalid_metadata_permutations_len() {
565        FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0]))
566            .unwrap();
567    }
568
569    #[test]
570    #[should_panic(
571        expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
572    )]
573    fn invalid_metadata_permutations_values() {
574        FixedShapeTensor::try_new(
575            DataType::Float32,
576            [100, 200, 500],
577            None,
578            Some(vec![4, 3, 2]),
579        )
580        .unwrap();
581    }
582}