HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime::OpKernelContext Class Reference

#include <op_kernel_context.h>

Public Types

using ArgMap = std::unordered_map< std::string, size_t >
 

Public Member Functions

 OpKernelContext (_Inout_ IExecutionFrame *frame, _In_ const OpKernel *kernel, _In_ Stream *stream, _In_opt_ concurrency::ThreadPool *threadpool, _In_ const logging::Logger &logger)
 
virtual ~OpKernelContext ()=default
 
virtual int NumVariadicInputs (size_t arg_num) const
 
virtual MLDataType InputType (int index) const
 
virtual MLDataType OutputType (int index) const
 
const OrtValueGetInputOrtValue (int index) const
 
template<typename T >
const T * Input (int index) const
 
template<typename T >
const T & RequiredInput (int index) const
 
template<typename T >
T * Output (int index)
 
TensorOutput (int index, const TensorShape &shape)
 
TensorOutput (int index, const std::vector< int64_t > &shape)
 
TensorOutput (int index, const std::initializer_list< int64_t > &shape)
 
TensorRequiredOutput (int index, const TensorShape &shape)
 
SparseTensorOutputSparse (int index, const TensorShape &shape)
 
template<typename T >
void OutputOptionalWithoutData (int index)
 
virtual bool TryGetInferredInputShape (int index, TensorShape &shape) const
 
virtual bool TryGetInferredOutputShape (int index, TensorShape &shape) const
 
const logging::LoggerLogger () const
 
virtual int InputCount () const
 
virtual int ImplicitInputCount () const
 
virtual int OutputCount () const
 
virtual Status GetTempSpaceAllocator (AllocatorPtr *output) const
 
Status GetTempSpaceCPUAllocator (AllocatorPtr *output) const
 
virtual int GetDeviceId () const
 
virtual StreamGetComputeStream () const
 
const std::stringGetOpDomain () const
 
const std::stringGetOpType () const
 
const std::stringGetNodeName () const
 
_Ret_maybenull_
onnxruntime::concurrency::ThreadPool
GetOperatorThreadPool () const
 
virtual bool GetUseDeterministicCompute () const
 
AllocatorPtr GetAllocator (const OrtDevice &device) const
 
template<>
TensorOutput (int index)
 
template<>
SparseTensorOutput (int index)
 

Protected Member Functions

 OpKernelContext (concurrency::ThreadPool *threadpool, const logging::Logger &logger, Stream *stream)
 
onnxruntime::NodeIndex GetNodeIndex () const
 
virtual const OrtValueGetInputMLValue (int index) const
 
const OrtValueGetImplicitInputMLValue (int index) const
 
OrtValueGetOutputMLValue (int index)
 
virtual OrtValueOutputMLValue (int index, const TensorShape &shape)
 
virtual OrtValueGetOrCreateOutputMLValue (int index)
 

Detailed Description

Definition at line 11 of file op_kernel_context.h.

Member Typedef Documentation

using onnxruntime::OpKernelContext::ArgMap = std::unordered_map<std::string, size_t>

Definition at line 13 of file op_kernel_context.h.

Constructor & Destructor Documentation

onnxruntime::OpKernelContext::OpKernelContext ( _Inout_ IExecutionFrame *  frame,
_In_ const OpKernel kernel,
_In_ Stream stream,
_In_opt_ concurrency::ThreadPool threadpool,
_In_ const logging::Logger logger 
)
virtual onnxruntime::OpKernelContext::~OpKernelContext ( )
virtualdefault
onnxruntime::OpKernelContext::OpKernelContext ( concurrency::ThreadPool threadpool,
const logging::Logger logger,
Stream stream 
)
protected

Member Function Documentation

AllocatorPtr onnxruntime::OpKernelContext::GetAllocator ( const OrtDevice device) const

Returns Allocator from a specific OrtMemoryInfo object. TODO(leca): Replace GetTempSpaceAllocator() and GetTempSpaceCPUAllocator() with this API in the future

virtual Stream* onnxruntime::OpKernelContext::GetComputeStream ( ) const
inlinevirtual

Return the compute stream associated with the EP that the kernel is partitioned to. For EPs that do not have a compute stream (e.g. CPU EP), a nullptr is returned.

Definition at line 152 of file op_kernel_context.h.

virtual int onnxruntime::OpKernelContext::GetDeviceId ( ) const
inlinevirtual

Return the device id that current kernel runs on.

Definition at line 144 of file op_kernel_context.h.

const OrtValue* onnxruntime::OpKernelContext::GetImplicitInputMLValue ( int  index) const
protected
virtual const OrtValue* onnxruntime::OpKernelContext::GetInputMLValue ( int  index) const
protectedvirtual
const OrtValue* onnxruntime::OpKernelContext::GetInputOrtValue ( int  index) const
inline

