arrow_schema/extension/canonical/
variable_shape_tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! VariableShapeTensor
19//!
20//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
21
22use serde::{Deserialize, Serialize};
23
24use crate::{ArrowError, DataType, Field, extension::ExtensionType};
25
26/// The extension type for `VariableShapeTensor`.
27///
28/// Extension name: `arrow.variable_shape_tensor`.
29///
30/// The storage type of the extension is: StructArray where struct is composed
31/// of data and shape fields describing a single tensor per row:
32/// - `data` is a List holding tensor elements (each list element is a single
33///   tensor). The List’s value type is the value type of the tensor, such as
34///   an integer or floating-point type.
35/// - `shape` is a `FixedSizeList<int32>[ndim]` of the tensor shape where the
36///   size of the list `ndim` is equal to the number of dimensions of the
37///   tensor.
38///
39/// Extension type parameters:
40/// `value_type`: the Arrow data type of individual tensor elements.
41///
42/// Optional parameters describing the logical layout:
43/// - `dim_names`: explicit names to tensor dimensions as an array. The length
44///   of it should be equal to the shape length and equal to the number of
45///   dimensions.
46///   `dim_names` can be used if the dimensions have well-known names and they
47///   map to the physical layout (row-major).
48/// - `permutation`: indices of the desired ordering of the original
49///   dimensions, defined as an array.
50///   The indices contain a permutation of the values `[0, 1, .., N-1]` where
51///   `N` is the number of dimensions. The permutation indicates which
52///   dimension of the logical layout corresponds to which dimension of the
53///   physical tensor (the i-th dimension of the logical view corresponds to
54///   the dimension with number `permutations[i]` of the physical tensor).
55///   Permutation can be useful in case the logical order of the tensor is a
56///   permutation of the physical order (row-major).
57///   When logical and physical layout are equal, the permutation will always
58///   be (`[0, 1, .., N-1]`) and can therefore be left out.
59/// - `uniform_shape`: sizes of individual tensor’s dimensions which are
60///   guaranteed to stay constant in uniform dimensions and can vary in non-
61///   uniform dimensions. This holds over all tensors in the array. Sizes in
62///   uniform dimensions are represented with int32 values, while sizes of the
63///   non-uniform dimensions are not known in advance and are represented with
64///   null. If `uniform_shape` is not provided it is assumed that all
65///   dimensions are non-uniform. An array containing a tensor with shape (2,
66///   3, 4) and whose first and last dimensions are uniform would have
67///   `uniform_shape` (2, null, 4). This allows for interpreting the tensor
68///   correctly without accounting for uniform dimensions while still
69///   permitting optional optimizations that take advantage of the uniformity.
70///
71/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
72#[derive(Debug, Clone, PartialEq)]
73pub struct VariableShapeTensor {
74    /// The data type of individual tensor elements.
75    value_type: DataType,
76
77    /// The number of dimensions of the tensor.
78    dimensions: usize,
79
80    /// The metadata of this extension type.
81    metadata: VariableShapeTensorMetadata,
82}
83
84impl VariableShapeTensor {
85    /// Returns a new variable shape tensor extension type.
86    ///
87    /// # Error
88    ///
89    /// Return an error if the provided dimension names, permutations or
90    /// uniform shapes are invalid.
91    pub fn try_new(
92        value_type: DataType,
93        dimensions: usize,
94        dimension_names: Option<Vec<String>>,
95        permutations: Option<Vec<usize>>,
96        uniform_shapes: Option<Vec<Option<i32>>>,
97    ) -> Result<Self, ArrowError> {
98        // TODO: are all data types are suitable as value type?
99        VariableShapeTensorMetadata::try_new(
100            dimensions,
101            dimension_names,
102            permutations,
103            uniform_shapes,
104        )
105        .map(|metadata| Self {
106            value_type,
107            dimensions,
108            metadata,
109        })
110    }
111
112    /// Returns the value type of the individual tensor elements.
113    pub fn value_type(&self) -> &DataType {
114        &self.value_type
115    }
116
117    /// Returns the number of dimensions  in this variable shape tensor.
118    pub fn dimensions(&self) -> usize {
119        self.dimensions
120    }
121
122    /// Returns the names of the dimensions in this variable shape tensor, if
123    /// set.
124    pub fn dimension_names(&self) -> Option<&[String]> {
125        self.metadata.dimension_names()
126    }
127
128    /// Returns the indices of the desired ordering of the original
129    /// dimensions, if set.
130    pub fn permutations(&self) -> Option<&[usize]> {
131        self.metadata.permutations()
132    }
133
134    /// Returns sizes of individual tensor’s dimensions which are guaranteed
135    /// to stay constant in uniform dimensions and can vary in non-uniform
136    /// dimensions.
137    pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
138        self.metadata.uniform_shapes()
139    }
140}
141
142/// Extension type metadata for [`VariableShapeTensor`].
143#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
144pub struct VariableShapeTensorMetadata {
145    /// Explicit names to tensor dimensions.
146    dim_names: Option<Vec<String>>,
147
148    /// Indices of the desired ordering of the original dimensions.
149    permutations: Option<Vec<usize>>,
150
151    /// Sizes of individual tensor’s dimensions which are guaranteed to stay
152    /// constant in uniform dimensions and can vary in non-uniform dimensions.
153    uniform_shape: Option<Vec<Option<i32>>>,
154}
155
156impl VariableShapeTensorMetadata {
157    /// Returns metadata for a variable shape tensor extension type.
158    ///
159    /// # Error
160    ///
161    /// Return an error if the provided dimension names, permutations or
162    /// uniform shapes are invalid.
163    pub fn try_new(
164        dimensions: usize,
165        dimension_names: Option<Vec<String>>,
166        permutations: Option<Vec<usize>>,
167        uniform_shapes: Option<Vec<Option<i32>>>,
168    ) -> Result<Self, ArrowError> {
169        let dim_names = dimension_names.map(|dimension_names| {
170            if dimension_names.len() != dimensions {
171                Err(ArrowError::InvalidArgumentError(format!(
172                    "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
173                )))
174            } else {
175                Ok(dimension_names)
176            }
177        }).transpose()?;
178
179        let permutations = permutations
180            .map(|permutations| {
181                if permutations.len() != dimensions {
182                    Err(ArrowError::InvalidArgumentError(format!(
183                        "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
184                        permutations.len()
185                    )))
186                } else {
187                    let mut sorted_permutations = permutations.clone();
188                    sorted_permutations.sort_unstable();
189                    if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
190                        Err(ArrowError::InvalidArgumentError(format!(
191                            "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
192                        )))
193                    } else {
194                        Ok(permutations)
195                    }
196                }
197            })
198            .transpose()?;
199
200        let uniform_shape = uniform_shapes
201            .map(|uniform_shapes| {
202                if uniform_shapes.len() != dimensions {
203                    Err(ArrowError::InvalidArgumentError(format!(
204                        "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
205                        uniform_shapes.len()
206                    )))
207                } else {
208                    Ok(uniform_shapes)
209                }
210            })
211            .transpose()?;
212
213        Ok(Self {
214            dim_names,
215            permutations,
216            uniform_shape,
217        })
218    }
219
220    /// Returns the names of the dimensions in this variable shape tensor, if
221    /// set.
222    pub fn dimension_names(&self) -> Option<&[String]> {
223        self.dim_names.as_ref().map(AsRef::as_ref)
224    }
225
226    /// Returns the indices of the desired ordering of the original dimensions,
227    /// if set.
228    pub fn permutations(&self) -> Option<&[usize]> {
229        self.permutations.as_ref().map(AsRef::as_ref)
230    }
231
232    /// Returns sizes of individual tensor’s dimensions which are guaranteed
233    /// to stay constant in uniform dimensions and can vary in non-uniform
234    /// dimensions.
235    pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
236        self.uniform_shape.as_ref().map(AsRef::as_ref)
237    }
238}
239
240impl ExtensionType for VariableShapeTensor {
241    const NAME: &'static str = "arrow.variable_shape_tensor";
242
243    type Metadata = VariableShapeTensorMetadata;
244
245    fn metadata(&self) -> &Self::Metadata {
246        &self.metadata
247    }
248
249    fn serialize_metadata(&self) -> Option<String> {
250        Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
251    }
252
253    fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
254        metadata.map_or_else(
255            || {
256                Err(ArrowError::InvalidArgumentError(
257                    "VariableShapeTensor extension types requires metadata".to_owned(),
258                ))
259            },
260            |value| {
261                serde_json::from_str(value).map_err(|e| {
262                    ArrowError::InvalidArgumentError(format!(
263                        "VariableShapeTensor metadata deserialization failed: {e}"
264                    ))
265                })
266            },
267        )
268    }
269
270    fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
271        let expected = DataType::Struct(
272            [
273                Field::new_list(
274                    "data",
275                    Field::new_list_field(self.value_type.clone(), false),
276                    false,
277                ),
278                Field::new(
279                    "shape",
280                    DataType::new_fixed_size_list(
281                        DataType::Int32,
282                        i32::try_from(self.dimensions()).expect("overflow"),
283                        false,
284                    ),
285                    false,
286                ),
287            ]
288            .into_iter()
289            .collect(),
290        );
291        data_type
292            .equals_datatype(&expected)
293            .then_some(())
294            .ok_or_else(|| {
295                ArrowError::InvalidArgumentError(format!(
296                    "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
297                ))
298            })
299    }
300
301    fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
302        match data_type {
303            DataType::Struct(fields)
304                if fields.len() == 2
305                    && matches!(fields.find("data"), Some((0, _)))
306                    && matches!(fields.find("shape"), Some((1, _))) =>
307            {
308                let shape_field = &fields[1];
309                match shape_field.data_type() {
310                    DataType::FixedSizeList(_, list_size) => {
311                        let dimensions = usize::try_from(*list_size).expect("conversion failed");
312                        // Make sure the metadata is valid.
313                        let metadata = VariableShapeTensorMetadata::try_new(
314                            dimensions,
315                            metadata.dim_names,
316                            metadata.permutations,
317                            metadata.uniform_shape,
318                        )?;
319                        let data_field = &fields[0];
320                        match data_field.data_type() {
321                            DataType::List(field) => Ok(Self {
322                                value_type: field.data_type().clone(),
323                                dimensions,
324                                metadata,
325                            }),
326                            data_type => Err(ArrowError::InvalidArgumentError(format!(
327                                "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
328                            ))),
329                        }
330                    }
331                    data_type => Err(ArrowError::InvalidArgumentError(format!(
332                        "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
333                    ))),
334                }
335            }
336            data_type => Err(ArrowError::InvalidArgumentError(format!(
337                "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
338            ))),
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    #[cfg(feature = "canonical_extension_types")]
346    use crate::extension::CanonicalExtensionType;
347    use crate::{
348        Field,
349        extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
350    };
351
352    use super::*;
353
354    #[test]
355    fn valid() -> Result<(), ArrowError> {
356        let variable_shape_tensor = VariableShapeTensor::try_new(
357            DataType::Float32,
358            3,
359            Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
360            Some(vec![2, 0, 1]),
361            Some(vec![Some(400), None, Some(3)]),
362        )?;
363        let mut field = Field::new_struct(
364            "",
365            vec![
366                Field::new_list(
367                    "data",
368                    Field::new_list_field(DataType::Float32, false),
369                    false,
370                ),
371                Field::new_fixed_size_list(
372                    "shape",
373                    Field::new("", DataType::Int32, false),
374                    3,
375                    false,
376                ),
377            ],
378            false,
379        );
380        field.try_with_extension_type(variable_shape_tensor.clone())?;
381        assert_eq!(
382            field.try_extension_type::<VariableShapeTensor>()?,
383            variable_shape_tensor
384        );
385        #[cfg(feature = "canonical_extension_types")]
386        assert_eq!(
387            field.try_canonical_extension_type()?,
388            CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
389        );
390        Ok(())
391    }
392
393    #[test]
394    #[should_panic(expected = "Field extension type name missing")]
395    fn missing_name() {
396        let field = Field::new_struct(
397            "",
398            vec![
399                Field::new_list(
400                    "data",
401                    Field::new_list_field(DataType::Float32, false),
402                    false,
403                ),
404                Field::new_fixed_size_list(
405                    "shape",
406                    Field::new("", DataType::Int32, false),
407                    3,
408                    false,
409                ),
410            ],
411            false,
412        )
413        .with_metadata(
414            [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
415                .into_iter()
416                .collect(),
417        );
418        field.extension_type::<VariableShapeTensor>();
419    }
420
421    #[test]
422    #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
423    fn invalid_type() {
424        let variable_shape_tensor =
425            VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
426        let field = Field::new_struct(
427            "",
428            vec![
429                Field::new_list(
430                    "data",
431                    Field::new_list_field(DataType::Float32, false),
432                    false,
433                ),
434                Field::new_fixed_size_list(
435                    "shape",
436                    Field::new("", DataType::Int32, false),
437                    3,
438                    false,
439                ),
440            ],
441            false,
442        );
443        field.with_extension_type(variable_shape_tensor);
444    }
445
446    #[test]
447    #[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
448    fn missing_metadata() {
449        let field = Field::new_struct(
450            "",
451            vec![
452                Field::new_list(
453                    "data",
454                    Field::new_list_field(DataType::Float32, false),
455                    false,
456                ),
457                Field::new_fixed_size_list(
458                    "shape",
459                    Field::new("", DataType::Int32, false),
460                    3,
461                    false,
462                ),
463            ],
464            false,
465        )
466        .with_metadata(
467            [(
468                EXTENSION_TYPE_NAME_KEY.to_owned(),
469                VariableShapeTensor::NAME.to_owned(),
470            )]
471            .into_iter()
472            .collect(),
473        );
474        field.extension_type::<VariableShapeTensor>();
475    }
476
477    #[test]
478    #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
479    fn invalid_metadata() {
480        let field = Field::new_struct(
481            "",
482            vec![
483                Field::new_list(
484                    "data",
485                    Field::new_list_field(DataType::Float32, false),
486                    false,
487                ),
488                Field::new_fixed_size_list(
489                    "shape",
490                    Field::new("", DataType::Int32, false),
491                    3,
492                    false,
493                ),
494            ],
495            false,
496        )
497        .with_metadata(
498            [
499                (
500                    EXTENSION_TYPE_NAME_KEY.to_owned(),
501                    VariableShapeTensor::NAME.to_owned(),
502                ),
503                (
504                    EXTENSION_TYPE_METADATA_KEY.to_owned(),
505                    r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
506                ),
507            ]
508            .into_iter()
509            .collect(),
510        );
511        field.extension_type::<VariableShapeTensor>();
512    }
513
514    #[test]
515    #[should_panic(
516        expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
517    )]
518    fn invalid_metadata_dimension_names() {
519        VariableShapeTensor::try_new(
520            DataType::Float32,
521            3,
522            Some(vec!["a".to_owned(), "b".to_owned()]),
523            None,
524            None,
525        )
526        .unwrap();
527    }
528
529    #[test]
530    #[should_panic(
531        expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
532    )]
533    fn invalid_metadata_permutations_len() {
534        VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
535    }
536
537    #[test]
538    #[should_panic(
539        expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
540    )]
541    fn invalid_metadata_permutations_values() {
542        VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
543            .unwrap();
544    }
545
546    #[test]
547    #[should_panic(
548        expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
549    )]
550    fn invalid_metadata_uniform_shapes() {
551        VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
552            .unwrap();
553    }
554}