HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
kernel_def_builder.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <limits.h>
7 #include <memory>
8 #include <optional>
9 #include <string>
10 #include <unordered_map>
11 #include <vector>
12 
13 #include "core/common/common.h"
16 #include "core/graph/basic_types.h"
17 
18 namespace onnxruntime {
20 
21 typedef std::map<size_t, OrtMemType> MemTypeMap;
22 
23 class KernelDef {
24  private:
25  // note that input/output might be on CPU implicitly when the node is from CPU execution provider
26  constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
27  return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
28  }
29 
30  public:
31  explicit KernelDef() = default;
32 
33  const std::string& OpName() const {
34  return op_name_;
35  }
36 
37  const std::string& Domain() const {
38  return op_domain_;
39  }
40 
41  void SinceVersion(/*out*/ int* start, /*out*/ int* end) const {
42  *start = op_since_version_start_;
43  *end = op_since_version_end_;
44  }
45 
46  const std::pair<int, int> SinceVersion() const {
47  return std::pair<int, int>(op_since_version_start_, op_since_version_end_);
48  }
49 
51  return provider_type_;
52  }
53 
54  // type constraints with types supported in this build
55  const std::unordered_map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
56  return type_constraints_;
57  }
58 
59  const std::vector<std::pair<int, int>>& MayInplace() const {
60  return inplace_map_;
61  }
62 
63  const std::vector<std::pair<int, int>>& Alias() const {
64  return alias_map_;
65  }
66 
67  const std::optional<std::pair<int, int>>& VariadicAlias() const {
68  return variadic_alias_offsets_;
69  }
70 
71  OrtMemType InputMemoryType(size_t input_index) const {
72  auto it = input_memory_type_args_.find(input_index);
73  if (it == input_memory_type_args_.end())
74  return default_inputs_mem_type_;
75  return it->second;
76  }
77 
78  bool IsInputOnCpu(size_t input_index) const { return MemTypeOnCpuExplicitly(InputMemoryType(input_index)); }
79 
80  bool IsOutputOnCpu(size_t output_index) const { return MemTypeOnCpuExplicitly(OutputMemoryType(output_index)); }
81 
82  bool AllocateInputsContiguously() const { return allocate_inputs_contiguously_; }
83 
84  bool HasExternalOutputs() const { return external_outputs_; }
85 
86 #ifdef ENABLE_STRIDED_TENSORS
87  const std::vector<int>& MayStridedInput() const { return may_strided_inputs_; }
88  const std::vector<std::pair<int, int>>& MayStridedOutput() const { return may_strided_output_map_; }
89 #endif
90 
91  OrtMemType OutputMemoryType(size_t output_index) const {
92  auto it = output_memory_type_args_.find(output_index);
93  if (it == output_memory_type_args_.end())
94  return default_outputs_mem_type_;
95  return it->second;
96  }
97 
98  int ExecQueueId() const {
99  return exec_queue_id_;
100  }
101 
102  bool IsConflict(const KernelDef& other) const;
103 
104  private:
105  friend class KernelDefBuilder;
106 
107  // The operator name supported by <*this> kernel..
108  std::string op_name_;
109 
110  // The operator since_version range supported by <*this> kernel.
111  // A kernel could support an operator definition between <op_since_version_start>
112  // and <op_since_version_end> (inclusive).
113  int op_since_version_start_ = 1;
114  int op_since_version_end_ = INT_MAX;
115 
116  // The operator domain supported by <*this> kernel.
117  // Default to 'onnxruntime::kOnnxDomain'.
118  // Please note the behavior of std::string("") and std::string() are different
119  std::string op_domain_;
120 
121  // The type of the execution provider.
122  std::string provider_type_;
123 
124  // The data types that are supported in this build (enabled) for inputs/outputs.
125  // Key is input/output/type constraint name defined in op schema, Value is supported types.
126  std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;
127 
128  // An element <i, j> means that output j reuses the memory of input i.
129  std::vector<std::pair<int, int>> inplace_map_;
130 
131  // An element <i, j> means that output j is an alias of input i.
132  std::vector<std::pair<int, int>> alias_map_;
133 
134  // This variable stores <input_offset, output_offset> for the variadic alias mapping
135  // output 'i + output_offset' is an alias of input 'i + input_offset' for all i >= 0
136  std::optional<std::pair<int, int>> variadic_alias_offsets_;
137 
138  // Require input tensors to be allocated contiguously.
139  bool allocate_inputs_contiguously_ = false;
140 
141  // Whether the outputs are from external.
142  bool external_outputs_ = false;
143 
144 #ifdef ENABLE_STRIDED_TENSORS
145  // An element i means i-th input can be strided tensor.
146  std::vector<int> may_strided_inputs_;
147 
148  // An element <i, j> means j-th output can be a strided tensor, which share the data from i-th input.
149  std::vector<std::pair<int, int>> may_strided_output_map_;
150 #endif
151 
152  // The memory types of inputs/outputs of this kernel
153  MemTypeMap input_memory_type_args_;
154  MemTypeMap output_memory_type_args_;
155 
156  // execution command queue id, 0 for default queue in execution provider
157  int exec_queue_id_ = 0;
158  // Default memory type for all inputs
159  OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
160  // Default memory type for all outputs
161  OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};
162 };
163 
165  public:
166  static std::unique_ptr<KernelDefBuilder> Create() { return std::make_unique<KernelDefBuilder>(); }
167 
168  explicit KernelDefBuilder()
169  : kernel_def_(std::make_unique<KernelDef>()) {}
170 
171  KernelDefBuilder& SetName(const std::string& op_name);
172  KernelDefBuilder& SetName(const char* op_name);
173 
174  KernelDefBuilder& SetDomain(const std::string& domain);
175  KernelDefBuilder& SetDomain(const char* domain);
176 
177  /**
178  This kernel supports operator definition since <since_version> (to latest).
179  */
180  KernelDefBuilder& SinceVersion(int since_version) {
181  kernel_def_->op_since_version_start_ = since_version;
182  return *this;
183  }
184 
185  /**
186  The start and end version should be set accordingly per version range for
187  each domain registered in OpSchemaRegistry::DomainToVersionRange in
188  \onnxruntime\onnxruntime\core\graph\op.h as below.
189  Key: domain. Value: <lowest version, highest version> pair.
190  std::unordered_map<std::string, std::pair<int, int>> map_;
191  */
192  KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) {
193  kernel_def_->op_since_version_start_ = since_version_start;
194  kernel_def_->op_since_version_end_ = since_version_end;
195  return *this;
196  }
197 
198  /**
199  The execution provider type of the kernel.
200  */
201  KernelDefBuilder& Provider(ProviderType provider_type);
202  KernelDefBuilder& Provider(const char* provider_type);
203 
204  /**
205  Specify the set of types that this kernel supports. A further restriction
206  of the set of types specified in the op schema.
207 
208  @param arg_name The arg name can be either op formal parameter name, say "X", or type
209  argument name specified in op schema, say "T".
210  @param types The types that are supported in this build.
211  */
212  KernelDefBuilder& TypeConstraint(const std::string& arg_name, std::vector<MLDataType> types);
213  KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector<MLDataType> types);
214 
215  /**
216  Like TypeConstraint but supports just a single type.
217  */
219  KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type);
220 
221  /**
222  Inplace mapping from inputs to outputs allowed.
223  It means that uplayer runtime could do memory in-place optimization
224  as it will not impact the correctness of this kernel.
225  */
226  KernelDefBuilder& MayInplace(const std::vector<std::pair<int, int>>& inplaces);
227  KernelDefBuilder& MayInplace(int input_index, int output_index);
228 
229  /**
230  Alias mapping from inputs to outputs. Different from Inplace that the
231  content of the tensor is not changed. This is to take care of operators
232  such as Identity and Reshape.
233  */
234  KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases);
235  KernelDefBuilder& Alias(int input_index, int output_index);
236 
237  /**
238  Apply variadic number of alias mapping from inputs to outputs.
239  This is effectively applying Alias(i + input_offset, i + output_offset) for i >= 0
240  */
241  KernelDefBuilder& VariadicAlias(int input_offset, int output_offset);
242 
243  /**
244  Specify that this kernel requires input tensors to be allocated
245  contiguously. This allows kernels to execute as a single large
246  computation, rather than numerous smaller computations.
247  */
249  kernel_def_->allocate_inputs_contiguously_ = true;
250  return *this;
251  }
252 
253  /**
254  Specify that this kernel's output buffers are passed from external,
255  i.e. not created or managed by ORT's memory allocator.
256  */
258  kernel_def_->external_outputs_ = true;
259  return *this;
260  }
261 
262 #ifdef ENABLE_STRIDED_TENSORS
263  /**
264  Specify that the input_index-th input can be strided tensor.
265  */
266  KernelDefBuilder& MayStridedInput(int input_index);
267 
268  /**
269  Specify that the output_index-th output can be strided tensor, and share the data
270  from input_index-th input.
271  */
272  KernelDefBuilder& MayStridedOutput(int input_index, int output_index);
273 #endif
274 
275  /**
276  Specify that this kernel requires an input arg
277  in certain memory type (instead of the default, device memory).
278  */
279  KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) {
280  kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, type));
281  return *this;
282  }
283 
284  /**
285  Specify that this kernel requires input arguments
286  in certain memory type (instead of the default, device memory).
287  */
288  KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector<int>& input_indexes) {
289  for (auto input_index : input_indexes) {
290  kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, type));
291  }
292  return *this;
293  }
294 
295  /**
296  Specify that this kernel provides an output arg
297  in certain memory type (instead of the default, device memory).
298  */
299  KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) {
300  kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, type));
301  return *this;
302  }
303 
304  /**
305  Specify that this kernel provides an output arguments
306  in certain memory type (instead of the default, device memory).
307  */
308  KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector<int>& output_indexes) {
309  for (auto output_index : output_indexes) {
310  kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, type));
311  }
312  return *this;
313  }
314 
315  /**
316  Specify that this kernel runs on which execution queue in the provider
317  */
318  KernelDefBuilder& ExecQueueId(int queue_id) {
319  kernel_def_->exec_queue_id_ = queue_id;
320  return *this;
321  }
322 
323  /**
324  Specify the default inputs memory type, if not specified, it is DefaultMemory
325  */
327  kernel_def_->default_inputs_mem_type_ = mem_type;
328  return *this;
329  }
330 
331  /**
332  Specify the default outputs memory type, if not specified, it is DefaultMemory
333  */
335  kernel_def_->default_outputs_mem_type_ = mem_type;
336  return *this;
337  }
338 
339  /**
340  Return the kernel definition, passing ownership of the KernelDef to the caller
341  */
342  std::unique_ptr<KernelDef> Build() {
343  return std::move(kernel_def_);
344  }
345 
346  private:
347  // we own the KernelDef until Build() is called.
348  std::unique_ptr<KernelDef> kernel_def_;
349 };
350 
351 } // namespace onnxruntime
KernelDefBuilder & Provider(ProviderType provider_type)
static std::unique_ptr< KernelDefBuilder > Create()
KernelDefBuilder & Alias(const std::vector< std::pair< int, int >> &aliases)
const std::string & ProviderType
Definition: basic_types.h:35
Base class for MLDataType.
Definition: data_types.h:76
KernelDefBuilder & ExternalOutputs()
const std::vector< std::pair< int, int > > & MayInplace() const
GLuint start
Definition: glcorearb.h:475
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
OrtMemType InputMemoryType(size_t input_index) const
const std::string & Domain() const
const std::vector< std::pair< int, int > > & Alias() const
bool HasExternalOutputs() const
KernelDefBuilder & SetDomain(const std::string &domain)
onnxruntime::ProviderType Provider() const
KernelDefBuilder & ExecQueueId(int queue_id)
KernelDefBuilder & VariadicAlias(int input_offset, int output_offset)
bool AllocateInputsContiguously() const
bool IsConflict(const KernelDef &other) const
void SinceVersion(int *start, int *end) const
KernelDefBuilder & OutputMemoryType(OrtMemType type, const std::vector< int > &output_indexes)
KernelDefBuilder & SetDefaultOutputMemoryType(OrtMemType mem_type)
KernelDefBuilder & AllocateInputsContiguously()
KernelDefBuilder & SetName(const std::string &op_name)
GLuint GLuint end
Definition: glcorearb.h:475
const std::unordered_map< std::string, std::vector< MLDataType > > & TypeConstraints() const
KernelDefBuilder & SinceVersion(int since_version)
bool IsOutputOnCpu(size_t output_index) const
std::map< size_t, OrtMemType > MemTypeMap
KernelDefBuilder & TypeConstraint(const std::string &arg_name, std::vector< MLDataType > types)
const std::pair< int, int > SinceVersion() const
const std::string & OpName() const
OrtMemType OutputMemoryType(size_t output_index) const
KernelDefBuilder & MayInplace(const std::vector< std::pair< int, int >> &inplaces)
KernelDefBuilder & InputMemoryType(OrtMemType type, int input_index)
KernelDefBuilder & InputMemoryType(OrtMemType type, const std::vector< int > &input_indexes)
std::unique_ptr< KernelDef > Build()
KernelDefBuilder & SetDefaultInputsMemoryType(OrtMemType mem_type)
GLsizei GLenum GLenum * types
Definition: glcorearb.h:2542
type
Definition: core.h:1059
const std::optional< std::pair< int, int > > & VariadicAlias() const
KernelDefBuilder & OutputMemoryType(OrtMemType type, int output_index)
KernelDefBuilder & SinceVersion(int since_version_start, int since_version_end)
bool IsInputOnCpu(size_t input_index) const