HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_node_proto_helper.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #ifndef SHARED_PROVIDER
7 #include "core/common/status.h"
10 #include "core/common/gsl.h"
11 #endif
12 
13 class IMLOpKernel;
14 
15 namespace onnxruntime {
16 
17 /**
18  A set of wrappers with common signatures for use with both OpKernelInfo
19  (as its base class) and InferenceContext. Used by ABI kernels for both
20  shape / type inference and kernel construction
21 */
22 template <class Impl_t>
24  public:
25  explicit OpNodeProtoHelper(const Impl_t* impl) : impl_(impl) {}
26 
27  /**
28  Get a single attribute
29  Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
30  */
31  template <typename T>
32  Status GetAttr(const std::string& name, T* value) const;
33 
34  /**
35  Get a single attribute
36  Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
37  Throws if an attribute with the specified type doesn't exist
38  */
39  template <typename T>
40  [[nodiscard]] T GetAttr(const std::string& name) const {
41  T value;
42  ORT_THROW_IF_ERROR(GetAttr(name, &value));
43  return value;
44  }
45 
46  /**
47  Get a single attribute
48  Call this function only when a default value for an optional attribute isn't specified in the op schema
49  */
50  template <typename T>
51  [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const {
52  T tmp;
53  return GetAttr<T>(name, &tmp).IsOK() ? tmp : default_value;
54  }
55 
56  /**
57  Get a single attribute
58  Call this function only when a default value for an optional attribute isn't specified in the op schema
59  */
60  template <typename T>
61  void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const {
62  if (!GetAttr<T>(name, value).IsOK())
63  *value = default_value;
64  }
65 
66  /**
67  Get repeated attributes
68  Call this function only when a default value for an optional attribute isn't specified in the op schema
69  */
70  template <typename T>
71  [[nodiscard]] std::vector<T> GetAttrsOrDefault(const std::string& name,
72  const std::vector<T>& default_value = {}) const {
73  std::vector<T> tmp;
74  return GetAttrs<T>(name, tmp).IsOK() ? tmp : default_value;
75  }
76 
77  /// <summary>
78  /// Return a gsl::span that points to an array of primitive types held by AttributeProto
79  /// This function allows to avoid copying big attributes locally into a kernel and operate on
80  /// AttributeProto data directly.
81  ///
82  /// Does not apply to strings, Tensors and Sparse Tensors that require special treatment.
83  /// </summary>
84  /// <typeparam name="T">Primitive type contained in the array</typeparam>
85  /// <param name="name">Attribute name</param>
86  /// <param name="values">Attribute data in a span, out parameter</param>
87  /// <returns>Status</returns>
88  template <typename T>
89  Status GetAttrsAsSpan(const std::string& name, gsl::span<const T>& values) const;
90 
91  Status GetAttrs(const std::string& name, TensorShapeVector& out) const;
92 
93  [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name,
94  const TensorShapeVector& default_value = {}) const {
96  return GetAttrs(name, tmp).IsOK() ? tmp : default_value;
97  }
98 
99  /**
100  Get repeated attributes
101  */
102  template <typename T>
103  Status GetAttrs(const std::string& name, std::vector<T>& values) const;
104 
105  template <typename T>
106  Status GetAttrs(const std::string& name, gsl::span<T> values) const;
107 
109  std::vector<std::reference_wrapper<const std::string>>& refs) const;
110 
111  [[nodiscard]] uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type,
112  const std::string& name) const noexcept;
113 
114  [[nodiscard]] bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type,
115  const std::string& name) const noexcept;
116 
117  [[nodiscard]] uint32_t GetInputCount() const {
118  return gsl::narrow_cast<uint32_t>(impl_->getNumInputs());
119  }
120 
121  [[nodiscard]] uint32_t GetOutputCount() const {
122  return gsl::narrow_cast<uint32_t>(impl_->getNumOutputs());
123  }
124 
125  [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const {
126  return impl_->getInputType(index);
127  }
128 
129  [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const {
130  // Work around lack of a const method from the onnx InferenceContext interface
131  return const_cast<Impl_t*>(impl_)->getOutputType(index);
132  }
133 
134  // Try to query an attribute, returning nullptr if it doesn't exist
135  [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const {
136  return impl_->getAttribute(name);
137  }
138 
139  [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const {
140  const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name);
141  ORT_ENFORCE(attr != nullptr);
142  return attr;
143  }
144 
145  private:
146  OpNodeProtoHelper() = delete;
147  const Impl_t* impl_ = nullptr;
148 };
149 
150 // The methods on the following class are called by OpNodeProtoHelper, implementing
151 // the same signatures as InferenceContext other than const-ness.
153  public:
154  explicit ProtoHelperNodeContext(const onnxruntime::Node& node) : node_(node) {}
155  ProtoHelperNodeContext() = delete;
156 
157  const ONNX_NAMESPACE::AttributeProto* getAttribute(const std::string& name) const;
158  size_t getNumInputs() const;
159  const ONNX_NAMESPACE::TypeProto* getInputType(size_t index) const;
160  size_t getNumOutputs() const;
161  const ONNX_NAMESPACE::TypeProto* getOutputType(size_t index) const;
162 
163  private:
164  const onnxruntime::Node& node_;
165 };
166 
167 } // namespace onnxruntime
T GetAttr(const std::string &name) const
void GetAttrOrDefault(const std::string &name, T *value, const T &default_value) const
ProtoHelperNodeContext(const onnxruntime::Node &node)
T GetAttrOrDefault(const std::string &name, const T &default_value) const
uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, const std::string &name) const noexcept
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
GLsizei const GLfloat * value
Definition: glcorearb.h:824
const ONNX_NAMESPACE::AttributeProto * TryGetAttribute(const std::string &name) const
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, const std::string &name) const noexcept
Status GetAttr(const std::string &name, T *value) const
TensorShapeVector GetAttrsOrDefault(const std::string &name, const TensorShapeVector &default_value={}) const
InlinedVector< int64_t > TensorShapeVector
Definition: tensor_shape.h:30
const ONNX_NAMESPACE::TypeProto * getInputType(size_t index) const
const ONNX_NAMESPACE::TypeProto * GetInputType(size_t index) const
GLuint const GLchar * name
Definition: glcorearb.h:786
const ONNX_NAMESPACE::AttributeProto * GetAttribute(const std::string &name) const
Status GetAttrs(const std::string &name, TensorShapeVector &out) const
std::vector< T > GetAttrsOrDefault(const std::string &name, const std::vector< T > &default_value={}) const
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
const ONNX_NAMESPACE::TypeProto * GetOutputType(size_t index) const
const ONNX_NAMESPACE::TypeProto * getOutputType(size_t index) const
GLuint index
Definition: glcorearb.h:786
const ONNX_NAMESPACE::AttributeProto * getAttribute(const std::string &name) const
Definition: core.h:1131
Status GetAttrsAsSpan(const std::string &name, gsl::span< const T > &values) const
Return a gsl::span that points to an array of primitive types held by AttributeProto This function al...
type
Definition: core.h:1059
#define ORT_THROW_IF_ERROR(expr)
Definition: common.h:235
Status GetAttrsStringRefs(const std::string &name, std::vector< std::reference_wrapper< const std::string >> &refs) const