Definition at line 31 of file op_kernel_context.h.

onnxruntime::NodeIndex onnxruntime::OpKernelContext::GetNodeIndex ( ) const
protected
const std::string& onnxruntime::OpKernelContext::GetNodeName ( ) const

Returns the node name of the underlying kernel

const std::string& onnxruntime::OpKernelContext::GetOpDomain ( ) const

Returns the opset domain of the underlying kernel

_Ret_maybenull_ onnxruntime::concurrency::ThreadPool* onnxruntime::OpKernelContext::GetOperatorThreadPool ( ) const
inline

Returns the intra-op threadpool, if available.

Definition at line 174 of file op_kernel_context.h.

const std::string& onnxruntime::OpKernelContext::GetOpType ( ) const

Returns the optype of the underlying kernel

virtual OrtValue* onnxruntime::OpKernelContext::GetOrCreateOutputMLValue ( int  index)
protectedvirtual
OrtValue* onnxruntime::OpKernelContext::GetOutputMLValue ( int  index)
protected
virtual Status onnxruntime::OpKernelContext::GetTempSpaceAllocator ( AllocatorPtr output) const
virtual

Return an allocator on device 0, with memtype of OrtMemTypeDefault.

Remarks
Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
Status onnxruntime::OpKernelContext::GetTempSpaceCPUAllocator ( AllocatorPtr output) const

Return the allocator associated with the CPU EP with memtype of OrtMemTypeDefault.

Remarks
Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
virtual bool onnxruntime::OpKernelContext::GetUseDeterministicCompute ( ) const
inlinevirtual

Returns whether deterministic computation is preferred.

Definition at line 179 of file op_kernel_context.h.

virtual int onnxruntime::OpKernelContext::ImplicitInputCount ( ) const
inlinevirtual

Definition at line 120 of file op_kernel_context.h.

template<typename T >
const T* onnxruntime::OpKernelContext::Input ( int  index) const
inline

Definition at line 36 of file op_kernel_context.h.

virtual int onnxruntime::OpKernelContext::InputCount ( ) const
inlinevirtual

Definition at line 115 of file op_kernel_context.h.

virtual MLDataType onnxruntime::OpKernelContext::InputType ( int  index) const
virtual
const logging::Logger& onnxruntime::OpKernelContext::Logger ( ) const
inline

Definition at line 110 of file op_kernel_context.h.

virtual int onnxruntime::OpKernelContext::NumVariadicInputs ( size_t  arg_num) const
virtual

Return the number of inputs for a variadic argument.

Parameters
arg_numThe operator argument number.
Returns
Number of inputs the argument has.
template<typename T >
T* onnxruntime::OpKernelContext::Output ( int  index)
inline

Definition at line 56 of file op_kernel_context.h.

Tensor* onnxruntime::OpKernelContext::Output ( int  index,
const TensorShape shape 
)
Tensor* onnxruntime::OpKernelContext::Output ( int  index,
const std::vector< int64_t > &  shape 
)
Tensor* onnxruntime::OpKernelContext::Output ( int  index,
const std::initializer_list< int64_t > &  shape 
)
template<>
Tensor* onnxruntime::OpKernelContext::Output ( int  index)
inline

Definition at line 228 of file op_kernel_context.h.

template<>
SparseTensor* onnxruntime::OpKernelContext::Output ( int  index)
inline

Definition at line 236 of file op_kernel_context.h.

virtual int onnxruntime::OpKernelContext::OutputCount ( ) const
inlinevirtual

Definition at line 125 of file op_kernel_context.h.

virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue ( int  index,
const TensorShape shape 
)
protectedvirtual
template<typename T >
void onnxruntime::OpKernelContext::OutputOptionalWithoutData ( int  index)
inline

Definition at line 89 of file op_kernel_context.h.

SparseTensor* onnxruntime::OpKernelContext::OutputSparse ( int  index,
const TensorShape shape 
)
virtual MLDataType onnxruntime::OpKernelContext::OutputType ( int  index) const
virtual
template<typename T >
const T& onnxruntime::OpKernelContext::RequiredInput ( int  index) const
inline

Definition at line 48 of file op_kernel_context.h.

Tensor& onnxruntime::OpKernelContext::RequiredOutput ( int  index,
const TensorShape shape 
)
inline

Definition at line 72 of file op_kernel_context.h.

virtual bool onnxruntime::OpKernelContext::TryGetInferredInputShape ( int  index,
TensorShape shape 
) const
virtual
virtual bool onnxruntime::OpKernelContext::TryGetInferredOutputShape ( int  index,
TensorShape shape 
) const
virtual

The documentation for this class was generated from the following file: