6 #include "boost/mp11.hpp"
10 #include "core/framework/prepacked_weights_container.h"
12 #ifndef SHARED_PROVIDER
27 #include "core/graph/onnx_protobuf.h"
29 namespace onnxruntime {
30 class OpKernelContext;
34 namespace onnxruntime {
50 [[nodiscard]]
virtual bool IsAsync()
const {
91 bool& is_packed, PrePackedWeights* ) {
103 removable_attributes.clear();
127 bool& used_shared_buffers) {
128 used_shared_buffers =
false;
134 return *op_kernel_info_;
138 ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(
OpKernel);
139 std::unique_ptr<OpKernelInfo> op_kernel_info_;
142 using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>;
143 using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::
type;
165 template <
typename T>
169 template <
typename T>
174 template <
typename T>
180 template <
typename T>
187 template <
typename T>
194 template <
typename T>
201 template <
typename T>
209 #define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
210 provider##_##name##_##domain##_ver##ver
212 #define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \
213 ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
215 #define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \
216 ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
218 #define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \
219 ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
221 #define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
222 class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
225 BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
226 return KernelCreateInfo( \
227 builder.SetName(#name) \
230 .Provider(provider) \
232 static_cast<KernelCreatePtrFn>( \
234 const OpKernelInfo& info, \
235 std::unique_ptr<OpKernel>& out) -> Status { \
236 out = std::make_unique<__VA_ARGS__>(info); \
237 return Status::OK(); \
241 #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \
242 provider##_##name##_##domain##_ver##startver##_##endver
244 #define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \
245 ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
247 #define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \
248 ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
250 #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \
251 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \
254 BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \
255 return KernelCreateInfo( \
256 builder.SetName(#name) \
258 .SinceVersion(startver, endver) \
259 .Provider(provider) \
261 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
264 #define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \
265 provider##_##name##_##domain##_ver##ver##_##type
267 #define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \
268 ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
270 #define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \
271 ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
273 #define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \
274 ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
276 #define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \
277 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
280 BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
281 return KernelCreateInfo( \
282 builder.SetName(#name) \
285 .Provider(provider) \
287 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
290 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \
291 provider##_##name##_##domain##_ver##ver##_##type1##_##type2
293 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \
294 class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \
297 BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \
298 return KernelCreateInfo( \
299 builder.SetName(#name) \
302 .Provider(provider) \
304 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
307 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
308 provider##_##name##_##domain##_ver##startver##_##endver##_##type
310 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \
311 ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \
314 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \
315 ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \
318 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \
319 ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \
322 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \
323 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \
326 BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
328 return KernelCreateInfo( \
329 builder.SetName(#name) \
331 .SinceVersion(startver, endver) \
332 .Provider(provider) \
334 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
337 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \
338 provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2
340 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \
341 provider, builder, ...) \
342 class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \
345 BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
346 type1, type2, name)>() { \
347 return KernelCreateInfo( \
348 builder.SetName(#name) \
350 .SinceVersion(startver, endver) \
351 .Provider(provider) \
353 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
356 template <
typename... Types>
359 return {DataTypeImpl::GetTensorType<Types>()...};
363 #if !defined(DISABLE_SPARSE_TENSORS)
364 template <
typename... Types>
367 return {DataTypeImpl::GetSparseTensorType<Types>()...};
375 template <
typename... Types>
380 #if !defined(DISABLE_SPARSE_TENSORS)
381 template <
typename... Types>
388 template <
typename L>
390 return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}();
393 #if !defined(DISABLE_SPARSE_TENSORS)
394 template <
typename L>
396 return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}();
402 #ifndef SHARED_PROVIDER
KernelCreateInfo BuildKernelCreateInfo()
KernelCreateInfo BuildKernelCreateInfo()
const OrtDevice GetDevice(OrtMemType mem_type) const
std::unique_ptr< KernelDef > kernel_def
KernelCreateInfo(KernelCreateInfo &&other) noexcept
KernelCreateInfo BuildKernelCreateInfo()
std::function< void()> DoneCallback
virtual Status UseSharedPrePackedBuffers(std::vector< BufferUniquePtr > &, int, bool &used_shared_buffers)
KernelCreateInfo BuildKernelCreateInfo()
#define ORT_NOT_IMPLEMENTED(...)
std::vector< MLDataType > BuildKernelDefConstraintsFromTypeList()
const onnxruntime::KernelDef & KernelDef() const
OpKernel(const OpKernelInfo &info)
virtual Status ComputeAsync(_Inout_ OpKernelContext *, DoneCallback) const
KernelCreateInfo BuildKernelCreateInfo()
std::unique_ptr< OpKernelInfo > CopyOpKernelInfo(const OpKernelInfo &info)
std::vector< MLDataType > BuildKernelDefConstraints()
virtual Status PrePack(const Tensor &, int, AllocatorPtr, bool &is_packed, PrePackedWeights *)
std::vector< MLDataType > operator()() const
absl::InlinedVector< T, N, Allocator > InlinedVector
KernelCreateInfo BuildKernelCreateInfo()
virtual ~OpKernel()=default
std::function< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)> KernelCreateFn
std::vector< MLDataType > BuildKernelDefSparseConstraintsFromTypeList()
KernelCreateInfo(*)( BuildKernelCreateInfoFn)
std::shared_ptr< IAllocator > AllocatorPtr
KernelCreateInfo(std::unique_ptr< KernelDef > definition, KernelCreateFn create_func)
KernelCreateFn kernel_create_func
const onnxruntime::Node & Node() const
virtual Status Compute(_Inout_ OpKernelContext *context) const =0
std::vector< MLDataType > BuildKernelDefSparseConstraints()
std::add_pointer< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)>::type KernelCreatePtrFn
KernelCreateInfo()=default
virtual Status GetRemovableAttributes(InlinedVector< std::string > &removable_attributes) const
std::vector< MLDataType > operator()() const
KernelCreateInfo BuildKernelCreateInfo()
virtual bool IsAsync() const
const OpKernelInfo & Info() const