1use serde::{Deserialize, Serialize};
23
24use crate::{extension::ExtensionType, ArrowError, DataType};
25
26#[derive(Debug, Clone, PartialEq)]
73pub struct FixedShapeTensor {
74 value_type: DataType,
76
77 metadata: FixedShapeTensorMetadata,
79}
80
81impl FixedShapeTensor {
82 pub fn try_new(
89 value_type: DataType,
90 shape: impl IntoIterator<Item = usize>,
91 dimension_names: Option<Vec<String>>,
92 permutations: Option<Vec<usize>>,
93 ) -> Result<Self, ArrowError> {
94 FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| {
96 Self {
97 value_type,
98 metadata,
99 }
100 })
101 }
102
103 pub fn value_type(&self) -> &DataType {
105 &self.value_type
106 }
107
108 pub fn list_size(&self) -> usize {
110 self.metadata.list_size()
111 }
112
113 pub fn dimensions(&self) -> usize {
115 self.metadata.dimensions()
116 }
117
118 pub fn dimension_names(&self) -> Option<&[String]> {
121 self.metadata.dimension_names()
122 }
123
124 pub fn permutations(&self) -> Option<&[usize]> {
127 self.metadata.permutations()
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
133pub struct FixedShapeTensorMetadata {
134 shape: Vec<usize>,
136
137 dim_names: Option<Vec<String>>,
139
140 permutations: Option<Vec<usize>>,
142}
143
144impl FixedShapeTensorMetadata {
145 pub fn try_new(
152 shape: impl IntoIterator<Item = usize>,
153 dimension_names: Option<Vec<String>>,
154 permutations: Option<Vec<usize>>,
155 ) -> Result<Self, ArrowError> {
156 let shape = shape.into_iter().collect::<Vec<_>>();
157 let dimensions = shape.len();
158
159 let dim_names = dimension_names.map(|dimension_names| {
160 if dimension_names.len() != dimensions {
161 Err(ArrowError::InvalidArgumentError(format!(
162 "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
163 )))
164 } else {
165 Ok(dimension_names)
166 }
167 }).transpose()?;
168
169 let permutations = permutations
170 .map(|permutations| {
171 if permutations.len() != dimensions {
172 Err(ArrowError::InvalidArgumentError(format!(
173 "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}",
174 permutations.len()
175 )))
176 } else {
177 let mut sorted_permutations = permutations.clone();
178 sorted_permutations.sort_unstable();
179 if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
180 Err(ArrowError::InvalidArgumentError(format!(
181 "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
182 )))
183 } else {
184 Ok(permutations)
185 }
186 }
187 })
188 .transpose()?;
189
190 Ok(Self {
191 shape,
192 dim_names,
193 permutations,
194 })
195 }
196
197 pub fn list_size(&self) -> usize {
199 self.shape.iter().product()
200 }
201
202 pub fn dimensions(&self) -> usize {
204 self.shape.len()
205 }
206
207 pub fn dimension_names(&self) -> Option<&[String]> {
210 self.dim_names.as_ref().map(AsRef::as_ref)
211 }
212
213 pub fn permutations(&self) -> Option<&[usize]> {
216 self.permutations.as_ref().map(AsRef::as_ref)
217 }
218}
219
220impl ExtensionType for FixedShapeTensor {
221 const NAME: &'static str = "arrow.fixed_shape_tensor";
222
223 type Metadata = FixedShapeTensorMetadata;
224
225 fn metadata(&self) -> &Self::Metadata {
226 &self.metadata
227 }
228
229 fn serialize_metadata(&self) -> Option<String> {
230 Some(serde_json::to_string(&self.metadata).expect("metadata serialization"))
231 }
232
233 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
234 metadata.map_or_else(
235 || {
236 Err(ArrowError::InvalidArgumentError(
237 "FixedShapeTensor extension types requires metadata".to_owned(),
238 ))
239 },
240 |value| {
241 serde_json::from_str(value).map_err(|e| {
242 ArrowError::InvalidArgumentError(format!(
243 "FixedShapeTensor metadata deserialization failed: {e}"
244 ))
245 })
246 },
247 )
248 }
249
250 fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
251 let expected = DataType::new_fixed_size_list(
252 self.value_type.clone(),
253 i32::try_from(self.list_size()).expect("overflow"),
254 false,
255 );
256 data_type
257 .equals_datatype(&expected)
258 .then_some(())
259 .ok_or_else(|| {
260 ArrowError::InvalidArgumentError(format!(
261 "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}"
262 ))
263 })
264 }
265
266 fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
267 match data_type {
268 DataType::FixedSizeList(field, list_size) if !field.is_nullable() => {
269 let metadata = FixedShapeTensorMetadata::try_new(
271 metadata.shape,
272 metadata.dim_names,
273 metadata.permutations,
274 )?;
275 let expected_size = i32::try_from(metadata.list_size()).expect("overflow");
277 if *list_size != expected_size {
278 Err(ArrowError::InvalidArgumentError(format!(
279 "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)"
280 )))
281 } else {
282 Ok(Self {
283 value_type: field.data_type().clone(),
284 metadata,
285 })
286 }
287 }
288 data_type => Err(ArrowError::InvalidArgumentError(format!(
289 "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}"
290 ))),
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 #[cfg(feature = "canonical_extension_types")]
298 use crate::extension::CanonicalExtensionType;
299 use crate::{
300 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
301 Field,
302 };
303
304 use super::*;
305
306 #[test]
307 fn valid() -> Result<(), ArrowError> {
308 let fixed_shape_tensor = FixedShapeTensor::try_new(
309 DataType::Float32,
310 [100, 200, 500],
311 Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
312 Some(vec![2, 0, 1]),
313 )?;
314 let mut field = Field::new_fixed_size_list(
315 "",
316 Field::new("", DataType::Float32, false),
317 i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
318 false,
319 );
320 field.try_with_extension_type(fixed_shape_tensor.clone())?;
321 assert_eq!(
322 field.try_extension_type::<FixedShapeTensor>()?,
323 fixed_shape_tensor
324 );
325 #[cfg(feature = "canonical_extension_types")]
326 assert_eq!(
327 field.try_canonical_extension_type()?,
328 CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor)
329 );
330 Ok(())
331 }
332
333 #[test]
334 #[should_panic(expected = "Field extension type name missing")]
335 fn missing_name() {
336 let field =
337 Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
338 .with_metadata(
339 [(
340 EXTENSION_TYPE_METADATA_KEY.to_owned(),
341 r#"{ "shape": [100, 200, 500], }"#.to_owned(),
342 )]
343 .into_iter()
344 .collect(),
345 );
346 field.extension_type::<FixedShapeTensor>();
347 }
348
349 #[test]
350 #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")]
351 fn invalid_type() {
352 let fixed_shape_tensor =
353 FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap();
354 let field = Field::new_fixed_size_list(
355 "",
356 Field::new("", DataType::Float32, false),
357 i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
358 false,
359 );
360 field.with_extension_type(fixed_shape_tensor);
361 }
362
363 #[test]
364 #[should_panic(expected = "FixedShapeTensor extension types requires metadata")]
365 fn missing_metadata() {
366 let field =
367 Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
368 .with_metadata(
369 [(
370 EXTENSION_TYPE_NAME_KEY.to_owned(),
371 FixedShapeTensor::NAME.to_owned(),
372 )]
373 .into_iter()
374 .collect(),
375 );
376 field.extension_type::<FixedShapeTensor>();
377 }
378
379 #[test]
380 #[should_panic(
381 expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`"
382 )]
383 fn invalid_metadata() {
384 let fixed_shape_tensor =
385 FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap();
386 let field = Field::new_fixed_size_list(
387 "",
388 Field::new("", DataType::Float32, false),
389 i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
390 false,
391 )
392 .with_metadata(
393 [
394 (
395 EXTENSION_TYPE_NAME_KEY.to_owned(),
396 FixedShapeTensor::NAME.to_owned(),
397 ),
398 (
399 EXTENSION_TYPE_METADATA_KEY.to_owned(),
400 r#"{ "not-shape": [] }"#.to_owned(),
401 ),
402 ]
403 .into_iter()
404 .collect(),
405 );
406 field.extension_type::<FixedShapeTensor>();
407 }
408
409 #[test]
410 #[should_panic(
411 expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2"
412 )]
413 fn invalid_metadata_dimension_names() {
414 FixedShapeTensor::try_new(
415 DataType::Float32,
416 [100, 200, 500],
417 Some(vec!["a".to_owned(), "b".to_owned()]),
418 None,
419 )
420 .unwrap();
421 }
422
423 #[test]
424 #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")]
425 fn invalid_metadata_permutations_len() {
426 FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0]))
427 .unwrap();
428 }
429
430 #[test]
431 #[should_panic(
432 expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
433 )]
434 fn invalid_metadata_permutations_values() {
435 FixedShapeTensor::try_new(
436 DataType::Float32,
437 [100, 200, 500],
438 None,
439 Some(vec![4, 3, 2]),
440 )
441 .unwrap();
442 }
443}