arrow_schema/extension/canonical/
fixed_shape_tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! FixedShapeTensor
19//!
20//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
21
22use serde::{Deserialize, Serialize};
23
24use crate::{extension::ExtensionType, ArrowError, DataType};
25
26/// The extension type for fixed shape tensor.
27///
28/// Extension name: `arrow.fixed_shape_tensor`.
29///
30/// The storage type of the extension: `FixedSizeList` where:
31/// - `value_type` is the data type of individual tensor elements.
32/// - `list_size` is the product of all the elements in tensor shape.
33///
34/// Extension type parameters:
35/// - `value_type`: the Arrow data type of individual tensor elements.
36/// - `shape`: the physical shape of the contained tensors as an array.
37///
38/// Optional parameters describing the logical layout:
39/// - `dim_names`: explicit names to tensor dimensions as an array. The
40///   length of it should be equal to the shape length and equal to the
41///   number of dimensions.
42///   `dim_names` can be used if the dimensions have
43///   well-known names and they map to the physical layout (row-major).
44/// - `permutation`: indices of the desired ordering of the original
45///   dimensions, defined as an array.
46///   The indices contain a permutation of the values `[0, 1, .., N-1]`
47///   where `N` is the number of dimensions. The permutation indicates
48///   which dimension of the logical layout corresponds to which dimension
49///   of the physical tensor (the i-th dimension of the logical view
50///   corresponds to the dimension with number `permutations[i]` of the
51///   physical tensor).
52///   Permutation can be useful in case the logical order of the tensor is
53///   a permutation of the physical order (row-major).
54///   When logical and physical layout are equal, the permutation will
55///   always be `([0, 1, .., N-1])` and can therefore be left out.
56///
57/// Description of the serialization:
58/// The metadata must be a valid JSON object including shape of the
59/// contained tensors as an array with key `shape` plus optional
60/// dimension names with keys `dim_names` and ordering of the
61/// dimensions with key `permutation`.
62/// Example: `{ "shape": [2, 5]}`
63/// Example with `dim_names` metadata for NCHW ordered data:
64/// `{ "shape": [100, 200, 500], "dim_names": ["C", "H", "W"]}`
65/// Example of permuted 3-dimensional tensor:
66/// `{ "shape": [100, 200, 500], "permutation": [2, 0, 1]}`
67///
68/// This is the physical layout shape and the shape of the logical layout
69/// would in this case be `[500, 100, 200]`.
70///
71/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
72#[derive(Debug, Clone, PartialEq)]
73pub struct FixedShapeTensor {
74    /// The data type of individual tensor elements.
75    value_type: DataType,
76
77    /// The metadata of this extension type.
78    metadata: FixedShapeTensorMetadata,
79}
80
81impl FixedShapeTensor {
82    /// Returns a new fixed shape tensor extension type.
83    ///
84    /// # Error
85    ///
86    /// Return an error if the provided dimension names or permutations are
87    /// invalid.
88    pub fn try_new(
89        value_type: DataType,
90        shape: impl IntoIterator<Item = usize>,
91        dimension_names: Option<Vec<String>>,
92        permutations: Option<Vec<usize>>,
93    ) -> Result<Self, ArrowError> {
94        // TODO: are all data types are suitable as value type?
95        FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| {
96            Self {
97                value_type,
98                metadata,
99            }
100        })
101    }
102
103    /// Returns the value type of the individual tensor elements.
104    pub fn value_type(&self) -> &DataType {
105        &self.value_type
106    }
107
108    /// Returns the product of all the elements in tensor shape.
109    pub fn list_size(&self) -> usize {
110        self.metadata.list_size()
111    }
112
113    /// Returns the number of dimensions in this fixed shape tensor.
114    pub fn dimensions(&self) -> usize {
115        self.metadata.dimensions()
116    }
117
118    /// Returns the names of the dimensions in this fixed shape tensor, if
119    /// set.
120    pub fn dimension_names(&self) -> Option<&[String]> {
121        self.metadata.dimension_names()
122    }
123
124    /// Returns the indices of the desired ordering of the original
125    /// dimensions, if set.
126    pub fn permutations(&self) -> Option<&[usize]> {
127        self.metadata.permutations()
128    }
129}
130
131/// Extension type metadata for [`FixedShapeTensor`].
132#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
133pub struct FixedShapeTensorMetadata {
134    /// The physical shape of the contained tensors.
135    shape: Vec<usize>,
136
137    /// Explicit names to tensor dimensions.
138    dim_names: Option<Vec<String>>,
139
140    /// Indices of the desired ordering of the original dimensions.
141    permutations: Option<Vec<usize>>,
142}
143
144impl FixedShapeTensorMetadata {
145    /// Returns metadata for a fixed shape tensor extension type.
146    ///
147    /// # Error
148    ///
149    /// Return an error if the provided dimension names or permutations are
150    /// invalid.
151    pub fn try_new(
152        shape: impl IntoIterator<Item = usize>,
153        dimension_names: Option<Vec<String>>,
154        permutations: Option<Vec<usize>>,
155    ) -> Result<Self, ArrowError> {
156        let shape = shape.into_iter().collect::<Vec<_>>();
157        let dimensions = shape.len();
158
159        let dim_names = dimension_names.map(|dimension_names| {
160            if dimension_names.len() != dimensions {
161                Err(ArrowError::InvalidArgumentError(format!(
162                    "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
163                )))
164            } else {
165                Ok(dimension_names)
166            }
167        }).transpose()?;
168
169        let permutations = permutations
170            .map(|permutations| {
171                if permutations.len() != dimensions {
172                    Err(ArrowError::InvalidArgumentError(format!(
173                        "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}",
174                        permutations.len()
175                    )))
176                } else {
177                    let mut sorted_permutations = permutations.clone();
178                    sorted_permutations.sort_unstable();
179                    if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
180                        Err(ArrowError::InvalidArgumentError(format!(
181                            "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
182                        )))
183                    } else {
184                        Ok(permutations)
185                    }
186                }
187            })
188            .transpose()?;
189
190        Ok(Self {
191            shape,
192            dim_names,
193            permutations,
194        })
195    }
196
197    /// Returns the product of all the elements in tensor shape.
198    pub fn list_size(&self) -> usize {
199        self.shape.iter().product()
200    }
201
202    /// Returns the number of dimensions in this fixed shape tensor.
203    pub fn dimensions(&self) -> usize {
204        self.shape.len()
205    }
206
207    /// Returns the names of the dimensions in this fixed shape tensor, if
208    /// set.
209    pub fn dimension_names(&self) -> Option<&[String]> {
210        self.dim_names.as_ref().map(AsRef::as_ref)
211    }
212
213    /// Returns the indices of the desired ordering of the original
214    /// dimensions, if set.
215    pub fn permutations(&self) -> Option<&[usize]> {
216        self.permutations.as_ref().map(AsRef::as_ref)
217    }
218}
219
220impl ExtensionType for FixedShapeTensor {
221    const NAME: &'static str = "arrow.fixed_shape_tensor";
222
223    type Metadata = FixedShapeTensorMetadata;
224
225    fn metadata(&self) -> &Self::Metadata {
226        &self.metadata
227    }
228
229    fn serialize_metadata(&self) -> Option<String> {
230        Some(serde_json::to_string(&self.metadata).expect("metadata serialization"))
231    }
232
233    fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
234        metadata.map_or_else(
235            || {
236                Err(ArrowError::InvalidArgumentError(
237                    "FixedShapeTensor extension types requires metadata".to_owned(),
238                ))
239            },
240            |value| {
241                serde_json::from_str(value).map_err(|e| {
242                    ArrowError::InvalidArgumentError(format!(
243                        "FixedShapeTensor metadata deserialization failed: {e}"
244                    ))
245                })
246            },
247        )
248    }
249
250    fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
251        let expected = DataType::new_fixed_size_list(
252            self.value_type.clone(),
253            i32::try_from(self.list_size()).expect("overflow"),
254            false,
255        );
256        data_type
257            .equals_datatype(&expected)
258            .then_some(())
259            .ok_or_else(|| {
260                ArrowError::InvalidArgumentError(format!(
261                    "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}"
262                ))
263            })
264    }
265
266    fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
267        match data_type {
268            DataType::FixedSizeList(field, list_size) if !field.is_nullable() => {
269                // Make sure the metadata is valid.
270                let metadata = FixedShapeTensorMetadata::try_new(
271                    metadata.shape,
272                    metadata.dim_names,
273                    metadata.permutations,
274                )?;
275                // Make sure it is compatible with this data type.
276                let expected_size = i32::try_from(metadata.list_size()).expect("overflow");
277                if *list_size != expected_size {
278                    Err(ArrowError::InvalidArgumentError(format!(
279                        "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)"
280                    )))
281                } else {
282                    Ok(Self {
283                        value_type: field.data_type().clone(),
284                        metadata,
285                    })
286                }
287            }
288            data_type => Err(ArrowError::InvalidArgumentError(format!(
289                "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}"
290            ))),
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    #[cfg(feature = "canonical_extension_types")]
298    use crate::extension::CanonicalExtensionType;
299    use crate::{
300        extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
301        Field,
302    };
303
304    use super::*;
305
306    #[test]
307    fn valid() -> Result<(), ArrowError> {
308        let fixed_shape_tensor = FixedShapeTensor::try_new(
309            DataType::Float32,
310            [100, 200, 500],
311            Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
312            Some(vec![2, 0, 1]),
313        )?;
314        let mut field = Field::new_fixed_size_list(
315            "",
316            Field::new("", DataType::Float32, false),
317            i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
318            false,
319        );
320        field.try_with_extension_type(fixed_shape_tensor.clone())?;
321        assert_eq!(
322            field.try_extension_type::<FixedShapeTensor>()?,
323            fixed_shape_tensor
324        );
325        #[cfg(feature = "canonical_extension_types")]
326        assert_eq!(
327            field.try_canonical_extension_type()?,
328            CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor)
329        );
330        Ok(())
331    }
332
333    #[test]
334    #[should_panic(expected = "Field extension type name missing")]
335    fn missing_name() {
336        let field =
337            Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
338                .with_metadata(
339                    [(
340                        EXTENSION_TYPE_METADATA_KEY.to_owned(),
341                        r#"{ "shape": [100, 200, 500], }"#.to_owned(),
342                    )]
343                    .into_iter()
344                    .collect(),
345                );
346        field.extension_type::<FixedShapeTensor>();
347    }
348
349    #[test]
350    #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")]
351    fn invalid_type() {
352        let fixed_shape_tensor =
353            FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap();
354        let field = Field::new_fixed_size_list(
355            "",
356            Field::new("", DataType::Float32, false),
357            i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
358            false,
359        );
360        field.with_extension_type(fixed_shape_tensor);
361    }
362
363    #[test]
364    #[should_panic(expected = "FixedShapeTensor extension types requires metadata")]
365    fn missing_metadata() {
366        let field =
367            Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
368                .with_metadata(
369                    [(
370                        EXTENSION_TYPE_NAME_KEY.to_owned(),
371                        FixedShapeTensor::NAME.to_owned(),
372                    )]
373                    .into_iter()
374                    .collect(),
375                );
376        field.extension_type::<FixedShapeTensor>();
377    }
378
379    #[test]
380    #[should_panic(
381        expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`"
382    )]
383    fn invalid_metadata() {
384        let fixed_shape_tensor =
385            FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap();
386        let field = Field::new_fixed_size_list(
387            "",
388            Field::new("", DataType::Float32, false),
389            i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
390            false,
391        )
392        .with_metadata(
393            [
394                (
395                    EXTENSION_TYPE_NAME_KEY.to_owned(),
396                    FixedShapeTensor::NAME.to_owned(),
397                ),
398                (
399                    EXTENSION_TYPE_METADATA_KEY.to_owned(),
400                    r#"{ "not-shape": [] }"#.to_owned(),
401                ),
402            ]
403            .into_iter()
404            .collect(),
405        );
406        field.extension_type::<FixedShapeTensor>();
407    }
408
409    #[test]
410    #[should_panic(
411        expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2"
412    )]
413    fn invalid_metadata_dimension_names() {
414        FixedShapeTensor::try_new(
415            DataType::Float32,
416            [100, 200, 500],
417            Some(vec!["a".to_owned(), "b".to_owned()]),
418            None,
419        )
420        .unwrap();
421    }
422
423    #[test]
424    #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")]
425    fn invalid_metadata_permutations_len() {
426        FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0]))
427            .unwrap();
428    }
429
430    #[test]
431    #[should_panic(
432        expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
433    )]
434    fn invalid_metadata_permutations_values() {
435        FixedShapeTensor::try_new(
436            DataType::Float32,
437            [100, 200, 500],
438            None,
439            Some(vec![4, 3, 2]),
440        )
441        .unwrap();
442    }
443}