HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_kernel_context.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 namespace onnxruntime {
5 class IExecutionFrame;
6 class Stream;
7 namespace concurrency {
8 class ThreadPool;
9 }
10 
12  public:
13  using ArgMap = std::unordered_map<std::string, size_t>;
14 
15  OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
16  _In_ Stream* stream,
17  _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
18 
19  virtual ~OpKernelContext() = default;
20 
21  /**
22  Return the number of inputs for a variadic argument.
23  @param arg_num The operator argument number.
24  @returns Number of inputs the argument has.
25  */
26  virtual int NumVariadicInputs(size_t arg_num) const;
27 
28  virtual MLDataType InputType(int index) const;
29  virtual MLDataType OutputType(int index) const;
30 
31  const OrtValue* GetInputOrtValue(int index) const {
32  return GetInputMLValue(index);
33  }
34 
35  template <typename T>
36  const T* Input(int index) const {
37  const OrtValue* p_ml_value = GetInputMLValue(index);
38  ORT_TRY {
39  return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
40  }
41  ORT_CATCH(const std::exception& /*e*/) {
42  ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
43  }
44  }
45 
46  // Fetch a required input, enforcing that it is present.
47  template <typename T>
48  const T& RequiredInput(int index) const {
49  const T* input_ptr = Input<T>(index);
50  ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
51  return *input_ptr;
52  }
53 
54  // Fetch output (non-tensor) with specified index.
55  template <typename T>
56  T* Output(int index) {
58  return nullptr;
59 
60  OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
61  return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
62  }
63 
64  // In the case that memory allocation has not been done for an output tensor,
65  // The memory allocation will be done on-the-fly with given tensor shape.
66  // Return nullptr if the output is an unused optional output.
67  Tensor* Output(int index, const TensorShape& shape);
68  Tensor* Output(int index, const std::vector<int64_t>& shape);
69  Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
70 
71  // Fetch a required tensor output, enforcing that it is present.
72  Tensor& RequiredOutput(int index, const TensorShape& shape) {
73  Tensor* output_ptr = Output(index, shape);
74  ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
75  return *output_ptr;
76  }
77 
78 #if !defined(DISABLE_SPARSE_TENSORS)
79  // Fetch a sparse-tensor output corresponding to the specified index.
80  // shape must specify the shape of the underlying dense-tensor.
81  // Memory allocation for the output may happen when this method is invoked,
82  // unless static optimization pre-allocates it.
83  SparseTensor* OutputSparse(int index, const TensorShape& shape);
84 #endif
85 
86 #if !defined(DISABLE_OPTIONAL_TYPE)
87  // Use this API to output a "None" of a specific type (e.g. Tensor) at specified index
88  template <typename T>
90  auto* output_ort_value = GetOutputMLValue(index);
91 
92  auto type = DataTypeImpl::GetType<T>();
93 
94  output_ort_value->Init(nullptr, // This OrtValue is "None" and has no data
95  type,
96  type->GetDeleteFunc());
97  }
98 #endif
99 
100  // Retrieve indexed shape obtained from memory planning before actual
101  // computation. If the indexed shape cannot be inferred, this function returns
102  // false.
103  virtual bool TryGetInferredInputShape(int index, TensorShape& shape) const;
104 
105  // Retrieve indexed shape obtained from memory planning before actual
106  // computation. If the indexed shape cannot be inferred, this function returns
107  // false.
108  virtual bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
109 
110  const logging::Logger& Logger() const {
111  return *logger_;
112  }
113 
114  // always >= 0
115  virtual int InputCount() const {
116  return static_cast<int>(kernel_->Node().InputDefs().size());
117  }
118 
119  // always >= 0
120  virtual int ImplicitInputCount() const {
121  return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
122  }
123 
124  // always >= 0
125  virtual int OutputCount() const {
126  return static_cast<int>(kernel_->Node().OutputDefs().size());
127  }
128 
129  /**
130  Return an allocator on device 0, with memtype of OrtMemTypeDefault.
131  @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
132  */
133  [[nodiscard]] virtual Status GetTempSpaceAllocator(AllocatorPtr* output) const;
134 
135  /**
136  Return the allocator associated with the CPU EP with memtype of OrtMemTypeDefault.
137  @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
138  */
139  [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const;
140 
141  /**
142  Return the device id that current kernel runs on.
143  */
144  virtual int GetDeviceId() const {
145  return kernel_->Info().GetExecutionProvider()->GetDeviceId();
146  }
147 
148  /**
149  Return the compute stream associated with the EP that the kernel is partitioned to.
150  For EPs that do not have a compute stream (e.g. CPU EP), a nullptr is returned.
151  */
152  [[nodiscard]] virtual Stream* GetComputeStream() const {
153  return stream_;
154  }
155 
156  /**
157  Returns the opset domain of the underlying kernel
158  **/
159  const std::string& GetOpDomain() const;
160 
161  /**
162  Returns the optype of the underlying kernel
163  **/
164  const std::string& GetOpType() const;
165 
166  /**
167  Returns the node name of the underlying kernel
168  **/
169  const std::string& GetNodeName() const;
170 
171  /**
172  Returns the intra-op threadpool, if available.
173  */
174  _Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
175 
176  /**
177  Returns whether deterministic computation is preferred.
178  */
179  virtual bool GetUseDeterministicCompute() const {
180  return true;
181  }
182 
183  /**
184  Returns Allocator from a specific OrtMemoryInfo object.
185  TODO(leca): Replace GetTempSpaceAllocator() and GetTempSpaceCPUAllocator() with this API in the future
186  */
187  AllocatorPtr GetAllocator(const OrtDevice& device) const;
188 
189  protected:
191 
193 
194  virtual const OrtValue* GetInputMLValue(int index) const;
195  const OrtValue* GetImplicitInputMLValue(int index) const;
197 
198 #ifdef ENABLE_ATEN
199  Status SetOutputMLValue(int index, const OrtValue& ort_value);
200 #endif
201 
202  // Creates the OrtValue* based on the shape, if it does not exist
203  virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);
204 
206 
207  private:
208  ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
209  int GetInputArgIndex(int index) const;
210  int GetImplicitInputArgIndex(int index) const;
211  int GetOutputArgIndex(int index) const;
212 
213  IExecutionFrame* const execution_frame_{};
214  const OpKernel* const kernel_{};
215  concurrency::ThreadPool* const threadpool_{};
216  const logging::Logger* const logger_{};
217 
218  // The argument starting index in ExecutionFrame.
219  int node_input_start_index_{-1};
220  int node_implicit_input_start_index_{-1};
221  int node_output_start_index_{-1};
222 
223  Stream* stream_;
224 };
225 
226 // Fetching output tensor without shape is not allowed except when it already exists
227 template <>
228 inline Tensor* OpKernelContext::Output<Tensor>(int index) {
229  OrtValue* p_ml_value = GetOutputMLValue(index);
230  ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
231  return p_ml_value->GetMutable<Tensor>();
232 }
233 
234 #if !defined(DISABLE_SPARSE_TENSORS)
235 template <>
236 inline SparseTensor* OpKernelContext::Output<SparseTensor>(int index) {
237  OrtValue* p_ml_value = GetOutputMLValue(index);
238  ORT_ENFORCE(p_ml_value, "Please fetch output sparse tensor with specified shape.");
239  return p_ml_value->GetMutable<SparseTensor>();
240 }
241 #endif
242 
243 } // namespace onnxruntime
GLuint GLuint stream
Definition: glcorearb.h:1832
const IExecutionProvider * GetExecutionProvider() const noexcept
const std::string & GetNodeName() const
virtual bool GetUseDeterministicCompute() const
OpKernelContext(_Inout_ IExecutionFrame *frame, _In_ const OpKernel *kernel, _In_ Stream *stream, _In_opt_ concurrency::ThreadPool *threadpool, _In_ const logging::Logger &logger)
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
Definition: graph.h:216
Base class for MLDataType.
Definition: data_types.h:76
virtual bool TryGetInferredOutputShape(int index, TensorShape &shape) const
virtual int GetDeviceId() const
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool * GetOperatorThreadPool() const
virtual int OutputCount() const
virtual Stream * GetComputeStream() const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
const T * Input(int index) const
const OrtValue * GetInputOrtValue(int index) const
const logging::Logger & Logger() const
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
virtual const OrtValue * GetInputMLValue(int index) const
virtual int NumVariadicInputs(size_t arg_num) const
#define ORT_TRY
Definition: common.h:153
void OutputOptionalWithoutData(int index)
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
Definition: graph.h:229
This class implements SparseTensor. This class holds sparse non-zero data (values) and sparse format ...
Definition: sparse_tensor.h:55
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Definition: graph.h:223
AllocatorPtr GetAllocator(const OrtDevice &device) const
virtual OrtValue * OutputMLValue(int index, const TensorShape &shape)
virtual int ImplicitInputCount() const
virtual int InputCount() const
Tensor & RequiredOutput(int index, const TensorShape &shape)
const std::string & GetOpDomain() const
virtual OrtValue * GetOrCreateOutputMLValue(int index)
std::unordered_map< std::string, size_t > ArgMap
const T & RequiredInput(int index) const
SparseTensor * OutputSparse(int index, const TensorShape &shape)
virtual ~OpKernelContext()=default
T * GetMutable()
Definition: ort_value.h:57
virtual MLDataType InputType(int index) const
onnxruntime::NodeIndex GetNodeIndex() const
#define ORT_THROW(...)
Definition: common.h:162
Status GetTempSpaceCPUAllocator(AllocatorPtr *output) const
OrtValue * GetOutputMLValue(int index)
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:261
const onnxruntime::Node & Node() const
const std::string & GetOpType() const
GLuint index
Definition: glcorearb.h:786
virtual bool TryGetInferredInputShape(int index, TensorShape &shape) const
virtual MLDataType OutputType(int index) const
type
Definition: core.h:1059
#define ORT_CATCH(x)
Definition: common.h:154
size_t NodeIndex
Definition: basic_types.h:30
const T & Get() const
Definition: ort_value.h:50
const OpKernelInfo & Info() const
Definition: op_kernel.h:133
const OrtValue * GetImplicitInputMLValue(int index) const
virtual Status GetTempSpaceAllocator(AllocatorPtr *output) const