HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
data_types_internal.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <array>
7 #include <cassert>
8 #include <cstdint>
9 #include <string>
10 #include <type_traits>
11 #include <vector>
12 
13 #include "boost/mp11.hpp"
14 
15 #include "core/common/common.h"
17 #ifndef SHARED_PROVIDER
18 #include "core/common/type_list.h"
20 #include "core/graph/onnx_protobuf.h"
21 #endif
22 
23 namespace onnxruntime {
24 namespace utils {
25 
26 // The following primitives are strongly recommended for switching on tensor input datatypes for
27 // kernel implementations.
28 //
29 // 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
30 // DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
31 // 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
32 // if you have a standalone MLDatatType with a sequence of if/else statements.
33 // 3) For something in between, we suggest to use CallDispatcher pattern.
34 //
35 // Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
36 // Every primitive type carries with it an integer constant that can be used for quick switching on types.
37 
38 #if !defined(DISABLE_FLOAT8_TYPES)
39 
40 #define DispatchOnTensorType(tensor_type, function, ...) \
41  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
42  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
43  function<float>(__VA_ARGS__); \
44  break; \
45  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
46  function<bool>(__VA_ARGS__); \
47  break; \
48  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
49  function<double>(__VA_ARGS__); \
50  break; \
51  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
52  function<std::string>(__VA_ARGS__); \
53  break; \
54  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
55  function<int8_t>(__VA_ARGS__); \
56  break; \
57  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
58  function<uint8_t>(__VA_ARGS__); \
59  break; \
60  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
61  function<int16_t>(__VA_ARGS__); \
62  break; \
63  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
64  function<uint16_t>(__VA_ARGS__); \
65  break; \
66  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
67  function<int32_t>(__VA_ARGS__); \
68  break; \
69  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
70  function<uint32_t>(__VA_ARGS__); \
71  break; \
72  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
73  function<int64_t>(__VA_ARGS__); \
74  break; \
75  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
76  function<uint64_t>(__VA_ARGS__); \
77  break; \
78  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
79  function<MLFloat16>(__VA_ARGS__); \
80  break; \
81  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
82  function<BFloat16>(__VA_ARGS__); \
83  break; \
84  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
85  function<Float8E4M3FN>(__VA_ARGS__); \
86  break; \
87  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
88  function<Float8E4M3FNUZ>(__VA_ARGS__); \
89  break; \
90  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
91  function<Float8E5M2>(__VA_ARGS__); \
92  break; \
93  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
94  function<Float8E5M2FNUZ>(__VA_ARGS__); \
95  break; \
96  default: \
97  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
98  }
99 
100 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
101  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
102  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
103  retval = function<float>(__VA_ARGS__); \
104  break; \
105  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
106  retval = function<bool>(__VA_ARGS__); \
107  break; \
108  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
109  retval = function<double>(__VA_ARGS__); \
110  break; \
111  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
112  retval = function<std::string>(__VA_ARGS__); \
113  break; \
114  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
115  retval = function<int8_t>(__VA_ARGS__); \
116  break; \
117  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
118  retval = function<uint8_t>(__VA_ARGS__); \
119  break; \
120  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
121  retval = function<uint16_t>(__VA_ARGS__); \
122  break; \
123  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
124  retval = function<int16_t>(__VA_ARGS__); \
125  break; \
126  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
127  retval = function<int32_t>(__VA_ARGS__); \
128  break; \
129  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
130  retval = function<uint32_t>(__VA_ARGS__); \
131  break; \
132  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
133  retval = function<int64_t>(__VA_ARGS__); \
134  break; \
135  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
136  retval = function<uint64_t>(__VA_ARGS__); \
137  break; \
138  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
139  retval = function<MLFloat16>(__VA_ARGS__); \
140  break; \
141  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
142  retval = function<BFloat16>(__VA_ARGS__); \
143  break; \
144  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
145  retval = function<Float8E4M3FN>(__VA_ARGS__); \
146  break; \
147  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
148  retval = function<Float8E4M3FNUZ>(__VA_ARGS__); \
149  break; \
150  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
151  retval = function<Float8E5M2>(__VA_ARGS__); \
152  break; \
153  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
154  retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
155  break; \
156  default: \
157  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
158  }
159 
160 #else
161 
162 #define DispatchOnTensorType(tensor_type, function, ...) \
163  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
164  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
165  function<float>(__VA_ARGS__); \
166  break; \
167  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
168  function<bool>(__VA_ARGS__); \
169  break; \
170  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
171  function<double>(__VA_ARGS__); \
172  break; \
173  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
174  function<std::string>(__VA_ARGS__); \
175  break; \
176  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
177  function<int8_t>(__VA_ARGS__); \
178  break; \
179  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
180  function<uint8_t>(__VA_ARGS__); \
181  break; \
182  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
183  function<int16_t>(__VA_ARGS__); \
184  break; \
185  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
186  function<uint16_t>(__VA_ARGS__); \
187  break; \
188  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
189  function<int32_t>(__VA_ARGS__); \
190  break; \
191  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
192  function<uint32_t>(__VA_ARGS__); \
193  break; \
194  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
195  function<int64_t>(__VA_ARGS__); \
196  break; \
197  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
198  function<uint64_t>(__VA_ARGS__); \
199  break; \
200  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
201  function<MLFloat16>(__VA_ARGS__); \
202  break; \
203  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
204  function<BFloat16>(__VA_ARGS__); \
205  break; \
206  default: \
207  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
208  }
209 
210 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
211  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
212  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
213  retval = function<float>(__VA_ARGS__); \
214  break; \
215  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
216  retval = function<bool>(__VA_ARGS__); \
217  break; \
218  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
219  retval = function<double>(__VA_ARGS__); \
220  break; \
221  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
222  retval = function<std::string>(__VA_ARGS__); \
223  break; \
224  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
225  retval = function<int8_t>(__VA_ARGS__); \
226  break; \
227  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
228  retval = function<uint8_t>(__VA_ARGS__); \
229  break; \
230  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
231  retval = function<uint16_t>(__VA_ARGS__); \
232  break; \
233  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
234  retval = function<int16_t>(__VA_ARGS__); \
235  break; \
236  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
237  retval = function<int32_t>(__VA_ARGS__); \
238  break; \
239  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
240  retval = function<uint32_t>(__VA_ARGS__); \
241  break; \
242  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
243  retval = function<int64_t>(__VA_ARGS__); \
244  break; \
245  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
246  retval = function<uint64_t>(__VA_ARGS__); \
247  break; \
248  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
249  retval = function<MLFloat16>(__VA_ARGS__); \
250  break; \
251  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
252  retval = function<BFloat16>(__VA_ARGS__); \
253  break; \
254  default: \
255  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
256  }
257 
258 #endif
259 
260 ////////////////////////////////////////////////////////////////////////////////
261 /// Use the following primitives if you have a few types to switch on so you
262 // can write a short sequence of if/else statements.
263 
264 // This is a frequently used check so we make a separate utility function.
265 inline bool IsDataTypeString(MLDataType dt_type) {
266  auto prim_type = dt_type->AsPrimitiveDataType();
267  return (prim_type != nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
268 }
269 
270 // Test if MLDataType is a concrete type of PrimitiveDataTypeBase
271 // and it is T
272 template <class T>
273 inline bool IsPrimitiveDataType(MLDataType dt_type) {
274  auto prim_type = dt_type->AsPrimitiveDataType();
275  return (prim_type != nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
276 }
277 
278 // Use after AsPrimitiveDataType() is successful
279 // Check if PrimitiveDataTypeBase is of type T
280 template <class T>
281 inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) {
282  assert(prim_type != nullptr);
283  return prim_type->GetDataType() == ToTensorProtoElementType<T>();
284 }
285 
286 // This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226
287 // GCC until very recently does not support template parameter pack expansion within lambda context.
288 namespace mltype_dispatcher_internal {
289 
290 // T - type handled by this helper
292  int32_t dt_type_; // Type currently dispatched
293  size_t called_;
294 
295  public:
296  explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {}
297 
298  // Must return integer to be in a expandable context
299  template <class T, class Fn, class... Args>
300  int Invoke(Fn&& fn, Args&&... args) {
301  if (utils::ToTensorProtoElementType<T>() == dt_type_) {
302  std::forward<Fn>(fn)(std::forward<Args>(args)...);
303  ++called_;
304  }
305  return 0;
306  }
307 
309  ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
310  }
311 };
312 
313 // Default policy is to throw an exception.
314 // Other policies may set the second result argument accordingly.
315 template <class Ret>
317  void operator()(int32_t dt_type, Ret& /*result*/) const {
318  ORT_THROW("Unsupported data type: ", dt_type);
319  }
320 };
321 
322 // Helper with the result type
323 template <class Ret, class UnsupportedPolicy>
325  int32_t dt_type_; // Type currently dispatched
326  size_t called_;
327  Ret result_;
328 
329  public:
330  explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {}
331 
332  Ret Get() {
333  // No type was invoked
334  if (called_ == 0) {
335  UnsupportedPolicy()(dt_type_, result_);
336  }
337  return result_;
338  }
339 
340  // Must return integer to be in a expandable context
341  template <class T, class Fn, class... Args>
342  int Invoke(Fn&& fn, Args&&... args) {
343  if (utils::ToTensorProtoElementType<T>() == dt_type_) {
344  result_ = std::forward<Fn>(fn)(std::forward<Args>(args)...);
345  ++called_;
346  }
347  return 0;
348  }
349 };
350 
351 template <typename T>
353  std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
354 
356  std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
357 
358 } // namespace mltype_dispatcher_internal
359 
360 /**
361  * This class helps to efficiently dispatch calls to implementation function
362  * objects with a tensor element type template argument.
363  *
364  * The constructor accepts a value corresponding to a tensor element type.
365  * For example, it can be obtained from:
366  * input_tensor->GetElementType()
367  *
368  * The Invoke member functions will instantiate and invoke the provided
369  * function object template, Fn. Fn must be default constructible. Fn must also
370  * have a tensor element type template argument. This type template argument
371  * will be the type that corresponds to the value given in the constructor.
372  * These functions accept and forward arbitrary function arguments. They ensure
373  * that Fn is called once with the type specified in the constructor.
374  *
375  * @tparam Types The types supported by the implementation. This should be a
376  * set of ONNX tensor element types that are supported by ORT.
377  */
378 template <typename... Types>
380  using SupportedTypeList = TypeList<Types...>;
381  using SupportedTensorProtoElementTypeList =
382  boost::mp11::mp_transform<
384 
385  static_assert(
386  boost::mp11::mp_and<
387  boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
388  boost::mp11::mp_not<
389  boost::mp11::mp_set_contains<
390  SupportedTensorProtoElementTypeList,
392  "Types must map to a unique set of ONNX tensor element data types supported by ORT.");
393 
394  int32_t dt_type_;
395 
396  public:
397  /**
398  * Constructor.
399  * @param dt_type The value corresponding to the tensor element type to be
400  * dispatched to. This can be obtained from
401  * input_tensor->GetElementType() or
402  * utils::ToTensorProtoElementType<T>().
403  */
404  explicit MLTypeCallDispatcher(int32_t dt_type) noexcept : dt_type_(dt_type) {}
405 
406  /**
407  * Invokes Fn<T> with the specified arguments.
408  *
409  * @tparam Fn The function object template.
410  * @tparam Args The argument types.
411  */
412  template <template <typename...> class Fn, typename... Args>
413  void Invoke(Args&&... args) const {
414  InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
415  }
416 
417  /**
418  * Invokes Fn<..., T> with leading template arguments and the specified
419  * arguments.
420  *
421  * @tparam Fn The function object template.
422  * @tparam LeadingTemplateArgTypeList A type list of the leading template
423  * arguments.
424  * @tparam Args The argument types.
425  */
426  template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
427  void InvokeWithLeadingTemplateArgs(Args&&... args) const {
428  static_assert(
430  "LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
431 
433 
434  // given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
435  // call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
436  static_cast<void>(std::array<int, sizeof...(Types)>{
437  helper.template Invoke<Types>(
438  boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
439  std::forward<Args>(args)...)...});
440 
441  // avoid "unused parameter" warning for the case where Types is empty
442  static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
443 
444  helper.CheckCalledOnce();
445  }
446 
447  /**
448  * Invokes Fn<T> with the specified arguments and returns the result.
449  *
450  * @tparam Ret The return type. Fn should return a type convertible to Ret.
451  * @tparam Fn The function object template.
452  * @tparam Args The argument types.
453  */
454  template <class Ret, template <typename...> class Fn, typename... Args>
455  Ret InvokeRet(Args&&... args) const {
458  std::forward<Args>(args)...);
459  }
460 
461  /**
462  * Invokes Fn<T> with the specified arguments and returns the result.
463  *
464  * @tparam Ret The return type. Fn should return a type convertible to Ret.
465  * @tparam Fn The function object template.
466  * @tparam UnsupportedPolicy The policy used to handle unsupported types.
467  * See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
468  * for an example.
469  * @tparam Args The argument types.
470  */
471  template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
472  Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
474  Ret, Fn, UnsupportedPolicy, TypeList<>>(
475  std::forward<Args>(args)...);
476  }
477 
478  /**
479  * Invokes Fn<..., T> with leading template arguments and the specified
480  * arguments and returns the result.
481  *
482  * @tparam Ret The return type. Fn should return a type convertible to Ret.
483  * @tparam Fn The function object template.
484  * @tparam LeadingTemplateArgTypeList A type list of the leading template
485  * arguments.
486  * @tparam Args The argument types.
487  */
488  template <class Ret, template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
489  Ret InvokeRetWithLeadingTemplateArgs(Args&&... args) const {
491  Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>, LeadingTemplateArgTypeList>(
492  std::forward<Args>(args)...);
493  }
494 
495  /**
496  * Invokes Fn<..., T> with leading template arguments and the specified
497  * arguments and returns the result.
498  *
499  * @tparam Ret The return type. Fn should return a type convertible to Ret.
500  * @tparam Fn The function object template.
501  * @tparam UnsupportedPolicy The policy used to handle unsupported types.
502  * See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
503  * for an example.
504  * @tparam LeadingTemplateArgTypeList A type list of the leading template
505  * arguments.
506  * @tparam Args The argument types.
507  */
508  template <class Ret,
509  template <typename...> class Fn,
510  class UnsupportedPolicy,
511  typename LeadingTemplateArgTypeList,
512  typename... Args>
515 
516  // given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
517  // call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
518  static_cast<void>(std::array<int, sizeof...(Types)>{
519  helper.template Invoke<Types>(
520  boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
521  std::forward<Args>(args)...)...});
522 
523  // avoid "unused parameter" warning for the case where Types is empty
524  static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
525 
526  return helper.Get();
527  }
528 };
529 
530 // the type MLTypeCallDispatcher<T...> given a type list L<T...>
531 template <typename L>
532 using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher, L>;
533 
534 namespace data_types_internal {
535 
536 enum class ContainerType : uint16_t {
537  kUndefined = 0,
538  kTensor = 1,
539  kMap = 2,
540  kSequence = 3,
541  kOpaque = 4,
542  kOptional = 5
543 };
544 
545 class TypeNode {
546  // type_ is a TypeProto value case enum
547  // that may be a kTypeTensor, kTypeMap, kTypeSequence
548  // prim_type_ is a TypeProto_DataType enum that has meaning
549  // - for Tensor then prim_type_ is the contained type
550  // - for Map prim_type is the key type. Next entry describes map value
551  // - For sequence prim_type_ is not used and has no meaning. Next entry
552  // describes the value for the sequence
553  // Tensor is always the last entry as it describes a contained primitive type.
554  ContainerType type_;
555  uint16_t prim_type_;
556 
557  public:
558  TypeNode(ContainerType type, int32_t prim_type) noexcept {
559  type_ = type;
560  prim_type_ = static_cast<uint16_t>(prim_type);
561  }
562 
563  bool IsType(ContainerType type) const noexcept {
564  return type_ == type;
565  }
566 
567  bool IsPrimType(int32_t prim_type) const noexcept {
568  return prim_type_ == static_cast<uint16_t>(prim_type);
569  }
570 };
571 
572 } // namespace data_types_internal
573 
574 ////////////////////////////////////////////////////////////////////
575 /// Provides generic interface to test whether MLDataType is a Sequence,
576 /// Map or an Opaque type including arbitrary recursive definitions
577 /// without querying DataTypeImpl::GetType<T> for all known complex types
578 
579 // T is a sequence contained element type
580 // If returns true then we know that the runtime
581 // representation is std::vector<T>
582 // T itself can be a runtime representation of another
583 // sequence, map, opaque type or a tensor
584 //
585 // That is it can be std::vector or a std::map
586 // If T is a primitive type sequence is tested whether it contains
587 // tensors of that type
588 //
589 // If T is an opaque type, then it is only tested to be opaque but not exactly
590 // a specific opaque type. To Test for a specific Opaque type use IsOpaqueType() below
591 //
592 // This class examines the supplied MLDataType and records
593 // its information in a vector so any subsequent checks for Sequences and Maps
594 // are quick.
596  using Cont = std::vector<data_types_internal::TypeNode>;
597  Cont types_;
598 
599  // Default IsContainerOfType is for Opaque type
600  template <class T>
601  struct IsContainerOfType {
602  static bool check(const Cont& c, size_t index) {
603  if (index >= c.size()) {
604  return false;
605  }
606  return c[index].IsType(data_types_internal::ContainerType::kOpaque);
607  }
608  };
609 
610  // Handles the case where sequence element is also a sequence
611  template <class T>
612  struct IsContainerOfType<std::vector<T>> {
613  static bool check(const Cont& c, size_t index) {
614  if (index >= c.size()) {
615  return false;
616  }
617  if (c[index].IsType(data_types_internal::ContainerType::kSequence)) {
618  ORT_ENFORCE(++index < c.size(), "Sequence is missing type entry for its element");
619  constexpr int32_t prim_type = ToTensorProtoElementType<T>();
620  // Check if this is a primitive type and it matches
621  if constexpr (prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
622  return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
623  c[index].IsPrimType(prim_type);
624  } else {
625  // T is not primitive, check next entry for non-primitive proto
626  return IsContainerOfType<T>::check(c, index);
627  }
628  }
629  return false;
630  }
631  };
632 
633  template <class K, class V>
634  struct IsContainerOfType<std::map<K, V>> {
635  static bool check(const Cont& c, size_t index) {
636  static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
637  "Map Key can not be a non-primitive type");
638  if (index >= c.size()) {
639  return false;
640  }
641  if (!c[index].IsType(data_types_internal::ContainerType::kMap)) {
642  return false;
643  }
644  constexpr int32_t key_type = ToTensorProtoElementType<K>();
645  if (!c[index].IsPrimType(key_type)) {
646  return false;
647  }
648  ORT_ENFORCE(++index < c.size(), "Map is missing type entry for its value");
649  constexpr int32_t val_type = ToTensorProtoElementType<V>();
650  if constexpr (val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
651  return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
652  c[index].IsPrimType(val_type);
653  } else
654  return IsContainerOfType<V>::check(c, index);
655  }
656  };
657 
658  public:
659  explicit ContainerChecker(MLDataType);
660  ~ContainerChecker() = default;
661 
662  bool IsMap() const noexcept {
663  assert(!types_.empty());
664  return types_[0].IsType(data_types_internal::ContainerType::kMap);
665  }
666 
667  bool IsSequence() const noexcept {
668  assert(!types_.empty());
669  return types_[0].IsType(data_types_internal::ContainerType::kSequence);
670  }
671 
672  template <class T>
673  bool IsSequenceOf() const {
674  assert(!types_.empty());
675  return IsContainerOfType<std::vector<T>>::check(types_, 0);
676  }
677 
678  template <class K, class V>
679  bool IsMapOf() const {
680  assert(!types_.empty());
681  return IsContainerOfType<std::map<K, V>>::check(types_, 0);
682  }
683 };
684 
685 bool IsOpaqueType(MLDataType ml_type, const char* domain, const char* name);
686 
687 } // namespace utils
688 } // namespace onnxruntime
typedef int(APIENTRYP RE_PFNGLXSWAPINTERVALSGIPROC)(int)
Base class for MLDataType.
Definition: data_types.h:76
MLTypeCallDispatcher(int32_t dt_type) noexcept
Ret InvokeRetWithUnsupportedPolicy(Args &&...args) const
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args &&...args) const
GLsizei const GLfloat * value
Definition: glcorearb.h:824
bool IsType(ContainerType type) const noexcept
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
Ret InvokeRetWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimType(int32_t prim_type) const noexcept
boost::mp11::mp_apply< MLTypeCallDispatcher, L > MLTypeCallDispatcherFromTypeList
void InvokeWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimitiveDataType(MLDataType dt_type)
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
Definition: data_types.h:998
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED > UndefinedTensorProtoElementTypeConstant
#define ORT_UNUSED_PARAMETER(x)
Definition: common.h:47
bool IsOpaqueType(MLDataType ml_type, const char *domain, const char *name)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
Definition: data_types.h:923
GLuint const GLchar * name
Definition: glcorearb.h:786
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType< T >()> TensorProtoElementTypeConstant
#define ORT_THROW(...)
Definition: common.h:162
bool IsDataTypeString(MLDataType dt_type)
Use the following primitives if you have a few types to switch on so you.
TypeNode(ContainerType type, int32_t prim_type) noexcept
GLuint index
Definition: glcorearb.h:786
**If you just want to fire and args
Definition: thread.h:609
Definition: core.h:1131
type
Definition: core.h:1059