HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_cxx_api.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5 //
6 // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7 // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8 // all the resources follow RAII and do not leak memory.
9 //
10 // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11 // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12 // until you assign an instance that actually holds an underlying object.
13 //
14 // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15 // Some objects have explicit 'Clone' methods for this purpose.
16 //
17 // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18 // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19 //
20 // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21 //
22 // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23 // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24 
25 #pragma once
26 #include "onnxruntime_c_api.h"
27 #include "onnxruntime_float16.h"
28 
29 #include <cstddef>
30 #include <cstdio>
31 #include <array>
32 #include <memory>
33 #include <stdexcept>
34 #include <string>
35 #include <vector>
36 #include <unordered_map>
37 #include <utility>
38 #include <type_traits>
39 
40 #ifdef ORT_NO_EXCEPTIONS
41 #include <iostream>
42 #endif
43 
44 /** \brief All C++ Onnxruntime APIs are defined inside this namespace
45  *
46  */
47 namespace Ort {
48 
49 /** \brief All C++ methods that can fail will throw an exception of this type
50  *
51  * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
52  */
53 struct Exception : std::exception {
54  Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
55 
56  OrtErrorCode GetOrtErrorCode() const { return code_; }
57  const char* what() const noexcept override { return message_.c_str(); }
58 
59  private:
60  std::string message_;
61  OrtErrorCode code_;
62 };
63 
64 #ifdef ORT_NO_EXCEPTIONS
65 // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
66 // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
67 #ifndef ORT_CXX_API_THROW
68 #define ORT_CXX_API_THROW(string, code) \
69  do { \
70  std::cerr << Ort::Exception(string, code) \
71  .what() \
72  << std::endl; \
73  abort(); \
74  } while (false)
75 #endif
76 #else
77 #define ORT_CXX_API_THROW(string, code) \
78  throw Ort::Exception(string, code)
79 #endif
80 
81 // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
82 // it's in a template so that we can define a global variable in a header and make
83 // it transparent to the users of the API.
84 template <typename T>
85 struct Global {
86  static const OrtApi* api_;
87 };
88 
89 // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
90 template <typename T>
91 #ifdef ORT_API_MANUAL_INIT
92 const OrtApi* Global<T>::api_{};
93 inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
94 
95 // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
96 // required by C++ APIs.
97 //
98 // Example mycustomop.cc:
99 //
100 // #define ORT_API_MANUAL_INIT
101 // #include <onnxruntime_cxx_api.h>
102 // #undef ORT_API_MANUAL_INIT
103 //
104 // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
105 // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
106 // // ...
107 // }
108 //
109 inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
110 #else
111 #if defined(_MSC_VER) && !defined(__clang__)
112 #pragma warning(push)
113 // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
114 // Please define ORT_API_MANUAL_INIT if it conerns you.
115 #pragma warning(disable : 26426)
116 #endif
117 const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
118 #if defined(_MSC_VER) && !defined(__clang__)
119 #pragma warning(pop)
120 #endif
121 #endif
122 
123 /// This returns a reference to the OrtApi interface in use
124 inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
125 
126 /// <summary>
127 /// This function returns the onnxruntime version string
128 /// </summary>
129 /// <returns>version string major.minor.rev</returns>
131 
132 /// <summary>
133 /// This function returns the onnxruntime build information: including git branch,
134 /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
135 /// </summary>
136 /// <returns>string</returns>
138 
139 /// <summary>
140 /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
141 /// returns a vector of strings representing the available execution providers.
142 /// </summary>
143 /// <returns>vector of strings</returns>
144 std::vector<std::string> GetAvailableProviders();
145 
146 /** \brief IEEE 754 half-precision floating point data type
147  *
148  * \details This struct is used for converting float to float16 and back
149  * so the user could feed inputs and fetch outputs using these type.
150  *
151  * The size of the structure should align with uint16_t and one can freely cast
152  * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
153  *
154  * \code{.unparsed}
155  * // This example demonstrates converion from float to float16
156  * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
157  * std::vector<Ort::Float16_t> fp16_values;
158  * fp16_values.reserve(std::size(values));
159  * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
160  * [](float value) { return Ort::Float16_t(value); });
161  *
162  * \endcode
163  */
165  private:
166  /// <summary>
167  /// Constructor from a 16-bit representation of a float16 value
168  /// No conversion is done here.
169  /// </summary>
170  /// <param name="v">16-bit representation</param>
171  constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
172 
173  public:
175 
176  /// <summary>
177  /// Default constructor
178  /// </summary>
179  Float16_t() = default;
180 
181  /// <summary>
182  /// Explicit conversion to uint16_t representation of float16.
183  /// </summary>
184  /// <param name="v">uint16_t bit representation of float16</param>
185  /// <returns>new instance of Float16_t</returns>
186  constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
187 
188  /// <summary>
189  /// __ctor from float. Float is converted into float16 16-bit representation.
190  /// </summary>
191  /// <param name="v">float value</param>
192  explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
193 
194  /// <summary>
195  /// Converts float16 to float
196  /// </summary>
197  /// <returns>float representation of float16 value</returns>
198  float ToFloat() const noexcept { return Base::ToFloatImpl(); }
199 
200  /// <summary>
201  /// Checks if the value is negative
202  /// </summary>
203  /// <returns>true if negative</returns>
204  using Base::IsNegative;
205 
206  /// <summary>
207  /// Tests if the value is NaN
208  /// </summary>
209  /// <returns>true if NaN</returns>
210  using Base::IsNaN;
211 
212  /// <summary>
213  /// Tests if the value is finite
214  /// </summary>
215  /// <returns>true if finite</returns>
216  using Base::IsFinite;
217 
218  /// <summary>
219  /// Tests if the value represents positive infinity.
220  /// </summary>
221  /// <returns>true if positive infinity</returns>
223 
224  /// <summary>
225  /// Tests if the value represents negative infinity
226  /// </summary>
227  /// <returns>true if negative infinity</returns>
229 
230  /// <summary>
231  /// Tests if the value is either positive or negative infinity.
232  /// </summary>
233  /// <returns>True if absolute value is infinity</returns>
234  using Base::IsInfinity;
235 
236  /// <summary>
237  /// Tests if the value is NaN or zero. Useful for comparisons.
238  /// </summary>
239  /// <returns>True if NaN or zero.</returns>
240  using Base::IsNaNOrZero;
241 
242  /// <summary>
243  /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
244  /// </summary>
245  /// <returns>True if so</returns>
246  using Base::IsNormal;
247 
248  /// <summary>
249  /// Tests if the value is subnormal (denormal).
250  /// </summary>
251  /// <returns>True if so</returns>
252  using Base::IsSubnormal;
253 
254  /// <summary>
255  /// Creates an instance that represents absolute value.
256  /// </summary>
257  /// <returns>Absolute value</returns>
258  using Base::Abs;
259 
260  /// <summary>
261  /// Creates a new instance with the sign flipped.
262  /// </summary>
263  /// <returns>Flipped sign instance</returns>
264  using Base::Negate;
265 
266  /// <summary>
267  /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
268  /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
269  /// and therefore equivalent, if the resulting value is still zero.
270  /// </summary>
271  /// <param name="lhs">first value</param>
272  /// <param name="rhs">second value</param>
273  /// <returns>True if both arguments represent zero</returns>
274  using Base::AreZero;
275 
276  /// <summary>
277  /// User defined conversion operator. Converts Float16_t to float.
278  /// </summary>
279  explicit operator float() const noexcept { return ToFloat(); }
280 
281  using Base::operator==;
282  using Base::operator!=;
283  using Base::operator<;
284 };
285 
286 static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
287 
288 /** \brief bfloat16 (Brain Floating Point) data type
289  *
290  * \details This struct is used for converting float to bfloat16 and back
291  * so the user could feed inputs and fetch outputs using these type.
292  *
293  * The size of the structure should align with uint16_t and one can freely cast
294  * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
295  *
296  * \code{.unparsed}
297  * // This example demonstrates converion from float to float16
298  * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
299  * std::vector<Ort::BFloat16_t> bfp16_values;
300  * bfp16_values.reserve(std::size(values));
301  * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
302  * [](float value) { return Ort::BFloat16_t(value); });
303  *
304  * \endcode
305  */
307  private:
308  /// <summary>
309  /// Constructor from a uint16_t representation of bfloat16
310  /// used in FromBits() to escape overload resolution issue with
311  /// constructor from float.
312  /// No conversion is done.
313  /// </summary>
314  /// <param name="v">16-bit bfloat16 value</param>
315  constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
316 
317  public:
319 
320  BFloat16_t() = default;
321 
322  /// <summary>
323  /// Explicit conversion to uint16_t representation of bfloat16.
324  /// </summary>
325  /// <param name="v">uint16_t bit representation of bfloat16</param>
326  /// <returns>new instance of BFloat16_t</returns>
327  static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
328 
329  /// <summary>
330  /// __ctor from float. Float is converted into bfloat16 16-bit representation.
331  /// </summary>
332  /// <param name="v">float value</param>
333  explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
334 
335  /// <summary>
336  /// Converts bfloat16 to float
337  /// </summary>
338  /// <returns>float representation of bfloat16 value</returns>
339  float ToFloat() const noexcept { return Base::ToFloatImpl(); }
340 
341  /// <summary>
342  /// Checks if the value is negative
343  /// </summary>
344  /// <returns>true if negative</returns>
345  using Base::IsNegative;
346 
347  /// <summary>
348  /// Tests if the value is NaN
349  /// </summary>
350  /// <returns>true if NaN</returns>
351  using Base::IsNaN;
352 
353  /// <summary>
354  /// Tests if the value is finite
355  /// </summary>
356  /// <returns>true if finite</returns>
357  using Base::IsFinite;
358 
359  /// <summary>
360  /// Tests if the value represents positive infinity.
361  /// </summary>
362  /// <returns>true if positive infinity</returns>
364 
365  /// <summary>
366  /// Tests if the value represents negative infinity
367  /// </summary>
368  /// <returns>true if negative infinity</returns>
370 
371  /// <summary>
372  /// Tests if the value is either positive or negative infinity.
373  /// </summary>
374  /// <returns>True if absolute value is infinity</returns>
375  using Base::IsInfinity;
376 
377  /// <summary>
378  /// Tests if the value is NaN or zero. Useful for comparisons.
379  /// </summary>
380  /// <returns>True if NaN or zero.</returns>
381  using Base::IsNaNOrZero;
382 
383  /// <summary>
384  /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
385  /// </summary>
386  /// <returns>True if so</returns>
387  using Base::IsNormal;
388 
389  /// <summary>
390  /// Tests if the value is subnormal (denormal).
391  /// </summary>
392  /// <returns>True if so</returns>
393  using Base::IsSubnormal;
394 
395  /// <summary>
396  /// Creates an instance that represents absolute value.
397  /// </summary>
398  /// <returns>Absolute value</returns>
399  using Base::Abs;
400 
401  /// <summary>
402  /// Creates a new instance with the sign flipped.
403  /// </summary>
404  /// <returns>Flipped sign instance</returns>
405  using Base::Negate;
406 
407  /// <summary>
408  /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
409  /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
410  /// and therefore equivalent, if the resulting value is still zero.
411  /// </summary>
412  /// <param name="lhs">first value</param>
413  /// <param name="rhs">second value</param>
414  /// <returns>True if both arguments represent zero</returns>
415  using Base::AreZero;
416 
417  /// <summary>
418  /// User defined conversion operator. Converts BFloat16_t to float.
419  /// </summary>
420  explicit operator float() const noexcept { return ToFloat(); }
421 
422  // We do not have an inherited impl for the below operators
423  // as the internal class implements them a little differently
424  bool operator==(const BFloat16_t& rhs) const noexcept;
425  bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
426  bool operator<(const BFloat16_t& rhs) const noexcept;
427 };
428 
429 static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
430 
431 /** \brief float8e4m3fn (Float8 Floating Point) data type
432  * \details It is necessary for type dispatching to make use of C++ API
433  * The type is implicitly convertible to/from uint8_t.
434  * See https://onnx.ai/onnx/technical/float8.html for further details.
435  */
437  uint8_t value;
438  constexpr Float8E4M3FN_t() noexcept : value(0) {}
439  constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
440  constexpr operator uint8_t() const noexcept { return value; }
441  // nan values are treated like any other value for operator ==, !=
442  constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
443  constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
444 };
445 
446 static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
447 
448 /** \brief float8e4m3fnuz (Float8 Floating Point) data type
449  * \details It is necessary for type dispatching to make use of C++ API
450  * The type is implicitly convertible to/from uint8_t.
451  * See https://onnx.ai/onnx/technical/float8.html for further details.
452  */
454  uint8_t value;
455  constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
456  constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
457  constexpr operator uint8_t() const noexcept { return value; }
458  // nan values are treated like any other value for operator ==, !=
459  constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
460  constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
461 };
462 
463 static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
464 
465 /** \brief float8e5m2 (Float8 Floating Point) data type
466  * \details It is necessary for type dispatching to make use of C++ API
467  * The type is implicitly convertible to/from uint8_t.
468  * See https://onnx.ai/onnx/technical/float8.html for further details.
469  */
470 struct Float8E5M2_t {
471  uint8_t value;
472  constexpr Float8E5M2_t() noexcept : value(0) {}
473  constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
474  constexpr operator uint8_t() const noexcept { return value; }
475  // nan values are treated like any other value for operator ==, !=
476  constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
477  constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
478 };
479 
480 static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
481 
482 /** \brief float8e5m2fnuz (Float8 Floating Point) data type
483  * \details It is necessary for type dispatching to make use of C++ API
484  * The type is implicitly convertible to/from uint8_t.
485  * See https://onnx.ai/onnx/technical/float8.html for further details.
486  */
488  uint8_t value;
489  constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
490  constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
491  constexpr operator uint8_t() const noexcept { return value; }
492  // nan values are treated like any other value for operator ==, !=
493  constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
494  constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
495 };
496 
497 static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
498 
499 namespace detail {
500 // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
501 // This can't be done in the C API since C doesn't have function overloading.
502 #define ORT_DEFINE_RELEASE(NAME) \
503  inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
504 
505 ORT_DEFINE_RELEASE(Allocator);
506 ORT_DEFINE_RELEASE(MemoryInfo);
507 ORT_DEFINE_RELEASE(CustomOpDomain);
508 ORT_DEFINE_RELEASE(ThreadingOptions);
509 ORT_DEFINE_RELEASE(Env);
511 ORT_DEFINE_RELEASE(Session);
512 ORT_DEFINE_RELEASE(SessionOptions);
513 ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
514 ORT_DEFINE_RELEASE(SequenceTypeInfo);
515 ORT_DEFINE_RELEASE(MapTypeInfo);
516 ORT_DEFINE_RELEASE(TypeInfo);
518 ORT_DEFINE_RELEASE(ModelMetadata);
519 ORT_DEFINE_RELEASE(IoBinding);
520 ORT_DEFINE_RELEASE(ArenaCfg);
522 ORT_DEFINE_RELEASE(OpAttr);
524 ORT_DEFINE_RELEASE(KernelInfo);
525 
526 #undef ORT_DEFINE_RELEASE
527 
528 /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
529  * has no ownership of the underlying C object.
530  */
531 template <typename T>
532 struct Unowned {
533  using Type = T;
534 };
535 
536 /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
537  * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
538  *
539  * All of the C++ classes
540  * a) serve as containers for pointers to objects that are created by the underlying C API.
541  * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
542  * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
543  * they would release objects owned automatically when going out of scope, they are move-only.
544  * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
545  * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
546  * such as Onnxruntime or instances of XXXX classes.
547  * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
548  * in C++ code.
549  *
550  */
551 
552 /// <summary>
553 /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
554 /// </summary>
555 template <typename T>
556 struct Base {
557  using contained_type = T;
558 
559  constexpr Base() = default;
560  constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
561  ~Base() { OrtRelease(p_); }
562 
563  Base(const Base&) = delete;
564  Base& operator=(const Base&) = delete;
565 
566  Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
567  Base& operator=(Base&& v) noexcept {
568  OrtRelease(p_);
569  p_ = v.release();
570  return *this;
571  }
572 
573  constexpr operator contained_type*() const noexcept { return p_; }
574 
575  /// \brief Relinquishes ownership of the contained C object pointer
576  /// The underlying object is not destroyed
578  T* p = p_;
579  p_ = nullptr;
580  return p;
581  }
582 
583  protected:
585 };
586 
587 // Undefined. For const types use Base<Unowned<const T>>
588 template <typename T>
589 struct Base<const T>;
590 
591 /// <summary>
592 /// Covers unowned pointers owned by either the ORT
593 /// or some other instance of CPP wrappers.
594 /// Used for ConstXXX and UnownedXXXX types that are copyable.
595 /// Also convenient to wrap raw OrtXX pointers .
596 /// </summary>
597 /// <typeparam name="T"></typeparam>
598 template <typename T>
599 struct Base<Unowned<T>> {
601 
602  constexpr Base() = default;
603  constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
604 
605  ~Base() = default;
606 
607  Base(const Base&) = default;
608  Base& operator=(const Base&) = default;
609 
610  Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
611  Base& operator=(Base&& v) noexcept {
612  p_ = nullptr;
613  std::swap(p_, v.p_);
614  return *this;
615  }
616 
617  constexpr operator contained_type*() const noexcept { return p_; }
618 
619  protected:
621 };
622 
623 // Light functor to release memory with OrtAllocator
625  OrtAllocator* allocator_;
626  explicit AllocatedFree(OrtAllocator* allocator)
627  : allocator_(allocator) {}
628  void operator()(void* ptr) const {
629  if (ptr) allocator_->Free(allocator_, ptr);
630  }
631 };
632 
633 } // namespace detail
634 
635 struct AllocatorWithDefaultOptions;
636 struct Env;
637 struct TypeInfo;
638 struct Value;
639 struct ModelMetadata;
640 
641 /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
642  * and release them at the end of the scope. The lifespan of the given allocator
643  * must eclipse the lifespan of AllocatedStringPtr instance
644  */
645 using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
646 
647 /** \brief The Status that holds ownership of OrtStatus received from C API
648  * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
649  * constructors to construct an instance of a Status object from exceptions.
650  */
651 struct Status : detail::Base<OrtStatus> {
652  explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
653  explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
654  explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
655  explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
656  Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
657  std::string GetErrorMessage() const;
658  OrtErrorCode GetErrorCode() const;
659  bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status.
660 };
661 
662 /** \brief The ThreadingOptions
663  *
664  * The ThreadingOptions used for set global threadpools' options of The Env.
665  */
666 struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
667  /// \brief Wraps OrtApi::CreateThreadingOptions
669 
670  /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
671  ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
672 
673  /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
674  ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
675 
676  /// \brief Wraps OrtApi::SetGlobalSpinControl
677  ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
678 
679  /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
680  ThreadingOptions& SetGlobalDenormalAsZero();
681 
682  /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
683  ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
684 
685  /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
686  ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
687 
688  /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
689  ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
690 };
691 
692 /** \brief The Env (Environment)
693  *
694  * The Env holds the logging state used by all other objects.
695  * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
696  */
697 struct Env : detail::Base<OrtEnv> {
698  explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
699 
700  /// \brief Wraps OrtApi::CreateEnv
701  Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
702 
703  /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
704  Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
705 
706  /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
707  Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
708 
709  /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
710  Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
711  OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
712 
713  /// \brief C Interop Helper
714  explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
715 
716  Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
717  Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
718 
719  Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
720 
721  Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
722 
723  Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
724 };
725 
726 /** \brief Custom Op Domain
727  *
728  */
729 struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
730  explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
731 
732  /// \brief Wraps OrtApi::CreateCustomOpDomain
733  explicit CustomOpDomain(const char* domain);
734 
735  // This does not take ownership of the op, simply registers it.
736  void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
737 };
738 
739 /** \brief RunOptions
740  *
741  */
742 struct RunOptions : detail::Base<OrtRunOptions> {
743  explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
744  RunOptions(); ///< Wraps OrtApi::CreateRunOptions
745 
746  RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
747  int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
748 
749  RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
750  int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
751 
752  RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
753  const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
754 
755  RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
756 
757  /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
758  *
759  * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
760  * Wraps OrtApi::RunOptionsSetTerminate
761  */
762  RunOptions& SetTerminate();
763 
764  /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
765  *
766  * Wraps OrtApi::RunOptionsUnsetTerminate
767  */
768  RunOptions& UnsetTerminate();
769 };
770 
771 namespace detail {
772 // Utility function that returns a SessionOption config entry key for a specific custom operator.
773 // Ex: custom_op.[custom_op_name].[config]
774 std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
775 } // namespace detail
776 
777 /// <summary>
778 /// Class that represents session configuration entries for one or more custom operators.
779 ///
780 /// Example:
781 /// Ort::CustomOpConfigs op_configs;
782 /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
783 ///
784 /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
785 /// </summary>
787  CustomOpConfigs() = default;
788  ~CustomOpConfigs() = default;
789  CustomOpConfigs(const CustomOpConfigs&) = default;
790  CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
791  CustomOpConfigs(CustomOpConfigs&& o) = default;
792  CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
793 
794  /** \brief Adds a session configuration entry/value for a specific custom operator.
795  *
796  * \param custom_op_name The name of the custom operator for which to add a configuration entry.
797  * Must match the name returned by the CustomOp's GetName() method.
798  * \param config_key The name of the configuration entry.
799  * \param config_value The value of the configuration entry.
800  * \return A reference to this object to enable call chaining.
801  */
802  CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
803 
804  /** \brief Returns a flattened map of custom operator configuration entries and their values.
805  *
806  * The keys has been flattened to include both the custom operator name and the configuration entry key name.
807  * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
808  * {"my_op.key", "value"}.
809  *
810  * \return An unordered map of flattened configurations.
811  */
812  const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
813 
814  private:
815  std::unordered_map<std::string, std::string> flat_configs_;
816 };
817 
818 /** \brief Options object used when creating a new Session object
819  *
820  * Wraps ::OrtSessionOptions object and methods
821  */
822 
823 struct SessionOptions;
824 
825 namespace detail {
826 // we separate const-only methods because passing const ptr to non-const methods
827 // is only discovered when inline methods are compiled which is counter-intuitive
828 template <typename T>
830  using B = Base<T>;
831  using B::B;
832 
833  SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
834 
835  std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
836  bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
837  std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
838 };
839 
840 template <typename T>
843  using B::B;
844 
845  SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
846  SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
847  SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
848  SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
849 
850  SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
851  SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
852 
853  SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
854 
855  SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
856  SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
857 
858  SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
859 
860  SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
861  SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
862 
863  SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
864 
865  SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
866  SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
867 
868  SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
869 
870  SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
871 
872  SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
873 
874  SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
875  SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
876 
877  SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
878  SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
879  SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
880  SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
881  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
882  SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
883  SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
884  SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
885  SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
886  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
887  SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
888  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
889  SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
890  /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
891  SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
892  const std::unordered_map<std::string, std::string>& provider_options = {});
893 
894  SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
895  SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
896  SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
897 
898  ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
899  ///< The custom operator configurations are optional. If provided, custom operator configs are set via
900  ///< OrtApi::AddSessionConfigEntry.
901  SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
902 
903  SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
904 };
905 } // namespace detail
906 
909 
910 /** \brief Wrapper around ::OrtSessionOptions
911  *
912  */
913 struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
914  explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
915  SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
916  explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
917  UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
918  ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
919 };
920 
921 /** \brief Wrapper around ::OrtModelMetadata
922  *
923  */
924 struct ModelMetadata : detail::Base<OrtModelMetadata> {
925  explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
926  explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
927 
928  /** \brief Returns a copy of the producer name.
929  *
930  * \param allocator to allocate memory for the copy of the name returned
931  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
932  * The OrtAllocator instances must be valid at the point of memory release.
933  */
934  AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
935 
936  /** \brief Returns a copy of the graph name.
937  *
938  * \param allocator to allocate memory for the copy of the name returned
939  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
940  * The OrtAllocator instances must be valid at the point of memory release.
941  */
942  AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
943 
944  /** \brief Returns a copy of the domain name.
945  *
946  * \param allocator to allocate memory for the copy of the name returned
947  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
948  * The OrtAllocator instances must be valid at the point of memory release.
949  */
950  AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
951 
952  /** \brief Returns a copy of the description.
953  *
954  * \param allocator to allocate memory for the copy of the string returned
955  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
956  * The OrtAllocator instances must be valid at the point of memory release.
957  */
958  AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
959 
960  /** \brief Returns a copy of the graph description.
961  *
962  * \param allocator to allocate memory for the copy of the string returned
963  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
964  * The OrtAllocator instances must be valid at the point of memory release.
965  */
966  AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
967 
968  /** \brief Returns a vector of copies of the custom metadata keys.
969  *
970  * \param allocator to allocate memory for the copy of the string returned
971  * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
972  * The OrtAllocator instance must be valid at the point of memory release.
973  */
974  std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
975 
976  /** \brief Looks up a value by a key in the Custom Metadata map
977  *
978  * \param key zero terminated string key to lookup
979  * \param allocator to allocate memory for the copy of the string returned
980  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
981  * maybe nullptr if key is not found.
982  *
983  * The OrtAllocator instances must be valid at the point of memory release.
984  */
985  AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
986 
987  int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
988 };
989 
990 struct IoBinding;
991 
992 namespace detail {
993 
994 // we separate const-only methods because passing const ptr to non-const methods
995 // is only discovered when inline methods are compiled which is counter-intuitive
996 template <typename T>
997 struct ConstSessionImpl : Base<T> {
998  using B = Base<T>;
999  using B::B;
1000 
1001  size_t GetInputCount() const; ///< Returns the number of model inputs
1002  size_t GetOutputCount() const; ///< Returns the number of model outputs
1003  size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
1004 
1005  /** \brief Returns a copy of input name at the specified index.
1006  *
1007  * \param index must less than the value returned by GetInputCount()
1008  * \param allocator to allocate memory for the copy of the name returned
1009  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1010  * The OrtAllocator instances must be valid at the point of memory release.
1011  */
1012  AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
1013 
1014  /** \brief Returns a copy of output name at then specified index.
1015  *
1016  * \param index must less than the value returned by GetOutputCount()
1017  * \param allocator to allocate memory for the copy of the name returned
1018  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1019  * The OrtAllocator instances must be valid at the point of memory release.
1020  */
1021  AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
1022 
1023  /** \brief Returns a copy of the overridable initializer name at then specified index.
1024  *
1025  * \param index must less than the value returned by GetOverridableInitializerCount()
1026  * \param allocator to allocate memory for the copy of the name returned
1027  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1028  * The OrtAllocator instances must be valid at the point of memory release.
1029  */
1030  AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
1031 
1032  uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
1033  ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
1034 
1035  TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
1036  TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
1037  TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
1038 };
1039 
1040 template <typename T>
1043  using B::B;
1044 
1045  /** \brief Run the model returning results in an Ort allocated vector.
1046  *
1047  * Wraps OrtApi::Run
1048  *
1049  * The caller provides a list of inputs and a list of the desired outputs to return.
1050  *
1051  * See the output logs for more information on warnings/errors that occur while processing the model.
1052  * Common errors are.. (TODO)
1053  *
1054  * \param[in] run_options
1055  * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
1056  * \param[in] input_values Array of Value objects of length input_count that is the list of input values
1057  * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
1058  * \param[in] output_names Array of C style strings of length output_count that is the list of output names
1059  * \param[in] output_count Number of outputs (the size of the output_names array)
1060  * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
1061  */
1062  std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1063  const char* const* output_names, size_t output_count);
1064 
1065  /** \brief Run the model returning results in user provided outputs
1066  * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
1067  */
1068  void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1069  const char* const* output_names, Value* output_values, size_t output_count);
1070 
1071  void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
1072 
1073  /** \brief Run the model asynchronously in a thread owned by intra op thread pool
1074  *
1075  * Wraps OrtApi::RunAsync
1076  *
1077  * \param[in] run_options
1078  * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
1079  * \param[in] input_values Array of Value objects of length input_count
1080  * \param[in] input_count Number of elements in the input_names and inputs arrays
1081  * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
1082  * \param[out] output_values Array of provided Values to be filled with outputs.
1083  * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
1084  * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
1085  * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
1086  * NOTE: it is customer's duty to finally release output_values and each of its member,
1087  * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
1088  * \param[in] output_count Number of elements in the output_names and outputs array
1089  * \param[in] callback Callback function on model run completion
1090  * \param[in] user_data User data that pass back to the callback
1091  */
1092  void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1093  const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1094 
1095  /** \brief End profiling and return a copy of the profiling file name.
1096  *
1097  * \param allocator to allocate memory for the copy of the string returned
1098  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1099  * The OrtAllocator instances must be valid at the point of memory release.
1100  */
1101  AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
1102 };
1103 
1104 } // namespace detail
1105 
1108 
1109 /** \brief Wrapper around ::OrtSession
1110  *
1111  */
1112 struct Session : detail::SessionImpl<OrtSession> {
1113  explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
1114  Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
1115  Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1116  OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
1117  Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
1118  Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1119  OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
1120 
1121  ConstSession GetConst() const { return ConstSession{this->p_}; }
1122  UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1123 };
1124 
1125 namespace detail {
1126 template <typename T>
1127 struct MemoryInfoImpl : Base<T> {
1128  using B = Base<T>;
1129  using B::B;
1130 
1131  std::string GetAllocatorName() const;
1132  OrtAllocatorType GetAllocatorType() const;
1133  int GetDeviceId() const;
1134  OrtMemoryInfoDeviceType GetDeviceType() const;
1135  OrtMemType GetMemoryType() const;
1136 
1137  template <typename U>
1138  bool operator==(const MemoryInfoImpl<U>& o) const;
1139 };
1140 } // namespace detail
1141 
1142 // Const object holder that does not own the underlying object
1144 
1145 /** \brief Wrapper around ::OrtMemoryInfo
1146  *
1147  */
1148 struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
1149  static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
1150  explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
1151  explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
1152  MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1153  ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1154 };
1155 
1156 namespace detail {
1157 template <typename T>
1159  using B = Base<T>;
1160  using B::B;
1161 
1162  ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
1163  size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
1164 
1165  size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
1166 
1167  /** \deprecated use GetShape() returning std::vector
1168  * [[deprecated]]
1169  * This interface is unsafe to use
1170  */
1171  [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
1172 
1173  void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
1174 
1175  std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
1176 };
1177 
1178 } // namespace detail
1179 
1181 
1182 /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
1183  *
1184  */
1185 struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
1186  explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
1187  explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
1189 };
1190 
1191 namespace detail {
1192 template <typename T>
1194  using B = Base<T>;
1195  using B::B;
1196  TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
1197 };
1198 
1199 } // namespace detail
1200 
1202 
1203 /** \brief Wrapper around ::OrtSequenceTypeInfo
1204  *
1205  */
1206 struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
1207  explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
1208  explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
1209  ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
1210 };
1211 
1212 namespace detail {
1213 template <typename T>
1215  using B = Base<T>;
1216  using B::B;
1217  TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo
1218 };
1219 
1220 } // namespace detail
1221 
1222 // This is always owned by the TypeInfo and can only be obtained from it.
1224 
1225 namespace detail {
1226 template <typename T>
1228  using B = Base<T>;
1229  using B::B;
1230  ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
1231  TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
1232 };
1233 
1234 } // namespace detail
1235 
1237 
1238 /** \brief Wrapper around ::OrtMapTypeInfo
1239  *
1240  */
1241 struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
1242  explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
1243  explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
1244  ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
1245 };
1246 
1247 namespace detail {
1248 template <typename T>
1250  using B = Base<T>;
1251  using B::B;
1252 
1253  ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
1254  ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
1255  ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
1256  ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo
1257 
1258  ONNXType GetONNXType() const;
1259 };
1260 } // namespace detail
1261 
1262 /// <summary>
1263 /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
1264 /// Provides access to const OrtTypeInfo APIs.
1265 /// </summary>
1267 
1268 /// <summary>
1269 /// Type information that may contain either TensorTypeAndShapeInfo or
1270 /// the information about contained sequence or map depending on the ONNXType.
1271 /// </summary>
1272 struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
1273  explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
1274  explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
1275 
1276  ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
1277 };
1278 
1279 namespace detail {
1280 // This structure is used to feed sparse tensor values
1281 // information for use with FillSparseTensor<Format>() API
1282 // if the data type for the sparse tensor values is numeric
1283 // use data.p_data, otherwise, use data.str pointer to feed
1284 // values. data.str is an array of const char* that are zero terminated.
1285 // number of strings in the array must match shape size.
1286 // For fully sparse tensors use shape {0} and set p_data/str
1287 // to nullptr.
1289  const int64_t* values_shape;
1291  union {
1292  const void* p_data;
1293  const char** str;
1294  } data;
1295 };
1296 
1297 // Provides a way to pass shape in a single
1298 // argument
1299 struct Shape {
1300  const int64_t* shape;
1301  size_t shape_len;
1302 };
1303 
1304 template <typename T>
1305 struct ConstValueImpl : Base<T> {
1306  using B = Base<T>;
1307  using B::B;
1308 
1309  /// <summary>
1310  /// Obtains a pointer to a user defined data for experimental purposes
1311  /// </summary>
1312  template <typename R>
1313  void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
1314 
1315  bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
1316  bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
1317 
1318  size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1319  Value GetValue(int index, OrtAllocator* allocator) const;
1320 
1321  /// <summary>
1322  /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
1323  /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
1324  /// for allocating necessary memory and calling GetStringTensorContent().
1325  /// </summary>
1326  /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
1327  size_t GetStringTensorDataLength() const;
1328 
1329  /// <summary>
1330  /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
1331  /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
1332  /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
1333  /// strings.
1334  ///
1335  /// Strings are always assumed to be on CPU, no X-device copy.
1336  /// </summary>
1337  /// <param name="buffer">user allocated buffer</param>
1338  /// <param name="buffer_length">length in bytes of the allocated buffer</param>
1339  /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
1340  /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
1341  /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
1342  /// for sparse tensors</param>
1343  void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1344 
1345  /// <summary>
1346  /// Returns a const typed pointer to the tensor contained data.
1347  /// No type checking is performed, the caller must ensure the type matches the tensor type.
1348  /// </summary>
1349  /// <typeparam name="T"></typeparam>
1350  /// <returns>const pointer to data, no copies made</returns>
1351  template <typename R>
1352  const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1353 
1354  /// <summary>
1355  /// Returns a non-typed pointer to a tensor contained data.
1356  /// </summary>
1357  /// <returns>const pointer to data, no copies made</returns>
1358  const void* GetTensorRawData() const;
1359 
1360  /// <summary>
1361  /// The API returns type information for data contained in a tensor. For sparse
1362  /// tensors it returns type information for contained non-zero values.
1363  /// It returns dense shape for sparse tensors.
1364  /// </summary>
1365  /// <returns>TypeInfo</returns>
1366  TypeInfo GetTypeInfo() const;
1367 
1368  /// <summary>
1369  /// The API returns type information for data contained in a tensor. For sparse
1370  /// tensors it returns type information for contained non-zero values.
1371  /// It returns dense shape for sparse tensors.
1372  /// </summary>
1373  /// <returns>TensorTypeAndShapeInfo</returns>
1374  TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1375 
1376  /// <summary>
1377  /// This API returns information about the memory allocation used to hold data.
1378  /// </summary>
1379  /// <returns>Non owning instance of MemoryInfo</returns>
1380  ConstMemoryInfo GetTensorMemoryInfo() const;
1381 
1382  /// <summary>
1383  /// The API copies UTF-8 encoded bytes for the requested string element
1384  /// contained within a tensor or a sparse tensor into a provided buffer.
1385  /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1386  /// </summary>
1387  /// <param name="buffer_length"></param>
1388  /// <param name="element_index"></param>
1389  /// <param name="buffer"></param>
1390  void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1391 
1392  /// <summary>
1393  /// Returns string tensor UTF-8 encoded string element.
1394  /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer.
1395  /// </summary>
1396  /// <param name="element_index"></param>
1397  /// <returns>std::string</returns>
1398  std::string GetStringTensorElement(size_t element_index) const;
1399 
1400  /// <summary>
1401  /// The API returns a byte length of UTF-8 encoded string element
1402  /// contained in either a tensor or a spare tensor values.
1403  /// </summary>
1404  /// <param name="element_index"></param>
1405  /// <returns>byte length for the specified string element</returns>
1406  size_t GetStringTensorElementLength(size_t element_index) const;
1407 
1408 #if !defined(DISABLE_SPARSE_TENSORS)
1409  /// <summary>
1410  /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1411  /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1412  /// the value returned is ORT_SPARSE_UNDEFINED.
1413  /// </summary>
1414  /// <returns>Format enum</returns>
1415  OrtSparseFormat GetSparseFormat() const;
1416 
1417  /// <summary>
1418  /// The API returns type and shape information for stored non-zero values of the
1419  /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1420  /// </summary>
1421  /// <returns>TensorTypeAndShapeInfo values information</returns>
1422  TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1423 
1424  /// <summary>
1425  /// The API returns type and shape information for the specified indices. Each supported
1426  /// indices have their own enum values even if a give format has more than one kind of indices.
1427  /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1428  /// </summary>
1429  /// <param name="format">enum requested</param>
1430  /// <returns>type and shape information</returns>
1431  TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1432 
1433  /// <summary>
1434  /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1435  /// a convenience data type casting on the return type pointer. Make sure you are requesting
1436  /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1437  /// </summary>
1438  /// <typeparam name="T">type to cast to</typeparam>
1439  /// <param name="indices_format">requested indices kind</param>
1440  /// <param name="num_indices">number of indices entries</param>
1441  /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1442  template <typename R>
1443  const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1444 
1445  /// <summary>
1446  /// Returns true if the OrtValue contains a sparse tensor
1447  /// </summary>
1448  /// <returns></returns>
1449  bool IsSparseTensor() const;
1450 
1451  /// <summary>
1452  /// The API returns a pointer to an internal buffer of the sparse tensor
1453  /// containing non-zero values. The API merely does casting. Make sure you
1454  /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1455  /// first.
1456  /// </summary>
1457  /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1458  /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1459  template <typename R>
1460  const R* GetSparseTensorValues() const;
1461 
1462 #endif
1463 };
1464 
1465 template <typename T>
1468  using B::B;
1469 
1470  /// <summary>
1471  /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1472  /// No type checking is performed, the caller must ensure the type matches the tensor type.
1473  /// </summary>
1474  /// <returns>non-const pointer to data, no copies made</returns>
1475  template <typename R>
1476  R* GetTensorMutableData();
1477 
1478  /// <summary>
1479  /// Returns a non-typed non-const pointer to a tensor contained data.
1480  /// </summary>
1481  /// <returns>pointer to data, no copies made</returns>
1482  void* GetTensorMutableRawData();
1483 
1484  /// <summary>
1485  // Obtain a reference to an element of data at the location specified
1486  /// by the vector of dims.
1487  /// </summary>
1488  /// <typeparam name="R"></typeparam>
1489  /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1490  /// <returns></returns>
1491  template <typename R>
1492  R& At(const std::vector<int64_t>& location);
1493 
1494  /// <summary>
1495  /// Set all strings at once in a string tensor
1496  /// </summary>
1497  /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1498  /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1499  void FillStringTensor(const char* const* s, size_t s_len);
1500 
1501  /// <summary>
1502  /// Set a single string in a string tensor
1503  /// </summary>
1504  /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1505  /// <param name="index">[in] Index of the string in the tensor to set</param>
1506  void FillStringTensorElement(const char* s, size_t index);
1507 
1508  /// <summary>
1509  /// Allocate if necessary and obtain a pointer to a UTF-8
1510  /// encoded string element buffer indexed by the flat element index,
1511  /// of the specified length.
1512  ///
1513  /// This API is for advanced usage. It avoids a need to construct
1514  /// an auxiliary array of string pointers, and allows to write data directly
1515  /// (do not zero terminate).
1516  /// </summary>
1517  /// <param name="index"></param>
1518  /// <param name="buffer_length"></param>
1519  /// <returns>a pointer to a writable buffer</returns>
1520  char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
1521 
1522 #if !defined(DISABLE_SPARSE_TENSORS)
1523  /// <summary>
1524  /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1525  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1526  /// allocated buffers lifespan must eclipse that of the OrtValue.
1527  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1528  /// </summary>
1529  /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1530  /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1531  void UseCooIndices(int64_t* indices_data, size_t indices_num);
1532 
1533  /// <summary>
1534  /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1535  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1536  /// allocated buffers lifespan must eclipse that of the OrtValue.
1537  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1538  /// </summary>
1539  /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1540  /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1541  /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1542  /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1543  void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1544 
1545  /// <summary>
1546  /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1547  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1548  /// allocated buffers lifespan must eclipse that of the OrtValue.
1549  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1550  /// </summary>
1551  /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1552  /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1553  void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1554 
1555  /// <summary>
1556  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1557  /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1558  /// at difference device than the allocator, a X-device copy will be performed if possible.
1559  /// </summary>
1560  /// <param name="data_mem_info">specified buffer memory description</param>
1561  /// <param name="values_param">values buffer information.</param>
1562  /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1563  /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1564  void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1565  const int64_t* indices_data, size_t indices_num);
1566 
1567  /// <summary>
1568  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1569  /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1570  /// at difference device than the allocator, a X-device copy will be performed if possible.
1571  /// </summary>
1572  /// <param name="data_mem_info">specified buffer memory description</param>
1573  /// <param name="values">values buffer information</param>
1574  /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1575  /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1576  /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1577  /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1578  void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1580  const int64_t* inner_indices_data, size_t inner_indices_num,
1581  const int64_t* outer_indices_data, size_t outer_indices_num);
1582 
1583  /// <summary>
1584  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1585  /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1586  /// at difference device than the allocator, a X-device copy will be performed if possible.
1587  /// </summary>
1588  /// <param name="data_mem_info">specified buffer memory description</param>
1589  /// <param name="values">values buffer information</param>
1590  /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1591  /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1592  void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1594  const Shape& indices_shape,
1595  const int32_t* indices_data);
1596 
1597 #endif
1598 };
1599 
1600 } // namespace detail
1601 
1604 
1605 /** \brief Wrapper around ::OrtValue
1606  *
1607  */
1608 struct Value : detail::ValueImpl<OrtValue> {
1612 
1613  explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1614  explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1615  Value(Value&&) = default;
1616  Value& operator=(Value&&) = default;
1617 
1618  ConstValue GetConst() const { return ConstValue{this->p_}; }
1619  UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1620 
1621  /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1622  * \tparam T The numeric datatype. This API is not suitable for strings.
1623  * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1624  * \param p_data Pointer to the data buffer.
1625  * \param p_data_element_count The number of elements in the data buffer.
1626  * \param shape Pointer to the tensor shape dimensions.
1627  * \param shape_len The number of tensor shape dimensions.
1628  */
1629  template <typename T>
1630  static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1631 
1632  /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1633  *
1634  * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1635  * \param p_data Pointer to the data buffer.
1636  * \param p_data_byte_count The number of bytes in the data buffer.
1637  * \param shape Pointer to the tensor shape dimensions.
1638  * \param shape_len The number of tensor shape dimensions.
1639  * \param type The data type.
1640  */
1641  static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1642  ONNXTensorElementDataType type);
1643 
1644  /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1645  * This overload will allocate the buffer for the tensor according to the supplied shape and data type.
1646  * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1647  * The input data would need to be copied into the allocated buffer.
1648  * This API is not suitable for strings.
1649  *
1650  * \tparam T The numeric datatype. This API is not suitable for strings.
1651  * \param allocator The allocator to use.
1652  * \param shape Pointer to the tensor shape dimensions.
1653  * \param shape_len The number of tensor shape dimensions.
1654  */
1655  template <typename T>
1656  static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1657 
1658  /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator.
1659  * Wraps OrtApi::CreateTensorAsOrtValue.
1660  * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1661  * The input data would need to be copied into the allocated buffer.
1662  * This API is not suitable for strings.
1663  *
1664  * \param allocator The allocator to use.
1665  * \param shape Pointer to the tensor shape dimensions.
1666  * \param shape_len The number of tensor shape dimensions.
1667  * \param type The data type.
1668  */
1669  static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1670 
1671  /** \brief Creates an OrtValue with a Map Onnx type representation.
1672  * The API would ref-count the supplied OrtValues and they will be released
1673  * when the returned OrtValue is released. The caller may release keys and values after the call
1674  * returns.
1675  *
1676  * \param keys an OrtValue containing a tensor with primitive data type keys.
1677  * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values.
1678  */
1679  static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue
1680 
1681  /** \brief Creates an OrtValue with a Sequence Onnx type representation.
1682  * The API would ref-count the supplied OrtValues and they will be released
1683  * when the returned OrtValue is released. The caller may release the values after the call
1684  * returns.
1685  *
1686  * \param values a vector of OrtValues that must have the same Onnx value type.
1687  */
1688  static Value CreateSequence(const std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1689 
1690  /** \brief Creates an OrtValue wrapping an Opaque type.
1691  * This is used for experimental support of non-tensor types.
1692  *
1693  * \tparam T - the type of the value.
1694  * \param domain - zero terminated utf-8 string. Domain of the type.
1695  * \param type_name - zero terminated utf-8 string. Name of the type.
1696  * \param value - the value to be wrapped.
1697  */
1698  template <typename T>
1699  static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
1700 
1701 #if !defined(DISABLE_SPARSE_TENSORS)
1702  /// <summary>
1703  /// This is a simple forwarding method to the other overload that helps deducing
1704  /// data type enum value from the type of the buffer.
1705  /// </summary>
1706  /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1707  /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1708  /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1709  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1710  /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1711  /// <returns></returns>
1712  template <typename T>
1713  static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1714  const Shape& values_shape);
1715 
1716  /// <summary>
1717  /// Creates an OrtValue instance containing SparseTensor. This constructs
1718  /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1719  /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1720  /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1721  /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1722  /// to supply a sparse format specific indices.
1723  /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1724  /// can be properly copied into the allocated buffer.
1725  /// </summary>
1726  /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1727  /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1728  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1729  /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1730  /// <param name="type">data type</param>
1731  /// <returns>Ort::Value instance containing SparseTensor</returns>
1732  static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1733  const Shape& values_shape, ONNXTensorElementDataType type);
1734 
1735  /// <summary>
1736  /// This is a simple forwarding method to the below CreateSparseTensor.
1737  /// This helps to specify data type enum in terms of C++ data type.
1738  /// Use CreateSparseTensor<T>
1739  /// </summary>
1740  /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1741  /// <param name="allocator">allocator to use</param>
1742  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1743  /// <returns>Ort::Value</returns>
1744  template <typename T>
1745  static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1746 
1747  /// <summary>
1748  /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1749  /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1750  /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1751  /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1752  /// strings.
1753  /// </summary>
1754  /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1755  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1756  /// <param name="type">data type</param>
1757  /// <returns>an instance of Ort::Value</returns>
1758  static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1759 
1760 #endif // !defined(DISABLE_SPARSE_TENSORS)
1761 };
1762 
1763 /// <summary>
1764 /// Represents native memory allocation coming from one of the
1765 /// OrtAllocators registered with OnnxRuntime.
1766 /// Use it to wrap an allocation made by an allocator
1767 /// so it can be automatically released when no longer needed.
1768 /// </summary>
1770  MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1771  ~MemoryAllocation();
1772  MemoryAllocation(const MemoryAllocation&) = delete;
1773  MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1774  MemoryAllocation(MemoryAllocation&&) noexcept;
1776 
1777  void* get() { return p_; }
1778  size_t size() const { return size_; }
1779 
1780  private:
1781  OrtAllocator* allocator_;
1782  void* p_;
1783  size_t size_;
1784 };
1785 
1786 namespace detail {
1787 template <typename T>
1788 struct AllocatorImpl : Base<T> {
1789  using B = Base<T>;
1790  using B::B;
1791 
1792  void* Alloc(size_t size);
1793  MemoryAllocation GetAllocation(size_t size);
1794  void Free(void* p);
1795  ConstMemoryInfo GetInfo() const;
1796 };
1797 
1798 } // namespace detail
1799 
1800 /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1801  *
1802  */
1803 struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1804  explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1806 };
1807 
1808 /** \brief Wrapper around ::OrtAllocator
1809  *
1810  */
1811 struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1812  explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1813  Allocator(const Session& session, const OrtMemoryInfo*);
1814 };
1815 
1817 
1818 namespace detail {
1819 namespace binding_utils {
1820 // Bring these out of template
1821 std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1822 std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1823 } // namespace binding_utils
1824 
1825 template <typename T>
1827  using B = Base<T>;
1828  using B::B;
1829 
1830  std::vector<std::string> GetOutputNames() const;
1831  std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1832  std::vector<Value> GetOutputValues() const;
1833  std::vector<Value> GetOutputValues(OrtAllocator*) const;
1834 };
1835 
1836 template <typename T>
1839  using B::B;
1840 
1841  void BindInput(const char* name, const Value&);
1842  void BindOutput(const char* name, const Value&);
1843  void BindOutput(const char* name, const OrtMemoryInfo*);
1844  void ClearBoundInputs();
1845  void ClearBoundOutputs();
1846  void SynchronizeInputs();
1847  void SynchronizeOutputs();
1848 };
1849 
1850 } // namespace detail
1851 
1854 
1855 /** \brief Wrapper around ::OrtIoBinding
1856  *
1857  */
1858 struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1859  explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1860  explicit IoBinding(Session& session);
1861  ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1862  UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1863 };
1864 
1865 /*! \struct Ort::ArenaCfg
1866  * \brief it is a structure that represents the configuration of an arena based allocator
1867  * \details Please see docs/C_API.md for details
1868  */
1869 struct ArenaCfg : detail::Base<OrtArenaCfg> {
1870  explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1871  /**
1872  * Wraps OrtApi::CreateArenaCfg
1873  * \param max_mem - use 0 to allow ORT to choose the default
1874  * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1875  * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1876  * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1877  * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1878  */
1879  ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1880 };
1881 
1882 //
1883 // Custom OPs (only needed to implement custom OPs)
1884 //
1885 
1886 /// <summary>
1887 /// This struct provides life time management for custom op attribute
1888 /// </summary>
1889 struct OpAttr : detail::Base<OrtOpAttr> {
1890  OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1891 };
1892 
1893 /**
1894  * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails.
1895  * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1896  *
1897  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1898  * \param message_severity The logging severity level of the message.
1899  * \param message A null-terminated UTF-8 message to log.
1900  */
1901 #define ORT_CXX_LOG(logger, message_severity, message) \
1902  do { \
1903  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1904  Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1905  static_cast<const char*>(__FUNCTION__), message)); \
1906  } \
1907  } while (false)
1908 
1909 /**
1910  * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored.
1911  * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1912  *
1913  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1914  * \param message_severity The logging severity level of the message.
1915  * \param message A null-terminated UTF-8 message to log.
1916  */
1917 #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
1918  do { \
1919  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1920  static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1921  static_cast<const char*>(__FUNCTION__), message)); \
1922  } \
1923  } while (false)
1924 
1925 /**
1926  * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if
1927  * OrtApi::Logger_LogMessage fails or if a formatting error occurs.
1928  * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
1929  *
1930  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1931  * \param message_severity The logging severity level of the message.
1932  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
1933  * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
1934  * \param ... Zero or more variadic arguments referenced by the format string.
1935  */
1936 #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
1937  do { \
1938  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1939  Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1940  static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1941  } \
1942  } while (false)
1943 
1944 /**
1945  * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors
1946  * are silently ignored.
1947  * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
1948  *
1949  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1950  * \param message_severity The logging severity level of the message.
1951  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
1952  * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
1953  * \param ... Zero or more variadic arguments referenced by the format string.
1954  */
1955 #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
1956  do { \
1957  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1958  static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1959  static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1960  } \
1961  } while (false)
1962 
1963 /// <summary>
1964 /// This class represents an ONNX Runtime logger that can be used to log information with an
1965 /// associated severity level and source code location (file path, line number, function name).
1966 ///
1967 /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger().
1968 /// Instances of Ort::Logger are the size of two pointers and can be passed by value.
1969 ///
1970 /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite
1971 /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API.
1972 /// </summary>
1973 struct Logger {
1974  /**
1975  * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
1976  */
1977  Logger() = default;
1978 
1979  /**
1980  * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
1981  */
1982  explicit Logger(std::nullptr_t) {}
1983 
1984  /**
1985  * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling
1986  * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails.
1987  *
1988  * \param logger The ::OrtLogger to wrap.
1989  */
1990  explicit Logger(const OrtLogger* logger);
1991 
1992  ~Logger() = default;
1993 
1994  Logger(const Logger&) = default;
1995  Logger& operator=(const Logger&) = default;
1996 
1997  Logger(Logger&& v) noexcept = default;
1998  Logger& operator=(Logger&& v) noexcept = default;
1999 
2000  /**
2001  * Returns the logger's current severity level from the cached member.
2002  *
2003  * \return The current ::OrtLoggingLevel.
2004  */
2005  OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
2006 
2007  /**
2008  * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT
2009  * macros to properly set the source code location and to use the cached severity level to potentially bypass
2010  * calls to the underlying C API.
2011  *
2012  * \param log_severity_level The message's logging severity level.
2013  * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2014  * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2015  * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2016  * \param message The message to log.
2017  * \return A Ort::Status value to indicate error or success.
2018  */
2019  Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2020  const char* func_name, const char* message) const noexcept;
2021 
2022  /**
2023  * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT
2024  * macros to properly set the source code location and to use the cached severity level to potentially bypass
2025  * calls to the underlying C API. Returns an error status if a formatting error occurs.
2026  *
2027  * \param log_severity_level The message's logging severity level.
2028  * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2029  * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2030  * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2031  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
2032  * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
2033  * \param args Zero or more variadic arguments referenced by the format string.
2034  * \return A Ort::Status value to indicate error or success.
2035  */
2036  template <typename... Args>
2037  Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2038  const char* func_name, const char* format, Args&&... args) const noexcept;
2039 
2040  private:
2041  const OrtLogger* logger_{};
2042  OrtLoggingLevel cached_severity_level_{};
2043 };
2044 
2045 /// <summary>
2046 /// This class wraps a raw pointer OrtKernelContext* that is being passed
2047 /// to the custom kernel Compute() method. Use it to safely access context
2048 /// attributes, input and output parameters with exception safety guarantees.
2049 /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
2050 /// </summary>
2052  explicit KernelContext(OrtKernelContext* context);
2053  size_t GetInputCount() const;
2054  size_t GetOutputCount() const;
2055  ConstValue GetInput(size_t index) const;
2056  UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2057  UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2058  void* GetGPUComputeStream() const;
2059  Logger GetLogger() const;
2060  OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
2061  OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2062  void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2063 
2064  private:
2065  OrtKernelContext* ctx_;
2066 };
2067 
2068 struct KernelInfo;
2069 
2070 namespace detail {
2071 namespace attr_utils {
2072 void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2073 void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2074 void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2075 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2076 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2077 } // namespace attr_utils
2078 
2079 template <typename T>
2080 struct KernelInfoImpl : Base<T> {
2081  using B = Base<T>;
2082  using B::B;
2083 
2084  KernelInfo Copy() const;
2085 
2086  template <typename R> // R is only implemented for float, int64_t, and string
2087  R GetAttribute(const char* name) const {
2088  R val;
2089  attr_utils::GetAttr(this->p_, name, val);
2090  return val;
2091  }
2092 
2093  template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
2094  std::vector<R> GetAttributes(const char* name) const {
2095  std::vector<R> result;
2096  attr_utils::GetAttrs(this->p_, name, result);
2097  return result;
2098  }
2099 
2100  Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
2101 
2102  size_t GetInputCount() const;
2103  size_t GetOutputCount() const;
2104 
2105  std::string GetInputName(size_t index) const;
2106  std::string GetOutputName(size_t index) const;
2107 
2108  TypeInfo GetInputTypeInfo(size_t index) const;
2109  TypeInfo GetOutputTypeInfo(size_t index) const;
2110 
2111  ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
2112 
2113  std::string GetNodeName() const;
2114  Logger GetLogger() const;
2115 };
2116 
2117 } // namespace detail
2118 
2120 
2121 /// <summary>
2122 /// This struct owns the OrtKernInfo* pointer when a copy is made.
2123 /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
2124 /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
2125 /// so it does not destroy the pointer the kernel does not own.
2126 /// </summary>
2127 struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
2128  explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
2129  explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
2130  ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
2131 };
2132 
2133 /// <summary>
2134 /// Create and own custom defined operation.
2135 /// </summary>
2136 struct Op : detail::Base<OrtOp> {
2137  explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
2138 
2139  explicit Op(OrtOp*); ///< Take ownership of the OrtOp
2140 
2141  static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
2142  int version, const char** type_constraint_names,
2143  const ONNXTensorElementDataType* type_constraint_values,
2144  size_t type_constraint_count,
2145  const OpAttr* attr_values,
2146  size_t attr_count,
2147  size_t input_count, size_t output_count);
2148 
2149  void Invoke(const OrtKernelContext* context,
2150  const Value* input_values,
2151  size_t input_count,
2152  Value* output_values,
2153  size_t output_count);
2154 
2155  // For easier refactoring
2156  void Invoke(const OrtKernelContext* context,
2157  const OrtValue* const* input_values,
2158  size_t input_count,
2159  OrtValue* const* output_values,
2160  size_t output_count);
2161 };
2162 
2163 /// <summary>
2164 /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
2165 /// </summary>
2168  SymbolicInteger(int64_t i) : i_(i), is_int_(true){};
2169  SymbolicInteger(const char* s) : s_(s), is_int_(false){};
2170  SymbolicInteger(const SymbolicInteger&) = default;
2171  SymbolicInteger(SymbolicInteger&&) = default;
2172 
2173  SymbolicInteger& operator=(const SymbolicInteger&) = default;
2175 
2176  bool operator==(const SymbolicInteger& dim) const {
2177  if (is_int_ == dim.is_int_) {
2178  if (is_int_) {
2179  return i_ == dim.i_;
2180  } else {
2181  return std::string{s_} == std::string{dim.s_};
2182  }
2183  }
2184  return false;
2185  }
2186 
2187  bool IsInt() const { return is_int_; }
2188  int64_t AsInt() const { return i_; }
2189  const char* AsSym() const { return s_; }
2190 
2191  static constexpr int INVALID_INT_DIM = -2;
2192 
2193  private:
2194  union {
2195  int64_t i_;
2196  const char* s_;
2197  };
2198  bool is_int_;
2199  };
2200 
2201  using Shape = std::vector<SymbolicInteger>;
2202 
2203  ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
2204 
2205  const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2206 
2207  size_t GetInputCount() const { return input_shapes_.size(); }
2208 
2209  Status SetOutputShape(size_t indice, const Shape& shape);
2210 
2211  int64_t GetAttrInt(const char* attr_name);
2212 
2213  using Ints = std::vector<int64_t>;
2214  Ints GetAttrInts(const char* attr_name);
2215 
2216  float GetAttrFloat(const char* attr_name);
2217 
2218  using Floats = std::vector<float>;
2219  Floats GetAttrFloats(const char* attr_name);
2220 
2221  std::string GetAttrString(const char* attr_name);
2222 
2223  using Strings = std::vector<std::string>;
2224  Strings GetAttrStrings(const char* attr_name);
2225 
2226  private:
2227  const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
2228  const OrtApi* ort_api_;
2229  OrtShapeInferContext* ctx_;
2230  std::vector<Shape> input_shapes_;
2231 };
2232 
2234 
2235 #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2236 
2237 template <typename TOp, typename TKernel, bool WithStatus = false>
2238 struct CustomOpBase : OrtCustomOp {
2240  OrtCustomOp::version = ORT_API_VERSION;
2241  OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2242 
2243  OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2244 
2245  OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2246  OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2247  OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2248 
2249  OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2250  OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2251 
2252 #if defined(_MSC_VER) && !defined(__clang__)
2253 #pragma warning(push)
2254 #pragma warning(disable : 26409)
2255 #endif
2256  OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2257 #if defined(_MSC_VER) && !defined(__clang__)
2258 #pragma warning(pop)
2259 #endif
2260  OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2261  OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2262 
2263  OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2264  OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2265  OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2266  OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2267 #ifdef __cpp_if_constexpr
2268  if constexpr (WithStatus) {
2269 #else
2270  if (WithStatus) {
2271 #endif
2272  OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
2273  return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
2274  };
2275  OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2276  return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
2277  };
2278  } else {
2279  OrtCustomOp::CreateKernelV2 = nullptr;
2280  OrtCustomOp::KernelComputeV2 = nullptr;
2281 
2282  OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2283  OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
2284  static_cast<TKernel*>(op_kernel)->Compute(context);
2285  };
2286  }
2287 
2288  SetShapeInferFn<TOp>(0);
2289 
2290  OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2291  return static_cast<const TOp*>(this_)->start_ver_;
2292  };
2293 
2294  OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2295  return static_cast<const TOp*>(this_)->end_ver_;
2296  };
2297  }
2298 
2299  // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2300  const char* GetExecutionProviderType() const { return nullptr; }
2301 
2302  // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2303  // (inputs and outputs are required by default)
2304  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
2305  return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2306  }
2307 
2308  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
2309  return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2310  }
2311 
2312  // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
2313  OrtMemType GetInputMemoryType(size_t /*index*/) const {
2314  return OrtMemTypeDefault;
2315  }
2316 
2317  // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
2318  // should expect at least 1 argument.
2320  return 1;
2321  }
2322 
2323  // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
2324  // to a variadic input should be of the same type.
2326  return true;
2327  }
2328 
2329  // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
2330  // should produce at least 1 output value.
2332  return 1;
2333  }
2334 
2335  // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
2336  // produced by a variadic output should be of the same type.
2338  return true;
2339  }
2340 
2341  // Declare list of session config entries used by this Custom Op.
2342  // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
2343  // This default implementation returns an empty vector of config entries.
2344  std::vector<std::string> GetSessionConfigKeys() const {
2345  return std::vector<std::string>{};
2346  }
2347 
2348  template <typename C>
2349  decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2350  OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
2351  ShapeInferContext ctx(&GetApi(), ort_ctx);
2352  return C::InferOutputShape(ctx);
2353  };
2354  return {};
2355  }
2356 
2357  template <typename C>
2358  void SetShapeInferFn(...) {
2359  OrtCustomOp::InferOutputShapeFn = {};
2360  }
2361 
2362  protected:
2363  // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2364  void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2365 
2366  int start_ver_ = 1;
2367  int end_ver_ = MAX_CUSTOM_OP_END_VER;
2368 };
2369 
2370 } // namespace Ort
2371 
2372 #include "onnxruntime_cxx_inline.h"
constexpr Float8E4M3FNUZ_t() noexcept
std::vector< int64_t > Ints
UnownedSession GetUnowned() const
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id...
GLuint GLsizei const GLchar * message
Definition: glcorearb.h:2543
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Logger(std::nullptr_t)
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
SequenceTypeInfo(OrtSequenceTypeInfo *p)
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
constexpr Base(contained_type *p) noexcept
bool IsNaN() const noexcept
Tests if the value is NaN
Float16_t Abs() const noexcept
Creates an instance that represents absolute value.
std::string GetErrorMessage() const
std::vector< std::string > Strings
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
BFloat16_t Negate() const noexcept
Creates a new instance with the sign flipped.
constexpr bool operator!=(const Float8E5M2FNUZ_t &rhs) const noexcept
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used...
Value(OrtValue *p)
Used for interop with the C API.
Env(OrtEnv *p)
C Interop Helper.
std::vector< float > Floats
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Custom Op Domain.
GLboolean * data
Definition: glcorearb.h:131
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Definition: UT_ArraySet.h:1639
const GLdouble * v
Definition: glcorearb.h:837
Base & operator=(Base &&v) noexcept
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
Float16_t()=default
Default constructor
bool GetVariadicInputHomogeneity() const
constexpr Float8E5M2_t(uint8_t v) noexcept
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
ConstMemoryInfo GetConst() const
Take ownership of a pointer created by C Api.
const Shape & GetInputShape(size_t indice) const
std::vector< SymbolicInteger > Shape
Wrapper around ::OrtModelMetadata.
GLint level
Definition: glcorearb.h:108
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Wrapper around ::OrtMapTypeInfo.
constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept
GLdouble s
Definition: glad.h:3009
This struct provides life time management for custom op attribute
float ToFloatImpl() const noexcept
Converts bfloat16 to float
TypeInfo(OrtTypeInfo *p)
detail::SequenceTypeInfoImpl< detail::Unowned< const OrtSequenceTypeInfo >> ConstSequenceTypeInfo
Wrapper around OrtValue.
MapTypeInfo(OrtMapTypeInfo *p)
static bool AreZero(const Float16Impl &lhs, const Float16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
bool operator!=(const BFloat16_t &rhs) const noexcept
std::vector< R > GetAttributes(const char *name) const
**But if you need a result
Definition: thread.h:613
it is a structure that represents the configuration of an arena based allocator
Provide access to per-node attributes and input shapes, so one could compute and set output shapes...
OCIOEXPORT void LogMessage(LoggingLevel level, const char *message)
Log a message using the library logging function.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
The Env (Environment)
static const OrtApi * api_
bool IsNegative() const noexcept
Checks if the value is negative
OrtMemType GetInputMemoryType(size_t) const
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
float8e4m3fnuz (Float8 Floating Point) data type
constexpr bool operator==(const Float8E4M3FN_t &rhs) const noexcept
ConstSession GetConst() const
float8e4m3fn (Float8 Floating Point) data type
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< int64_t > &)
static constexpr uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
bool IsFinite() const noexcept
Tests if the value is finite
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
bool operator==(const BaseDimensions< T > &a, const BaseDimensions< Y > &b)
Definition: Dimensions.h:137
Wrapper around ::OrtAllocator.
bool operator==(const SymbolicInteger &dim) const
constexpr Base()=default
constexpr Float8E5M2FNUZ_t() noexcept
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
Wrapper around OrtMemoryInfo.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
bool operator<(const BFloat16_t &rhs) const noexcept
detail::MapTypeInfoImpl< detail::Unowned< const OrtMapTypeInfo >> ConstMapTypeInfo
Definition: core.h:760
Wrapper around ::OrtIoBinding.
bool IsFinite() const noexcept
Tests if the value is finite
IMATH_NAMESPACE::V2f float
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
BFloat16_t()=default
void GetAttr(const OrtKernelInfo *p, const char *name, std::string &)
OrtKernelContext * GetOrtKernelContext() const
float ToFloat() const noexcept
Converts float16 to float
Shared implementation between public and internal classes. CRTP pattern.
constexpr bool operator==(const Float8E5M2_t &rhs) const noexcept
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions >> UnownedSessionOptions
constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept
bool IsNegative() const noexcept
Checks if the value is negative
All C++ methods that can fail will throw an exception of this type.
const char * what() const noexceptoverride
A generic, discriminated value, whose type may be queried dynamically.
Definition: Value.h:44
Base(Base &&v) noexcept
typename Unowned< T >::Type contained_type
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Wrapper around ::OrtSequenceTypeInfo.
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
Definition: CLI11.h:1729
GLint GLint GLsizei GLint GLenum format
Definition: glcorearb.h:108
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
detail::TypeInfoImpl< detail::Unowned< const OrtTypeInfo >> ConstTypeInfo
Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. Provides access to const OrtTypeInfo APIs.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity
float ToFloatImpl() const noexcept
Converts float16 to float
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
IEEE 754 half-precision floating point data type.
std::vector< std::string > GetSessionConfigKeys() const
constexpr Float8E4M3FN_t() noexcept
GLint location
Definition: glcorearb.h:805
SessionOptions(OrtSessionOptions *p)
Create and own custom defined operation.
ConstIoBinding GetConst() const
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
float8e5m2fnuz (Float8 Floating Point) data type
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
constexpr Float8E5M2_t() noexcept
constexpr bool operator!=(const Float8E4M3FNUZ_t &rhs) const noexcept
GLuint const GLchar * name
Definition: glcorearb.h:786
int GetVariadicInputMinArity() const
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Float16_t Negate() const noexcept
Creates a new instance with the sign flipped.
OrtRunOptions RunOptions
Definition: run_options.h:48
OCIOEXPORT const char * GetVersion()
Get the version number for the library, as a dot-delimited string (e.g., "1.0.0").
AllocatedFree(OrtAllocator *allocator)
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
ORT_DEFINE_RELEASE(Allocator)
contained_type * p_
bool GetVariadicOutputHomogeneity() const
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
constexpr Base(contained_type *p) noexcept
BFloat16_t Abs() const noexcept
Creates an instance that represents absolute value.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.
GT_API const UT_StringHolder version
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
GLsizeiptr size
Definition: glcorearb.h:664
OrtErrorCode GetErrorCode() const
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void operator()(void *ptr) const
ConstValue GetConst() const
constexpr bool operator==(const Float8E5M2FNUZ_t &rhs) const noexcept
The ThreadingOptions.
float ToFloat() const noexcept
Converts bfloat16 to float
bool IsNaN() const noexcept
Tests if the value is NaN
constexpr bool operator!=(const Float8E5M2_t &rhs) const noexcept
ModelMetadata(OrtModelMetadata *p)
bool operator==(const BFloat16_t &rhs) const noexcept
MemoryInfo(OrtMemoryInfo *p)
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
GLuint index
Definition: glcorearb.h:786
auto ptr(T p) -> const void *
Definition: format.h:2448
GLuint GLfloat * val
Definition: glcorearb.h:1608
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
float8e5m2 (Float8 Floating Point) data type
**If you just want to fire and args
Definition: thread.h:609
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
constexpr Float8E4M3FN_t(uint8_t v) noexcept
int GetVariadicOutputMinArity() const
constexpr bool operator==(const Float8E4M3FNUZ_t &rhs) const noexcept
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
std::string GetVersionString()
This function returns the onnxruntime version string
static bool AreZero(const BFloat16Impl &lhs, const BFloat16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
OrtErrorCode GetOrtErrorCode() const
Definition: core.h:1131
#define MAX_CUSTOM_OP_END_VER
const char * GetExecutionProviderType() const
Base & operator=(Base &&v) noexcept
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
Exception(std::string &&string, OrtErrorCode code)
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Wrapper around ::OrtSessionOptions.
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed...
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Base & operator=(const Base &)=delete
type
Definition: core.h:1059
UnownedValue GetUnowned() const
R GetAttribute(const char *name) const
Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
bfloat16 (Brain Floating Point) data type
ConstKernelInfo GetConst() const
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
constexpr FMT_INLINE value()
Definition: core.h:1154
Shared implementation between public and internal classes. CRTP pattern.
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used. ...
Class that represents session configuration entries for one or more custom operators.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
constexpr bool operator!=(const Float8E4M3FN_t &rhs) const noexcept
Definition: format.h:895
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
static uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation
ConstTensorTypeAndShapeInfo GetConst() const
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity
UnownedIoBinding GetUnowned() const