arrow_schema/extension/canonical/
variable_shape_tensor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! VariableShapeTensor
19//!
20//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
21
22use serde::{Deserialize, Serialize};
23
24use crate::{extension::ExtensionType, ArrowError, DataType, Field};
25
26/// The extension type for `VariableShapeTensor`.
27///
28/// Extension name: `arrow.variable_shape_tensor`.
29///
30/// The storage type of the extension is: StructArray where struct is composed
31/// of data and shape fields describing a single tensor per row:
32/// - `data` is a List holding tensor elements (each list element is a single
33///   tensor). The List’s value type is the value type of the tensor, such as
34///   an integer or floating-point type.
35/// - `shape` is a `FixedSizeList<int32>[ndim]` of the tensor shape where the
36///   size of the list `ndim` is equal to the number of dimensions of the
37///   tensor.
38///
39/// Extension type parameters:
40/// `value_type`: the Arrow data type of individual tensor elements.
41///
42/// Optional parameters describing the logical layout:
43/// - `dim_names`: explicit names to tensor dimensions as an array. The length
44///   of it should be equal to the shape length and equal to the number of
45///   dimensions.
46///   `dim_names` can be used if the dimensions have well-known names and they
47///   map to the physical layout (row-major).
48/// - `permutation`: indices of the desired ordering of the original
49///   dimensions, defined as an array.
50///   The indices contain a permutation of the values `[0, 1, .., N-1]` where
51///   `N` is the number of dimensions. The permutation indicates which
52///   dimension of the logical layout corresponds to which dimension of the
53///   physical tensor (the i-th dimension of the logical view corresponds to
54///   the dimension with number `permutations[i]` of the physical tensor).
55///   Permutation can be useful in case the logical order of the tensor is a
56///   permutation of the physical order (row-major).
57///   When logical and physical layout are equal, the permutation will always
58///   be (`[0, 1, .., N-1]`) and can therefore be left out.
59/// - `uniform_shape`: sizes of individual tensor’s dimensions which are
60///   guaranteed to stay constant in uniform dimensions and can vary in non-
61///   uniform dimensions. This holds over all tensors in the array. Sizes in
62///   uniform dimensions are represented with int32 values, while sizes of the
63///   non-uniform dimensions are not known in advance and are represented with
64///   null. If `uniform_shape` is not provided it is assumed that all
65///   dimensions are non-uniform. An array containing a tensor with shape (2,
66///   3, 4) and whose first and last dimensions are uniform would have
67///   `uniform_shape` (2, null, 4). This allows for interpreting the tensor
68///   correctly without accounting for uniform dimensions while still
69///   permitting optional optimizations that take advantage of the uniformity.
70///
71/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
72#[derive(Debug, Clone, PartialEq)]
73pub struct VariableShapeTensor {
74    /// The data type of individual tensor elements.
75    value_type: DataType,
76
77    /// The number of dimensions of the tensor.
78    dimensions: usize,
79
80    /// The metadata of this extension type.
81    metadata: VariableShapeTensorMetadata,
82}
83
84impl VariableShapeTensor {
85    /// Returns a new variable shape tensor extension type.
86    ///
87    /// # Error
88    ///
89    /// Return an error if the provided dimension names, permutations or
90    /// uniform shapes are invalid.
91    pub fn try_new(
92        value_type: DataType,
93        dimensions: usize,
94        dimension_names: Option<Vec<String>>,
95        permutations: Option<Vec<usize>>,
96        uniform_shapes: Option<Vec<Option<i32>>>,
97    ) -> Result<Self, ArrowError> {
98        // TODO: are all data types are suitable as value type?
99        VariableShapeTensorMetadata::try_new(
100            dimensions,
101            dimension_names,
102            permutations,
103            uniform_shapes,
104        )
105        .map(|metadata| Self {
106            value_type,
107            dimensions,
108            metadata,
109        })
110    }
111
112    /// Returns the value type of the individual tensor elements.
113    pub fn value_type(&self) -> &DataType {
114        &self.value_type
115    }
116
117    /// Returns the number of dimensions  in this variable shape tensor.
118    pub fn dimensions(&self) -> usize {
119        self.dimensions
120    }
121
122    /// Returns the names of the dimensions in this variable shape tensor, if
123    /// set.
124    pub fn dimension_names(&self) -> Option<&[String]> {
125        self.metadata.dimension_names()
126    }
127
128    /// Returns the indices of the desired ordering of the original
129    /// dimensions, if set.
130    pub fn permutations(&self) -> Option<&[usize]> {
131        self.metadata.permutations()
132    }
133
134    /// Returns sizes of individual tensor’s dimensions which are guaranteed
135    /// to stay constant in uniform dimensions and can vary in non-uniform
136    /// dimensions.
137    pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
138        self.metadata.uniform_shapes()
139    }
140}
141
142/// Extension type metadata for [`VariableShapeTensor`].
143#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
144pub struct VariableShapeTensorMetadata {
145    /// Explicit names to tensor dimensions.
146    dim_names: Option<Vec<String>>,
147
148    /// Indices of the desired ordering of the original dimensions.
149    permutations: Option<Vec<usize>>,
150
151    /// Sizes of individual tensor’s dimensions which are guaranteed to stay
152    /// constant in uniform dimensions and can vary in non-uniform dimensions.
153    uniform_shape: Option<Vec<Option<i32>>>,
154}
155
156impl VariableShapeTensorMetadata {
157    /// Returns metadata for a variable shape tensor extension type.
158    ///
159    /// # Error
160    ///
161    /// Return an error if the provided dimension names, permutations or
162    /// uniform shapes are invalid.
163    pub fn try_new(
164        dimensions: usize,
165        dimension_names: Option<Vec<String>>,
166        permutations: Option<Vec<usize>>,
167        uniform_shapes: Option<Vec<Option<i32>>>,
168    ) -> Result<Self, ArrowError> {
169        let dim_names = dimension_names.map(|dimension_names| {
170            if dimension_names.len() != dimensions {
171                Err(ArrowError::InvalidArgumentError(format!(
172                    "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
173                )))
174            } else {
175                Ok(dimension_names)
176            }
177        }).transpose()?;
178
179        let permutations = permutations
180            .map(|permutations| {
181                if permutations.len() != dimensions {
182                    Err(ArrowError::InvalidArgumentError(format!(
183                        "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
184                        permutations.len()
185                    )))
186                } else {
187                    let mut sorted_permutations = permutations.clone();
188                    sorted_permutations.sort_unstable();
189                    if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
190                        Err(ArrowError::InvalidArgumentError(format!(
191                            "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
192                        )))
193                    } else {
194                        Ok(permutations)
195                    }
196                }
197            })
198            .transpose()?;
199
200        let uniform_shape = uniform_shapes
201            .map(|uniform_shapes| {
202                if uniform_shapes.len() != dimensions {
203                    Err(ArrowError::InvalidArgumentError(format!(
204                        "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
205                        uniform_shapes.len()
206                    )))
207                } else {
208                    Ok(uniform_shapes)
209                }
210            })
211            .transpose()?;
212
213        Ok(Self {
214            dim_names,
215            permutations,
216            uniform_shape,
217        })
218    }
219
220    /// Returns the names of the dimensions in this variable shape tensor, if
221    /// set.
222    pub fn dimension_names(&self) -> Option<&[String]> {
223        self.dim_names.as_ref().map(AsRef::as_ref)
224    }
225
226    /// Returns the indices of the desired ordering of the original dimensions,
227    /// if set.
228    pub fn permutations(&self) -> Option<&[usize]> {
229        self.permutations.as_ref().map(AsRef::as_ref)
230    }
231
232    /// Returns sizes of individual tensor’s dimensions which are guaranteed
233    /// to stay constant in uniform dimensions and can vary in non-uniform
234    /// dimensions.
235    pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
236        self.uniform_shape.as_ref().map(AsRef::as_ref)
237    }
238}
239
240impl ExtensionType for VariableShapeTensor {
241    const NAME: &'static str = "arrow.variable_shape_tensor";
242
243    type Metadata = VariableShapeTensorMetadata;
244
245    fn metadata(&self) -> &Self::Metadata {
246        &self.metadata
247    }
248
249    fn serialize_metadata(&self) -> Option<String> {
250        Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
251    }
252
253    fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
254        metadata.map_or_else(
255            || {
256                Err(ArrowError::InvalidArgumentError(
257                    "VariableShapeTensor extension types requires metadata".to_owned(),
258                ))
259            },
260            |value| {
261                serde_json::from_str(value).map_err(|e| {
262                    ArrowError::InvalidArgumentError(format!(
263                        "VariableShapeTensor metadata deserialization failed: {e}"
264                    ))
265                })
266            },
267        )
268    }
269
270    fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
271        let expected = DataType::Struct(
272            [
273                Field::new_list(
274                    "data",
275                    Field::new_list_field(self.value_type.clone(), false),
276                    false,
277                ),
278                Field::new(
279                    "shape",
280                    DataType::new_fixed_size_list(
281                        DataType::Int32,
282                        i32::try_from(self.dimensions()).expect("overflow"),
283                        false,
284                    ),
285                    false,
286                ),
287            ]
288            .into_iter()
289            .collect(),
290        );
291        data_type
292            .equals_datatype(&expected)
293            .then_some(())
294            .ok_or_else(|| {
295                ArrowError::InvalidArgumentError(format!(
296                    "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
297                ))
298            })
299    }
300
301    fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
302        match data_type {
303            DataType::Struct(fields)
304                if fields.len() == 2
305                    && matches!(fields.find("data"), Some((0, _)))
306                    && matches!(fields.find("shape"), Some((1, _))) =>
307            {
308                let shape_field = &fields[1];
309                match shape_field.data_type() {
310                    DataType::FixedSizeList(_, list_size) => {
311                        let dimensions = usize::try_from(*list_size).expect("conversion failed");
312                        // Make sure the metadata is valid.
313                        let metadata = VariableShapeTensorMetadata::try_new(dimensions, metadata.dim_names, metadata.permutations, metadata.uniform_shape)?;
314                        let data_field = &fields[0];
315                        match data_field.data_type() {
316                            DataType::List(field) => {
317                                Ok(Self {
318                                    value_type: field.data_type().clone(),
319                                    dimensions,
320                                    metadata
321                                })
322                            }
323                            data_type => Err(ArrowError::InvalidArgumentError(format!(
324                                "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
325                            ))),
326                        }
327                    }
328                    data_type => Err(ArrowError::InvalidArgumentError(format!(
329                        "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
330                    ))),
331                }
332            }
333            data_type => Err(ArrowError::InvalidArgumentError(format!(
334                "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
335            ))),
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    #[cfg(feature = "canonical_extension_types")]
343    use crate::extension::CanonicalExtensionType;
344    use crate::{
345        extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
346        Field,
347    };
348
349    use super::*;
350
351    #[test]
352    fn valid() -> Result<(), ArrowError> {
353        let variable_shape_tensor = VariableShapeTensor::try_new(
354            DataType::Float32,
355            3,
356            Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
357            Some(vec![2, 0, 1]),
358            Some(vec![Some(400), None, Some(3)]),
359        )?;
360        let mut field = Field::new_struct(
361            "",
362            vec![
363                Field::new_list(
364                    "data",
365                    Field::new_list_field(DataType::Float32, false),
366                    false,
367                ),
368                Field::new_fixed_size_list(
369                    "shape",
370                    Field::new("", DataType::Int32, false),
371                    3,
372                    false,
373                ),
374            ],
375            false,
376        );
377        field.try_with_extension_type(variable_shape_tensor.clone())?;
378        assert_eq!(
379            field.try_extension_type::<VariableShapeTensor>()?,
380            variable_shape_tensor
381        );
382        #[cfg(feature = "canonical_extension_types")]
383        assert_eq!(
384            field.try_canonical_extension_type()?,
385            CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
386        );
387        Ok(())
388    }
389
390    #[test]
391    #[should_panic(expected = "Field extension type name missing")]
392    fn missing_name() {
393        let field = Field::new_struct(
394            "",
395            vec![
396                Field::new_list(
397                    "data",
398                    Field::new_list_field(DataType::Float32, false),
399                    false,
400                ),
401                Field::new_fixed_size_list(
402                    "shape",
403                    Field::new("", DataType::Int32, false),
404                    3,
405                    false,
406                ),
407            ],
408            false,
409        )
410        .with_metadata(
411            [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
412                .into_iter()
413                .collect(),
414        );
415        field.extension_type::<VariableShapeTensor>();
416    }
417
418    #[test]
419    #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
420    fn invalid_type() {
421        let variable_shape_tensor =
422            VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
423        let field = Field::new_struct(
424            "",
425            vec![
426                Field::new_list(
427                    "data",
428                    Field::new_list_field(DataType::Float32, false),
429                    false,
430                ),
431                Field::new_fixed_size_list(
432                    "shape",
433                    Field::new("", DataType::Int32, false),
434                    3,
435                    false,
436                ),
437            ],
438            false,
439        );
440        field.with_extension_type(variable_shape_tensor);
441    }
442
443    #[test]
444    #[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
445    fn missing_metadata() {
446        let field = Field::new_struct(
447            "",
448            vec![
449                Field::new_list(
450                    "data",
451                    Field::new_list_field(DataType::Float32, false),
452                    false,
453                ),
454                Field::new_fixed_size_list(
455                    "shape",
456                    Field::new("", DataType::Int32, false),
457                    3,
458                    false,
459                ),
460            ],
461            false,
462        )
463        .with_metadata(
464            [(
465                EXTENSION_TYPE_NAME_KEY.to_owned(),
466                VariableShapeTensor::NAME.to_owned(),
467            )]
468            .into_iter()
469            .collect(),
470        );
471        field.extension_type::<VariableShapeTensor>();
472    }
473
474    #[test]
475    #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
476    fn invalid_metadata() {
477        let field = Field::new_struct(
478            "",
479            vec![
480                Field::new_list(
481                    "data",
482                    Field::new_list_field(DataType::Float32, false),
483                    false,
484                ),
485                Field::new_fixed_size_list(
486                    "shape",
487                    Field::new("", DataType::Int32, false),
488                    3,
489                    false,
490                ),
491            ],
492            false,
493        )
494        .with_metadata(
495            [
496                (
497                    EXTENSION_TYPE_NAME_KEY.to_owned(),
498                    VariableShapeTensor::NAME.to_owned(),
499                ),
500                (
501                    EXTENSION_TYPE_METADATA_KEY.to_owned(),
502                    r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
503                ),
504            ]
505            .into_iter()
506            .collect(),
507        );
508        field.extension_type::<VariableShapeTensor>();
509    }
510
511    #[test]
512    #[should_panic(
513        expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
514    )]
515    fn invalid_metadata_dimension_names() {
516        VariableShapeTensor::try_new(
517            DataType::Float32,
518            3,
519            Some(vec!["a".to_owned(), "b".to_owned()]),
520            None,
521            None,
522        )
523        .unwrap();
524    }
525
526    #[test]
527    #[should_panic(
528        expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
529    )]
530    fn invalid_metadata_permutations_len() {
531        VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
532    }
533
534    #[test]
535    #[should_panic(
536        expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
537    )]
538    fn invalid_metadata_permutations_values() {
539        VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
540            .unwrap();
541    }
542
543    #[test]
544    #[should_panic(
545        expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
546    )]
547    fn invalid_metadata_uniform_shapes() {
548        VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
549            .unwrap();
550    }
551}