11 #include <unordered_map>
18 #include "core/graph/onnx_protobuf.h"
23 namespace ONNX_NAMESPACE {
27 namespace onnxruntime {
30 #if !defined(DISABLE_ML_OPS)
54 #if !defined(DISABLE_SPARSE_TENSORS)
59 #if !defined(DISABLE_OPTIONAL_TYPE)
104 virtual bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const = 0;
116 virtual const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const = 0;
147 #if !defined(DISABLE_SPARSE_TENSORS)
152 #if !defined(DISABLE_OPTIONAL_TYPE)
163 template <
typename T>
167 template <
typename elemT>
170 template <
typename elemT>
173 #if !defined(DISABLE_SPARSE_TENSORS)
175 template <
typename elemT>
179 template <
typename T,
typename elemT>
193 #if !defined(DISABLE_SPARSE_TENSORS)
198 static std::vector<std::string>
ToString(
const std::vector<MLDataType>&
types);
255 namespace data_types_internal {
262 template <
typename T,
typename... Types>
266 template <
typename T,
typename Tail>
267 struct IsAnyOf<
T, Tail> :
public std::is_same<T, Tail> {
270 template <
typename T,
typename H,
typename... Tail>
280 template <
typename T>
282 int32_t, int64_t, std::string, bool, MLFloat16,
283 double, uint32_t, uint64_t, BFloat16
284 #if !defined(DISABLE_FLOAT8_TYPES)
286 Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
291 #if !defined(DISABLE_SPARSE_TENSORS)
295 template <
typename T>
297 int32_t, int64_t, std::string, bool, MLFloat16,
298 double, uint32_t, uint64_t, BFloat16
299 #if !defined(DISABLE_FLOAT8_TYPES)
301 Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
307 #if !defined(DISABLE_OPTIONAL_TYPE)
310 template <
typename T>
318 template <
typename T,
bool TensorContainedType>
321 template <
typename T>
324 return DataTypeImpl::GetTensorType<T>();
328 template <
typename T>
331 return DataTypeImpl::GetType<T>();
336 static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
337 ONNX_NAMESPACE::TypeProto& proto) {
338 proto.mutable_tensor_type()->set_elem_type(element_type);
342 #if !defined(DISABLE_SPARSE_TENSORS)
344 static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
345 ONNX_NAMESPACE::TypeProto& proto) {
346 proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
349 #endif // !defined(DISABLE_SPARSE_TENSORS)
351 #if !defined(DISABLE_ML_OPS)
355 ONNX_NAMESPACE::TypeProto&);
360 template <
typename V>
365 static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type,
const ONNX_NAMESPACE::TypeProto* value_proto,
366 ONNX_NAMESPACE::TypeProto& proto) {
367 ORT_ENFORCE(value_proto !=
nullptr,
"expected a registered ONNX type");
368 proto.mutable_map_type()->set_key_type(key_type);
378 ONNX_NAMESPACE::TypeProto&);
382 template <
typename T>
387 static void Set(
const ONNX_NAMESPACE::TypeProto* elem_proto,
388 ONNX_NAMESPACE::TypeProto& proto) {
389 ORT_ENFORCE(elem_proto !=
nullptr,
"expected a registered ONNX type");
397 ONNX_NAMESPACE::TypeProto&);
401 template <
typename T,
typename elemT>
404 return DataTypeImpl::GetTensorType<elemT>();
407 return DataTypeImpl::GetSequenceTensorType<elemT>();
411 static void Set(
const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
412 ORT_ENFORCE(elem_proto !=
nullptr,
"expected a registered ONNX type");
420 ONNX_NAMESPACE::TypeProto& proto);
426 #if defined(_MSC_VER) && !defined(__clang__)
427 #pragma warning(push)
428 #pragma warning(disable : 26436)
438 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
442 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
474 template <
typename elemT>
478 "Requires one of the tensor fundamental types");
485 return DataTypeImpl::GetType<elemT>();
490 using namespace data_types_internal;
491 TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(),
MutableTypeProto());
495 #if defined(DISABLE_OPTIONAL_TYPE)
504 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto&)
const override {
513 ORT_THROW(
"Type is disabled in this build.");
517 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
526 ~DisabledTypeBase()
override;
535 #if !defined(DISABLE_SPARSE_TENSORS)
541 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
545 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
565 template <
typename elemT>
569 "Requires one of the sparse-tensor fundamental types");
575 return DataTypeImpl::GetType<elemT>();
580 using namespace data_types_internal;
581 SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(),
MutableTypeProto());
585 #endif // !defined(DISABLE_SPARSE_TENSORS)
589 #if !defined(DISABLE_OPTIONAL_TYPE)
594 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
601 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
625 template <
typename T,
typename elemT>
627 #if !defined(DISABLE_OPTIONAL_TYPE)
630 public DisabledTypeBase
636 #if !defined(DISABLE_OPTIONAL_TYPE)
638 "Requires one of the supported types: Tensor or TensorSeq");
641 "Requires one of the tensor fundamental types");
644 return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
649 #if !defined(DISABLE_OPTIONAL_TYPE)
655 using namespace data_types_internal;
688 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
700 virtual void FromDataContainer(
const void*
data,
size_t data_size,
OrtValue& output)
const;
710 virtual void ToDataContainer(
const OrtValue& input,
size_t data_size,
void* data)
const;
721 bool IsMapCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const;
723 bool IsSequenceCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const;
725 bool IsOpaqueCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const;
733 template <
typename T>
736 static void Delete(
void* p) {
737 delete static_cast<T*
>(p);
746 return []() ->
void* {
return new T(); };
753 #if !defined(DISABLE_ML_OPS)
764 template <
typename CPPType>
768 "Requires one of the tensor fundamental types as key");
772 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override {
773 return this->IsMapCompatible(type_proto);
778 using namespace data_types_internal;
779 MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
780 MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->
GetTypeProto(),
795 template <
typename CPPType>
800 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override {
801 return this->IsSequenceCompatible(type_proto);
806 using namespace data_types_internal;
807 SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->
GetTypeProto(),
821 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
830 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
845 #if defined(_MSC_VER) && !defined(__clang__)
858 template <
typename TensorElemType>
862 "Requires one of the tensor fundamental types");
868 return DataTypeImpl::GetType<TensorElemType>();
873 using namespace data_types_internal;
874 SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->
GetTypeProto(),
892 template <
typename T, const
char D[], const
char N[]>
897 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override {
898 return this->IsOpaqueCompatible(type_proto);
942 const int32_t data_type_;
953 template <
typename T>
956 static void Delete(
void* p) {
957 delete static_cast<T*
>(p);
970 utils::ToTensorProtoElementType<T>()} {
982 #if !defined(DISABLE_SPARSE_TENSORS)
988 #if !defined(DISABLE_OPTIONAL_TYPE)
1007 #define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
1009 MLDataType TensorType<ELEM_TYPE>::Type() { \
1010 static TensorType<ELEM_TYPE> tensor_type; \
1011 return &tensor_type; \
1014 MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
1015 return TensorType<ELEM_TYPE>::Type(); \
1018 #if !defined(DISABLE_SPARSE_TENSORS)
1019 #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
1021 MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
1022 static SparseTensorType<ELEM_TYPE> tensor_type; \
1023 return &tensor_type; \
1026 MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
1027 return SparseTensorType<ELEM_TYPE>::Type(); \
1031 #define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
1033 MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
1034 static OptionalType<ORT_TYPE, TYPE> optional_type; \
1035 return &optional_type; \
1038 MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
1039 return OptionalType<ORT_TYPE, TYPE>::Type(); \
1042 #if !defined(DISABLE_ML_OPS)
1043 #define ORT_REGISTER_MAP(TYPE) \
1045 MLDataType MapType<TYPE>::Type() { \
1046 static MapType<TYPE> map_type; \
1050 MLDataType DataTypeImpl::GetType<TYPE>() { \
1051 return MapType<TYPE>::Type(); \
1055 #define ORT_REGISTER_SEQ(TYPE) \
1057 MLDataType SequenceType<TYPE>::Type() { \
1058 static SequenceType<TYPE> sequence_type; \
1059 return &sequence_type; \
1062 MLDataType DataTypeImpl::GetType<TYPE>() { \
1063 return SequenceType<TYPE>::Type(); \
1066 #define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
1068 MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
1069 static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
1070 return &sequence_tensor_type; \
1073 MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
1074 return SequenceTensorType<ELEM_TYPE>::Type(); \
1077 #define ORT_REGISTER_PRIM_TYPE(TYPE) \
1079 MLDataType PrimitiveDataType<TYPE>::Type() { \
1080 static PrimitiveDataType<TYPE> prim_data_type; \
1081 return &prim_data_type; \
1084 MLDataType DataTypeImpl::GetType<TYPE>() { \
1085 return PrimitiveDataType<TYPE>::Type(); \
1088 #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
1090 MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
1091 static OpaqueType<CPPType, Domain, Name> opaque_type; \
1092 return &opaque_type; \
1095 MLDataType DataTypeImpl::GetType<CPPType>() { \
1096 return OpaqueType<CPPType, Domain, Name>::Type(); \
void AssignOpaqueDomainName(const char *domain, const char *name, ONNX_NAMESPACE::TypeProto &proto)
OpaqueTypes helpers.
std::vector< int64_t > VectorInt64
static void RegisterDataType(MLDataType)
static const TensorTypeBase * TensorTypeFromONNXEnum(int type)
virtual const ONNX_NAMESPACE::TypeProto * GetTypeProto() const =0
Retrieves an instance of TypeProto for a given MLDataType.
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Base class for MLDataType.
static const std::vector< MLDataType > & AllNumericTensorTypesIRv9()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
virtual MLDataType GetElementType() const
static void ToContainer(const OrtValue &, size_t, void *)
static const std::vector< MLDataType > & AllNumericTensorTypesIRv4()
void FromDataContainer(const void *data, size_t data_size, OrtValue &output) const override
static const std::vector< MLDataType > & AllOptionalTypesIRv9()
SequenceTensorTypeBase serves as a base type class for Tensor sequences. Akin to TensorTypeBase. Runtime representation is always TensorSeq.
ONNX_NAMESPACE::TypeProto & MutableTypeProto()
static void FromContainer(MLDataType, const void *, size_t, OrtValue &)
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypes()
GLsizei const GLchar *const * string
void ToDataContainer(const OrtValue &input, size_t data_size, void *data) const override
GLsizei const GLfloat * value
std::map< std::string, float > MapStringToFloat
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const =0
this API will be used to check type compatibility at runtime
static const std::vector< MLDataType > & AllTensorTypesIRv9()
static MLDataType GetTensorType()
std::map< int64_t, int64_t > MapInt64ToInt64
#define ORT_NOT_IMPLEMENTED(...)
MapType. Use this type to register mapping types.
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Sequence helpers.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypesIRv4()
#define ORT_ENFORCE(condition,...)
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
static MLDataType GetElemType()
CreateFunc GetCreateFunc() const override
static const std::vector< MLDataType > & AllOptionalTypes()
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const final
Retrieves an instance of TypeProto for a given MLDataType.
DeleteFunc GetDeleteFunc() const override
virtual MLDataType GetElementType() const
Common base-class for all sparse-tensors (with different element types).
static const std::vector< MLDataType > & AllTensorTypesIRv4()
static const std::vector< MLDataType > & AllOptionalTypesIRv4()
static MLDataType GetSequenceTensorType()
int32_t GetDataType() const
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Optional helpers.
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const override
Retrieves an instance of TypeProto for a given MLDataType.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllIEEEFloatTensorTypes()
static const std::vector< MLDataType > & AllFixedSizeTensorTypesIRv4()
static const std::vector< MLDataType > & AllFixedSizeTensorTypes()
static void Set(const onnx::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
static const std::vector< MLDataType > & AllTensorTypes()
const SparseTensorTypeBase * AsSparseTensorType() const
static MLDataType GetSparseTensorType()
PrimitiveDataTypeBase(size_t size, int32_t data_type)
static const std::vector< MLDataType > & AllSequenceTensorTypesIRv4()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypes()
static const std::vector< MLDataType > & AllSequenceTensorTypes()
std::map< int64_t, float > MapInt64ToFloat
void(*)(void *) DeleteFunc
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase)
static void Set(const ONNX_NAMESPACE::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
virtual MLDataType GetElementType() const
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypes()
std::map< std::string, int64_t > MapStringToInt64
const NonTensorTypeBase * AsNonTensorType() const
Common base-class for all optional types.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypesIRv4()
PrimitiveDataType Typed specialization for primitive types. Concrete instances of this class are used...
bool IsNonTensorType() const
const TensorTypeBase * AsTensorType() const
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
bool IsTensorType() const
MLDataType GetElementType() const override
static MLDataType GetDataType(const std::string &)
DeleteFunc GetDeleteFunc() const override
STATIC_INLINE uint64_t H(uint64_t x, uint64_t y, uint64_t mul, int r)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
GLuint const GLchar * name
std::map< std::string, std::string > MapStringToString
Predefined registered types.
static MLDataType GetValueType()
static const SequenceTensorTypeBase * SequenceTensorTypeFromONNXEnum(int type)
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Map helpers.
static MLDataType GetOptionalType()
const DataTypeImpl * MLDataType
static const std::vector< MLDataType > & AllSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypes()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
std::map< int64_t, double > MapInt64ToDouble
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllNumericTensorTypes()
virtual MLDataType GetElementType() const
std::map< int64_t, std::string > MapInt64ToString
SequenceType. Use to register sequence for non-tensor types.
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
const OptionalTypeBase * AsOptionalType() const
SequenceTensorType. Use to register sequence for non-tensor types.
LeafData & operator=(const LeafData &)=delete
DataTypeImpl(GeneralType type, size_t size)
bool IsSparseTensorType() const
MLDataType GetElementType() const override
virtual ~DataTypeImpl()=default
std::map< std::string, double > MapStringToDouble
GA_API const UT_StringHolder N
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
static const std::vector< MLDataType > & AllFixedSizeTensorTypesIRv9()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &) const override
this API will be used to check type compatibility at runtime
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypes()
Tensor type. This type does not have a C++ type associated with it at registration time except the el...
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypesIRv4()
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypesIRv4()
DeleteFunc GetDeleteFunc() const override
DeleteFunc GetDeleteFunc() const override
bool IsPrimitiveDataType() const
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypesIRv9()
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypesIRv4()
GLsizei GLenum GLenum * types
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
Base type for all non-tensors, maps, sequences and opaques.
const SequenceTensorTypeBase * AsSequenceTensorType() const
std::vector< std::string > VectorString
static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto *value_proto, ONNX_NAMESPACE::TypeProto &proto)
std::vector< MapInt64ToFloat > VectorMapInt64ToFloat
static MLDataType GetElemType()
~TensorTypeBase() override
static const SparseTensorTypeBase * SparseTensorTypeFromONNXEnum(int type)
std::vector< MapStringToFloat > VectorMapStringToFloat
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypesIRv9()
virtual DeleteFunc GetDeleteFunc() const =0
Provide a specialization for your C++ Non-tensor type so your implementation FromDataTypeContainer/To...
bool IsOptionalType() const
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
static MLDataType GetType()