10 #include <type_traits>
13 #include "boost/mp11.hpp"
17 #ifndef SHARED_PROVIDER
18 #include "core/common/type_list.h"
20 #include "core/graph/onnx_protobuf.h"
23 namespace onnxruntime {
38 #if !defined(DISABLE_FLOAT8_TYPES)
40 #define DispatchOnTensorType(tensor_type, function, ...) \
41 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
42 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
43 function<float>(__VA_ARGS__); \
45 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
46 function<bool>(__VA_ARGS__); \
48 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
49 function<double>(__VA_ARGS__); \
51 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
52 function<std::string>(__VA_ARGS__); \
54 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
55 function<int8_t>(__VA_ARGS__); \
57 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
58 function<uint8_t>(__VA_ARGS__); \
60 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
61 function<int16_t>(__VA_ARGS__); \
63 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
64 function<uint16_t>(__VA_ARGS__); \
66 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
67 function<int32_t>(__VA_ARGS__); \
69 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
70 function<uint32_t>(__VA_ARGS__); \
72 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
73 function<int64_t>(__VA_ARGS__); \
75 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
76 function<uint64_t>(__VA_ARGS__); \
78 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
79 function<MLFloat16>(__VA_ARGS__); \
81 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
82 function<BFloat16>(__VA_ARGS__); \
84 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
85 function<Float8E4M3FN>(__VA_ARGS__); \
87 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
88 function<Float8E4M3FNUZ>(__VA_ARGS__); \
90 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
91 function<Float8E5M2>(__VA_ARGS__); \
93 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
94 function<Float8E5M2FNUZ>(__VA_ARGS__); \
97 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
100 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
101 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
102 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
103 retval = function<float>(__VA_ARGS__); \
105 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
106 retval = function<bool>(__VA_ARGS__); \
108 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
109 retval = function<double>(__VA_ARGS__); \
111 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
112 retval = function<std::string>(__VA_ARGS__); \
114 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
115 retval = function<int8_t>(__VA_ARGS__); \
117 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
118 retval = function<uint8_t>(__VA_ARGS__); \
120 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
121 retval = function<uint16_t>(__VA_ARGS__); \
123 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
124 retval = function<int16_t>(__VA_ARGS__); \
126 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
127 retval = function<int32_t>(__VA_ARGS__); \
129 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
130 retval = function<uint32_t>(__VA_ARGS__); \
132 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
133 retval = function<int64_t>(__VA_ARGS__); \
135 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
136 retval = function<uint64_t>(__VA_ARGS__); \
138 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
139 retval = function<MLFloat16>(__VA_ARGS__); \
141 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
142 retval = function<BFloat16>(__VA_ARGS__); \
144 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
145 retval = function<Float8E4M3FN>(__VA_ARGS__); \
147 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
148 retval = function<Float8E4M3FNUZ>(__VA_ARGS__); \
150 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
151 retval = function<Float8E5M2>(__VA_ARGS__); \
153 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
154 retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
157 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
162 #define DispatchOnTensorType(tensor_type, function, ...) \
163 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
164 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
165 function<float>(__VA_ARGS__); \
167 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
168 function<bool>(__VA_ARGS__); \
170 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
171 function<double>(__VA_ARGS__); \
173 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
174 function<std::string>(__VA_ARGS__); \
176 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
177 function<int8_t>(__VA_ARGS__); \
179 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
180 function<uint8_t>(__VA_ARGS__); \
182 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
183 function<int16_t>(__VA_ARGS__); \
185 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
186 function<uint16_t>(__VA_ARGS__); \
188 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
189 function<int32_t>(__VA_ARGS__); \
191 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
192 function<uint32_t>(__VA_ARGS__); \
194 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
195 function<int64_t>(__VA_ARGS__); \
197 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
198 function<uint64_t>(__VA_ARGS__); \
200 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
201 function<MLFloat16>(__VA_ARGS__); \
203 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
204 function<BFloat16>(__VA_ARGS__); \
207 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
210 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
211 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
212 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
213 retval = function<float>(__VA_ARGS__); \
215 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
216 retval = function<bool>(__VA_ARGS__); \
218 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
219 retval = function<double>(__VA_ARGS__); \
221 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
222 retval = function<std::string>(__VA_ARGS__); \
224 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
225 retval = function<int8_t>(__VA_ARGS__); \
227 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
228 retval = function<uint8_t>(__VA_ARGS__); \
230 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
231 retval = function<uint16_t>(__VA_ARGS__); \
233 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
234 retval = function<int16_t>(__VA_ARGS__); \
236 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
237 retval = function<int32_t>(__VA_ARGS__); \
239 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
240 retval = function<uint32_t>(__VA_ARGS__); \
242 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
243 retval = function<int64_t>(__VA_ARGS__); \
245 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
246 retval = function<uint64_t>(__VA_ARGS__); \
248 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
249 retval = function<MLFloat16>(__VA_ARGS__); \
251 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
252 retval = function<BFloat16>(__VA_ARGS__); \
255 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
267 return (prim_type !=
nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
275 return (prim_type !=
nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
282 assert(prim_type !=
nullptr);
283 return prim_type->
GetDataType() == ToTensorProtoElementType<T>();
288 namespace mltype_dispatcher_internal {
299 template <
class T,
class Fn,
class... Args>
301 if (utils::ToTensorProtoElementType<T>() == dt_type_) {
302 std::forward<Fn>(fn)(std::forward<Args>(
args)...);
309 ORT_ENFORCE(called_ == 1,
"Unsupported data type: ", dt_type_);
318 ORT_THROW(
"Unsupported data type: ", dt_type);
323 template <
class Ret,
class UnsupportedPolicy>
335 UnsupportedPolicy()(dt_type_, result_);
341 template <
class T,
class Fn,
class... Args>
343 if (utils::ToTensorProtoElementType<T>() == dt_type_) {
344 result_ = std::forward<Fn>(fn)(std::forward<Args>(
args)...);
351 template <
typename T>
353 std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
356 std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
378 template <
typename... Types>
380 using SupportedTypeList = TypeList<Types...>;
381 using SupportedTensorProtoElementTypeList =
382 boost::mp11::mp_transform<
387 boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
389 boost::mp11::mp_set_contains<
390 SupportedTensorProtoElementTypeList,
392 "Types must map to a unique set of ONNX tensor element data types supported by ORT.");
412 template <
template <
typename...>
class Fn,
typename... Args>
414 InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(
args)...);
426 template <
template <
typename...>
class Fn,
typename LeadingTemplateArgTypeList,
typename... Args>
430 "LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
436 static_cast<void>(std::array<
int,
sizeof...(Types)>{
437 helper.template Invoke<Types>(
438 boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
439 std::forward<Args>(
args)...)...});
454 template <
class Ret,
template <
typename...>
class Fn,
typename... Args>
458 std::forward<Args>(
args)...);
471 template <
class Ret,
template <
typename...>
class Fn,
class UnsupportedPolicy,
typename... Args>
474 Ret, Fn, UnsupportedPolicy, TypeList<>>(
475 std::forward<Args>(
args)...);
488 template <
class Ret,
template <
typename...>
class Fn,
typename LeadingTemplateArgTypeList,
typename... Args>
492 std::forward<Args>(
args)...);
509 template <
typename...>
class Fn,
510 class UnsupportedPolicy,
511 typename LeadingTemplateArgTypeList,
518 static_cast<void>(std::array<
int,
sizeof...(Types)>{
519 helper.template Invoke<Types>(
520 boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
521 std::forward<Args>(
args)...)...});
531 template <
typename L>
534 namespace data_types_internal {
560 prim_type_ =
static_cast<uint16_t
>(prim_type);
564 return type_ ==
type;
568 return prim_type_ ==
static_cast<uint16_t
>(prim_type);
596 using Cont = std::vector<data_types_internal::TypeNode>;
601 struct IsContainerOfType {
602 static bool check(
const Cont&
c,
size_t index) {
603 if (index >= c.size()) {
612 struct IsContainerOfType<std::vector<T>> {
613 static bool check(
const Cont&
c,
size_t index) {
614 if (index >= c.size()) {
618 ORT_ENFORCE(++index < c.size(),
"Sequence is missing type entry for its element");
619 constexpr int32_t prim_type = ToTensorProtoElementType<T>();
621 if constexpr (prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
623 c[index].IsPrimType(prim_type);
626 return IsContainerOfType<T>::check(c, index);
633 template <
class K,
class V>
634 struct IsContainerOfType<std::map<K, V>> {
635 static bool check(
const Cont&
c,
size_t index) {
636 static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
637 "Map Key can not be a non-primitive type");
638 if (index >= c.size()) {
644 constexpr int32_t key_type = ToTensorProtoElementType<K>();
645 if (!c[index].IsPrimType(key_type)) {
648 ORT_ENFORCE(++index < c.size(),
"Map is missing type entry for its value");
649 constexpr int32_t val_type = ToTensorProtoElementType<V>();
650 if constexpr (val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
652 c[index].IsPrimType(val_type);
654 return IsContainerOfType<V>::check(c, index);
663 assert(!types_.empty());
668 assert(!types_.empty());
674 assert(!types_.empty());
675 return IsContainerOfType<std::vector<T>>::check(types_, 0);
678 template <
class K,
class V>
680 assert(!types_.empty());
681 return IsContainerOfType<std::map<K, V>>::check(types_, 0);
typedef int(APIENTRYP RE_PFNGLXSWAPINTERVALSGIPROC)(int)
Base class for MLDataType.
MLTypeCallDispatcher(int32_t dt_type) noexcept
Ret InvokeRetWithUnsupportedPolicy(Args &&...args) const
int Invoke(Fn &&fn, Args &&...args)
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args &&...args) const
GLsizei const GLfloat * value
bool IsType(ContainerType type) const noexcept
#define ORT_ENFORCE(condition,...)
void operator()(int32_t dt_type, Ret &) const
Ret InvokeRetWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimType(int32_t prim_type) const noexcept
int32_t GetDataType() const
bool IsSequence() const noexcept
CallableDispatchableHelper(int32_t dt_type) noexcept
boost::mp11::mp_apply< MLTypeCallDispatcher, L > MLTypeCallDispatcherFromTypeList
void InvokeWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimitiveDataType(MLDataType dt_type)
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED > UndefinedTensorProtoElementTypeConstant
#define ORT_UNUSED_PARAMETER(x)
bool IsOpaqueType(MLDataType ml_type, const char *domain, const char *name)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
GLuint const GLchar * name
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType< T >()> TensorProtoElementTypeConstant
CallableDispatchableRetHelper(int32_t dt_type) noexcept
bool IsDataTypeString(MLDataType dt_type)
Use the following primitives if you have a few types to switch on so you.
TypeNode(ContainerType type, int32_t prim_type) noexcept
int Invoke(Fn &&fn, Args &&...args)
Ret InvokeRet(Args &&...args) const
**If you just want to fire and args
bool IsMap() const noexcept
void Invoke(Args &&...args) const
bool IsSequenceOf() const