HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ort_value.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string>
7 #ifndef SHARED_PROVIDER
8 #include "core/common/common.h"
12 #include "core/framework/tensor.h"
13 
14 namespace onnxruntime {
15 #if !defined(DISABLE_SPARSE_TENSORS)
16 class SparseTensor;
17 #endif
18 class TensorSeq;
19 } // namespace onnxruntime
20 
21 #endif
22 
23 /**
24  Represents both tensors and non-tensors.
25 */
26 struct OrtValue {
27  public:
28  OrtValue() : data_(nullptr) {}
29  ~OrtValue() = default;
30 
32  Init(pData, type, deleter);
33  }
34 
36  data_.reset(pData, deleter);
37  type_ = type;
38  }
39 
40  void Init(void* pData, onnxruntime::MLDataType type, const std::function<void(void*)>& deleter) {
41  data_.reset(pData, deleter);
42  type_ = type;
43  }
44 
45  bool IsAllocated() const {
46  return data_ && type_;
47  }
48 
49  template <typename T>
50  const T& Get() const {
51  ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
52  return *static_cast<T*>(data_.get());
53  }
54 
55  // May return nullptr, if this OrtValue is an optional type and it is "None".
56  template <typename T>
58  ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
59  return static_cast<T*>(data_.get());
60  }
61 
62  bool IsTensor() const noexcept {
63  return (type_ != nullptr && type_->IsTensorType());
64  }
65 
66  bool IsTensorSequence() const noexcept {
67  return (type_ != nullptr && type_->IsTensorSequenceType());
68  }
69 
70  bool IsSparseTensor() const {
71  return (type_ != nullptr && type_->IsSparseTensorType());
72  }
73 
75  return type_;
76  }
77 
78  private:
79  std::shared_ptr<void> data_;
80  onnxruntime::MLDataType type_{nullptr};
81 };
82 
83 template <>
84 inline const onnxruntime::Tensor& OrtValue::Get<onnxruntime::Tensor>() const {
85  ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
86  return *static_cast<onnxruntime::Tensor*>(data_.get());
87 }
88 
89 template <>
90 inline onnxruntime::Tensor* OrtValue::GetMutable<onnxruntime::Tensor>() {
91  ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
92  return static_cast<onnxruntime::Tensor*>(data_.get());
93 }
94 
95 template <>
96 inline const onnxruntime::TensorSeq& OrtValue::Get<onnxruntime::TensorSeq>() const {
97  ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
98  return *static_cast<onnxruntime::TensorSeq*>(data_.get());
99 }
100 
101 template <>
102 inline onnxruntime::TensorSeq* OrtValue::GetMutable<onnxruntime::TensorSeq>() {
103  ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
104  return static_cast<onnxruntime::TensorSeq*>(data_.get());
105 }
106 
107 #if !defined(DISABLE_SPARSE_TENSORS)
108 template <>
109 inline const onnxruntime::SparseTensor& OrtValue::Get<onnxruntime::SparseTensor>() const {
110  ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
111  return *static_cast<onnxruntime::SparseTensor*>(data_.get());
112 }
113 
114 template <>
115 inline onnxruntime::SparseTensor* OrtValue::GetMutable<onnxruntime::SparseTensor>() {
116  ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
117  return static_cast<onnxruntime::SparseTensor*>(data_.get());
118 }
119 #endif
onnxruntime::MLDataType Type() const
Definition: ort_value.h:74
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Definition: data_types.h:122
Base class for MLDataType.
Definition: data_types.h:76
bool IsTensor() const noexcept
Definition: ort_value.h:62
~OrtValue()=default
OrtValue()
Definition: ort_value.h:28
void Init(void *pData, onnxruntime::MLDataType type, const std::function< void(void *)> &deleter)
Definition: ort_value.h:40
void Init(void *pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter)
Definition: ort_value.h:35
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
bool IsTensorSequence() const noexcept
Definition: ort_value.h:66
This class implements SparseTensor. This class holds sparse non-zero data (values) and sparse format ...
Definition: sparse_tensor.h:55
void(*)(void *) DeleteFunc
Definition: data_types.h:69
bool IsSparseTensor() const
Definition: ort_value.h:70
bool IsTensorType() const
Definition: data_types.h:118
T * GetMutable()
Definition: ort_value.h:57
OrtValue(void *pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter)
Definition: ort_value.h:31
bool IsSparseTensorType() const
Definition: data_types.h:126
type
Definition: core.h:1059
const T & Get() const
Definition: ort_value.h:50
bool IsAllocated() const
Definition: ort_value.h:45