HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
data_types.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <cstdint>
7 #include <cstring>
8 #include <string>
9 #include <type_traits>
10 #include <map>
11 #include <unordered_map>
12 #include "core/common/gsl.h"
13 #include "core/common/common.h"
14 #include "core/common/exceptions.h"
15 #include "core/framework/endian.h"
16 #include "core/framework/float8.h"
17 #include "core/framework/float16.h"
18 #include "core/graph/onnx_protobuf.h"
20 
21 struct OrtValue;
22 
23 namespace ONNX_NAMESPACE {
24 class TypeProto;
25 } // namespace ONNX_NAMESPACE
26 
27 namespace onnxruntime {
28 /// Predefined registered types
29 
30 #if !defined(DISABLE_ML_OPS)
31 
32 // maps (only used by ML ops)
33 using MapStringToString = std::map<std::string, std::string>;
34 using MapStringToInt64 = std::map<std::string, int64_t>;
35 using MapStringToFloat = std::map<std::string, float>;
36 using MapStringToDouble = std::map<std::string, double>;
37 using MapInt64ToString = std::map<int64_t, std::string>;
38 using MapInt64ToInt64 = std::map<int64_t, int64_t>;
39 using MapInt64ToFloat = std::map<int64_t, float>;
40 using MapInt64ToDouble = std::map<int64_t, double>;
41 
42 // vectors/sequences
43 using VectorMapStringToFloat = std::vector<MapStringToFloat>;
44 using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
45 
46 #endif
47 
48 using VectorString = std::vector<std::string>;
49 using VectorInt64 = std::vector<int64_t>;
50 
51 // Forward declarations
52 class DataTypeImpl;
53 class TensorTypeBase;
54 #if !defined(DISABLE_SPARSE_TENSORS)
56 #endif
58 class NonTensorTypeBase;
59 #if !defined(DISABLE_OPTIONAL_TYPE)
60 class OptionalTypeBase;
61 #endif
63 class Tensor;
64 class TensorSeq;
65 
66 // DataTypeImpl pointer as unique DataTypeImpl identifier.
67 using MLDataType = const DataTypeImpl*;
68 // be used with class MLValue
69 using DeleteFunc = void (*)(void*);
70 using CreateFunc = void* (*)();
71 
72 /**
73  * \brief Base class for MLDataType
74  *
75  */
76 class DataTypeImpl {
77  public:
78  enum class GeneralType {
79  kInvalid = 0,
80  kNonTensor = 1,
81  kTensor = 2,
82  kTensorSequence = 3,
83  kSparseTensor = 4,
84  kOptional = 5,
85  kPrimitive = 6,
86  };
87 
89  const size_t size_;
90 
91  protected:
93 
94  public:
95  virtual ~DataTypeImpl() = default;
96 
97  /**
98  * \brief this API will be used to check type compatibility at runtime
99  *
100  * \param type_proto a TypeProto instance that is constructed for a specific type
101  * will be checked against a TypeProto instance contained within a corresponding
102  * MLDataType instance.
103  */
104  virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
105 
106  size_t Size() const { return size_; }
107 
108  virtual DeleteFunc GetDeleteFunc() const = 0;
109 
110  /**
111  * \brief Retrieves an instance of TypeProto for
112  * a given MLDataType
113  * \returns optional TypeProto. Only ONNX types
114  has type proto, non-ONNX types will return nullptr.
115  */
116  virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
117 
118  bool IsTensorType() const {
119  return type_ == GeneralType::kTensor;
120  }
121 
122  bool IsTensorSequenceType() const {
124  }
125 
126  bool IsSparseTensorType() const {
128  }
129 
130  bool IsOptionalType() const {
131  return type_ == GeneralType::kOptional;
132  }
133 
134  bool IsNonTensorType() const {
135  return type_ == GeneralType::kNonTensor;
136  }
137 
138  bool IsPrimitiveDataType() const {
139  return type_ == GeneralType::kPrimitive;
140  }
141 
142  // Returns this if this is of tensor-type and null otherwise
143  const TensorTypeBase* AsTensorType() const;
144 
146 
147 #if !defined(DISABLE_SPARSE_TENSORS)
148  // Returns this if this is of sparse-tensor-type and null otherwise
150 #endif
151 
152 #if !defined(DISABLE_OPTIONAL_TYPE)
153  const OptionalTypeBase* AsOptionalType() const;
154 #endif
155 
156  const NonTensorTypeBase* AsNonTensorType() const;
157 
158  // Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase)
159  // and null otherwise
161 
162  // Return the type meta that we are using in the runtime.
163  template <typename T>
164  static MLDataType GetType();
165 
166  // Return the types for a concrete tensor type, like Tensor_Float
167  template <typename elemT>
168  static MLDataType GetTensorType();
169 
170  template <typename elemT>
172 
173 #if !defined(DISABLE_SPARSE_TENSORS)
174  // Return the MLDataType for a concrete sparse tensor type.
175  template <typename elemT>
177 #endif
178 
179  template <typename T, typename elemT>
180  static MLDataType GetOptionalType();
181 
182  /**
183  * Convert an ONNX TypeProto to onnxruntime DataTypeImpl.
184  * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back.
185  * Even though GetTypeProto() will not have the original information, it will still have enough to correctly
186  * map to MLDataType.
187  * \param proto
188  */
189  static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto);
190 
191  static const TensorTypeBase* TensorTypeFromONNXEnum(int type);
193 #if !defined(DISABLE_SPARSE_TENSORS)
195 #endif
196 
197  static const char* ToString(MLDataType type);
198  static std::vector<std::string> ToString(const std::vector<MLDataType>& types);
199  // Registers ONNX_NAMESPACE::DataType (internalized string) with
200  // MLDataType. DataType is produced by internalizing an instance of
201  // TypeProto contained within MLDataType
202  static void RegisterDataType(MLDataType);
203  static MLDataType GetDataType(const std::string&);
204 
205  // IR4: includes all float types, includes float16, bfloat16
206  // IR9: includes float 8 types as well
207  static const std::vector<MLDataType>& AllTensorTypes(); // up to IR4 (no float 8), deprecated
208  static const std::vector<MLDataType>& AllTensorTypesIRv4();
209  static const std::vector<MLDataType>& AllTensorTypesIRv9();
210 
211  static const std::vector<MLDataType>& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated
212  static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv4();
213  static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv9();
214 
215  static const std::vector<MLDataType>& AllSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
216  static const std::vector<MLDataType>& AllSequenceTensorTypesIRv4();
217  static const std::vector<MLDataType>& AllSequenceTensorTypesIRv9();
218 
219  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
220  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypesIRv4();
221  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypesIRv9();
222 
223  static const std::vector<MLDataType>& AllNumericTensorTypes(); // up to IR4 (no float 8), deprecated
224  static const std::vector<MLDataType>& AllNumericTensorTypesIRv4();
225  static const std::vector<MLDataType>& AllNumericTensorTypesIRv9();
226 
227  static const std::vector<MLDataType>& AllIEEEFloatTensorTypes(); // float16, float, double
228 
229  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
230  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypesIRv4();
231  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypesIRv9();
232 
233  static const std::vector<MLDataType>& AllOptionalAndTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
234  static const std::vector<MLDataType>& AllOptionalAndTensorAndSequenceTensorTypesIRv4();
235  static const std::vector<MLDataType>& AllOptionalAndTensorAndSequenceTensorTypesIRv9();
236 
237  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
238  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypesIRv4();
239  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypesIRv9();
240 
241  static const std::vector<MLDataType>& AllOptionalTypes(); // up to IR4 (no float 8), deprecated
242  static const std::vector<MLDataType>& AllOptionalTypesIRv4();
243  static const std::vector<MLDataType>& AllOptionalTypesIRv9();
244 
245  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypes(); // up to IR4 (no float 8), deprecated
246  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypesIRv4();
247  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypesIRv9();
248 };
249 
250 std::ostream& operator<<(std::ostream& out, MLDataType data_type);
251 
252 /*
253  * Type registration helpers
254  */
255 namespace data_types_internal {
256 /// TensorType helpers
257 ///
258 
259 /// Is a given type on the list of types?
260 /// Accepts a list of types and the first argument is the type
261 /// We are checking if it is listed among those that follow
262 template <typename T, typename... Types>
263 struct IsAnyOf;
264 
265 /// Two types remaining, end of the list
266 template <typename T, typename Tail>
267 struct IsAnyOf<T, Tail> : public std::is_same<T, Tail> {
268 };
269 
270 template <typename T, typename H, typename... Tail>
271 struct IsAnyOf<T, H, Tail...> {
272  static constexpr bool value = (std::is_same<T, H>::value ||
273  IsAnyOf<T, Tail...>::value);
274 };
275 
276 /// Tells if the specified type is one of fundamental types
277 /// that can be contained within a tensor.
278 /// We do not have raw fundamental types, rather a subset
279 /// of fundamental types is contained within tensors.
280 template <typename T>
281 struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
282  int32_t, int64_t, std::string, bool, MLFloat16,
283  double, uint32_t, uint64_t, BFloat16
284 #if !defined(DISABLE_FLOAT8_TYPES)
285  ,
286  Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
287 #endif
288  > {
289 };
290 
291 #if !defined(DISABLE_SPARSE_TENSORS)
292 /// Use "IsSparseTensorContainedType<T>::value" to test if a type T
293 /// is permitted as the element-type of a sparse-tensor.
294 
295 template <typename T>
296 struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
297  int32_t, int64_t, std::string, bool, MLFloat16,
298  double, uint32_t, uint64_t, BFloat16
299 #if !defined(DISABLE_FLOAT8_TYPES)
300  ,
301  Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
302 #endif
303  > {
304 };
305 #endif
306 
307 #if !defined(DISABLE_OPTIONAL_TYPE)
308 /// Tells if the specified type is one of ORT types
309 /// that can be contained within an optional struct.
310 template <typename T>
311 struct IsOptionalOrtType : public IsAnyOf<T, Tensor, TensorSeq> {
312 };
313 #endif
314 
315 /// This template's Get() returns a corresponding MLDataType
316 /// It dispatches the call to either GetTensorType<>() or
317 /// GetType<>()
318 template <typename T, bool TensorContainedType>
320 
321 template <typename T>
322 struct GetMLDataType<T, true> {
323  static MLDataType Get() {
324  return DataTypeImpl::GetTensorType<T>();
325  }
326 };
327 
328 template <typename T>
329 struct GetMLDataType<T, false> {
330  static MLDataType Get() {
331  return DataTypeImpl::GetType<T>();
332  }
333 };
334 
336  static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
337  ONNX_NAMESPACE::TypeProto& proto) {
338  proto.mutable_tensor_type()->set_elem_type(element_type);
339  }
340 };
341 
342 #if !defined(DISABLE_SPARSE_TENSORS)
344  static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
345  ONNX_NAMESPACE::TypeProto& proto) {
346  proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
347  }
348 };
349 #endif // !defined(DISABLE_SPARSE_TENSORS)
350 
351 #if !defined(DISABLE_ML_OPS)
352 /// Map helpers
353 
354 void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&,
355  ONNX_NAMESPACE::TypeProto&);
356 
358  // V can be either a primitive type (in which case it is a tensor)
359  // or other preregistered types
360  template <typename V>
363  }
364 
365  static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto* value_proto,
366  ONNX_NAMESPACE::TypeProto& proto) {
367  ORT_ENFORCE(value_proto != nullptr, "expected a registered ONNX type");
368  proto.mutable_map_type()->set_key_type(key_type);
369  CopyMutableMapValue(*value_proto, proto);
370  }
371 };
372 #endif
373 
374 /// Sequence helpers
375 
376 // Element type is a primitive type so we set it to a tensor<elemT>
377 void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
378  ONNX_NAMESPACE::TypeProto&);
379 
380 // helper to create TypeProto with minimal binary size impact
382  template <typename T>
385  }
386 
387  static void Set(const ONNX_NAMESPACE::TypeProto* elem_proto,
388  ONNX_NAMESPACE::TypeProto& proto) {
389  ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
390  CopyMutableSeqElement(*elem_proto, proto);
391  }
392 };
393 
394 /// Optional helpers
395 
396 void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
397  ONNX_NAMESPACE::TypeProto&);
398 
399 // helper to create TypeProto with minimal binary size impact
401  template <typename T, typename elemT>
403  if constexpr (std::is_same<T, Tensor>::value) {
404  return DataTypeImpl::GetTensorType<elemT>();
405  } else {
406  static_assert(std::is_same<T, TensorSeq>::value, "Unsupported element type for optional type");
407  return DataTypeImpl::GetSequenceTensorType<elemT>();
408  }
409  }
410 
411  static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
412  ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
413  CopyMutableOptionalElement(*elem_proto, proto);
414  }
415 };
416 
417 /// OpaqueTypes helpers
418 
419 void AssignOpaqueDomainName(const char* domain, const char* name,
420  ONNX_NAMESPACE::TypeProto& proto);
421 
422 } // namespace data_types_internal
423 
424 // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
425 // However, we do not allocate this type on heap.
426 #if defined(_MSC_VER) && !defined(__clang__)
427 #pragma warning(push)
428 #pragma warning(disable : 26436)
429 #endif
430 /// All tensors base
431 class TensorTypeBase : public DataTypeImpl {
432  public:
433  static MLDataType Type();
434 
435  /// We first compare type_proto pointers and then
436  /// if they do not match try to account for the case
437  /// where TypeProto was created ad-hoc and not queried from MLDataType
438  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
439 
440  DeleteFunc GetDeleteFunc() const override;
441 
442  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
443 
444  virtual MLDataType GetElementType() const {
445  // should never reach here.
446  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
447  }
448 
450 
451  protected:
452  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
453 
454  TensorTypeBase();
455  ~TensorTypeBase() override;
456 
457  private:
458  struct Impl;
459  Impl* impl_;
460 };
461 
462 /**
463  * \brief Tensor type. This type does not have a C++ type associated with
464  * it at registration time except the element type. One of the types mentioned
465  * above at IsTensorContainedType<> list is acceptable.
466  *
467  * \details
468  * Usage:
469  * ORT_REGISTER_TENSOR(ELEMENT_TYPE)
470  * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor<type>
471  * type. IsCompatible() currently ignores shape.
472  */
473 
474 template <typename elemT>
475 class TensorType : public TensorTypeBase {
476  public:
478  "Requires one of the tensor fundamental types");
479 
480  static MLDataType Type();
481 
482  /// Tensors only can contain basic data types
483  /// that have been previously registered with ONNXRuntime
484  MLDataType GetElementType() const override {
485  return DataTypeImpl::GetType<elemT>();
486  }
487 
488  private:
489  TensorType() {
490  using namespace data_types_internal;
491  TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
492  }
493 };
494 
495 #if defined(DISABLE_OPTIONAL_TYPE)
496 
497 // TODO is this still needed after removing kernel def hashes?
498 /// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
499 /// with disabled types to keep the ORT format model kernel hashes stable.
500 class DisabledTypeBase : public DataTypeImpl {
501  public:
502  static MLDataType Type();
503 
504  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
505  // We always want to return false for the IsCompatible() for a disabled type
506  // because this will ensure that no kernel supporting the disabled type will
507  // be matched to a model node requiring that type and the model load will
508  // result in failure.
509  return false;
510  }
511 
512  DeleteFunc GetDeleteFunc() const override {
513  ORT_THROW("Type is disabled in this build.");
514  }
515 
516  // This must work
517  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
518 
519  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase);
520 
521  protected:
522  // This must work
523  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
524 
525  DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
526  ~DisabledTypeBase() override;
527 
528  private:
529  struct Impl;
530  Impl* impl_;
531 };
532 
533 #endif
534 
535 #if !defined(DISABLE_SPARSE_TENSORS)
536 /// Common base-class for all sparse-tensors (with different element types).
538  public:
539  static MLDataType Type();
540 
541  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
542 
543  DeleteFunc GetDeleteFunc() const override;
544 
545  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
546 
547  virtual MLDataType GetElementType() const {
548  // should never reach here.
549  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
550  }
551 
553 
554  protected:
555  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
556 
558  ~SparseTensorTypeBase() override;
559 
560  private:
561  struct Impl;
562  Impl* impl_;
563 };
564 
565 template <typename elemT>
567  public:
569  "Requires one of the sparse-tensor fundamental types");
570 
571  static MLDataType Type();
572 
573  /// Return a MLDataType representing the element-type
574  MLDataType GetElementType() const override {
575  return DataTypeImpl::GetType<elemT>();
576  }
577 
578  private:
579  SparseTensorType() {
580  using namespace data_types_internal;
581  SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
582  }
583 };
584 
585 #endif // !defined(DISABLE_SPARSE_TENSORS)
586 
587 /// Common base-class for all optional types.
588 
589 #if !defined(DISABLE_OPTIONAL_TYPE)
591  public:
592  static MLDataType Type();
593 
594  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
595 
596  DeleteFunc GetDeleteFunc() const override {
597  // should never reach here.
598  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
599  }
600 
601  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
602 
603  virtual MLDataType GetElementType() const {
604  // should never reach here.
605  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
606  }
607 
608  OptionalTypeBase(const OptionalTypeBase&) = delete;
609  OptionalTypeBase& operator=(const OptionalTypeBase&) = delete;
610 
611  protected:
612  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
613 
615  ~OptionalTypeBase() override;
616 
617  private:
618  struct Impl;
619  Impl* impl_;
620 };
621 #endif
622 
623 // Derive from OptionalTypeBase if the Optional type support is enabled,
624 // else derive from DisabledTypeBase
625 template <typename T, typename elemT>
627 #if !defined(DISABLE_OPTIONAL_TYPE)
628  public OptionalTypeBase
629 #else
630  public DisabledTypeBase
631 #endif
632 {
633  public:
634  static MLDataType Type();
635 
636 #if !defined(DISABLE_OPTIONAL_TYPE)
638  "Requires one of the supported types: Tensor or TensorSeq");
639 
641  "Requires one of the tensor fundamental types");
642 
643  MLDataType GetElementType() const override {
644  return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
645  }
646 #endif
647 
648  private:
649 #if !defined(DISABLE_OPTIONAL_TYPE)
650  OptionalType()
651 #else
652  OptionalType() : DisabledTypeBase{DataTypeImpl::GeneralType::kOptional, 0}
653 #endif
654  {
655  using namespace data_types_internal;
656  OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>()->GetTypeProto(), MutableTypeProto());
657  }
658 }; // namespace onnxruntime
659 
660 /**
661  * \brief Provide a specialization for your C++ Non-tensor type
662  * so your implementation FromDataTypeContainer/ToDataTypeContainer
663  * functions correctly. Otherwise you get a default implementation
664  * which may not be what you need/want.
665  *
666  * This class is used to create OrtValue, fetch data from OrtValue via
667  * C/C++ APIs
668  */
669 template <class T>
671  static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
672  ORT_THROW("Not implemented");
673  }
674  static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) {
675  ORT_THROW("Not implemented");
676  }
677 };
678 
679 /**
680  * \brief Base type for all non-tensors, maps, sequences and opaques
681  */
683  public:
684  DeleteFunc GetDeleteFunc() const override = 0;
685 
686  virtual CreateFunc GetCreateFunc() const = 0;
687 
688  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
689 
690  // \brief Override for Non-tensor types to initialize non-tensor CPP
691  // data representation from data. The caller of the interface
692  // should have a shared definition of the data which is used to initialize
693  // CPP data representation. This is used from C API.
694  //
695  // \param data - pointer to a data container structure non_tensor type specific
696  // \param data_size - size of the data container structure, used for rudimentary checks
697  // \param output - reference to a default constructed non-tensor type
698  // \returns OrtValue
699  // \throw if there is an error
700  virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const;
701 
702  // \brief Override for Non-tensor types to fetch data from the internal CPP data representation
703  // The caller of the interface should have a shared definition of the data which is used to initialize
704  // CPP data representation. This is used from C API.
705  //
706  // \param input - OrtValue containing data
707  // \param data_size - size of the structure that is being passed for receiving data, used for
708  // validation
709  // \param data - pointer to receiving data structure
710  virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const;
711 
712  NonTensorTypeBase(const NonTensorTypeBase&) = delete;
714 
715  protected:
716  NonTensorTypeBase(size_t size);
717  ~NonTensorTypeBase() override;
718 
719  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
720 
721  bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
722 
723  bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
724 
725  bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
726 
727  private:
728  struct Impl;
729  Impl* impl_;
730 };
731 
732 // This is where T is the actual CPPRuntimeType
733 template <typename T>
735  private:
736  static void Delete(void* p) {
737  delete static_cast<T*>(p);
738  }
739 
740  public:
741  DeleteFunc GetDeleteFunc() const override {
742  return &Delete;
743  }
744 
745  CreateFunc GetCreateFunc() const override {
746  return []() -> void* { return new T(); };
747  }
748 
749  protected:
751 };
752 
753 #if !defined(DISABLE_ML_OPS)
754 /**
755  * \brief MapType. Use this type to register
756  * mapping types.
757  *
758  * \param T - cpp type that you wish to register as runtime MapType
759  *
760  * \details Usage: ORT_REGISTER_MAP(C++Type)
761  * The type is required to have mapped_type and
762  * key_type defined
763  */
764 template <typename CPPType>
765 class MapType : public NonTensorType<CPPType> {
766  public:
768  "Requires one of the tensor fundamental types as key");
769 
770  static MLDataType Type();
771 
772  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
773  return this->IsMapCompatible(type_proto);
774  }
775 
776  private:
777  MapType() {
778  using namespace data_types_internal;
779  MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
780  MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->GetTypeProto(),
781  this->MutableTypeProto());
782  }
783 };
784 #endif
785 
786 /**
787  * \brief SequenceType. Use to register sequence for non-tensor types.
788  *
789  * \param T - CPP type that you wish to register as Sequence
790  * runtime type.
791  *
792  * \details Usage: ORT_REGISTER_SEQ(C++Type)
793  * The type is required to have value_type defined
794  */
795 template <typename CPPType>
796 class SequenceType : public NonTensorType<CPPType> {
797  public:
798  static MLDataType Type();
799 
800  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
801  return this->IsSequenceCompatible(type_proto);
802  }
803 
804  private:
805  SequenceType() {
806  using namespace data_types_internal;
807  SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->GetTypeProto(),
808  this->MutableTypeProto());
809  }
810 };
811 
812 /**
813  * \brief SequenceTensorTypeBase serves as a base type class for
814  * Tensor sequences. Akin to TensorTypeBase.
815  * Runtime representation is always TensorSeq.
816  */
818  public:
819  static MLDataType Type();
820 
821  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
822 
823  virtual MLDataType GetElementType() const {
824  // should never reach here.
825  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
826  }
827 
828  DeleteFunc GetDeleteFunc() const override;
829 
830  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
831 
834 
835  protected:
838 
839  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
840 
841  private:
842  struct Impl;
843  Impl* impl_;
844 };
845 #if defined(_MSC_VER) && !defined(__clang__)
846 #pragma warning(pop)
847 #endif
848 /**
849  * \brief SequenceTensorType. Use to register sequence for non-tensor types.
850  *
851  * \param CPPRuntime - We always use TensorSeq
852  *
853  * \param TensorElemType - one of the primitive types
854  *
855  * \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE()
856  * The type is required to have value_type defined
857  */
858 template <typename TensorElemType>
860  public:
862  "Requires one of the tensor fundamental types");
863 
864  static MLDataType Type();
865 
866  /// Return a MLDataType representing the element-type
867  MLDataType GetElementType() const override {
868  return DataTypeImpl::GetType<TensorElemType>();
869  }
870 
871  private:
873  using namespace data_types_internal;
874  SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->GetTypeProto(),
875  MutableTypeProto());
876  }
877 };
878 
879 /**
880  * \brief OpaqueType
881  *
882  * \tparam T - cpp runtume that implements the Opaque type
883  *
884  * \tparam const char D[] - domain must be extern to be unique
885  *
886  * \tparam const char N[] - name must be extern to be unique
887  *
888  * \details Only one CPP type can be associated with a particular
889  * OpaqueType registration
890  *
891  */
892 template <typename T, const char D[], const char N[]>
893 class OpaqueType : public NonTensorType<T> {
894  public:
895  static MLDataType Type();
896 
897  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
898  return this->IsOpaqueCompatible(type_proto);
899  }
900 
901  void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override {
902  NonTensorTypeConverter<T>::FromContainer(this, data, data_size, output);
903  }
904 
905  void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override {
906  NonTensorTypeConverter<T>::ToContainer(input, data_size, data);
907  }
908 
909  private:
910  OpaqueType() {
912  }
913 };
914 
915 /**
916  * \brief PrimitiveDataTypeBase
917  * Base class for primitive Tensor contained types
918  *
919  * \details This class contains an integer constant that can be
920  * used for input data type dispatching
921  *
922  */
924  public:
925  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
926  return false;
927  }
928 
929  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
930  return nullptr;
931  }
932 
933  int32_t GetDataType() const {
934  return data_type_;
935  }
936 
937  protected:
938  PrimitiveDataTypeBase(size_t size, int32_t data_type)
939  : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
940 
941  private:
942  const int32_t data_type_;
943 };
944 
945 /**
946  * \brief PrimitiveDataType
947  * Typed specialization for primitive types.
948  * Concrete instances of this class are used by Tensor.
949  *
950  * \param T - primitive data type
951  *
952  */
953 template <typename T>
955  private:
956  static void Delete(void* p) {
957  delete static_cast<T*>(p);
958  }
959 
960  public:
961  static MLDataType Type();
962 
963  DeleteFunc GetDeleteFunc() const override {
964  return &Delete;
965  }
966 
967  private:
969  : PrimitiveDataTypeBase{sizeof(T),
970  utils::ToTensorProtoElementType<T>()} {
971  }
972 };
973 
975  return IsTensorType() ? static_cast<const TensorTypeBase*>(this) : nullptr;
976 }
977 
979  return IsTensorSequenceType() ? static_cast<const SequenceTensorTypeBase*>(this) : nullptr;
980 }
981 
982 #if !defined(DISABLE_SPARSE_TENSORS)
984  return IsSparseTensorType() ? static_cast<const SparseTensorTypeBase*>(this) : nullptr;
985 }
986 #endif
987 
988 #if !defined(DISABLE_OPTIONAL_TYPE)
990  return IsOptionalType() ? static_cast<const OptionalTypeBase*>(this) : nullptr;
991 }
992 #endif
993 
995  return IsNonTensorType() ? static_cast<const NonTensorTypeBase*>(this) : nullptr;
996 }
997 
999  return IsPrimitiveDataType() ? static_cast<const PrimitiveDataTypeBase*>(this) : nullptr;
1000 }
1001 
1002 // Explicit specialization of base class template function
1003 // is only possible within the enclosing namespace scope,
1004 // thus a simple way to pre-instantiate a given template
1005 // at a registration time does not currently work and the macro
1006 // is needed.
1007 #define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
1008  template <> \
1009  MLDataType TensorType<ELEM_TYPE>::Type() { \
1010  static TensorType<ELEM_TYPE> tensor_type; \
1011  return &tensor_type; \
1012  } \
1013  template <> \
1014  MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
1015  return TensorType<ELEM_TYPE>::Type(); \
1016  }
1017 
1018 #if !defined(DISABLE_SPARSE_TENSORS)
1019 #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
1020  template <> \
1021  MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
1022  static SparseTensorType<ELEM_TYPE> tensor_type; \
1023  return &tensor_type; \
1024  } \
1025  template <> \
1026  MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
1027  return SparseTensorType<ELEM_TYPE>::Type(); \
1028  }
1029 #endif
1030 
1031 #define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
1032  template <> \
1033  MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
1034  static OptionalType<ORT_TYPE, TYPE> optional_type; \
1035  return &optional_type; \
1036  } \
1037  template <> \
1038  MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
1039  return OptionalType<ORT_TYPE, TYPE>::Type(); \
1040  }
1041 
1042 #if !defined(DISABLE_ML_OPS)
1043 #define ORT_REGISTER_MAP(TYPE) \
1044  template <> \
1045  MLDataType MapType<TYPE>::Type() { \
1046  static MapType<TYPE> map_type; \
1047  return &map_type; \
1048  } \
1049  template <> \
1050  MLDataType DataTypeImpl::GetType<TYPE>() { \
1051  return MapType<TYPE>::Type(); \
1052  }
1053 #endif
1054 
1055 #define ORT_REGISTER_SEQ(TYPE) \
1056  template <> \
1057  MLDataType SequenceType<TYPE>::Type() { \
1058  static SequenceType<TYPE> sequence_type; \
1059  return &sequence_type; \
1060  } \
1061  template <> \
1062  MLDataType DataTypeImpl::GetType<TYPE>() { \
1063  return SequenceType<TYPE>::Type(); \
1064  }
1065 
1066 #define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
1067  template <> \
1068  MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
1069  static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
1070  return &sequence_tensor_type; \
1071  } \
1072  template <> \
1073  MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
1074  return SequenceTensorType<ELEM_TYPE>::Type(); \
1075  }
1076 
1077 #define ORT_REGISTER_PRIM_TYPE(TYPE) \
1078  template <> \
1079  MLDataType PrimitiveDataType<TYPE>::Type() { \
1080  static PrimitiveDataType<TYPE> prim_data_type; \
1081  return &prim_data_type; \
1082  } \
1083  template <> \
1084  MLDataType DataTypeImpl::GetType<TYPE>() { \
1085  return PrimitiveDataType<TYPE>::Type(); \
1086  }
1087 
1088 #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
1089  template <> \
1090  MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
1091  static OpaqueType<CPPType, Domain, Name> opaque_type; \
1092  return &opaque_type; \
1093  } \
1094  template <> \
1095  MLDataType DataTypeImpl::GetType<CPPType>() { \
1096  return OpaqueType<CPPType, Domain, Name>::Type(); \
1097  }
1098 } // namespace onnxruntime
void AssignOpaqueDomainName(const char *domain, const char *name, ONNX_NAMESPACE::TypeProto &proto)
OpaqueTypes helpers.
std::vector< int64_t > VectorInt64
Definition: data_types.h:49
static void RegisterDataType(MLDataType)
static const TensorTypeBase * TensorTypeFromONNXEnum(int type)
virtual const ONNX_NAMESPACE::TypeProto * GetTypeProto() const =0
Retrieves an instance of TypeProto for a given MLDataType.
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Definition: data_types.h:122
Base class for MLDataType.
Definition: data_types.h:76
static const std::vector< MLDataType > & AllNumericTensorTypesIRv9()
size_t Size() const
Definition: data_types.h:106
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:772
virtual MLDataType GetElementType() const
Definition: data_types.h:547
static void ToContainer(const OrtValue &, size_t, void *)
Definition: data_types.h:674
static const std::vector< MLDataType > & AllNumericTensorTypesIRv4()
void
Definition: png.h:1083
void FromDataContainer(const void *data, size_t data_size, OrtValue &output) const override
Definition: data_types.h:901
static const std::vector< MLDataType > & AllOptionalTypesIRv9()
SequenceTensorTypeBase serves as a base type class for Tensor sequences. Akin to TensorTypeBase. Runtime representation is always TensorSeq.
Definition: data_types.h:817
ONNX_NAMESPACE::TypeProto & MutableTypeProto()
static void FromContainer(MLDataType, const void *, size_t, OrtValue &)
Definition: data_types.h:671
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypes()
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
void ToDataContainer(const OrtValue &input, size_t data_size, void *data) const override
Definition: data_types.h:905
GLsizei const GLfloat * value
Definition: glcorearb.h:824
std::map< std::string, float > MapStringToFloat
Definition: data_types.h:35
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const =0
this API will be used to check type compatibility at runtime
static const std::vector< MLDataType > & AllTensorTypesIRv9()
static MLDataType GetTensorType()
std::map< int64_t, int64_t > MapInt64ToInt64
Definition: data_types.h:38
#define ORT_NOT_IMPLEMENTED(...)
Definition: common.h:166
MapType. Use this type to register mapping types.
Definition: data_types.h:765
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Sequence helpers.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypesIRv4()
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
Definition: data_types.h:574
CreateFunc GetCreateFunc() const override
Definition: data_types.h:745
static const std::vector< MLDataType > & AllOptionalTypes()
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const final
Retrieves an instance of TypeProto for a given MLDataType.
Definition: data_types.h:929
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:596
virtual MLDataType GetElementType() const
Definition: data_types.h:823
Common base-class for all sparse-tensors (with different element types).
Definition: data_types.h:537
static const std::vector< MLDataType > & AllTensorTypesIRv4()
static const std::vector< MLDataType > & AllOptionalTypesIRv4()
static MLDataType GetSequenceTensorType()
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Optional helpers.
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const override
Retrieves an instance of TypeProto for a given MLDataType.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypesIRv9()
static MLDataType Type()
static const std::vector< MLDataType > & AllIEEEFloatTensorTypes()
static const std::vector< MLDataType > & AllFixedSizeTensorTypesIRv4()
static const std::vector< MLDataType > & AllFixedSizeTensorTypes()
static void Set(const onnx::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:411
static const std::vector< MLDataType > & AllTensorTypes()
const SparseTensorTypeBase * AsSparseTensorType() const
Definition: data_types.h:983
All tensors base.
Definition: data_types.h:431
static MLDataType GetSparseTensorType()
PrimitiveDataTypeBase(size_t size, int32_t data_type)
Definition: data_types.h:938
static const std::vector< MLDataType > & AllSequenceTensorTypesIRv4()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypes()
static const std::vector< MLDataType > & AllSequenceTensorTypes()
std::map< int64_t, float > MapInt64ToFloat
Definition: data_types.h:39
void(*)(void *) DeleteFunc
Definition: data_types.h:69
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase)
static void Set(const ONNX_NAMESPACE::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:387
virtual MLDataType GetElementType() const
Definition: data_types.h:444
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypes()
std::map< std::string, int64_t > MapStringToInt64
Definition: data_types.h:34
const GeneralType type_
Definition: data_types.h:88
const NonTensorTypeBase * AsNonTensorType() const
Definition: data_types.h:994
Common base-class for all optional types.
Definition: data_types.h:590
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypesIRv4()
PrimitiveDataType Typed specialization for primitive types. Concrete instances of this class are used...
Definition: data_types.h:954
bool IsNonTensorType() const
Definition: data_types.h:134
const TensorTypeBase * AsTensorType() const
Definition: data_types.h:974
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
Definition: data_types.h:998
bool IsTensorType() const
Definition: data_types.h:118
MLDataType GetElementType() const override
Definition: data_types.h:643
static MLDataType GetDataType(const std::string &)
DeleteFunc GetDeleteFunc() const override
STATIC_INLINE uint64_t H(uint64_t x, uint64_t y, uint64_t mul, int r)
Definition: farmhash.h:701
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
Definition: data_types.h:923
GLuint const GLchar * name
Definition: glcorearb.h:786
std::map< std::string, std::string > MapStringToString
Predefined registered types.
Definition: data_types.h:33
static const SequenceTensorTypeBase * SequenceTensorTypeFromONNXEnum(int type)
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Map helpers.
static MLDataType GetOptionalType()
void *(*)( CreateFunc)
Definition: data_types.h:70
const DataTypeImpl * MLDataType
Definition: data_types.h:67
static const std::vector< MLDataType > & AllSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypes()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:897
std::map< int64_t, double > MapInt64ToDouble
Definition: data_types.h:40
#define ORT_THROW(...)
Definition: common.h:162
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllNumericTensorTypes()
GLsizeiptr size
Definition: glcorearb.h:664
virtual MLDataType GetElementType() const
Definition: data_types.h:603
std::map< int64_t, std::string > MapInt64ToString
Definition: data_types.h:37
SequenceType. Use to register sequence for non-tensor types.
Definition: data_types.h:796
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:800
const OptionalTypeBase * AsOptionalType() const
Definition: data_types.h:989
SequenceTensorType. Use to register sequence for non-tensor types.
Definition: data_types.h:859
LeafData & operator=(const LeafData &)=delete
DataTypeImpl(GeneralType type, size_t size)
Definition: data_types.h:92
bool IsSparseTensorType() const
Definition: data_types.h:126
MLDataType GetElementType() const override
Definition: data_types.h:484
virtual ~DataTypeImpl()=default
std::map< std::string, double > MapStringToDouble
Definition: data_types.h:36
GA_API const UT_StringHolder N
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:336
static const std::vector< MLDataType > & AllFixedSizeTensorTypesIRv9()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:925
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypes()
Tensor type. This type does not have a C++ type associated with it at registration time except the el...
Definition: data_types.h:475
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypesIRv4()
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypesIRv4()
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:741
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:963
bool IsPrimitiveDataType() const
Definition: data_types.h:138
Definition: core.h:1131
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypesIRv9()
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypesIRv4()
GLsizei GLenum GLenum * types
Definition: glcorearb.h:2542
static MLDataType Type()
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
Base type for all non-tensors, maps, sequences and opaques.
Definition: data_types.h:682
const SequenceTensorTypeBase * AsSequenceTensorType() const
Definition: data_types.h:978
std::vector< std::string > VectorString
Definition: data_types.h:48
static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:344
static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto *value_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:365
type
Definition: core.h:1059
std::vector< MapInt64ToFloat > VectorMapInt64ToFloat
Definition: data_types.h:44
static const SparseTensorTypeBase * SparseTensorTypeFromONNXEnum(int type)
std::vector< MapStringToFloat > VectorMapStringToFloat
Definition: data_types.h:43
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypesIRv9()
virtual DeleteFunc GetDeleteFunc() const =0
Provide a specialization for your C++ Non-tensor type so your implementation FromDataTypeContainer/To...
Definition: data_types.h:670
Definition: format.h:895
bool IsOptionalType() const
Definition: data_types.h:130
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
Definition: data_types.h:867
static MLDataType GetType()