HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_kernel.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include "boost/mp11.hpp"
7 
8 // It is safe to include the below header even if SHARED_PROVIDER macro is enabled
9 // as it doesn't include any pb headers.
10 #include "core/framework/prepacked_weights_container.h"
11 
12 #ifndef SHARED_PROVIDER
13 #include <functional>
14 
15 #include "core/common/exceptions.h"
17 #include "core/common/status.h"
24 #include "core/framework/tensor.h"
25 #include "core/graph/constants.h"
27 #include "core/graph/onnx_protobuf.h"
28 #include "core/common/gsl.h"
29 namespace onnxruntime {
30 class OpKernelContext;
31 }
32 #endif
33 
34 namespace onnxruntime {
35 
36 std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info);
37 
38 class OpKernel {
39  public:
40  using DoneCallback = std::function<void()>;
41 
42  explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
43  virtual ~OpKernel() = default;
44 
45  const onnxruntime::Node& Node() const;
46  const onnxruntime::KernelDef& KernelDef() const;
47 
48  [[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0;
49 
50  [[nodiscard]] virtual bool IsAsync() const {
51  // by default all kernels are sync version.
52  return false;
53  }
54 
55  [[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const {
56  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
57  }
58 
59  // Override this function to PrePack initialized constant tensor to the format as needed.
60  // For example, MatMul kernel can pack the input B if it is constant like code below.
61  // Status PrePack(const Tensor& tensor, int input_idx, /*out*/ bool& is_packed,
62  // /*out*/ PrePackedWeights* prepacked_weight_for_caching,
63  // AllocatorPtr alloc) override {
64  // is_packed = false;
65  // if (input_idx == 1) {
66  // is_packed = true;
67  // this.Pack(tensor, this.buffer_, alloc);
68  // if (prepacked_weight_for_caching) {
69  // // LOGIC TO CACHE `this.buffer_` SINCE THE KERNEL DOESN"T OWN THE PACKED WEIGHT
70  // }
71  // }
72  // return Status::OK();
73  // }
74  // Please refer to MatMulIntegerToFloatBase for a complete example
75  // @param tensor: The initialized constant tensor
76  // @param input_idx: The input index of the tensor in this kernel
77  // @param alloc: The kernel's PrePack() method MUST use this allocator for allocating the pre-packed
78  // weights' buffers. The alloc that the PrePack() method will receive will be either
79  // the allocator tied to the session if the kernel owns the pre-packed buffer or an
80  // allocator shared between sessions if the pre-packed buffer is to be shared across sessions
81  // (i.e.) the kernel does not own the buffer.
82  // @param is_packed: Set it to true if the kernel packed the tensor or to false
83  // The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
84  // and the original initialized constant tensor will be released and not accessible anymore in
85  // the Compute function.
86  // @param prepacked_weights: A PrePackedWeights instance will be provided to the kernel IF the pre-packed weights
87  // are meant to be stored in a shared container.
88 
89  virtual Status
90  PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
91  /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
92  is_packed = false;
93  return Status::OK();
94  }
95 
96  // Override this function to return a list of attributes the session can safely remove
97  // after it is intialized and saved. This option is useful to reduce memory usage
98  // when the kernel does not reuse the operator attributes but copies them.
99  // All attributes returned by this method will be removed by method
100  // PruneRemovableAttributes of they exists.
101  // @param removable_attributes set of attributes the session can safely remove.
102  virtual Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
103  removable_attributes.clear();
104  return Status::OK();
105  }
106 
107  // Override this function to use provided pre-packed weight.
108  // Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
109  // int input_idx,
110  // /*out*/ bool& used_shared_buffers) {
111  // used_shared_buffers = true;
112  // this.buffer_ = std::move(prepacked_buffers[0]);
113  // return Status::OK();
114  // }
115  // Please refer to MatMulIntegerToFloatBase for a complete example
116  // @param prepacked_buffers: The pre-packed buffers to be used by this kernel for the provided input index
117  // (Sometimes a single constant initializer may have multiple pre-packed buffers associated
118  // with it and it upto the kernel developer to store it in any order of their choice in PrePack()
119  // and must use the same order for retrieval in UseSharedPrePackedBuffers(). Though each element
120  // of this vector is a BufferUniquePtr, the deleter of the BufferUniquePtr is NULL. So actually they
121  // are raw pointers.
122  // @param input_idx: The input index of the tensor in this kernel
123  // @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
124  // that the provided weight has been used by the kernel.
125  virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
126  int /*input_idx*/,
127  /*out*/ bool& used_shared_buffers) {
128  used_shared_buffers = false;
129  return Status::OK();
130  }
131 
132  const OrtDevice GetDevice(OrtMemType mem_type) const;
133  const OpKernelInfo& Info() const {
134  return *op_kernel_info_;
135  }
136 
137  private:
138  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
139  std::unique_ptr<OpKernelInfo> op_kernel_info_;
140 };
141 class FuncManager;
142 using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>;
143 using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::type;
144 
146  std::unique_ptr<KernelDef> kernel_def; // Owned and stored in the global kernel registry.
149 
150  KernelCreateInfo(std::unique_ptr<KernelDef> definition,
151  KernelCreateFn create_func)
152  : kernel_def(std::move(definition)),
153  kernel_create_func(create_func) {
154  assert(kernel_def != nullptr);
155  }
156 
158  : kernel_def(std::move(other.kernel_def)),
159  kernel_create_func(std::move(other.kernel_create_func)) {}
160 
161  KernelCreateInfo() = default;
162 };
163 
164 // Forward declarations for the non-specialized BuildKernelCreateInfo method.
165 template <typename T>
166 KernelCreateInfo BuildKernelCreateInfo();
167 
168 namespace ml {
169 template <typename T>
171 } // namespace ml
172 
173 namespace contrib {
174 template <typename T>
176 } // namespace contrib
177 
178 namespace contrib {
179 namespace cuda {
180 template <typename T>
182 } // namespace cuda
183 } // namespace contrib
184 
185 namespace contrib {
186 namespace js {
187 template <typename T>
189 } // namespace js
190 } // namespace contrib
191 
192 namespace contrib {
193 namespace rocm {
194 template <typename T>
196 } // namespace rocm
197 } // namespace contrib
198 
199 namespace contrib {
200 namespace snpe {
201 template <typename T>
203 } // namespace snpe
204 } // namespace contrib
205 
207 
208 // Naming convention for operator kernel classes
209 #define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
210  provider##_##name##_##domain##_ver##ver
211 
212 #define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \
213  ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
214 
215 #define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \
216  ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
217 
218 #define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \
219  ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
220 
221 #define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
222  class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
223  template <> \
224  KernelCreateInfo \
225  BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
226  return KernelCreateInfo( \
227  builder.SetName(#name) \
228  .SetDomain(domain) \
229  .SinceVersion(ver) \
230  .Provider(provider) \
231  .Build(), \
232  static_cast<KernelCreatePtrFn>( \
233  [](FuncManager&, \
234  const OpKernelInfo& info, \
235  std::unique_ptr<OpKernel>& out) -> Status { \
236  out = std::make_unique<__VA_ARGS__>(info); \
237  return Status::OK(); \
238  })); \
239  }
240 
241 #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \
242  provider##_##name##_##domain##_ver##startver##_##endver
243 
244 #define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \
245  ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
246 
247 #define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \
248  ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
249 
250 #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \
251  class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \
252  template <> \
253  KernelCreateInfo \
254  BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \
255  return KernelCreateInfo( \
256  builder.SetName(#name) \
257  .SetDomain(domain) \
258  .SinceVersion(startver, endver) \
259  .Provider(provider) \
260  .Build(), \
261  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
262  }
263 
264 #define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \
265  provider##_##name##_##domain##_ver##ver##_##type
266 
267 #define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \
268  ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
269 
270 #define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \
271  ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
272 
273 #define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \
274  ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
275 
276 #define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \
277  class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
278  template <> \
279  KernelCreateInfo \
280  BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
281  return KernelCreateInfo( \
282  builder.SetName(#name) \
283  .SetDomain(domain) \
284  .SinceVersion(ver) \
285  .Provider(provider) \
286  .Build(), \
287  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
288  }
289 
290 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \
291  provider##_##name##_##domain##_ver##ver##_##type1##_##type2
292 
293 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \
294  class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \
295  template <> \
296  KernelCreateInfo \
297  BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \
298  return KernelCreateInfo( \
299  builder.SetName(#name) \
300  .SetDomain(domain) \
301  .SinceVersion(ver) \
302  .Provider(provider) \
303  .Build(), \
304  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
305  }
306 
307 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
308  provider##_##name##_##domain##_ver##startver##_##endver##_##type
309 
310 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \
311  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \
312  __VA_ARGS__)
313 
314 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \
315  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \
316  __VA_ARGS__)
317 
318 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \
319  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \
320  __VA_ARGS__)
321 
322 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \
323  class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \
324  template <> \
325  KernelCreateInfo \
326  BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
327  type, name)>() { \
328  return KernelCreateInfo( \
329  builder.SetName(#name) \
330  .SetDomain(domain) \
331  .SinceVersion(startver, endver) \
332  .Provider(provider) \
333  .Build(), \
334  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
335  }
336 
337 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \
338  provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2
339 
340 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \
341  provider, builder, ...) \
342  class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \
343  template <> \
344  KernelCreateInfo \
345  BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
346  type1, type2, name)>() { \
347  return KernelCreateInfo( \
348  builder.SetName(#name) \
349  .SetDomain(domain) \
350  .SinceVersion(startver, endver) \
351  .Provider(provider) \
352  .Build(), \
353  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
354  }
355 
356 template <typename... Types>
358  std::vector<MLDataType> operator()() const {
359  return {DataTypeImpl::GetTensorType<Types>()...};
360  }
361 };
362 
363 #if !defined(DISABLE_SPARSE_TENSORS)
364 template <typename... Types>
366  std::vector<MLDataType> operator()() const {
367  return {DataTypeImpl::GetSparseTensorType<Types>()...};
368  }
369 };
370 #endif
371 
372 // Use within macro definitions to create a custom vector of constraints.
373 // Example: #define REG_KERNEL(OP, VERSION, KERNEL_CLASS, Type, ...)
374 // .TypeConstraint("T", BuildKernelDefConstraints<Type, __VA_ARGS_>())
375 template <typename... Types>
376 inline std::vector<MLDataType> BuildKernelDefConstraints() {
377  return BuildKernelDefConstraintsImpl<Types...>{}();
378 }
379 
380 #if !defined(DISABLE_SPARSE_TENSORS)
381 template <typename... Types>
382 inline std::vector<MLDataType> BuildKernelDefSparseConstraints() {
383  return BuildKernelDefSparseConstraintsImpl<Types...>{}();
384 }
385 #endif
386 
387 // version of BuildKernelDefConstraints() which takes a type list
388 template <typename L>
389 inline std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() {
390  return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}();
391 }
392 
393 #if !defined(DISABLE_SPARSE_TENSORS)
394 template <typename L>
395 inline std::vector<MLDataType> BuildKernelDefSparseConstraintsFromTypeList() {
396  return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}();
397 }
398 #endif
399 
400 } // namespace onnxruntime
401 
402 #ifndef SHARED_PROVIDER
404 #endif
KernelCreateInfo BuildKernelCreateInfo()
KernelCreateInfo BuildKernelCreateInfo()
const OrtDevice GetDevice(OrtMemType mem_type) const
std::unique_ptr< KernelDef > kernel_def
Definition: op_kernel.h:146
KernelCreateInfo(KernelCreateInfo &&other) noexcept
Definition: op_kernel.h:157
KernelCreateInfo BuildKernelCreateInfo()
std::function< void()> DoneCallback
Definition: op_kernel.h:40
virtual Status UseSharedPrePackedBuffers(std::vector< BufferUniquePtr > &, int, bool &used_shared_buffers)
Definition: op_kernel.h:125
KernelCreateInfo BuildKernelCreateInfo()
#define ORT_NOT_IMPLEMENTED(...)
Definition: common.h:166
std::vector< MLDataType > BuildKernelDefConstraintsFromTypeList()
Definition: op_kernel.h:389
const onnxruntime::KernelDef & KernelDef() const
OpKernel(const OpKernelInfo &info)
Definition: op_kernel.h:42
virtual Status ComputeAsync(_Inout_ OpKernelContext *, DoneCallback) const
Definition: op_kernel.h:55
KernelCreateInfo BuildKernelCreateInfo()
std::unique_ptr< OpKernelInfo > CopyOpKernelInfo(const OpKernelInfo &info)
std::vector< MLDataType > BuildKernelDefConstraints()
Definition: op_kernel.h:376
virtual Status PrePack(const Tensor &, int, AllocatorPtr, bool &is_packed, PrePackedWeights *)
Definition: op_kernel.h:90
std::vector< MLDataType > operator()() const
Definition: op_kernel.h:358
absl::InlinedVector< T, N, Allocator > InlinedVector
KernelCreateInfo BuildKernelCreateInfo()
virtual ~OpKernel()=default
std::function< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)> KernelCreateFn
Definition: op_kernel.h:142
std::vector< MLDataType > BuildKernelDefSparseConstraintsFromTypeList()
Definition: op_kernel.h:395
KernelCreateInfo(*)( BuildKernelCreateInfoFn)
Definition: op_kernel.h:206
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:261
KernelCreateInfo(std::unique_ptr< KernelDef > definition, KernelCreateFn create_func)
Definition: op_kernel.h:150
KernelCreateFn kernel_create_func
Definition: op_kernel.h:147
const onnxruntime::Node & Node() const
virtual Status Compute(_Inout_ OpKernelContext *context) const =0
std::vector< MLDataType > BuildKernelDefSparseConstraints()
Definition: op_kernel.h:382
std::add_pointer< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)>::type KernelCreatePtrFn
Definition: op_kernel.h:143
virtual Status GetRemovableAttributes(InlinedVector< std::string > &removable_attributes) const
Definition: op_kernel.h:102
std::vector< MLDataType > operator()() const
Definition: op_kernel.h:366
type
Definition: core.h:1059
KernelCreateInfo BuildKernelCreateInfo()
virtual bool IsAsync() const
Definition: op_kernel.h:50
const OpKernelInfo & Info() const
Definition: op_kernel.h:133