4 namespace onnxruntime {
7 namespace concurrency {
13 using ArgMap = std::unordered_map<std::string, size_t>;
39 return p_ml_value ? &(p_ml_value->
Get<
T>()) :
nullptr;
50 ORT_ENFORCE(input_ptr,
"Required input at index ", index,
" is not present.");
61 return p_ml_value ? p_ml_value->
GetMutable<
T>() :
nullptr;
74 ORT_ENFORCE(output_ptr,
"Required output at index ", index,
" is not present.");
78 #if !defined(DISABLE_SPARSE_TENSORS)
86 #if !defined(DISABLE_OPTIONAL_TYPE)
92 auto type = DataTypeImpl::GetType<T>();
94 output_ort_value->Init(
nullptr,
96 type->GetDeleteFunc());
209 int GetInputArgIndex(
int index)
const;
210 int GetImplicitInputArgIndex(
int index)
const;
211 int GetOutputArgIndex(
int index)
const;
213 IExecutionFrame*
const execution_frame_{};
214 const OpKernel*
const kernel_{};
215 concurrency::ThreadPool*
const threadpool_{};
216 const logging::Logger*
const logger_{};
219 int node_input_start_index_{-1};
220 int node_implicit_input_start_index_{-1};
221 int node_output_start_index_{-1};
230 ORT_ENFORCE(p_ml_value,
"Please fetch output tensor with specified shape.");
234 #if !defined(DISABLE_SPARSE_TENSORS)
238 ORT_ENFORCE(p_ml_value,
"Please fetch output sparse tensor with specified shape.");
const IExecutionProvider * GetExecutionProvider() const noexcept
const std::string & GetNodeName() const
virtual bool GetUseDeterministicCompute() const
OpKernelContext(_Inout_ IExecutionFrame *frame, _In_ const OpKernel *kernel, _In_ Stream *stream, _In_opt_ concurrency::ThreadPool *threadpool, _In_ const logging::Logger &logger)
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
Base class for MLDataType.
virtual bool TryGetInferredOutputShape(int index, TensorShape &shape) const
virtual int GetDeviceId() const
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool * GetOperatorThreadPool() const
virtual int OutputCount() const
virtual Stream * GetComputeStream() const
GLsizei const GLchar *const * string
const T * Input(int index) const
const OrtValue * GetInputOrtValue(int index) const
const logging::Logger & Logger() const
#define ORT_ENFORCE(condition,...)
virtual const OrtValue * GetInputMLValue(int index) const
virtual int NumVariadicInputs(size_t arg_num) const
void OutputOptionalWithoutData(int index)
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
This class implements SparseTensor. This class holds sparse non-zero data (values) and sparse format ...
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
AllocatorPtr GetAllocator(const OrtDevice &device) const
virtual OrtValue * OutputMLValue(int index, const TensorShape &shape)
virtual int ImplicitInputCount() const
virtual int InputCount() const
Tensor & RequiredOutput(int index, const TensorShape &shape)
const std::string & GetOpDomain() const
virtual OrtValue * GetOrCreateOutputMLValue(int index)
std::unordered_map< std::string, size_t > ArgMap
const T & RequiredInput(int index) const
SparseTensor * OutputSparse(int index, const TensorShape &shape)
virtual ~OpKernelContext()=default
virtual MLDataType InputType(int index) const
onnxruntime::NodeIndex GetNodeIndex() const
virtual int GetDeviceId() const
Status GetTempSpaceCPUAllocator(AllocatorPtr *output) const
OrtValue * GetOutputMLValue(int index)
std::shared_ptr< IAllocator > AllocatorPtr
const onnxruntime::Node & Node() const
const std::string & GetOpType() const
virtual bool TryGetInferredInputShape(int index, TensorShape &shape) const
virtual MLDataType OutputType(int index) const
const OpKernelInfo & Info() const
const OrtValue * GetImplicitInputMLValue(int index) const
virtual Status GetTempSpaceAllocator(AllocatorPtr *output) const