1use serde::{Deserialize, Serialize};
23
24use crate::{extension::ExtensionType, ArrowError, DataType, Field};
25
26#[derive(Debug, Clone, PartialEq)]
73pub struct VariableShapeTensor {
74 value_type: DataType,
76
77 dimensions: usize,
79
80 metadata: VariableShapeTensorMetadata,
82}
83
84impl VariableShapeTensor {
85 pub fn try_new(
92 value_type: DataType,
93 dimensions: usize,
94 dimension_names: Option<Vec<String>>,
95 permutations: Option<Vec<usize>>,
96 uniform_shapes: Option<Vec<Option<i32>>>,
97 ) -> Result<Self, ArrowError> {
98 VariableShapeTensorMetadata::try_new(
100 dimensions,
101 dimension_names,
102 permutations,
103 uniform_shapes,
104 )
105 .map(|metadata| Self {
106 value_type,
107 dimensions,
108 metadata,
109 })
110 }
111
112 pub fn value_type(&self) -> &DataType {
114 &self.value_type
115 }
116
117 pub fn dimensions(&self) -> usize {
119 self.dimensions
120 }
121
122 pub fn dimension_names(&self) -> Option<&[String]> {
125 self.metadata.dimension_names()
126 }
127
128 pub fn permutations(&self) -> Option<&[usize]> {
131 self.metadata.permutations()
132 }
133
134 pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
138 self.metadata.uniform_shapes()
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
144pub struct VariableShapeTensorMetadata {
145 dim_names: Option<Vec<String>>,
147
148 permutations: Option<Vec<usize>>,
150
151 uniform_shape: Option<Vec<Option<i32>>>,
154}
155
156impl VariableShapeTensorMetadata {
157 pub fn try_new(
164 dimensions: usize,
165 dimension_names: Option<Vec<String>>,
166 permutations: Option<Vec<usize>>,
167 uniform_shapes: Option<Vec<Option<i32>>>,
168 ) -> Result<Self, ArrowError> {
169 let dim_names = dimension_names.map(|dimension_names| {
170 if dimension_names.len() != dimensions {
171 Err(ArrowError::InvalidArgumentError(format!(
172 "VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
173 )))
174 } else {
175 Ok(dimension_names)
176 }
177 }).transpose()?;
178
179 let permutations = permutations
180 .map(|permutations| {
181 if permutations.len() != dimensions {
182 Err(ArrowError::InvalidArgumentError(format!(
183 "VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
184 permutations.len()
185 )))
186 } else {
187 let mut sorted_permutations = permutations.clone();
188 sorted_permutations.sort_unstable();
189 if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
190 Err(ArrowError::InvalidArgumentError(format!(
191 "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
192 )))
193 } else {
194 Ok(permutations)
195 }
196 }
197 })
198 .transpose()?;
199
200 let uniform_shape = uniform_shapes
201 .map(|uniform_shapes| {
202 if uniform_shapes.len() != dimensions {
203 Err(ArrowError::InvalidArgumentError(format!(
204 "VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
205 uniform_shapes.len()
206 )))
207 } else {
208 Ok(uniform_shapes)
209 }
210 })
211 .transpose()?;
212
213 Ok(Self {
214 dim_names,
215 permutations,
216 uniform_shape,
217 })
218 }
219
220 pub fn dimension_names(&self) -> Option<&[String]> {
223 self.dim_names.as_ref().map(AsRef::as_ref)
224 }
225
226 pub fn permutations(&self) -> Option<&[usize]> {
229 self.permutations.as_ref().map(AsRef::as_ref)
230 }
231
232 pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
236 self.uniform_shape.as_ref().map(AsRef::as_ref)
237 }
238}
239
240impl ExtensionType for VariableShapeTensor {
241 const NAME: &'static str = "arrow.variable_shape_tensor";
242
243 type Metadata = VariableShapeTensorMetadata;
244
245 fn metadata(&self) -> &Self::Metadata {
246 &self.metadata
247 }
248
249 fn serialize_metadata(&self) -> Option<String> {
250 Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
251 }
252
253 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
254 metadata.map_or_else(
255 || {
256 Err(ArrowError::InvalidArgumentError(
257 "VariableShapeTensor extension types requires metadata".to_owned(),
258 ))
259 },
260 |value| {
261 serde_json::from_str(value).map_err(|e| {
262 ArrowError::InvalidArgumentError(format!(
263 "VariableShapeTensor metadata deserialization failed: {e}"
264 ))
265 })
266 },
267 )
268 }
269
270 fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
271 let expected = DataType::Struct(
272 [
273 Field::new_list(
274 "data",
275 Field::new_list_field(self.value_type.clone(), false),
276 false,
277 ),
278 Field::new(
279 "shape",
280 DataType::new_fixed_size_list(
281 DataType::Int32,
282 i32::try_from(self.dimensions()).expect("overflow"),
283 false,
284 ),
285 false,
286 ),
287 ]
288 .into_iter()
289 .collect(),
290 );
291 data_type
292 .equals_datatype(&expected)
293 .then_some(())
294 .ok_or_else(|| {
295 ArrowError::InvalidArgumentError(format!(
296 "VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
297 ))
298 })
299 }
300
301 fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
302 match data_type {
303 DataType::Struct(fields)
304 if fields.len() == 2
305 && matches!(fields.find("data"), Some((0, _)))
306 && matches!(fields.find("shape"), Some((1, _))) =>
307 {
308 let shape_field = &fields[1];
309 match shape_field.data_type() {
310 DataType::FixedSizeList(_, list_size) => {
311 let dimensions = usize::try_from(*list_size).expect("conversion failed");
312 let metadata = VariableShapeTensorMetadata::try_new(dimensions, metadata.dim_names, metadata.permutations, metadata.uniform_shape)?;
314 let data_field = &fields[0];
315 match data_field.data_type() {
316 DataType::List(field) => {
317 Ok(Self {
318 value_type: field.data_type().clone(),
319 dimensions,
320 metadata
321 })
322 }
323 data_type => Err(ArrowError::InvalidArgumentError(format!(
324 "VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
325 ))),
326 }
327 }
328 data_type => Err(ArrowError::InvalidArgumentError(format!(
329 "VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
330 ))),
331 }
332 }
333 data_type => Err(ArrowError::InvalidArgumentError(format!(
334 "VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
335 ))),
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 #[cfg(feature = "canonical_extension_types")]
343 use crate::extension::CanonicalExtensionType;
344 use crate::{
345 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
346 Field,
347 };
348
349 use super::*;
350
351 #[test]
352 fn valid() -> Result<(), ArrowError> {
353 let variable_shape_tensor = VariableShapeTensor::try_new(
354 DataType::Float32,
355 3,
356 Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
357 Some(vec![2, 0, 1]),
358 Some(vec![Some(400), None, Some(3)]),
359 )?;
360 let mut field = Field::new_struct(
361 "",
362 vec![
363 Field::new_list(
364 "data",
365 Field::new_list_field(DataType::Float32, false),
366 false,
367 ),
368 Field::new_fixed_size_list(
369 "shape",
370 Field::new("", DataType::Int32, false),
371 3,
372 false,
373 ),
374 ],
375 false,
376 );
377 field.try_with_extension_type(variable_shape_tensor.clone())?;
378 assert_eq!(
379 field.try_extension_type::<VariableShapeTensor>()?,
380 variable_shape_tensor
381 );
382 #[cfg(feature = "canonical_extension_types")]
383 assert_eq!(
384 field.try_canonical_extension_type()?,
385 CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
386 );
387 Ok(())
388 }
389
390 #[test]
391 #[should_panic(expected = "Field extension type name missing")]
392 fn missing_name() {
393 let field = Field::new_struct(
394 "",
395 vec![
396 Field::new_list(
397 "data",
398 Field::new_list_field(DataType::Float32, false),
399 false,
400 ),
401 Field::new_fixed_size_list(
402 "shape",
403 Field::new("", DataType::Int32, false),
404 3,
405 false,
406 ),
407 ],
408 false,
409 )
410 .with_metadata(
411 [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
412 .into_iter()
413 .collect(),
414 );
415 field.extension_type::<VariableShapeTensor>();
416 }
417
418 #[test]
419 #[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
420 fn invalid_type() {
421 let variable_shape_tensor =
422 VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
423 let field = Field::new_struct(
424 "",
425 vec![
426 Field::new_list(
427 "data",
428 Field::new_list_field(DataType::Float32, false),
429 false,
430 ),
431 Field::new_fixed_size_list(
432 "shape",
433 Field::new("", DataType::Int32, false),
434 3,
435 false,
436 ),
437 ],
438 false,
439 );
440 field.with_extension_type(variable_shape_tensor);
441 }
442
443 #[test]
444 #[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
445 fn missing_metadata() {
446 let field = Field::new_struct(
447 "",
448 vec![
449 Field::new_list(
450 "data",
451 Field::new_list_field(DataType::Float32, false),
452 false,
453 ),
454 Field::new_fixed_size_list(
455 "shape",
456 Field::new("", DataType::Int32, false),
457 3,
458 false,
459 ),
460 ],
461 false,
462 )
463 .with_metadata(
464 [(
465 EXTENSION_TYPE_NAME_KEY.to_owned(),
466 VariableShapeTensor::NAME.to_owned(),
467 )]
468 .into_iter()
469 .collect(),
470 );
471 field.extension_type::<VariableShapeTensor>();
472 }
473
474 #[test]
475 #[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
476 fn invalid_metadata() {
477 let field = Field::new_struct(
478 "",
479 vec![
480 Field::new_list(
481 "data",
482 Field::new_list_field(DataType::Float32, false),
483 false,
484 ),
485 Field::new_fixed_size_list(
486 "shape",
487 Field::new("", DataType::Int32, false),
488 3,
489 false,
490 ),
491 ],
492 false,
493 )
494 .with_metadata(
495 [
496 (
497 EXTENSION_TYPE_NAME_KEY.to_owned(),
498 VariableShapeTensor::NAME.to_owned(),
499 ),
500 (
501 EXTENSION_TYPE_METADATA_KEY.to_owned(),
502 r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
503 ),
504 ]
505 .into_iter()
506 .collect(),
507 );
508 field.extension_type::<VariableShapeTensor>();
509 }
510
511 #[test]
512 #[should_panic(
513 expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
514 )]
515 fn invalid_metadata_dimension_names() {
516 VariableShapeTensor::try_new(
517 DataType::Float32,
518 3,
519 Some(vec!["a".to_owned(), "b".to_owned()]),
520 None,
521 None,
522 )
523 .unwrap();
524 }
525
526 #[test]
527 #[should_panic(
528 expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
529 )]
530 fn invalid_metadata_permutations_len() {
531 VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
532 }
533
534 #[test]
535 #[should_panic(
536 expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
537 )]
538 fn invalid_metadata_permutations_values() {
539 VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
540 .unwrap();
541 }
542
543 #[test]
544 #[should_panic(
545 expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
546 )]
547 fn invalid_metadata_uniform_shapes() {
548 VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
549 .unwrap();
550 }
551}