1use serde::{Deserialize, Serialize};
23
24use crate::{ArrowError, DataType, Field, extension::ExtensionType};
25
26#[derive(Debug, Clone, PartialEq)]
73pub struct VariableShapeTensor {
74 value_type: DataType,
76
77 dimensions: usize,
79
80 metadata: VariableShapeTensorMetadata,
82}
83
84impl VariableShapeTensor {
85 pub fn try_new(
92 value_type: DataType,
93 dimensions: usize,
94 dimension_names: Option<Vec<String>>,
95 permutations: Option<Vec<usize>>,
96 uniform_shapes: Option<Vec<Option<i32>>>,
97 ) -> Result<Self, ArrowError> {
98 VariableShapeTensorMetadata::try_new(
100 dimensions,
101 dimension_names,
102 permutations,
103 uniform_shapes,
104 )
105 .map(|metadata| Self {
106 value_type,
107 dimensions,
108 metadata,
109 })
110 }
111
112 pub fn value_type(&self) -> &DataType {
114 &self.value_type
115 }
116
117 pub fn dimensions(&self) -> usize {
119 self.dimensions
120 }
121
122 pub fn dimension_names(&self) -> Option<&[String]> {
125 self.metadata.dimension_names()
126 }
127
128 pub fn permutations(&self) -> Option<&[usize]> {
131 self.metadata.permutations()
132 }
133
134 pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
138 self.metadata.uniform_shapes()
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
144pub struct VariableShapeTensorMetadata {
145 dim_names: Option<Vec<String>>,
147
148 permutations: Option<Vec<usize>>,
150
151 uniform_shape: Option<Vec<Option<i32>>>,
154}
155
156impl VariableShapeTensorMetadata {
157 pub fn try_new(
164 dimensions: usize,
165 dimension_names: Option<Vec<String>>,
166 permutations: Option<Vec<usize>>,
167 uniform_shapes: Option<Vec<Option<i32>>>,
168 ) -> Result<Self, ArrowError> {
169 let dim_names = dimension_names.map(|dimension_names| {
170 if dimension_names.len() != dimensions {
171 Err(ArrowError::InvalidArgumentError(format!(
172 "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
173 )))
174 } else {
175 Ok(dimension_names)
176 }
177 }).transpose()?;
178
179 let permutations = permutations
180 .map(|permutations| {
181 if permutations.len() != dimensions {
182 Err(ArrowError::InvalidArgumentError(format!(
183 "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
184 permutations.len()
185 )))
186 } else {
187 let mut sorted_permutations = permutations.clone();
188 sorted_permutations.sort_unstable();
189 if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
190 Err(ArrowError::InvalidArgumentError(format!(
191 "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
192 )))
193 } else {
194 Ok(permutations)
195 }
196 }
197 })
198 .transpose()?;
199
200 let uniform_shape = uniform_shapes
201 .map(|uniform_shapes| {
202 if uniform_shapes.len() != dimensions {
203 Err(ArrowError::InvalidArgumentError(format!(
204 "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
205 uniform_shapes.len()
206 )))
207 } else {
208 Ok(uniform_shapes)
209 }
210 })
211 .transpose()?;
212
213 Ok(Self {
214 dim_names,
215 permutations,
216 uniform_shape,
217 })
218 }
219
220 pub fn dimension_names(&self) -> Option<&[String]> {
223 self.dim_names.as_ref().map(AsRef::as_ref)
224 }
225
226 pub fn permutations(&self) -> Option<&[usize]> {
229 self.permutations.as_ref().map(AsRef::as_ref)
230 }
231
232 pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
236 self.uniform_shape.as_ref().map(AsRef::as_ref)
237 }
238}
239
240impl ExtensionType for VariableShapeTensor {
241 const NAME: &'static str = "arrow.variable_shape_tensor";
242
243 type Metadata = VariableShapeTensorMetadata;
244
245 fn metadata(&self) -> &Self::Metadata {
246 &self.metadata
247 }
248
249 fn serialize_metadata(&self) -> Option<String> {
250 Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
251 }
252
253 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
254 metadata.map_or_else(
255 || {
256 Err(ArrowError::InvalidArgumentError(
257 "VariableShapeTensor extension types requires metadata".to_owned(),
258 ))
259 },
260 |value| {
261 serde_json::from_str(value).map_err(|e| {
262 ArrowError::InvalidArgumentError(format!(
263 "VariableShapeTensor metadata deserialization failed: {e}"
264 ))
265 })
266 },
267 )
268 }
269
270 fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
271 let expected = DataType::Struct(
272 [
273 Field::new_list(
274 "data",
275 Field::new_list_field(self.value_type.clone(), false),
276 false,
277 ),
278 Field::new(
279 "shape",
280 DataType::new_fixed_size_list(
281 DataType::Int32,
282 i32::try_from(self.dimensions()).expect("overflow"),
283 false,
284 ),
285 false,
286 ),
287 ]
288 .into_iter()
289 .collect(),
290 );
291 data_type
292 .equals_datatype(&expected)
293 .then_some(())
294 .ok_or_else(|| {
295 ArrowError::InvalidArgumentError(format!(
296 "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
297 ))
298 })
299 }
300
301 fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
302 match data_type {
303 DataType::Struct(fields)
304 if fields.len() == 2
305 && matches!(fields.find("data"), Some((0, _)))
306 && matches!(fields.find("shape"), Some((1, _))) =>
307 {
308 let shape_field = &fields[1];
309 match shape_field.data_type() {
310 DataType::FixedSizeList(_, list_size) => {
311 let dimensions = usize::try_from(*list_size).expect("conversion failed");
312 let metadata = VariableShapeTensorMetadata::try_new(
314 dimensions,
315 metadata.dim_names,
316 metadata.permutations,
317 metadata.uniform_shape,
318 )?;
319 let data_field = &fields[0];
320 match data_field.data_type() {
321 DataType::List(field) => Ok(Self {
322 value_type: field.data_type().clone(),
323 dimensions,
324 metadata,
325 }),
326 data_type => Err(ArrowError::InvalidArgumentError(format!(
327 "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
328 ))),
329 }
330 }
331 data_type => Err(ArrowError::InvalidArgumentError(format!(
332 "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
333 ))),
334 }
335 }
336 data_type => Err(ArrowError::InvalidArgumentError(format!(
337 "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
338 ))),
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 #[cfg(feature = "canonical_extension_types")]
346 use crate::extension::CanonicalExtensionType;
347 use crate::{
348 Field,
349 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
350 };
351
352 use super::*;
353
354 #[test]
355 fn valid() -> Result<(), ArrowError> {
356 let variable_shape_tensor = VariableShapeTensor::try_new(
357 DataType::Float32,
358 3,
359 Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
360 Some(vec![2, 0, 1]),
361 Some(vec![Some(400), None, Some(3)]),
362 )?;
363 let mut field = Field::new_struct(
364 "",
365 vec![
366 Field::new_list(
367 "data",
368 Field::new_list_field(DataType::Float32, false),
369 false,
370 ),
371 Field::new_fixed_size_list(
372 "shape",
373 Field::new("", DataType::Int32, false),
374 3,
375 false,
376 ),
377 ],
378 false,
379 );
380 field.try_with_extension_type(variable_shape_tensor.clone())?;
381 assert_eq!(
382 field.try_extension_type::<VariableShapeTensor>()?,
383 variable_shape_tensor
384 );
385 #[cfg(feature = "canonical_extension_types")]
386 assert_eq!(
387 field.try_canonical_extension_type()?,
388 CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
389 );
390 Ok(())
391 }
392
393 #[test]
394 #[should_panic(expected = "Field extension type name missing")]
395 fn missing_name() {
396 let field = Field::new_struct(
397 "",
398 vec![
399 Field::new_list(
400 "data",
401 Field::new_list_field(DataType::Float32, false),
402 false,
403 ),
404 Field::new_fixed_size_list(
405 "shape",
406 Field::new("", DataType::Int32, false),
407 3,
408 false,
409 ),
410 ],
411 false,
412 )
413 .with_metadata(
414 [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
415 .into_iter()
416 .collect(),
417 );
418 field.extension_type::<VariableShapeTensor>();
419 }
420
421 #[test]
422 #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
423 fn invalid_type() {
424 let variable_shape_tensor =
425 VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
426 let field = Field::new_struct(
427 "",
428 vec![
429 Field::new_list(
430 "data",
431 Field::new_list_field(DataType::Float32, false),
432 false,
433 ),
434 Field::new_fixed_size_list(
435 "shape",
436 Field::new("", DataType::Int32, false),
437 3,
438 false,
439 ),
440 ],
441 false,
442 );
443 field.with_extension_type(variable_shape_tensor);
444 }
445
446 #[test]
447 #[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
448 fn missing_metadata() {
449 let field = Field::new_struct(
450 "",
451 vec![
452 Field::new_list(
453 "data",
454 Field::new_list_field(DataType::Float32, false),
455 false,
456 ),
457 Field::new_fixed_size_list(
458 "shape",
459 Field::new("", DataType::Int32, false),
460 3,
461 false,
462 ),
463 ],
464 false,
465 )
466 .with_metadata(
467 [(
468 EXTENSION_TYPE_NAME_KEY.to_owned(),
469 VariableShapeTensor::NAME.to_owned(),
470 )]
471 .into_iter()
472 .collect(),
473 );
474 field.extension_type::<VariableShapeTensor>();
475 }
476
477 #[test]
478 #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
479 fn invalid_metadata() {
480 let field = Field::new_struct(
481 "",
482 vec![
483 Field::new_list(
484 "data",
485 Field::new_list_field(DataType::Float32, false),
486 false,
487 ),
488 Field::new_fixed_size_list(
489 "shape",
490 Field::new("", DataType::Int32, false),
491 3,
492 false,
493 ),
494 ],
495 false,
496 )
497 .with_metadata(
498 [
499 (
500 EXTENSION_TYPE_NAME_KEY.to_owned(),
501 VariableShapeTensor::NAME.to_owned(),
502 ),
503 (
504 EXTENSION_TYPE_METADATA_KEY.to_owned(),
505 r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
506 ),
507 ]
508 .into_iter()
509 .collect(),
510 );
511 field.extension_type::<VariableShapeTensor>();
512 }
513
514 #[test]
515 #[should_panic(
516 expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
517 )]
518 fn invalid_metadata_dimension_names() {
519 VariableShapeTensor::try_new(
520 DataType::Float32,
521 3,
522 Some(vec!["a".to_owned(), "b".to_owned()]),
523 None,
524 None,
525 )
526 .unwrap();
527 }
528
529 #[test]
530 #[should_panic(
531 expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
532 )]
533 fn invalid_metadata_permutations_len() {
534 VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
535 }
536
537 #[test]
538 #[should_panic(
539 expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
540 )]
541 fn invalid_metadata_permutations_values() {
542 VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
543 .unwrap();
544 }
545
546 #[test]
547 #[should_panic(
548 expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
549 )]
550 fn invalid_metadata_uniform_shapes() {
551 VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
552 .unwrap();
553 }
554}