HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
kernel_registry.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string_view>
7 
9 
10 namespace onnxruntime {
11 
12 using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
13 using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
14 
15 class IKernelTypeStrResolver;
16 
17 /**
18  * Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider.
19  */
21  public:
22  KernelRegistry() = default;
23 
24  // Register a kernel with kernel definition and function to create the kernel.
25  Status Register(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator);
26 
27  Status Register(KernelCreateInfo&& create_info);
28 
29  // TODO(edgchen1) for TryFindKernel(), consider using `out` != nullptr as indicator of whether kernel was found and
30  // Status as an indication of failure
31 
32  // Check if an execution provider can create kernel for a node and return the kernel if so.
33  // Kernel matching uses the types from the node and the kernel_type_str_resolver.
34  Status TryFindKernel(const Node& node, ProviderType exec_provider,
35  const IKernelTypeStrResolver& kernel_type_str_resolver,
36  const KernelCreateInfo** out) const;
37 
38  // map of type constraint name to required type
40 
41  // Check if an execution provider can create kernel for a node and return the kernel if so.
42  // Kernel matching uses the explicit type constraint name to required type map in type_constraints.
43  Status TryFindKernel(const Node& node, ProviderType exec_provider,
44  const TypeConstraintMap& type_constraints,
45  const KernelCreateInfo** out) const;
46 
47  /**
48  * @brief Find out whether a kernel is registered, without a node.
49  * This should be useful in graph optimizers, to check whether
50  * the node it is about to generate, is supported or not.
51  * @param exec_provider
52  * @param op_type
53  * @param domain
54  * @param version
55  * @param type_constraints
56  * @param out
57  * @return
58  */
59  Status TryFindKernel(ProviderType exec_provider,
60  std::string_view op_type,
61  std::string_view domain,
62  int version,
63  const KernelRegistry::TypeConstraintMap& type_constraints,
64  const KernelCreateInfo** out) const;
65 
66  static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
67  ProviderType exec_provider,
68  const IKernelTypeStrResolver& kernel_type_str_resolver) {
69  const KernelCreateInfo* info;
70  Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
71  return st.IsOK();
72  }
73 
74  bool IsEmpty() const { return kernel_creator_fn_map_.empty(); }
75 
76  // This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
78  return kernel_creator_fn_map_;
79  }
80 
81  private:
82  // TryFindKernel implementation. Either kernel_type_str_resolver or type_constraints is provided.
83  Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
84  const IKernelTypeStrResolver* kernel_type_str_resolver,
85  const TypeConstraintMap* type_constraints,
86  const KernelCreateInfo** out) const;
87 
88  // Check whether the types of inputs/outputs of the given node match the extra
89  // type-constraints of the given kernel. This serves two purposes: first, to
90  // select the right kernel implementation based on the types of the arguments
91  // when we have multiple kernels, e.g., Clip<float> and Clip<int>; second, to
92  // accommodate (and check) mapping of ONNX (specification) type to the onnxruntime
93  // implementation type (e.g., if we want to implement ONNX's float16 as a regular
94  // float in onnxruntime). (The second, however, requires a globally uniform mapping.)
95  //
96  // Note that this is not intended for type-checking the node against the ONNX
97  // type specification of the corresponding op, which is done before this check.
98  //
99  // In typical usage kernel_type_str_resolver is provided and type information from the node is used with
100  // kernel_type_str_resolver.
101  //
102  // There is also usage from a node dynamically created within a custom op via OrtApi CreateOp where an explicit
103  // type value for each type constraint is provided in type_constraints.
104  //
105  // Either kernel_type_str_resolver or type_constraints is provided and not both.
106  static bool VerifyKernelDef(const Node& node, const KernelDef& kernel_def,
107  const IKernelTypeStrResolver* kernel_type_str_resolver,
108  const TypeConstraintMap* type_constraints,
109  std::string& error_str);
110 
111  static std::string GetMapKey(std::string_view op_name, std::string_view domain, std::string_view provider) {
112  std::string key(op_name);
113  // use the kOnnxDomainAlias of 'ai.onnx' instead of kOnnxDomain's empty string
114  key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider);
115  return key;
116  }
117 
118  static std::string GetMapKey(const KernelDef& kernel_def) {
119  return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider());
120  }
121  // Kernel create function map from op name to kernel creation info.
122  // key is opname+domain_name+provider_name
123  KernelCreateMap kernel_creator_fn_map_;
124 };
125 } // namespace onnxruntime
const std::string & ProviderType
Definition: basic_types.h:35
constexpr const char * kOnnxDomainAlias
Definition: constants.h:16
std::multimap< std::string, KernelCreateInfo > KernelCreateMap
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
static bool HasImplementationOf(const KernelRegistry &r, const Node &node, ProviderType exec_provider, const IKernelTypeStrResolver &kernel_type_str_resolver)
basic_string_view< char > string_view
Definition: core.h:522
Status TryFindKernel(const Node &node, ProviderType exec_provider, const IKernelTypeStrResolver &kernel_type_str_resolver, const KernelCreateInfo **out) const
std::vector< std::pair< std::string, HashValue >> KernelDefHashes
InlinedHashMap< std::string, MLDataType > TypeConstraintMap
Status Register(KernelDefBuilder &kernel_def_builder, const KernelCreateFn &kernel_creator)
std::function< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)> KernelCreateFn
Definition: op_kernel.h:142
const KernelCreateMap & GetKernelCreateMap() const
GT_API const UT_StringHolder version
GT_API const UT_StringHolder st
GLboolean r
Definition: glcorearb.h:1222