26 #include "onnxruntime_c_api.h"
36 #include <unordered_map>
38 #include <type_traits>
40 #ifdef ORT_NO_EXCEPTIONS
57 const char*
what() const noexcept
override {
return message_.c_str(); }
64 #ifdef ORT_NO_EXCEPTIONS
67 #ifndef ORT_CXX_API_THROW
68 #define ORT_CXX_API_THROW(string, code) \
70 std::cerr << Ort::Exception(string, code) \
77 #define ORT_CXX_API_THROW(string, code) \
78 throw Ort::Exception(string, code)
91 #ifdef ORT_API_MANUAL_INIT
93 inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
109 inline void InitApi(
const OrtApi* api) noexcept { Global<void>::api_ = api; }
111 #if defined(_MSC_VER) && !defined(__clang__)
112 #pragma warning(push)
115 #pragma warning(disable : 26426)
117 const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
118 #if defined(_MSC_VER) && !defined(__clang__)
281 using Base::operator==;
282 using Base::operator!=;
283 using Base::operator<;
286 static_assert(
sizeof(Float16_t) ==
sizeof(uint16_t),
"Sizes must match");
429 static_assert(
sizeof(BFloat16_t) ==
sizeof(uint16_t),
"Sizes must match");
440 constexpr
operator uint8_t() const noexcept {
return value; }
446 static_assert(
sizeof(Float8E4M3FN_t) ==
sizeof(uint8_t),
"Sizes must match");
457 constexpr
operator uint8_t() const noexcept {
return value; }
463 static_assert(
sizeof(Float8E4M3FNUZ_t) ==
sizeof(uint8_t),
"Sizes must match");
474 constexpr
operator uint8_t() const noexcept {
return value; }
480 static_assert(
sizeof(Float8E5M2_t) ==
sizeof(uint8_t),
"Sizes must match");
491 constexpr
operator uint8_t() const noexcept {
return value; }
497 static_assert(
sizeof(Float8E5M2FNUZ_t) ==
sizeof(uint8_t),
"Sizes must match");
502 #define ORT_DEFINE_RELEASE(NAME) \
503 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
526 #undef ORT_DEFINE_RELEASE
531 template <
typename T>
555 template <
typename T>
559 constexpr
Base() =
default;
588 template <
typename T>
589 struct Base<const
T>;
598 template <
typename T>
602 constexpr
Base() =
default;
635 struct AllocatorWithDefaultOptions;
639 struct ModelMetadata;
652 explicit Status(std::nullptr_t) noexcept {}
653 explicit Status(OrtStatus* status) noexcept;
655 explicit
Status(const std::exception&) noexcept;
659 bool IsOK() const noexcept;
683 ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
686 ThreadingOptions& SetGlobalCustomThreadCreationOptions(
void* ort_custom_thread_creation_options);
689 ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
698 explicit Env(std::nullptr_t) {}
701 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_
const char* logid =
"");
704 Env(OrtLoggingLevel logging_level,
const char* logid, OrtLoggingFunction logging_function,
void* logger_param);
707 Env(
const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_
const char* logid =
"");
710 Env(
const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function,
void* logger_param,
711 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_
const char* logid =
"");
716 Env& EnableTelemetryEvents();
717 Env& DisableTelemetryEvents();
719 Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level);
723 Env& CreateAndRegisterAllocatorV2(
const std::string& provider_type,
const OrtMemoryInfo* mem_info,
const std::unordered_map<std::string, std::string>& options,
const OrtArenaCfg* arena_cfg);
736 void Add(
const OrtCustomOp* op);
747 int GetRunLogVerbosityLevel()
const;
750 int GetRunLogSeverityLevel()
const;
753 const char* GetRunTag()
const;
755 RunOptions& AddConfigEntry(
const char* config_key,
const char* config_value);
802 CustomOpConfigs& AddConfig(
const char* custom_op_name,
const char* config_key,
const char* config_value);
812 const std::unordered_map<std::string, std::string>& GetFlattenedConfigs()
const;
815 std::unordered_map<std::string, std::string> flat_configs_;
828 template <
typename T>
835 std::string GetConfigEntry(
const char* config_key)
const;
836 bool HasConfigEntry(
const char* config_key)
const;
840 template <
typename T>
847 SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
875 SessionOptionsImpl& AddExternalInitializers(
const std::vector<std::string>& names,
const std::vector<Value>& ort_values);
877 SessionOptionsImpl& AppendExecutionProvider_CUDA(
const OrtCUDAProviderOptions& provider_options);
879 SessionOptionsImpl& AppendExecutionProvider_ROCM(
const OrtROCMProviderOptions& provider_options);
880 SessionOptionsImpl& AppendExecutionProvider_OpenVINO(
const OrtOpenVINOProviderOptions& provider_options);
882 SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(
const std::unordered_map<std::string, std::string>& provider_options = {});
883 SessionOptionsImpl& AppendExecutionProvider_TensorRT(
const OrtTensorRTProviderOptions& provider_options);
885 SessionOptionsImpl& AppendExecutionProvider_MIGraphX(
const OrtMIGraphXProviderOptions& provider_options);
892 const std::unordered_map<std::string, std::string>& provider_options = {});
894 SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
895 SessionOptionsImpl& SetCustomThreadCreationOptions(
void* ort_custom_thread_creation_options);
896 SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
901 SessionOptionsImpl& RegisterCustomOpsLibrary(
const ORTCHAR_T* library_name,
const CustomOpConfigs& custom_op_configs = {});
903 SessionOptionsImpl& RegisterCustomOpsUsingFunction(
const char* function_name);
916 explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {}
974 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator)
const;
985 AllocatedStringPtr LookupCustomMetadataMapAllocated(
const char* key, OrtAllocator* allocator)
const;
996 template <
typename T>
1001 size_t GetInputCount()
const;
1002 size_t GetOutputCount()
const;
1003 size_t GetOverridableInitializerCount()
const;
1032 uint64_t GetProfilingStartTimeNs()
const;
1037 TypeInfo GetOverridableInitializerTypeInfo(
size_t index)
const;
1040 template <
typename T>
1062 std::vector<Value> Run(
const RunOptions& run_options,
const char*
const* input_names,
const Value* input_values,
size_t input_count,
1063 const char*
const* output_names,
size_t output_count);
1068 void Run(
const RunOptions& run_options,
const char*
const* input_names,
const Value* input_values,
size_t input_count,
1069 const char*
const* output_names,
Value* output_values,
size_t output_count);
1092 void RunAsync(
const RunOptions& run_options,
const char*
const* input_names,
const Value* input_values,
size_t input_count,
1093 const char*
const* output_names,
Value* output_values,
size_t output_count, RunAsyncCallbackFn callback,
void* user_data);
1116 OrtPrepackedWeightsContainer* prepacked_weights_container);
1119 OrtPrepackedWeightsContainer* prepacked_weights_container);
1126 template <
typename T>
1132 OrtAllocatorType GetAllocatorType()
const;
1133 int GetDeviceId()
const;
1134 OrtMemoryInfoDeviceType GetDeviceType()
const;
1135 OrtMemType GetMemoryType()
const;
1137 template <
typename U>
1149 static MemoryInfo CreateCpu(OrtAllocatorType
type, OrtMemType mem_type1);
1152 MemoryInfo(
const char*
name, OrtAllocatorType type,
int id, OrtMemType mem_type);
1157 template <
typename T>
1162 ONNXTensorElementDataType GetElementType()
const;
1163 size_t GetElementCount()
const;
1165 size_t GetDimensionsCount()
const;
1171 [[deprecated(
"use GetShape()")]]
void GetDimensions(int64_t*
values,
size_t values_count)
const;
1173 void GetSymbolicDimensions(
const char**
values,
size_t values_count)
const;
1175 std::vector<int64_t> GetShape()
const;
1192 template <
typename T>
1196 TypeInfo GetSequenceElementType()
const;
1208 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {}
1213 template <
typename T>
1217 TypeInfo GetOptionalElementType()
const;
1226 template <
typename T>
1230 ONNXTensorElementDataType GetMapKeyType()
const;
1243 explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {}
1248 template <
typename T>
1258 ONNXType GetONNXType()
const;
1274 explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {}
1304 template <
typename T>
1312 template <
typename R>
1313 void GetOpaqueData(
const char* domain,
const char*
type_name, R&)
const;
1315 bool IsTensor()
const;
1316 bool HasValue()
const;
1318 size_t GetCount()
const;
1319 Value GetValue(
int index, OrtAllocator* allocator)
const;
1327 size_t GetStringTensorDataLength()
const;
1343 void GetStringTensorContent(
void*
buffer,
size_t buffer_length,
size_t*
offsets,
size_t offsets_count)
const;
1351 template <
typename R>
1352 const R* GetTensorData()
const;
1358 const void* GetTensorRawData()
const;
1390 void GetStringTensorElement(
size_t buffer_length,
size_t element_index,
void*
buffer)
const;
1398 std::string GetStringTensorElement(
size_t element_index)
const;
1406 size_t GetStringTensorElementLength(
size_t element_index)
const;
1408 #if !defined(DISABLE_SPARSE_TENSORS)
1415 OrtSparseFormat GetSparseFormat()
const;
1442 template <
typename R>
1443 const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format,
size_t& num_indices)
const;
1449 bool IsSparseTensor()
const;
1459 template <
typename R>
1460 const R* GetSparseTensorValues()
const;
1465 template <
typename T>
1475 template <
typename R>
1476 R* GetTensorMutableData();
1482 void* GetTensorMutableRawData();
1491 template <
typename R>
1492 R& At(
const std::vector<int64_t>&
location);
1499 void FillStringTensor(
const char*
const*
s,
size_t s_len);
1506 void FillStringTensorElement(
const char*
s,
size_t index);
1520 char* GetResizedStringTensorElementBuffer(
size_t index,
size_t buffer_length);
1522 #if !defined(DISABLE_SPARSE_TENSORS)
1531 void UseCooIndices(int64_t* indices_data,
size_t indices_num);
1543 void UseCsrIndices(int64_t* inner_data,
size_t inner_num, int64_t* outer_data,
size_t outer_num);
1553 void UseBlockSparseIndices(
const Shape& indices_shape, int32_t* indices_data);
1565 const int64_t* indices_data,
size_t indices_num);
1578 void FillSparseTensorCsr(
const OrtMemoryInfo* data_mem_info,
1580 const int64_t* inner_indices_data,
size_t inner_indices_num,
1581 const int64_t* outer_indices_data,
size_t outer_indices_num);
1592 void FillSparseTensorBlockSparse(
const OrtMemoryInfo* data_mem_info,
1594 const Shape& indices_shape,
1595 const int32_t* indices_data);
1629 template <
typename T>
1630 static Value CreateTensor(
const OrtMemoryInfo* info, T* p_data,
size_t p_data_element_count,
const int64_t* shape,
size_t shape_len);
1641 static Value CreateTensor(
const OrtMemoryInfo* info,
void* p_data,
size_t p_data_byte_count,
const int64_t* shape,
size_t shape_len,
1642 ONNXTensorElementDataType
type);
1655 template <
typename T>
1656 static Value CreateTensor(OrtAllocator* allocator,
const int64_t* shape,
size_t shape_len);
1669 static Value CreateTensor(OrtAllocator* allocator,
const int64_t* shape,
size_t shape_len, ONNXTensorElementDataType
type);
1688 static Value CreateSequence(
const std::vector<Value>&
values);
1698 template <
typename T>
1701 #if !defined(DISABLE_SPARSE_TENSORS)
1712 template <
typename T>
1713 static Value CreateSparseTensor(
const OrtMemoryInfo* info, T* p_data,
const Shape& dense_shape,
1714 const Shape& values_shape);
1732 static Value CreateSparseTensor(
const OrtMemoryInfo* info,
void* p_data,
const Shape& dense_shape,
1733 const Shape& values_shape, ONNXTensorElementDataType
type);
1744 template <
typename T>
1745 static Value CreateSparseTensor(OrtAllocator* allocator,
const Shape& dense_shape);
1758 static Value CreateSparseTensor(OrtAllocator* allocator,
const Shape& dense_shape, ONNXTensorElementDataType
type);
1760 #endif // !defined(DISABLE_SPARSE_TENSORS)
1777 void*
get() {
return p_; }
1778 size_t size()
const {
return size_; }
1781 OrtAllocator* allocator_;
1787 template <
typename T>
1792 void* Alloc(
size_t size);
1819 namespace binding_utils {
1825 template <
typename T>
1830 std::vector<std::string> GetOutputNames()
const;
1831 std::vector<std::string> GetOutputNames(OrtAllocator*)
const;
1832 std::vector<Value> GetOutputValues()
const;
1833 std::vector<Value> GetOutputValues(OrtAllocator*)
const;
1836 template <
typename T>
1841 void BindInput(
const char*
name,
const Value&);
1842 void BindOutput(
const char*
name,
const Value&);
1844 void ClearBoundInputs();
1845 void ClearBoundOutputs();
1846 void SynchronizeInputs();
1847 void SynchronizeOutputs();
1879 ArenaCfg(
size_t max_mem,
int arena_extend_strategy,
int initial_chunk_size_bytes,
int max_dead_bytes_per_chunk);
1901 #define ORT_CXX_LOG(logger, message_severity, message) \
1903 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1904 Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1905 static_cast<const char*>(__FUNCTION__), message)); \
1917 #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
1919 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1920 static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1921 static_cast<const char*>(__FUNCTION__), message)); \
1936 #define ORT_CXX_LOGF(logger, message_severity, ...) \
1938 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1939 Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1940 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1955 #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, ...) \
1957 if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1958 static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1959 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1990 explicit Logger(
const OrtLogger* logger);
2005 OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
2019 Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
int line_number,
2020 const
char* func_name, const
char*
message) const noexcept;
2036 template <typename... Args>
2037 Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
int line_number,
2038 const
char* func_name, const
char*
format, Args&&...
args) const noexcept;
2041 const OrtLogger* logger_{};
2042 OrtLoggingLevel cached_severity_level_{};
2053 size_t GetInputCount()
const;
2054 size_t GetOutputCount()
const;
2056 UnownedValue GetOutput(
size_t index,
const int64_t* dim_values,
size_t dim_count)
const;
2057 UnownedValue GetOutput(
size_t index,
const std::vector<int64_t>& dims)
const;
2058 void* GetGPUComputeStream()
const;
2059 Logger GetLogger()
const;
2060 OrtAllocator* GetAllocator(
const OrtMemoryInfo& memory_info)
const;
2062 void ParallelFor(
void (*fn)(
void*,
size_t),
size_t total,
size_t num_batch,
void* usr_data)
const;
2065 OrtKernelContext*
ctx_;
2071 namespace attr_utils {
2072 void GetAttr(
const OrtKernelInfo* p,
const char*
name,
float&);
2073 void GetAttr(
const OrtKernelInfo* p,
const char* name, int64_t&);
2075 void GetAttrs(
const OrtKernelInfo* p,
const char* name, std::vector<float>&);
2076 void GetAttrs(
const OrtKernelInfo* p,
const char* name, std::vector<int64_t>&);
2079 template <
typename T>
2086 template <
typename R>
2093 template <
typename R>
2100 Value GetTensorAttribute(
const char*
name, OrtAllocator* allocator)
const;
2102 size_t GetInputCount()
const;
2103 size_t GetOutputCount()
const;
2108 TypeInfo GetInputTypeInfo(
size_t index)
const;
2109 TypeInfo GetOutputTypeInfo(
size_t index)
const;
2111 ConstValue GetTensorConstantInput(
size_t index,
int* is_constant)
const;
2114 Logger GetLogger()
const;
2137 explicit Op(std::nullptr_t) {}
2139 explicit Op(OrtOp*);
2141 static Op Create(
const OrtKernelInfo* info,
const char* op_name,
const char* domain,
2142 int version,
const char** type_constraint_names,
2143 const ONNXTensorElementDataType* type_constraint_values,
2144 size_t type_constraint_count,
2145 const OpAttr* attr_values,
2147 size_t input_count,
size_t output_count);
2149 void Invoke(
const OrtKernelContext* context,
2150 const Value* input_values,
2152 Value* output_values,
2153 size_t output_count);
2156 void Invoke(
const OrtKernelContext* context,
2157 const OrtValue*
const* input_values,
2160 size_t output_count);
2177 if (is_int_ == dim.is_int_) {
2179 return i_ == dim.
i_;
2189 const char*
AsSym()
const {
return s_; }
2191 static constexpr
int INVALID_INT_DIM = -2;
2201 using Shape = std::vector<SymbolicInteger>;
2209 Status SetOutputShape(
size_t indice,
const Shape& shape);
2211 int64_t GetAttrInt(
const char* attr_name);
2214 Ints GetAttrInts(
const char* attr_name);
2216 float GetAttrFloat(
const char* attr_name);
2219 Floats GetAttrFloats(
const char* attr_name);
2224 Strings GetAttrStrings(
const char* attr_name);
2227 const OrtOpAttr* GetAttrHdl(
const char* attr_name)
const;
2228 const OrtApi* ort_api_;
2229 OrtShapeInferContext*
ctx_;
2230 std::vector<Shape> input_shapes_;
2235 #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2237 template <
typename TOp,
typename TKernel,
bool WithStatus = false>
2241 OrtCustomOp::GetName = [](
const OrtCustomOp* this_) {
return static_cast<const TOp*
>(this_)->GetName(); };
2243 OrtCustomOp::GetExecutionProviderType = [](
const OrtCustomOp* this_) {
return static_cast<const TOp*
>(this_)->GetExecutionProviderType(); };
2245 OrtCustomOp::GetInputTypeCount = [](
const OrtCustomOp* this_) {
return static_cast<const TOp*
>(this_)->GetInputTypeCount(); };
2246 OrtCustomOp::GetInputType = [](
const OrtCustomOp* this_,
size_t index) {
return static_cast<const TOp*
>(this_)->GetInputType(index); };
2247 OrtCustomOp::GetInputMemoryType = [](
const OrtCustomOp* this_,
size_t index) {
return static_cast<const TOp*
>(this_)->GetInputMemoryType(index); };
2249 OrtCustomOp::GetOutputTypeCount = [](
const OrtCustomOp* this_) {
return static_cast<const TOp*
>(this_)->GetOutputTypeCount(); };
2250 OrtCustomOp::GetOutputType = [](
const OrtCustomOp* this_,
size_t index) {
return static_cast<const TOp*
>(this_)->GetOutputType(index); };
2252 #if defined(_MSC_VER) && !defined(__clang__)
2253 #pragma warning(push)
2254 #pragma warning(disable : 26409)
2256 OrtCustomOp::KernelDestroy = [](
void* op_kernel) {
delete static_cast<TKernel*
>(op_kernel); };
2257 #if defined(_MSC_VER) && !defined(__clang__)
2258 #pragma warning(pop)
2260 OrtCustomOp::GetInputCharacteristic = [](
const OrtCustomOp* this_,
size_t index) {
return static_cast<const TOp*
>(this_)->GetInputCharacteristic(index); };
2261 OrtCustomOp::GetOutputCharacteristic = [](
const OrtCustomOp* this_,
size_t index) {
return static_cast<const TOp*
>(this_)->GetOutputCharacteristic(index); };
2263 OrtCustomOp::GetVariadicInputMinArity = [](
const OrtCustomOp* this_) {
return static_cast<const TOp*
>(this_)->GetVariadicInputMinArity(); };
2264 OrtCustomOp::GetVariadicInputHomogeneity = [](
const OrtCustomOp* this_) {
return static_cast<int>(
static_cast<const TOp*
>(this_)->GetVariadicInputHomogeneity()); };
2265 OrtCustomOp::GetVariadicOutputMinArity = [](
const OrtCustomOp* this_) {
return static_cast<const TOp*
>(this_)->GetVariadicOutputMinArity(); };
2266 OrtCustomOp::GetVariadicOutputHomogeneity = [](
const OrtCustomOp* this_) {
return static_cast<int>(
static_cast<const TOp*
>(this_)->GetVariadicOutputHomogeneity()); };
2267 #ifdef __cpp_if_constexpr
2268 if constexpr (WithStatus) {
2272 OrtCustomOp::CreateKernelV2 = [](
const OrtCustomOp* this_,
const OrtApi* api,
const OrtKernelInfo* info,
void** op_kernel) -> OrtStatusPtr {
2273 return static_cast<const TOp*
>(this_)->CreateKernelV2(*api, info, op_kernel);
2275 OrtCustomOp::KernelComputeV2 = [](
void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2276 return static_cast<TKernel*
>(op_kernel)->ComputeV2(context);
2279 OrtCustomOp::CreateKernelV2 =
nullptr;
2280 OrtCustomOp::KernelComputeV2 =
nullptr;
2282 OrtCustomOp::CreateKernel = [](
const OrtCustomOp* this_,
const OrtApi* api,
const OrtKernelInfo* info) {
return static_cast<const TOp*
>(this_)->CreateKernel(*api, info); };
2283 OrtCustomOp::KernelCompute = [](
void* op_kernel, OrtKernelContext* context) {
2284 static_cast<TKernel*
>(op_kernel)->Compute(context);
2288 SetShapeInferFn<TOp>(0);
2290 OrtCustomOp::GetStartVersion = [](
const OrtCustomOp* this_) {
2291 return static_cast<const TOp*
>(this_)->start_ver_;
2294 OrtCustomOp::GetEndVersion = [](
const OrtCustomOp* this_) {
2295 return static_cast<const TOp*
>(this_)->end_ver_;
2305 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2309 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2314 return OrtMemTypeDefault;
2345 return std::vector<std::string>{};
2348 template <
typename C>
2349 decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2350 OrtCustomOp::InferOutputShapeFn = [](
const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
2352 return C::InferOutputShape(ctx);
2357 template <
typename C>
2359 OrtCustomOp::InferOutputShapeFn = {};
2364 void GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
ConstSessionOptions options)
const;
constexpr Float8E4M3FNUZ_t() noexcept
std::vector< int64_t > Ints
UnownedSession GetUnowned() const
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id...
GLuint GLsizei const GLchar * message
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
SequenceTypeInfo(OrtSequenceTypeInfo *p)
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
constexpr Base(contained_type *p) noexcept
bool IsNaN() const noexcept
Tests if the value is NaN
size_t GetInputCount() const
Float16_t Abs() const noexcept
Creates an instance that represents absolute value.
std::string GetErrorMessage() const
std::vector< std::string > Strings
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
BFloat16_t Negate() const noexcept
Creates a new instance with the sign flipped.
constexpr bool operator!=(const Float8E5M2FNUZ_t &rhs) const noexcept
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used...
Value(OrtValue *p)
Used for interop with the C API.
Env(OrtEnv *p)
C Interop Helper.
std::vector< float > Floats
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Base & operator=(Base &&v) noexcept
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
GLsizei const GLchar *const * string
Float16_t()=default
Default constructor
bool GetVariadicInputHomogeneity() const
constexpr Float8E5M2_t(uint8_t v) noexcept
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
ConstMemoryInfo GetConst() const
Take ownership of a pointer created by C Api.
void SetShapeInferFn(...)
const Shape & GetInputShape(size_t indice) const
std::vector< SymbolicInteger > Shape
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Wrapper around ::OrtMapTypeInfo.
constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept
This struct provides life time management for custom op attribute
float ToFloatImpl() const noexcept
Converts bfloat16 to float
detail::SequenceTypeInfoImpl< detail::Unowned< const OrtSequenceTypeInfo >> ConstSequenceTypeInfo
const int64_t * values_shape
MapTypeInfo(OrtMapTypeInfo *p)
static bool AreZero(const Float16Impl &lhs, const Float16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
bool operator!=(const BFloat16_t &rhs) const noexcept
std::vector< R > GetAttributes(const char *name) const
**But if you need a result
it is a structure that represents the configuration of an arena based allocator
Provide access to per-node attributes and input shapes, so one could compute and set output shapes...
OCIOEXPORT void LogMessage(LoggingLevel level, const char *message)
Log a message using the library logging function.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
static const OrtApi * api_
bool IsNegative() const noexcept
Checks if the value is negative
OrtMemType GetInputMemoryType(size_t) const
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
float8e4m3fnuz (Float8 Floating Point) data type
constexpr bool operator==(const Float8E4M3FN_t &rhs) const noexcept
ConstSession GetConst() const
float8e4m3fn (Float8 Floating Point) data type
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< int64_t > &)
static constexpr uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation
bool IsFinite() const noexcept
Tests if the value is finite
GLuint GLsizei const GLuint const GLintptr * offsets
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
SymbolicInteger(const char *s)
bool operator==(const BaseDimensions< T > &a, const BaseDimensions< Y > &b)
Wrapper around ::OrtAllocator.
bool operator==(const SymbolicInteger &dim) const
constexpr Float8E5M2FNUZ_t() noexcept
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
Wrapper around OrtMemoryInfo.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
bool operator<(const BFloat16_t &rhs) const noexcept
detail::MapTypeInfoImpl< detail::Unowned< const OrtMapTypeInfo >> ConstMapTypeInfo
Wrapper around ::OrtIoBinding.
bool IsFinite() const noexcept
Tests if the value is finite
IMATH_NAMESPACE::V2f float
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
void GetAttr(const OrtKernelInfo *p, const char *name, std::string &)
OrtKernelContext * GetOrtKernelContext() const
float ToFloat() const noexcept
Converts float16 to float
Shared implementation between public and internal classes. CRTP pattern.
constexpr bool operator==(const Float8E5M2_t &rhs) const noexcept
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions >> UnownedSessionOptions
constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept
bool IsNegative() const noexcept
Checks if the value is negative
All C++ methods that can fail will throw an exception of this type.
const char * what() const noexceptoverride
A generic, discriminated value, whose type may be queried dynamically.
typename Unowned< T >::Type contained_type
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Wrapper around ::OrtSequenceTypeInfo.
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
GLint GLint GLsizei GLint GLenum format
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
detail::TypeInfoImpl< detail::Unowned< const OrtTypeInfo >> ConstTypeInfo
Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. Provides access to const OrtTypeInfo APIs.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity
float ToFloatImpl() const noexcept
Converts float16 to float
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
IEEE 754 half-precision floating point data type.
std::vector< std::string > GetSessionConfigKeys() const
constexpr Float8E4M3FN_t() noexcept
SessionOptions(OrtSessionOptions *p)
Create and own custom defined operation.
ConstIoBinding GetConst() const
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
float8e5m2fnuz (Float8 Floating Point) data type
SymbolicInteger(int64_t i)
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
constexpr Float8E5M2_t() noexcept
constexpr bool operator!=(const Float8E4M3FNUZ_t &rhs) const noexcept
GLuint const GLchar * name
int GetVariadicInputMinArity() const
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Float16_t Negate() const noexcept
Creates a new instance with the sign flipped.
OCIOEXPORT const char * GetVersion()
Get the version number for the library, as a dot-delimited string (e.g., "1.0.0").
AllocatedFree(OrtAllocator *allocator)
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
ORT_DEFINE_RELEASE(Allocator)
bool GetVariadicOutputHomogeneity() const
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
constexpr Base(contained_type *p) noexcept
BFloat16_t Abs() const noexcept
Creates an instance that represents absolute value.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.
GT_API const UT_StringHolder version
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
OrtAllocator * allocator_
OrtErrorCode GetErrorCode() const
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
GLenum GLsizei GLsizei GLint * values
void operator()(void *ptr) const
ConstValue GetConst() const
constexpr bool operator==(const Float8E5M2FNUZ_t &rhs) const noexcept
float ToFloat() const noexcept
Converts bfloat16 to float
bool IsNaN() const noexcept
Tests if the value is NaN
constexpr bool operator!=(const Float8E5M2_t &rhs) const noexcept
bool operator==(const BFloat16_t &rhs) const noexcept
MemoryInfo(OrtMemoryInfo *p)
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
float8e5m2 (Float8 Floating Point) data type
**If you just want to fire and args
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
constexpr Float8E4M3FN_t(uint8_t v) noexcept
int GetVariadicOutputMinArity() const
constexpr bool operator==(const Float8E4M3FNUZ_t &rhs) const noexcept
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
std::string GetVersionString()
This function returns the onnxruntime version string
static bool AreZero(const BFloat16Impl &lhs, const BFloat16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
OrtErrorCode GetOrtErrorCode() const
#define MAX_CUSTOM_OP_END_VER
const char * GetExecutionProviderType() const
Base & operator=(Base &&v) noexcept
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
Exception(std::string &&string, OrtErrorCode code)
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Wrapper around ::OrtSessionOptions.
const char * AsSym() const
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed...
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Base & operator=(const Base &)=delete
UnownedValue GetUnowned() const
R GetAttribute(const char *name) const
Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
bfloat16 (Brain Floating Point) data type
ConstKernelInfo GetConst() const
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
constexpr FMT_INLINE value()
Shared implementation between public and internal classes. CRTP pattern.
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used. ...
Class that represents session configuration entries for one or more custom operators.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
constexpr bool operator!=(const Float8E4M3FN_t &rhs) const noexcept
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
static uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation
ConstTensorTypeAndShapeInfo GetConst() const
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity
UnownedIoBinding GetUnowned() const