arrow_schema/extension/canonical/
opaque.rs1use serde_core::ser::SerializeStruct;
23use serde_core::{
24 Deserialize, Deserializer, Serialize, Serializer,
25 de::{MapAccess, Visitor},
26};
27
28use crate::{ArrowError, DataType, extension::ExtensionType};
29
30#[derive(Debug, Clone, PartialEq)]
42pub struct Opaque(OpaqueMetadata);
43
44impl Opaque {
45 pub fn new(type_name: impl Into<String>, vendor_name: impl Into<String>) -> Self {
47 Self(OpaqueMetadata::new(type_name, vendor_name))
48 }
49
50 pub fn type_name(&self) -> &str {
52 self.0.type_name()
53 }
54
55 pub fn vendor_name(&self) -> &str {
57 self.0.vendor_name()
58 }
59}
60
61impl From<OpaqueMetadata> for Opaque {
62 fn from(value: OpaqueMetadata) -> Self {
63 Self(value)
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
69pub struct OpaqueMetadata {
70 type_name: String,
72
73 vendor_name: String,
75}
76
77impl Serialize for OpaqueMetadata {
78 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
79 where
80 S: Serializer,
81 {
82 let mut state = serializer.serialize_struct("OpaqueMetadata", 2)?;
83 state.serialize_field("type_name", &self.type_name)?;
84 state.serialize_field("vendor_name", &self.vendor_name)?;
85 state.end()
86 }
87}
88
89#[derive(Debug)]
90enum MetadataField {
91 TypeName,
92 VendorName,
93}
94
95struct MetadataFieldVisitor;
96
97impl<'de> Visitor<'de> for MetadataFieldVisitor {
98 type Value = MetadataField;
99
100 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
101 formatter.write_str("`type_name` or `vendor_name`")
102 }
103
104 fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
105 where
106 E: serde_core::de::Error,
107 {
108 match value {
109 "type_name" => Ok(MetadataField::TypeName),
110 "vendor_name" => Ok(MetadataField::VendorName),
111 _ => Err(serde_core::de::Error::unknown_field(
112 value,
113 &["type_name", "vendor_name"],
114 )),
115 }
116 }
117}
118
119impl<'de> Deserialize<'de> for MetadataField {
120 fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
121 where
122 D: Deserializer<'de>,
123 {
124 deserializer.deserialize_identifier(MetadataFieldVisitor)
125 }
126}
127
128struct OpaqueMetadataVisitor;
129
130impl<'de> Visitor<'de> for OpaqueMetadataVisitor {
131 type Value = OpaqueMetadata;
132
133 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
134 formatter.write_str("struct OpaqueMetadata")
135 }
136
137 fn visit_seq<V>(self, mut seq: V) -> Result<OpaqueMetadata, V::Error>
138 where
139 V: serde_core::de::SeqAccess<'de>,
140 {
141 let type_name = seq
142 .next_element()?
143 .ok_or_else(|| serde_core::de::Error::invalid_length(0, &self))?;
144 let vendor_name = seq
145 .next_element()?
146 .ok_or_else(|| serde_core::de::Error::invalid_length(1, &self))?;
147 Ok(OpaqueMetadata {
148 type_name,
149 vendor_name,
150 })
151 }
152
153 fn visit_map<V>(self, mut map: V) -> Result<OpaqueMetadata, V::Error>
154 where
155 V: MapAccess<'de>,
156 {
157 let mut type_name = None;
158 let mut vendor_name = None;
159
160 while let Some(key) = map.next_key()? {
161 match key {
162 MetadataField::TypeName => {
163 if type_name.is_some() {
164 return Err(serde_core::de::Error::duplicate_field("type_name"));
165 }
166 type_name = Some(map.next_value()?);
167 }
168 MetadataField::VendorName => {
169 if vendor_name.is_some() {
170 return Err(serde_core::de::Error::duplicate_field("vendor_name"));
171 }
172 vendor_name = Some(map.next_value()?);
173 }
174 }
175 }
176
177 let type_name =
178 type_name.ok_or_else(|| serde_core::de::Error::missing_field("type_name"))?;
179 let vendor_name =
180 vendor_name.ok_or_else(|| serde_core::de::Error::missing_field("vendor_name"))?;
181
182 Ok(OpaqueMetadata {
183 type_name,
184 vendor_name,
185 })
186 }
187}
188
189impl<'de> Deserialize<'de> for OpaqueMetadata {
190 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191 where
192 D: Deserializer<'de>,
193 {
194 deserializer.deserialize_struct(
195 "OpaqueMetadata",
196 &["type_name", "vendor_name"],
197 OpaqueMetadataVisitor,
198 )
199 }
200}
201
202impl OpaqueMetadata {
203 pub fn new(type_name: impl Into<String>, vendor_name: impl Into<String>) -> Self {
205 OpaqueMetadata {
206 type_name: type_name.into(),
207 vendor_name: vendor_name.into(),
208 }
209 }
210
211 pub fn type_name(&self) -> &str {
213 &self.type_name
214 }
215
216 pub fn vendor_name(&self) -> &str {
218 &self.vendor_name
219 }
220}
221
222impl ExtensionType for Opaque {
223 const NAME: &'static str = "arrow.opaque";
224
225 type Metadata = OpaqueMetadata;
226
227 fn metadata(&self) -> &Self::Metadata {
228 &self.0
229 }
230
231 fn serialize_metadata(&self) -> Option<String> {
232 Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
233 }
234
235 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
236 metadata.map_or_else(
237 || {
238 Err(ArrowError::InvalidArgumentError(
239 "Opaque extension types requires metadata".to_owned(),
240 ))
241 },
242 |value| {
243 serde_json::from_str(value).map_err(|e| {
244 ArrowError::InvalidArgumentError(format!(
245 "Opaque metadata deserialization failed: {e}"
246 ))
247 })
248 },
249 )
250 }
251
252 fn supports_data_type(&self, _data_type: &DataType) -> Result<(), ArrowError> {
253 Ok(())
255 }
256
257 fn try_new(_data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
258 Ok(Self::from(metadata))
259 }
260
261 fn validate(_data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
262 Ok(())
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 #[cfg(feature = "canonical_extension_types")]
269 use crate::extension::CanonicalExtensionType;
270 use crate::{
271 Field,
272 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
273 };
274
275 use super::*;
276
277 #[test]
278 fn valid() -> Result<(), ArrowError> {
279 let opaque = Opaque::new("name", "vendor");
280 let mut field = Field::new("", DataType::Null, false);
281 field.try_with_extension_type(opaque.clone())?;
282 assert_eq!(field.try_extension_type::<Opaque>()?, opaque);
283 #[cfg(feature = "canonical_extension_types")]
284 assert_eq!(
285 field.try_canonical_extension_type()?,
286 CanonicalExtensionType::Opaque(opaque)
287 );
288 Ok(())
289 }
290
291 #[test]
292 #[should_panic(expected = "Extension type name missing")]
293 fn missing_name() {
294 let field = Field::new("", DataType::Null, false).with_metadata(
295 [(
296 EXTENSION_TYPE_METADATA_KEY.to_owned(),
297 r#"{ "type_name": "type", "vendor_name": "vendor" }"#.to_owned(),
298 )]
299 .into_iter()
300 .collect(),
301 );
302 field.extension_type::<Opaque>();
303 }
304
305 #[test]
306 #[should_panic(expected = "Opaque extension types requires metadata")]
307 fn missing_metadata() {
308 let field = Field::new("", DataType::Null, false).with_metadata(
309 [(EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned())]
310 .into_iter()
311 .collect(),
312 );
313 field.extension_type::<Opaque>();
314 }
315
316 #[test]
317 #[should_panic(
318 expected = "Opaque metadata deserialization failed: missing field `vendor_name`"
319 )]
320 fn invalid_metadata() {
321 let field = Field::new("", DataType::Null, false).with_metadata(
322 [
323 (EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned()),
324 (
325 EXTENSION_TYPE_METADATA_KEY.to_owned(),
326 r#"{ "type_name": "no-vendor" }"#.to_owned(),
327 ),
328 ]
329 .into_iter()
330 .collect(),
331 );
332 field.extension_type::<Opaque>();
333 }
334}