13 #define RETURN_ON_API_FAIL(expression) \
15 auto err = (expression); \
44 inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
48 p_ =
GetApi().CreateStatus(ORT_FAIL, e.what());
52 p_ =
GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
69 return (
p_ ==
nullptr);
74 struct TypeToTensorType;
76 struct TypeToTensorType<
float> {
77 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
81 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
85 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
88 struct TypeToTensorType<double> {
89 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
92 struct TypeToTensorType<int8_t> {
93 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
96 struct TypeToTensorType<int16_t> {
97 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
100 struct TypeToTensorType<int32_t> {
101 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
104 struct TypeToTensorType<int64_t> {
105 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
108 struct TypeToTensorType<uint8_t> {
109 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
112 struct TypeToTensorType<uint16_t> {
113 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
116 struct TypeToTensorType<uint32_t> {
117 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
120 struct TypeToTensorType<uint64_t> {
121 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
124 struct TypeToTensorType<bool> {
125 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
130 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
134 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
138 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
142 static constexpr ONNXTensorElementDataType
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
146 if (IsNaN() || rhs.IsNaN()) {
150 return val == rhs.val;
154 if (IsNaN() || rhs.IsNaN()) {
159 const bool left_is_negative = IsNegative();
160 if (left_is_negative != rhs.IsNegative()) {
164 return left_is_negative && !AreZero(*
this, rhs);
166 return (
val != rhs.val) && ((
val < rhs.val) ^ left_is_negative);
170 : allocator_(allocator), p_(p), size_(size) {
176 auto ret =
GetApi().AllocatorFree(allocator_, p_);
177 static_cast<void>(ret);
182 *
this = std::move(o);
186 OrtAllocator* alloc =
nullptr;
207 template <
typename T>
214 template <
typename T>
222 template <
typename T>
227 template <
typename T>
246 template <
typename T>
248 const char*
name =
nullptr;
253 template <
typename T>
255 OrtAllocatorType
type;
260 template <
typename T>
267 template <
typename T>
269 OrtMemoryInfoDeviceType
type;
270 GetApi().MemoryInfoGetDeviceType(this->p_, &type);
274 template <
typename T>
281 template <
typename T>
282 template <
typename U>
286 return comp_result == 0;
302 template <
typename T>
308 template <
typename T>
313 template <
typename T>
319 template <
typename T>
324 template <
typename T>
329 template <
typename T>
334 template <
typename T>
339 template <
typename T>
341 GetApi().ClearBoundInputs(this->p_);
344 template <
typename T>
346 GetApi().ClearBoundOutputs(this->p_);
349 template <
typename T>
354 template <
typename T>
359 namespace binding_utils {
361 std::vector<std::string>
result;
363 using Ptr = std::unique_ptr<void, decltype(free_fn)>;
368 ThrowOnError(
GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
374 Ptr buffer_g(buffer, free_fn);
375 Ptr lengths_g(lengths, free_fn);
377 result.reserve(count);
378 for (
size_t i = 0; i < count; ++i) {
380 result.emplace_back(buffer, sz);
388 std::vector<Value>
result;
390 size_t output_count = 0;
393 auto free_fn = [&owned, &output_count, allocator](
OrtValue**
buffer) {
395 while (owned < output_count) {
396 auto* p =
buffer + owned++;
397 GetApi().ReleaseValue(*p);
399 allocator->Free(allocator,
buffer);
402 using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
405 ThrowOnError(
GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
406 if (output_count == 0) {
410 Ptr buffer_g(output_buffer, free_fn);
412 result.reserve(output_count);
413 for (
size_t i = 0; i < output_count; ++i) {
414 result.emplace_back(output_buffer[i]);
427 inline ArenaCfg::ArenaCfg(
size_t max_mem,
int arena_extend_strategy,
int initial_chunk_size_bytes,
int max_dead_bytes_per_chunk) {
428 ThrowOnError(
GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &
p_));
470 inline Env::Env(OrtLoggingLevel logging_level, _In_
const char* logid) {
472 if (strcmp(logid,
"onnxruntime-node") == 0) {
475 ThrowOnError(
GetApi().SetLanguageProjection(
p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
479 inline Env::Env(OrtLoggingLevel logging_level,
const char* logid, OrtLoggingFunction logging_function,
void* logger_param) {
480 ThrowOnError(
GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &
p_));
481 if (strcmp(logid,
"onnxruntime-node") == 0) {
484 ThrowOnError(
GetApi().SetLanguageProjection(
p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
488 inline Env::Env(
const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_
const char* logid) {
490 if (strcmp(logid,
"onnxruntime-node") == 0) {
493 ThrowOnError(
GetApi().SetLanguageProjection(
p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
497 inline Env::Env(
const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function,
void* logger_param,
498 OrtLoggingLevel logging_level, _In_
const char* logid) {
499 ThrowOnError(
GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &
p_));
500 if (strcmp(logid,
"onnxruntime-node") == 0) {
503 ThrowOnError(
GetApi().SetLanguageProjection(
p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
507 inline Env& Env::EnableTelemetryEvents() {
528 std::vector<const char*> keys,
values;
529 auto num_entries = options.size();
530 if (num_entries > 0) {
531 keys.reserve(num_entries);
532 values.reserve(num_entries);
533 for (
const auto& entry : options) {
534 keys.push_back(entry.first.c_str());
535 values.push_back(entry.second.c_str());
604 template <
typename T>
606 OrtSessionOptions* out;
611 template <
typename T>
620 out.resize(size - 1);
625 template <
typename T>
629 return static_cast<bool>(out);
632 template <
typename T>
634 if (!this->HasConfigEntry(config_key)) {
638 return this->GetConfigEntry(config_key);
641 template <
typename T>
647 template <
typename T>
653 template <
typename T>
655 ThrowOnError(
GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
659 template <
typename T>
665 template <
typename T>
671 template <
typename T>
677 template <
typename T>
683 template <
typename T>
689 template <
typename T>
695 template <
typename T>
701 template <
typename T>
707 template <
typename T>
713 template <
typename T>
719 template <
typename T>
725 template <
typename T>
731 template <
typename T>
737 template <
typename T>
743 template <
typename T>
749 template <
typename T>
755 template <
typename T>
757 const std::vector<Value>& ort_values) {
758 const size_t inputs_num = names.size();
759 if (inputs_num != ort_values.size()) {
760 ORT_CXX_API_THROW(
"Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
762 std::vector<const char*> names_ptr;
763 std::vector<const OrtValue*> ort_values_ptrs;
764 names_ptr.reserve(inputs_num);
765 ort_values_ptrs.reserve(inputs_num);
766 for (
size_t i = 0; i < inputs_num; ++i) {
767 names_ptr.push_back(names[i].c_str());
768 ort_values_ptrs.push_back(ort_values[i]);
770 ThrowOnError(
GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
774 template <
typename T>
776 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
780 template <
typename T>
782 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
786 template <
typename T>
788 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
792 template <
typename T>
794 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
798 template <
typename T>
800 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
804 template <
typename T>
806 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
810 template <
typename T>
812 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
816 template <
typename T>
818 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
822 template <
typename T>
825 const std::unordered_map<std::string, std::string>& provider_options) {
826 auto num_entries = provider_options.size();
827 std::vector<const char*> keys,
values;
828 if (num_entries > 0) {
829 keys.reserve(num_entries);
830 values.reserve(num_entries);
832 for (
const auto& entry : provider_options) {
833 keys.push_back(entry.first.c_str());
834 values.push_back(entry.second.c_str());
838 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
839 keys.data(), values.data(), num_entries));
844 template <
typename T>
846 ThrowOnError(
GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
850 template <
typename T>
852 ThrowOnError(
GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
856 template <
typename T>
858 ThrowOnError(
GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
862 template <
typename T>
864 ThrowOnError(
GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
868 template <
typename T>
870 auto num_entries = provider_options.size();
871 std::vector<const char*> keys,
values;
872 if (num_entries > 0) {
873 keys.reserve(num_entries);
874 values.reserve(num_entries);
876 for (
const auto& entry : provider_options) {
877 keys.push_back(entry.first.c_str());
878 values.push_back(entry.second.c_str());
883 keys.data(), values.data(), num_entries));
888 template <
typename T>
894 AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
901 template <
typename T>
908 template <
typename T>
915 template <
typename T>
922 template <
typename T>
929 template <
typename T>
936 template <
typename T>
943 template <
typename T>
946 ThrowOnError(
GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
950 template <
typename T>
957 template <
typename T>
959 OrtModelMetadata* out;
964 template <
typename T>
971 template <
typename T>
978 template <
typename T>
981 ThrowOnError(
GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
985 template <
typename T>
987 const char*
const* output_names,
size_t output_count) {
988 std::vector<Value> output_values;
989 output_values.reserve(output_count);
990 for (
size_t i = 0; i < output_count; i++)
991 output_values.emplace_back(
nullptr);
992 Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
993 return output_values;
996 template <
typename T>
998 const char*
const* output_names,
Value* output_values,
size_t output_count) {
999 static_assert(
sizeof(
Value) ==
sizeof(
OrtValue*),
"Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1000 auto ort_input_values =
reinterpret_cast<const OrtValue* const*
>(input_values);
1001 auto ort_output_values =
reinterpret_cast<OrtValue**
>(output_values);
1002 ThrowOnError(
GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
1005 template <
typename T>
1010 template <
typename T>
1012 const char*
const* output_names,
Value* output_values,
size_t output_count, RunAsyncCallbackFn callback,
void* user_data) {
1013 auto ort_input_values =
reinterpret_cast<const OrtValue* const*
>(input_values);
1014 auto ort_output_values =
reinterpret_cast<OrtValue**
>(output_values);
1016 ort_input_values, input_count, output_names, output_count,
1017 ort_output_values, callback, user_data));
1020 template <
typename T>
1022 char* out =
nullptr;
1037 config_key += custom_op_name;
1039 config_key += config;
1046 flat_configs_[full_flat_key] = config_value;
1051 return flat_configs_;
1059 OrtPrepackedWeightsContainer* prepacked_weights_container) {
1060 ThrowOnError(
GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->
p_));
1064 ThrowOnError(
GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->
p_));
1068 const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1069 ThrowOnError(
GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1070 prepacked_weights_container, &this->
p_));
1073 inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator)
const {
1111 std::vector<AllocatedStringPtr>
result;
1113 char** out =
nullptr;
1114 int64_t num_keys = 0;
1116 if (num_keys <= 0) {
1121 std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1123 auto strings_deletor = [&deletor, num_keys](
char** out) {
for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1124 std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1125 result.reserve(static_cast<size_t>(num_keys));
1126 strings_guard.release();
1127 for (int64_t i = 0; i < num_keys; ++i) {
1142 template <
typename T>
1144 ONNXTensorElementDataType out;
1149 template <
typename T>
1153 return static_cast<size_t>(out);
1156 template <
typename T>
1163 template <
typename T>
1168 template <
typename T>
1173 template <
typename T>
1175 std::vector<int64_t> out(GetDimensionsCount(), 0);
1180 template <
typename T>
1182 const OrtTensorTypeAndShapeInfo* out;
1187 template <
typename T>
1189 const OrtSequenceTypeInfo* out;
1194 template <
typename T>
1196 const OrtMapTypeInfo* out;
1201 template <
typename T>
1208 template <
typename T>
1210 OrtTypeInfo* output;
1215 template <
typename T>
1222 template <
typename T>
1224 ONNXTensorElementDataType out;
1229 template <
typename T>
1231 OrtTypeInfo* output;
1236 template <
typename T>
1238 const OrtOptionalTypeInfo* info;
1247 template <
typename T>
1248 template <
typename R>
1253 template <
typename T>
1260 template <
typename T>
1267 template <
typename T>
1274 template <
typename T>
1281 template <
typename T>
1288 template <
typename T>
1295 template <
typename T>
1296 template <
typename R>
1299 ThrowOnError(
GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (
void**)&out));
1303 template <
typename T>
1310 template <
typename T>
1312 OrtTypeInfo* output;
1317 template <
typename T>
1319 OrtTensorTypeAndShapeInfo* output;
1324 template <
typename T>
1331 template <
typename T>
1333 ThrowOnError(
GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1336 template <
typename T>
1338 size_t buffer_length;
1339 ThrowOnError(
GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1342 s.resize(buffer_length);
1343 ThrowOnError(
GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1347 template <
typename T>
1349 ThrowOnError(
GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1352 #if !defined(DISABLE_SPARSE_TENSORS)
1353 template <
typename T>
1360 template <
typename T>
1362 OrtTensorTypeAndShapeInfo* output;
1367 template <
typename T>
1369 OrtTensorTypeAndShapeInfo* output;
1370 ThrowOnError(
GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1374 template <
typename T>
1375 template <
typename R>
1378 ThrowOnError(
GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1379 return reinterpret_cast<const R*
>(out);
1382 template <
typename T>
1389 template <
typename T>
1390 template <
typename R>
1394 return reinterpret_cast<const R*
>(out);
1399 template <
typename T>
1404 template <
typename T>
1409 template <
typename T>
1412 ThrowOnError(
GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1416 template <
typename T>
1423 template <
typename T>
1424 template <
typename R>
1431 template <
typename T>
1432 template <
typename R>
1436 ThrowOnError(
GetApi().TensorAt(this->p_, location.data(), location.size(), (
void**)&out));
1440 #if !defined(DISABLE_SPARSE_TENSORS)
1441 template <
typename T>
1446 template <
typename T>
1448 ThrowOnError(
GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1451 template <
typename T>
1456 template <
typename T>
1458 const int64_t* indices_data,
size_t indices_num) {
1461 indices_data, indices_num));
1464 template <
typename T>
1467 const int64_t* inner_indices_data,
size_t inner_indices_num,
1468 const int64_t* outer_indices_data,
size_t outer_indices_num) {
1470 inner_indices_data, inner_indices_num,
1471 outer_indices_data, outer_indices_num));
1474 template <
typename T>
1477 const Shape& indices_shape,
1478 const int32_t* indices_data) {
1484 #endif // !defined(DISABLE_SPARSE_TENSORS)
1488 template <
typename T>
1494 ONNXTensorElementDataType type) {
1496 ThrowOnError(
GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1500 template <
typename T>
1507 ThrowOnError(
GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1511 #if !defined(DISABLE_SPARSE_TENSORS)
1513 template <
typename T>
1515 const Shape& values_shape) {
1520 const Shape& values_shape, ONNXTensorElementDataType type) {
1527 template <
typename T>
1533 ONNXTensorElementDataType type) {
1538 #endif // !defined(DISABLE_SPARSE_TENSORS)
1542 const OrtValue* inputs[2] = {keys, values};
1549 std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1550 ThrowOnError(
GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1554 template <
typename T>
1557 ThrowOnError(
GetApi().CreateOpaqueValue(domain, type_name, &data_container,
sizeof(
T), &out));
1569 return cached_severity_level_;
1573 const char* func_name,
const char*
message)
const noexcept {
1574 OrtStatus* status =
GetApi().Logger_LogMessage(logger_, log_severity_level,
message, file_path, line_number,
1582 #if defined(__GNUC__)
1583 #pragma GCC diagnostic push
1584 #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1585 #pragma GCC diagnostic ignored "-Wformat-security"
1586 #elif defined(__clang__)
1587 #pragma clang diagnostic push
1588 #pragma clang diagnostic ignored "-Wformat-nonliteral"
1589 #pragma clang diagnostic ignored "-Wformat-security"
1591 template <
typename... Args>
1593 int line_number,
const char* func_name,
const char*
format,
1594 Args&&...
args) const noexcept {
1595 int msg_len = std::snprintf(
nullptr, 0U, format, std::forward<Args>(
args)...);
1598 return Status(
"Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1601 OrtStatus* status =
nullptr;
1602 const size_t buffer_size =
static_cast<size_t>(msg_len) + 1U;
1604 constexpr
size_t kStackBufferSize = 1024;
1606 if (buffer_size < kStackBufferSize) {
1607 char buffer[kStackBufferSize];
1608 snprintf(buffer, kStackBufferSize, format, std::forward<Args>(
args)...);
1609 status =
GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1612 #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1613 auto buffer = std::make_unique<char[]>(buffer_size);
1615 std::unique_ptr<char[]>
buffer(
new char[buffer_size]);
1617 std::snprintf(buffer.get(), buffer_size,
format, std::forward<Args>(
args)...);
1618 status =
GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1624 #if defined(__GNUC__)
1625 #pragma GCC diagnostic pop
1626 #elif defined(__clang__)
1627 #pragma clang diagnostic pop
1664 void* out =
nullptr;
1670 OrtAllocator* out =
nullptr;
1676 const OrtLogger* out =
nullptr;
1682 ThrowOnError(
GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
1690 template <
typename T>
1692 OrtKernelInfo* info_copy =
nullptr;
1697 template <
typename T>
1704 template <
typename T>
1711 template <
typename T>
1721 out.resize(size - 1);
1726 template <
typename T>
1736 out.resize(size - 1);
1741 template <
typename T>
1743 OrtTypeInfo* out =
nullptr;
1748 template <
typename T>
1750 OrtTypeInfo* out =
nullptr;
1755 template <
typename T>
1758 ThrowOnError(
GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1762 template <
typename T>
1765 ThrowOnError(
GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1769 template <
typename T>
1779 out.resize(size - 1);
1784 template <
typename T>
1786 const OrtLogger* out =
nullptr;
1807 out.resize(size - 1);
1816 std::vector<float> out;
1828 std::vector<int64_t> out;
1837 inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1839 inline Op Op::Create(
const OrtKernelInfo* info,
const char* op_name,
const char* domain,
int version,
1840 const char** type_constraint_names,
1841 const ONNXTensorElementDataType* type_constraint_values,
1842 size_t type_constraint_count,
1843 const OpAttr* attr_values,
size_t attr_count,
1844 size_t input_count,
size_t output_count) {
1845 static_assert(
sizeof(OpAttr) ==
sizeof(OrtOpAttr*),
1846 "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1847 auto attr_input_values =
reinterpret_cast<const OrtOpAttr* const*
>(attr_values);
1849 Ort::ThrowOnError(
GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1850 static_cast<int>(type_constraint_count),
1852 static_cast<int>(attr_count),
1853 static_cast<int>(input_count),
1854 static_cast<int>(output_count), &op));
1858 inline void Op::Invoke(
const OrtKernelContext* context,
1859 const Value* input_values,
1861 Value* output_values,
1862 size_t output_count) {
1864 "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1865 auto ort_input_values =
reinterpret_cast<const OrtValue* const*
>(input_values);
1866 auto ort_output_values =
reinterpret_cast<OrtValue**
>(output_values);
1868 ort_output_values, static_cast<int>(output_count)));
1871 inline void Op::Invoke(
const OrtKernelContext* context,
1872 const OrtValue*
const* input_values,
1875 size_t output_count) {
1877 output_values, static_cast<int>(output_count)));
1881 return OrtGetApiBase()->GetVersionString();
1885 return GetApi().GetBuildInfoString();
1892 auto release_fn = [&len](
char** providers) {
1898 std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1899 std::vector<std::string> available_providers;
1900 available_providers.reserve(static_cast<size_t>(len));
1901 for (
int i = 0; i < len; ++i) {
1902 available_providers.emplace_back(providers[i]);
1904 return available_providers;
1907 template <
typename TOp,
typename TKernel,
bool WithStatus>
1910 const TOp* derived =
static_cast<const TOp*
>(
this);
1911 std::vector<std::string> keys = derived->GetSessionConfigKeys();
1913 out.reserve(keys.size());
1916 const size_t prefix_size = config_entry_key.length();
1918 for (
const auto& key : keys) {
1919 config_entry_key.resize(prefix_size);
1920 config_entry_key.append(key);
1921 out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(),
"");
1926 OrtShapeInferContext* ctx) : ort_api_(ort_api),
ctx_(ctx) {
1927 size_t input_count = 0;
1929 for (
size_t ith_input = 0; ith_input < input_count; ++ith_input) {
1930 OrtTensorTypeAndShapeInfo* info{};
1931 Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
1933 auto integer_shape = type_shape_info.GetShape();
1934 std::vector<const char*> symbolic_shape(integer_shape.size(), {});
1935 type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
1937 for (
size_t ith = 0; ith < integer_shape.size(); ++ith) {
1938 if (symbolic_shape[ith] &&
std::string{symbolic_shape[ith]}.size() > 0) {
1939 shape.emplace_back(symbolic_shape[ith]);
1941 shape.emplace_back(integer_shape[ith]);
1944 input_shapes_.push_back(std::move(shape));
1945 type_shape_info.release();
1949 inline Status ShapeInferContext::SetOutputShape(
size_t indice,
const Shape& shape) {
1950 OrtTensorTypeAndShapeInfo* info = {};
1953 using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
1955 InfoPtr
info_ptr(info, [
this](OrtTensorTypeAndShapeInfo* obj) {
1956 ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
1959 std::vector<int64_t> integer_dims;
1960 std::vector<const char*> symbolic_dims;
1962 for (
const auto dim : shape) {
1964 integer_dims.push_back(dim.IsInt());
1965 symbolic_dims.push_back(
"");
1967 if (!dim.AsSym() ||
std::string{dim.AsSym()}.empty()) {
1968 ORT_CXX_API_THROW(
"Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
1970 integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
1971 symbolic_dims.push_back(dim.AsSym());
1975 RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
1976 RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
1981 inline int64_t ShapeInferContext::GetAttrInt(
const char* attr_name) {
1982 const auto* attr = GetAttrHdl(attr_name);
1985 Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i,
sizeof(i), &out));
1990 const auto* attr = GetAttrHdl(attr_name);
1994 auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i,
sizeof(i), &out);
1996 size_t num_i = out /
sizeof(int64_t);
1998 Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
2005 inline float ShapeInferContext::GetAttrFloat(
const char* attr_name) {
2006 const auto* attr = GetAttrHdl(attr_name);
2009 Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f,
sizeof(f), &out));
2014 const auto* attr = GetAttrHdl(attr_name);
2018 auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f,
sizeof(f), &out);
2020 size_t num_f = out /
sizeof(
float);
2022 Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
2029 inline std::string ShapeInferContext::GetAttrString(
const char* attr_name) {
2030 const auto* attr = GetAttrHdl(attr_name);
2034 auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c,
sizeof(
char), &out);
2036 std::vector<char> chars(out,
'\0');
2037 Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
2038 return {chars.data()};
2045 const auto* attr = GetAttrHdl(attr_name);
2049 auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c,
sizeof(
char), &out);
2051 std::vector<char> chars(out,
'\0');
2052 Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
2054 char* char_st = chars.data();
2055 char* char_ed = char_st + out;
2056 while (char_st < char_ed) {
2057 strings.emplace_back(char_st);
2058 while (*char_st !=
'\0') {
2069 inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(
const char* attr_name)
const {
2070 const OrtOpAttr* attr_hdl = {};
OrtMemType GetMemoryType() const
OrtMemoryInfoDeviceType GetDeviceType() const
Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T *file_path, int line_number, const char *func_name, const char *format, Args &&...args) const noexcept
std::vector< int64_t > Ints
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id...
SessionOptionsImpl & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
GLuint GLsizei const GLchar * message
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
ONNXType GetONNXType() const
MemoryAllocation & operator=(const MemoryAllocation &)=delete
Env & DisableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
union Ort::detail::OrtSparseValuesParam::@164 data
OrtAllocator * GetAllocator(const OrtMemoryInfo &memory_info) const
std::string GetErrorMessage() const
png_const_structrp png_const_inforp info_ptr
size_t GetInputCount() const
Returns the number of model inputs.
SessionOptionsImpl & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
std::vector< std::string > Strings
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
SessionOptionsImpl & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
SessionOptionsImpl & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
std::vector< float > Floats
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
ConstMemoryInfo GetInfo() const
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
SessionOptionsImpl & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Value GetValue(int index, OrtAllocator *allocator) const
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
SessionOptionsImpl & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN.
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
static Value CreateSequence(const std::vector< Value > &values)
Creates an OrtValue with a Sequence Onnx type representation. The API would ref-count the supplied Or...
GLsizei const GLchar *const * string
void ThrowOnError(OrtStatus *ort_status)
GLsizei const GLfloat * value
void BindInput(const char *name, const Value &)
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
ConstValue GetInput(size_t index) const
SessionOptionsImpl & AppendExecutionProvider_CANN(const OrtCANNProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl.
MemoryAllocation GetAllocation(size_t size)
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
SessionOptionsImpl & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
bool HasConfigEntry(const char *config_key) const
Wraps OrtApi::HasSessionConfigEntry.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
static Value CreateOpaque(const char *domain, const char *type_name, const T &value)
Creates an OrtValue wrapping an Opaque type. This is used for experimental support of non-tensor type...
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor
std::vector< SymbolicInteger > Shape
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
SessionOptionsImpl & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
void SynchronizeOutputs()
std::string GetConfigEntryOrDefault(const char *config_key, const std::string &def)
std::string GetNodeName() const
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
SessionOptionsImpl & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
SessionOptionsImpl & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
R & At(const std::vector< int64_t > &location)
std::string GetAllocatorName() const
void RunAsync(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count, RunAsyncCallbackFn callback, void *user_data)
Run the model asynchronously in a thread owned by intra op thread pool.
SessionOptionsImpl & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
const int64_t * values_shape
SessionOptionsImpl & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
**But if you need a result
SessionOptionsImpl & RegisterCustomOpsUsingFunction(const char *function_name)
Wraps OrtApi::RegisterCustomOpsUsingFunction.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
SessionOptionsImpl & SetOptimizedModelFilePath(const ORTCHAR_T *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
float8e4m3fnuz (Float8 Floating Point) data type
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values...
void * Alloc(size_t size)
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
SessionOptionsImpl & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed, the caller must ensure the type matches the tensor type.
SessionOptionsImpl & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
std::string GetConfigEntry(const char *config_key) const
Wraps OrtApi::GetSessionConfigEntry.
void ThrowStatus(const Status &st)
std::vector< std::string > GetOutputNames() const
float8e4m3fn (Float8 Floating Point) data type
SessionOptionsImpl & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
ConstValue GetTensorConstantInput(size_t index, int *is_constant) const
GLuint GLsizei const GLuint const GLintptr * offsets
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetOutputTypeInfo(size_t index) const
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
Wrapper around OrtMemoryInfo.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
bool operator<(const BFloat16_t &rhs) const noexcept
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
Wrapper around ::OrtIoBinding.
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< float > &)
IMATH_NAMESPACE::V2f float
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
SessionOptionsImpl & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
A generic, discriminated value, whose type may be queried dynamically.
void GetAttr(const OrtKernelInfo *p, const char *name, float &)
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2.
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
SessionOptionsImpl & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
GLint GLint GLsizei GLint GLenum format
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Value GetTensorAttribute(const char *name, OrtAllocator *allocator) const
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
IEEE 754 half-precision floating point data type.
size_t GetInputCount() const
static Value CreateMap(const Value &keys, const Value &values)
Creates an OrtValue with a Map Onnx type representation. The API would ref-count the supplied OrtValu...
SessionOptionsImpl & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
size_t GetOutputCount() const
float8e5m2fnuz (Float8 Floating Point) data type
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
SessionOptionsImpl & AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions &provider_options)
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
void * GetGPUComputeStream() const
void BindOutput(const char *name, const Value &)
GLuint const GLchar * name
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
SessionOptionsImpl & EnableProfiling(const ORTCHAR_T *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor...
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
GLsizei const GLchar *const * strings
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance...
OrtAllocatorType GetAllocatorType() const
SessionOptionsImpl & AppendExecutionProvider_OpenVINO_V2(const std::unordered_map< std::string, std::string > &provider_options={})
size_t GetInputCount() const
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
detail::ValueImpl< detail::Unowned< OrtValue >> UnownedValue
SessionOptionsImpl & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor...
GT_API const UT_StringHolder version
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
bool operator==(const MemoryInfoImpl< U > &o) const
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
RunOptions()
Wraps OrtApi::CreateRunOptions.
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
std::vector< Value > GetOutputValues() const
OrtErrorCode GetErrorCode() const
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
void ParallelFor(void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
KernelContext(OrtKernelContext *context)
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Env & CreateAndRegisterAllocatorV2(const std::string &provider_type, const OrtMemoryInfo *mem_info, const std::unordered_map< std::string, std::string > &options, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocatorV2.
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
SessionOptionsImpl & SetDeterministicCompute(bool value)
Wraps OrtApi::SetDeterministicCompute.
GT_API const UT_StringHolder st
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
GLenum GLsizei GLsizei GLint * values
std::string GetOutputName(size_t index) const
bool operator==(const BFloat16_t &rhs) const noexcept
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
SessionOptionsImpl & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
SessionOptionsImpl & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
AllocatorWithDefaultOptions()
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
ShapeInferContext(const OrtApi *ort_api, OrtShapeInferContext *ctx)
SessionOptionsImpl & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK...
SessionOptionsImpl & DisableProfiling()
Wraps OrtApi::DisableProfiling.
float8e5m2 (Float8 Floating Point) data type
**If you just want to fire and args
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
std::string GetVersionString()
This function returns the onnxruntime version string
size_t GetOutputCount() const
Returns the number of model outputs.
Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T *file_path, int line_number, const char *func_name, const char *message) const noexcept
SessionOptionsImpl & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None ...
std::string GetInputName(size_t index) const
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
Wrapper around ::OrtSessionOptions.
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
detail::MemoryInfoImpl< detail::Unowned< const OrtMemoryInfo >> ConstMemoryInfo
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
TypeInfo GetInputTypeInfo(size_t index) const
SessionOptionsImpl & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
SessionOptionsImpl & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptionsImpl & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
#define RETURN_ON_API_FAIL(expression)
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
bfloat16 (Brain Floating Point) data type
Class that represents session configuration entries for one or more custom operators.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
#define ORT_CXX_API_THROW(string, code)
SessionOptionsImpl & RegisterCustomOpsLibrary(const ORTCHAR_T *library_name, const CustomOpConfigs &custom_op_configs={})
GLsizei GLenum GLenum GLuint GLenum GLsizei * lengths
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.