HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
node_arg.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include "onnx/onnx_pb.h"
7 
9 #include "core/common/status.h"
11 
12 namespace onnxruntime {
13 
14 // Node argument definition, for both input and output,
15 // including arg name, arg type (contains both type and shape).
16 //
17 // Design Question: in my opinion, shape should not be part of type.
18 // We may align the protobuf design with our operator registry interface,
19 // which has type specified for each operator, but no shape. Well, shape
20 // should be inferred with a separate shape inference function given
21 // input shapes, or input tensor data sometimes.
22 // With shape as part of type (current protobuf design),
23 // 1) we'll have to split the "TypeProto" into type and shape in this internal
24 // representation interface so that it could be easily used when doing type
25 // inference and matching with operator registry.
26 // 2) SetType should be always called before SetShape, otherwise, SetShape()
27 // will fail. Because shape is located in a TypeProto.
28 // Thoughts?
29 //
30 
31 /**
32 @class NodeArg
33 Class representing a data type that is input or output for a Node, including the shape if it is a Tensor.
34 */
35 class NodeArg {
36  public:
37  /**
38  Construct a new NodeArg.
39  @param name The name to use.
40  @param p_arg_type Optional TypeProto specifying type and shape information.
41  */
42  NodeArg(const std::string& name,
43  const ONNX_NAMESPACE::TypeProto* p_arg_type);
44 
45  NodeArg(NodeArg&&) = default;
46  NodeArg& operator=(NodeArg&& other) = default;
47 
48  /** Gets the name. */
49  const std::string& Name() const noexcept;
50 
51  /** Gets the data type. */
52  const std::string* Type() const noexcept;
53 
54  /** Gets the TypeProto
55  @returns TypeProto if type is set. nullptr otherwise. */
56  const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept;
57 
58  /** Gets the shape if NodeArg is for a Tensor.
59  @returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */
60  const ONNX_NAMESPACE::TensorShapeProto* Shape() const;
61 
62  /** Return an indicator.
63  @returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */
64  bool HasTensorOrScalarShape() const;
65 
66 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
67 
68  /** Sets the shape.
69  @remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called,
70  as the shape information is stored as part of TypeProto. */
71  void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape);
72 
73  /** Clears shape info.
74  @remarks If there is a mismatch during shape inferencing that can't be resolved the shape info may be removed. */
75  void ClearShape();
76 
77 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
78 
79 #if !defined(ORT_MINIMAL_BUILD)
80 
81  /** Override current type from input_type if override_types is set to true, return failure status otherwise.
82  @param input_tensor_elem_type Tensor element type parsed input_type
83  @param current_tensor_elem_type Tensor element type parsed from existing type
84  @param override_types If true, resolve the two inputs or two outputs type when different
85  @returns Success unless there is existing type or shape info that can't be successfully updated. */
86  common::Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type,
87  int32_t input_tensor_elem_type,
88  int32_t current_tensor_elem_type,
89  bool override_types);
90 
91  /** Validate and merge type [and shape] info from input_type.
92  @param strict If true, the shape update will fail if there are incompatible values.
93  If false, will be lenient and merge only shape info that can be validly processed.
94  @param override_types If true, resolve the two inputs or two outputs type when different
95  @returns Success unless there is existing type or shape info that can't be successfully updated. */
96  common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);
97 
98  /** Validate and merge type [and shape] info from node_arg.
99  @param strict If true, the shape update will fail if there are incompatible values.
100  If false, will be lenient and merge only shape info that can be validly processed.
101  @param override_types If true, resolve the two inputs or two outputs type when different
102  @returns Success unless there is existing type or shape info that can't be successfully updated. */
103  common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);
104 
105 #endif // !defined(ORT_MINIMAL_BUILD)
106 
107  /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
108  const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
109 
110  /** Gets a flag indicating whether this NodeArg exists or not.
111  Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
112  bool Exists() const noexcept;
113 
114  friend class Graph;
115 
116  NodeArg(NodeArgInfo&& node_arg_info);
117 
118  private:
119  ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg);
120  void SetType(const std::string* p_type);
121 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
122  void SetType(const ONNX_NAMESPACE::TypeProto& type_proto);
123 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
124 
125  // Node arg PType.
126  const std::string* type_;
127 
128  // Node arg name, type and shape.
129  NodeArgInfo node_arg_info_;
130 
131  // Flag indicates whether <*this> node arg exists or not.
132  bool exists_;
133 };
134 } // namespace onnxruntime
const std::string * Type() const noexcept
const NodeArgInfo & ToProto() const noexcept
Definition: node_arg.h:108
*get result *(waiting if necessary)*A common idiom is to fire a bunch of sub tasks at the and then *wait for them to all complete We provide a helper class
Definition: thread.h:623
const std::string & Name() const noexcept
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
NodeArg & operator=(NodeArg &&other)=default
common::Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto &input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types)
bool HasTensorOrScalarShape() const
GLuint const GLchar * name
Definition: glcorearb.h:786
void SetShape(const ONNX_NAMESPACE::TensorShapeProto &shape)
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto &input_type, bool strict, bool override_types, const logging::Logger &logger)
const ONNX_NAMESPACE::TensorShapeProto * Shape() const
ONNX_NAMESPACE::ValueInfoProto NodeArgInfo
Definition: basic_types.h:32
const ONNX_NAMESPACE::TypeProto * TypeAsProto() const noexcept
NodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
bool Exists() const noexcept