HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <functional>
7 #include <limits>
8 #include <memory>
9 #include <string>
10 #include <type_traits>
11 #include <unordered_map>
12 #include <unordered_set>
13 
14 #ifdef _WIN32
15 #pragma warning(push)
16 // disable some warnings from protobuf to pass Windows build
17 #pragma warning(disable : 4244)
18 #endif
19 
20 #ifdef _WIN32
21 #pragma warning(pop)
22 #endif
23 
24 #include "flatbuffers/flatbuffers.h"
25 
26 #include "core/common/gsl.h"
27 
28 #include "core/common/common.h"
30 #if !defined(ORT_MINIMAL_BUILD)
32 #endif
34 #include "core/common/path.h"
35 #include "core/common/span_utils.h"
36 #include "core/common/status.h"
38 #include "core/graph/onnx_protobuf.h"
39 #include "core/graph/basic_types.h"
40 #include "core/graph/constants.h"
41 #include "core/graph/function.h"
42 #if !defined(ORT_MINIMAL_BUILD)
43 #include "core/graph/function_template.h"
44 #endif
45 #include "core/graph/graph_nodes.h"
46 #include "core/graph/node_arg.h"
47 #include "core/graph/ort_format_load_options.h"
48 
49 namespace onnxruntime {
50 class Graph;
51 struct IndexedSubGraph;
52 class Model;
53 class OpSignature;
54 
55 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
56 class RuntimeOptimizationRecordContainer;
57 #endif
58 
59 namespace fbs {
60 struct Graph;
61 struct Node;
62 struct NodeEdge;
63 } // namespace fbs
64 
65 /**
66 @class Node
67 Class representing a node in the graph.
68 */
69 class Node {
70  public:
71  /** Node types */
72  enum class Type {
73  Primitive = 0, ///< The node refers to a primitive operator.
74  Fused = 1, ///< The node refers to a function.
75  };
76 
77  explicit Node() = default;
78 
79 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
81  std::string_view op_type,
82  std::string_view description,
83  gsl::span<NodeArg* const> input_args,
84  gsl::span<NodeArg* const> output_args,
85  const NodeAttributes* attributes,
86  std::string_view domain) {
87  Init(name, op_type, description,
88  input_args,
89  output_args,
90  attributes, domain);
91  }
92 #endif
93 
94  ~Node() = default;
95 
96  /**
97  @class EdgeEnd
98  Class representing the end of an edge. It could be an input or output edge end of a node.
99  For the node's input edge end, it's the source end, as the destination end is the node itself.
100  For the node's output edge end, it's the destination end, as the source end is the node itself.
101  */
102  class EdgeEnd {
103  public:
104  /**
105  Construct an EdgeEnd
106  @param node The source node if this is an input edge to the current node,
107  or the destination node if this is an output edge from the current node.
108  @param src_arg_index The node arg index of source node of the edge.
109  @param dst_arg_index The node arg index of destination node of the edge.
110  */
111  EdgeEnd(const Node& node, int src_arg_index, int dst_arg_index) noexcept;
112 
113  /** Construct a control edge.
114  @param node The node the edge joins to the current node.
115  */
116  explicit EdgeEnd(const Node& node) noexcept;
117 
118  /** Gets the Node that this EdgeEnd refers to. */
119  const Node& GetNode() const noexcept { return *node_; }
120 
121  /** Gets the source arg index.
122  @returns the source arg index of <*this> edge.*/
123  int GetSrcArgIndex() const { return src_arg_index_; }
124 
125  /** Gets the destination arg index.
126  @returns the destination arg index of <*this> edge.*/
127  int GetDstArgIndex() const { return dst_arg_index_; }
128 
129  private:
130  const Node* node_;
131  const int src_arg_index_;
132  const int dst_arg_index_;
133  };
134 
135  /** Gets the Node's NodeIndex. */
136  NodeIndex Index() const noexcept { return index_; }
137 
138  /** Gets the Node's name. */
139  const std::string& Name() const noexcept { return name_; }
140 
141  /** Gets the Node's operator type. */
142  const std::string& OpType() const noexcept { return op_type_; }
143 
144  /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType.
145  * @remarks If this is an ONNX operator the value will be kOnnxDomain not kOnnxDomainAlias
146  */
147  const std::string& Domain() const noexcept { return domain_; }
148 
149  /** Gets the path of the owning model if any. */
150  const Path& ModelPath() const noexcept;
151 
152  /** Gets the Node's execution priority.
153  @remarks Lower value means higher priority */
154  int Priority() const noexcept { return priority_; };
155 
156  /** Sets the execution priority of a node.
157  @remarks Lower value means higher priority */
158  void SetPriority(int priority) noexcept;
159 
160  /** Gets the node description. */
161  const std::string& Description() const noexcept { return description_; }
162 
163  /** Gets the Node's Node::Type. */
164  Node::Type NodeType() const noexcept { return node_type_; }
165 
166  /** Gets the opset version that the Node's operator was first defined in.
167  @returns Opset version. If -1 the Node's operator has not been set.
168  @remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
169  */
170  int SinceVersion() const noexcept { return since_version_; }
171 
172  /** Sets the since version (opset version that the Node's operator was first defined in.) for this node.
173  @remarks Used during layout transformation for setting since version for layout transformed nodes with
174  domain kMSNHWC.
175  */
176  void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; }
177 
178 #if !defined(ORT_MINIMAL_BUILD)
179  /** Gets the Node's OpSchema.
180  @remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */
181  const ONNX_NAMESPACE::OpSchema* Op() const noexcept { return op_; }
182 
183  /** Create a copy of the called op's FunctionProto if it has one. Returns true if successful. */
184  bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& func_proto) const;
185 
186  bool CanBeInlined() const;
187 
188  /** Gets the function body if applicable otherwise nullptr. */
189  const Function* GetFunctionBody() const noexcept { return func_body_.get(); }
190 #endif
191 
192  /**
193  Helper to iterate through the container returned by #InputDefs() or #OutputDefs() and call the provided function.
194  @param node_args Collection of NodeArgs returned by #InputDefs() or #OutputDefs()
195  @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
196  and the index number in the container.
197  @returns common::Status with success or error information.
198  @remarks Returns immediately on error.
199  */
200  static common::Status ForEachWithIndex(const ConstPointerContainer<std::vector<NodeArg*>>& node_args,
201  std::function<common::Status(const NodeArg& arg, size_t index)> func) {
202  for (size_t index = 0; index < node_args.size(); ++index) {
203  auto arg = node_args[index];
204  if (!arg->Exists())
205  continue;
206  ORT_RETURN_IF_ERROR(func(*arg, index));
207  }
208  return common::Status::OK();
209  }
210 
211  /** Gets the count of arguments for each of the Node's explicit inputs. */
212  const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
213 
214  /** Gets the Node's input definitions.
215  @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
218  }
219 
220  /** Gets the implicit inputs to this Node.
221  If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that
222  subgraph. e.g. If and Loop operators.*/
225  }
226 
227  /** Gets the Node's output definitions.
228  @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
231  }
232 
233 #if !defined(ORT_MINIMAL_BUILD)
234  /**
235  Helper to iterate through the container returned by #MutableInputDefs() or #MutableOutputDefs() and call the provided function.
236  @param node_args Collection of NodeArgs returned by #MutableInputDefs() or #MutableOutputDefs()
237  @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
238  and the index number in the container.
239  @returns common::Status with success or error information.
240  @remarks Returns immediately on error.
241  */
242  static common::Status ForEachMutableWithIndex(std::vector<NodeArg*>& node_args,
243  std::function<common::Status(NodeArg& arg, size_t index)> func) {
244  for (size_t index = 0; index < node_args.size(); ++index) {
245  auto arg = node_args[index];
246  if (!arg->Exists())
247  continue;
248  ORT_RETURN_IF_ERROR(func(*arg, index));
249  }
250  return common::Status::OK();
251  }
252 
253  /** Gets a modifiable collection of the Node's implicit input definitions. */
254  std::vector<NodeArg*>& MutableImplicitInputDefs() noexcept {
255  return definitions_.implicit_input_defs;
256  }
257 #endif // !defined(ORT_MINIMAL_BUILD)
258 
259 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
260  /** Gets a modifiable count of arguments for each of the Node's explicit inputs.
261  @todo This should be removed in favor of a method that updates the input args and the count.
262  Currently these operations are separate which is not a good setup. */
263  std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
264 
265  /** Gets a modifiable collection of the Node's input definitions. */
266  std::vector<NodeArg*>& MutableInputDefs() noexcept {
267  return definitions_.input_defs;
268  }
269 
270  /** Gets a modifiable collection of the Node's output definitions. */
271  std::vector<NodeArg*>& MutableOutputDefs() noexcept {
272  return definitions_.output_defs;
273  }
274 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
275 
276  /** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */
277  struct EdgeEndCompare {
278  bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const {
279  if (lhs.GetNode().Index() == rhs.GetNode().Index()) {
280  if (lhs.GetSrcArgIndex() == rhs.GetSrcArgIndex()) {
281  return lhs.GetDstArgIndex() < rhs.GetDstArgIndex();
282  }
283  return lhs.GetSrcArgIndex() < rhs.GetSrcArgIndex();
284  }
285  return lhs.GetNode().Index() < rhs.GetNode().Index();
286  }
287  };
288 
289  using EdgeSet = std::set<EdgeEnd, EdgeEndCompare>;
290  using EdgeConstIterator = EdgeSet::const_iterator;
291 
292  /**
293  @class NodeConstIterator
294  Class to provide const access to Node instances iterated via an EdgeConstIterator. */
296  public:
298 
299  bool operator==(const NodeConstIterator& p_other) const;
300 
301  bool operator!=(const NodeConstIterator& p_other) const;
302 
303  void operator++();
304  void operator--();
305 
306  const Node& operator*() const;
307  const Node* operator->() const;
308 
309  private:
310  EdgeConstIterator m_iter;
311  };
312 
313  // Functions defined to traverse a Graph as below.
314 
315  /** Gets an iterator to the beginning of the input nodes to this Node. */
316  NodeConstIterator InputNodesBegin() const noexcept { return NodeConstIterator(relationships_.input_edges.cbegin()); };
317  /** Gets an iterator to the end of the input nodes to this Node. */
318  NodeConstIterator InputNodesEnd() const noexcept { return NodeConstIterator(relationships_.input_edges.cend()); }
319 
320  /** Gets an iterator to the beginning of the output nodes from this Node. */
322  return NodeConstIterator(relationships_.output_edges.cbegin());
323  }
324 
325  /** Gets an iterator to the end of the output nodes from this Node. */
326  NodeConstIterator OutputNodesEnd() const noexcept { return NodeConstIterator(relationships_.output_edges.cend()); }
327 
328  /** Gets an iterator to the beginning of the input edges to this Node.
329  @remarks There are no nullptr entries in this collection. */
330  EdgeConstIterator InputEdgesBegin() const noexcept { return relationships_.input_edges.cbegin(); }
331 
332  /** Gets an iterator to the end of the input edges to this Node. */
333  EdgeConstIterator InputEdgesEnd() const noexcept { return relationships_.input_edges.cend(); }
334 
335  /** Gets an iterator to the beginning of the output edges from this Node.
336  @remarks There are no nullptr entries in this collection. */
337  EdgeConstIterator OutputEdgesBegin() const noexcept { return relationships_.output_edges.cbegin(); }
338 
339  /** Gets an iterator to the end of the output edges from this Node. */
340  EdgeConstIterator OutputEdgesEnd() const noexcept { return relationships_.output_edges.cend(); }
341 
342  /** Gets the Node's control inputs. */
343  const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
344 
345  /** Gets the number of input edges to this Node */
346  size_t GetInputEdgesCount() const noexcept { return relationships_.input_edges.size(); }
347 
348  /** Gets the number of output edges from this Node */
349  size_t GetOutputEdgesCount() const noexcept { return relationships_.output_edges.size(); }
350 
351  /** Adds an AttributeProto to this Node.
352  @remarks The attribute name is used as the key in the attribute map. */
353  void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value);
354 
355  // keep this signature in sync with ADD_ATTR_SINGLE_INTERFACE below
356  /** Adds an attribute to this Node with the specified attribute name and value. */
357  void AddAttribute(std::string attr_name, int64_t value);
358 
359  // keep this signature in sync with ADD_ATTR_LIST_INTERFACE below
360  /** Adds an attribute to this Node with the specified attribute name and values. */
361  void AddAttribute(std::string attr_name, gsl::span<const int64_t> values);
362 
363 #define ADD_ATTR_SINGLE_INTERFACE(Type) \
364  void AddAttribute(std::string attr_name, Type value)
365 
366 #define ADD_ATTR_LIST_INTERFACE(Type) \
367  void AddAttribute(std::string attr_name, gsl::span<const Type> values)
368 
369 #define ADD_ATTR_INTERFACES(Type) \
370  ADD_ATTR_SINGLE_INTERFACE(Type); \
371  ADD_ATTR_LIST_INTERFACE(Type)
372 
373  ADD_ATTR_INTERFACES(float);
375  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TensorProto);
376 #if !defined(DISABLE_SPARSE_TENSORS)
377  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::SparseTensorProto);
378 #endif
379  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TypeProto);
380 
381  ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto);
382 
383 #undef ADD_ATTR_SINGLE_INTERFACE
384 #undef ADD_ATTR_LIST_INTERFACE
385 #undef ADD_ATTR_INTERFACES
386 
387  // The below overload is made so the compiler does not attempt to resolve
388  // string literals with the gsl::span overload
389  template <size_t N>
390  void AddAttribute(std::string attr_name, const char (&value)[N]) {
391  this->AddAttribute(std::move(attr_name), std::string(value, N - 1));
392  }
393 
394  /** Gets the Node's attributes. */
395  const NodeAttributes& GetAttributes() const noexcept { return attributes_; }
396 
397 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
398  /** Remove the specified attribute from this Node */
399  bool ClearAttribute(const std::string& attr_name);
400 
401  /** Gets the Node's mutable attributes. */
402  NodeAttributes& GetMutableAttributes() noexcept { return attributes_; }
403 
404 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
405 
406  /**
407  * Clears removable attributes. These are no longer needed after the initialization
408  * of the session. The function returns the number of removed attributes.
409  */
410  int PruneRemovableAttributes(gsl::span<const std::string> removable_attributes);
411 
412 #if !defined(ORT_MINIMAL_BUILD)
413 
414  /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
415  @param attr_name Attribute name for the GraphProto attribute.
416  @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto.
417  */
418  const Graph* GetGraphAttribute(const std::string& attr_name) const;
419 
420  /** Gets the mutable Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
421  @param attr_name Attribute name for the GraphProto attribute.
422  @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto.
423  */
424  Graph* GetMutableGraphAttribute(const std::string& attr_name);
425 #endif // !defined(ORT_MINIMAL_BUILD)
426 
427  /** Checks if the Node contains at least one subgraph (this is the case for control flow operators, such as If, Scan, Loop).
428  @returns true if the Node contains a subgraph.
429  */
430  bool ContainsSubgraph() const {
431  return !attr_to_subgraph_map_.empty();
432  }
433 
434  /** Get the const subgraphs from a node.
435  @remarks Creates a new vector so calling ContainsSubgraphs first is preferred. */
436  std::vector<gsl::not_null<const Graph*>> GetSubgraphs() const;
437 
438  /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
439  @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
440  nullptr if the Node has no subgraphs.
441  */
442  const std::unordered_map<std::string, gsl::not_null<Graph*>>& GetAttributeNameToMutableSubgraphMap() {
443  return attr_to_subgraph_map_;
444  }
445 
446  /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
447  * @returns a mutable map of mutable subgraphs.
448  */
449  std::unordered_map<std::string, gsl::not_null<Graph*>>& GetMutableMapOfAttributeNameToSubgraph() {
450  return attr_to_subgraph_map_;
451  }
452 
453  /** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node.
454  @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
455  nullptr if the Node has no subgraphs.
456  */
457  std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;
458 
459  /** Gets the execution ProviderType that this node will be executed by. */
460  ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
461 
462  /** Sets the execution ProviderType that this Node will be executed by. */
463  void SetExecutionProviderType(ProviderType execution_provider_type) {
464  execution_provider_type_ = execution_provider_type;
465  }
466 
467  /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
468  If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
469  @param include_missing_optional_defs Include NodeArgs that are optional and were not provided
470  i.e. NodeArg::Exists() == false.
471  */
472  void ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
473  bool include_missing_optional_defs = false) const;
474 
475 #if !defined(ORT_MINIMAL_BUILD)
476  /** Replaces any matching definitions in the Node's explicit inputs or explicit outputs.
477  @param replacements Map of current NodeArg to replacement NodeArg.
478  */
479  void ReplaceDefs(const std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*>& replacements);
480 
481  /** Gets the NodeProto representation of this Node.
482  @param update_subgraphs Update the GraphProto values for any subgraphs in the returned NodeProto.
483  If graph optimization has been run this is most likely required
484  to ensure the complete Graph is valid.
485  */
486  void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const;
487 
488  Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
489  flatbuffers::Offset<onnxruntime::fbs::Node>& fbs_node) const;
490 
491  flatbuffers::Offset<onnxruntime::fbs::NodeEdge>
492  SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder& builder) const;
493 
494  void SetFunctionTemplate(const FunctionTemplate& func_template);
495 #endif
496 
497  static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
498  const OrtFormatLoadOptions& load_options,
499  const logging::Logger& logger, std::unique_ptr<Node>& node);
500 
501  Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
502  const OrtFormatLoadOptions& load_options,
503  const logging::Logger& logger);
504  Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph);
505 
506  /**
507  @class Definitions
508  The input and output definitions for this Node.
509  */
510  class Definitions {
511  public:
512  Definitions() = default;
513 
514  /** The Node's explicit input definitions. */
515  std::vector<NodeArg*> input_defs;
516 
517  /**
518  The number of inputs for each argument of the operator or function which this node refers.
519  @remarks For example, #input_defs has 10 elements (inputs), and #input_arg_count is {4, 6}.
520  This means that 4 elements (inputs) of input_defs map to the first argument of the operator or function, and
521  the other 6 map to the second argument.
522  */
523  std::vector<int> input_arg_count;
524 
525  /** The Node's output definitions. */
526  std::vector<NodeArg*> output_defs;
527 
528  /** The Node's implicit input definitions if the Node contains one or more subgraphs
529  (i.e. GraphProto attributes) and the subgraph/s implicitly consume these values.
530  @remarks For example, a subgraph in an 'If' node gets all its input values via this mechanism rather than
531  there being explicit inputs to the 'If' node that are passed to the subgraph.
532  They are pseudo-inputs to this Node as it has an implicit dependency on them. */
533  std::vector<NodeArg*> implicit_input_defs;
534 
536 
537  private:
538  };
539 
540  /**
541  @class Relationships
542  Defines the relationships between this Node and other Nodes in the Graph.
543  */
545  public:
546  Relationships() = default;
547 
548  void Clear() noexcept {
549  input_edges.clear();
550  output_edges.clear();
551  control_inputs.clear();
552  }
553 
554  /** The edges for Nodes that provide inputs to this Node. */
556 
557  /** The edges for Nodes that receive outputs from this Node. */
559 
560  /** The Node names of the control inputs to this Node. */
561  std::set<std::string> control_inputs;
562 
563  private:
564  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Relationships);
565  };
566 
567  // NOTE: This friendship relationship should ONLY be used for calling methods of the Node class and not accessing
568  // the data members directly, so that the Node can maintain its internal invariants.
569  friend class Graph;
570  Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph), can_be_saved_(true) {}
571 
572  private:
574 
575 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
576  void Init(std::string_view name,
577  std::string_view op_type,
578  std::string_view description,
579  gsl::span<NodeArg* const> input_args,
580  gsl::span<NodeArg* const> output_args,
581  const NodeAttributes* attributes,
582  std::string_view domain);
583 #endif
584 
585 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
586  // internal only method to allow selected classes to directly alter the input/output definitions and arg counts
587  Definitions& MutableDefinitions() noexcept;
588 
589  // internal only method to allow selected classes to directly alter the links between nodes.
590  Relationships& MutableRelationships() noexcept;
591 
592  void SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; }
593 #endif
594 
595  // create a Graph instance for an attribute that contains a GraphProto
596  void CreateSubgraph(const std::string& attr_name);
597 
598  std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
599 
600  // validate and update the input arg count
601  common::Status UpdateInputArgCount();
602 
603  const Definitions& GetDefinitions() const noexcept { return definitions_; }
604  const Relationships& GetRelationships() const noexcept { return relationships_; }
605 
606  // Node index. Default to impossible value rather than 0.
608 
609  // Node name.
610  std::string name_;
611 
612  // Node operator type.
613  std::string op_type_;
614 
615  // OperatorSet domain of op_type_.
616  std::string domain_;
617 
618 #if !defined(ORT_MINIMAL_BUILD)
619  // OperatorSchema that <*this> node refers to.
620  const ONNX_NAMESPACE::OpSchema* op_ = nullptr;
621 
622  // Reference to the function template defined in the model.
623  const FunctionTemplate* func_template_ = nullptr;
624 #endif
625 
626  // Execution priority, lower value for higher priority
627  int priority_ = 0;
628 
629  // set from op_->SinceVersion() or via deserialization when OpSchema is not available
630  int since_version_ = -1;
631 
632  Node::Type node_type_ = Node::Type::Primitive;
633 
634  // The function body is owned by graph_
635  std::unique_ptr<Function> func_body_ = nullptr;
636 
637  // Node doc string.
638  std::string description_;
639 
640  // input/output defs and arg count
641  Definitions definitions_;
642 
643  // Relationships between this node and others in the graph
644  Relationships relationships_;
645 
646  // Device.
647  std::string execution_provider_type_;
648 
649  // Map from attribute name to attribute.
650  // This allows attribute adding and removing.
651  NodeAttributes attributes_;
652 
653  // Graph that contains this Node
654  Graph* graph_ = nullptr;
655 
656  // Map of attribute name to the Graph instance created from the GraphProto attribute
657  std::unordered_map<std::string, gsl::not_null<Graph*>> attr_to_subgraph_map_;
658 
659  // Graph instances for subgraphs that are owned by this Node
660  std::vector<std::unique_ptr<Graph>> subgraphs_;
661 
662  // Can be saved? The node cannot be saved anymore if removable attributes have been cleared.
663  bool can_be_saved_;
664 };
665 
666 /**
667 @class Graph
668 The Graph representation containing the graph inputs and outputs, the Node instances,
669 and the edges connecting the nodes.
670 */
671 class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve existing data member order for readability
672  public:
673  /** Gets the Graph name. */
674  const std::string& Name() const noexcept;
675 
676  /** Gets the Graph description. */
677  const std::string& Description() const noexcept;
678 
679  /** Gets the path of the owning model, if any. */
680  const Path& ModelPath() const;
681 
682  /** Returns true if this is a subgraph or false if it is a high-level graph. */
683  bool IsSubgraph() const { return parent_graph_ != nullptr; }
684 
685  /** Returns the parent graph if this is a subgraph */
686  const Graph* ParentGraph() const { return parent_graph_; }
687 
688  /** Returns the mutable parent graph if this is a subgraph */
689  Graph* MutableParentGraph() { return parent_graph_; }
690 
691  /** Returns the strict_shape_type_inference that was passed into the constructor. */
692  bool StrictShapeTypeInference() const { return strict_shape_type_inference_; }
693 
694 #if !defined(ORT_MINIMAL_BUILD)
695  /** Sets the Graph name. */
696  void SetName(const std::string& name);
697 
698  /** Gets the Graph description. */
699  void SetDescription(const std::string& description);
700 
701  /** Replaces the initializer tensor with the same name as the given initializer tensor.
702  The replacement initializer tensor must have the same type and shape as the existing initializer tensor.
703 
704  Note: This currently has linear time complexity. There is room for improvement but it would likely require changes to
705  how initializer tensors are stored and tracked.
706  */
707  common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer);
708 
709 #if !defined(DISABLE_EXTERNAL_INITIALIZERS)
710  /** This function takes externally provided data for initializers with external data
711  * and replaces graph initializers with its content.
712  */
714 #endif // !defined(DISABLE_EXTERNAL_INITIALIZERS)
715 
716 #endif // !defined(ORT_MINIMAL_BUILD)
717 
718 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
719  /** Add an initializer tensor to the Graph. */
720  void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto);
721 #endif
722 
723  /** Remove the initializer tensor with the provided name from the Graph. */
724  void RemoveInitializedTensor(const std::string& tensor_name);
725 
726  /** Check if a given name is an initializer tensor's name in this graph. */
727  bool IsInitializedTensor(const std::string& name) const;
728 
729 #if !defined(DISABLE_SPARSE_TENSORS)
730  /** Check if a given name is a sparse initializer's name in the model
731  * we currently convert sparse_initializer field in the model into dense Tensor instances.
732  * However, we sometimes want to check if this initializer was stored as sparse in the model.
733  */
734  bool IsSparseInitializer(const std::string& name) const;
735 #endif
736 
737  /** Gets an initializer tensor with the provided name.
738  @param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not.
739  @returns True if found.
740  */
741  bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;
742 
743  /** Gets all the initializer tensors in this Graph. */
744  const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; }
745 
746  /** Removes all initializer tensors from this Graph and releases the memory they were using. */
747  void CleanAllInitializedTensors() noexcept;
748 
749  /** Returns true if an initializer value can be overridden by a graph input with the same name. */
750  bool CanOverrideInitializer() const noexcept { return ir_version_ >= 4; }
751 
752  /** returns the initializer's TensorProto if 'name' is an initializer, is constant and
753  cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
754  @param check_outer_scope If true and the graph is a subgraph,
755  check ancestor graph/s for 'name' if not found in 'graph'.
756  @remarks check_outer_scope of true is not supported in a minimal build
757  */
758  const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
759 
760  /** returns the initializer's TensorProto if 'name' is an initializer (both constant and overridable).
761  If the initializer is not found, a nullptr is returned.
762  @param check_outer_scope If true and the graph is a subgraph,
763  check ancestor graph/s for 'name' if not found in 'graph'.
764  @remarks check_outer_scope of true is not supported in a minimal build
765  */
766  const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const;
767 
768  /** Gets the Graph inputs excluding initializers.
769  These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs.
770  @remarks Contains no nullptr values. */
771  const std::vector<const NodeArg*>& GetInputs() const noexcept { return graph_inputs_excluding_initializers_; }
772 
773  /** Gets the Graph inputs including initializers.
774  This is the full set of inputs, in the same order as defined in the GraphProto.
775  @remarks Contains no nullptr values. */
776  const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept {
777  return graph_inputs_including_initializers_;
778  }
779 
780  /** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
781  bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept {
782  return std::find(graph_inputs_including_initializers_.begin(),
783  graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
784  }
785 
786  /** Gets the Graph inputs that are initializers
787  These are overridable initializers. This is a difference between
788  graph_inputs_including_initializers_ and graph_inputs_excluding_initializers_
789  @remarks Contains no nullptr values. */
790  const std::vector<const NodeArg*>& GetOverridableInitializers() const {
791  return graph_overridable_initializers_;
792  }
793 
794  /** Gets the Graph outputs.
795  @remarks Contains no nullptr values.*/
796  const std::vector<const NodeArg*>& GetOutputs() const noexcept { return graph_outputs_; }
797 
798  bool IsOutput(const NodeArg* node_arg) const noexcept {
799  return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
800  }
801 
802  /** Returns true if one or more of the Node outputs are Graph outputs.
803  @remarks Cheaper than calling GetNodeOutputsInGraphOutputs.
804  */
805  bool NodeProducesGraphOutput(const Node& node) const {
806  auto end_outputs = graph_outputs_.cend();
807  for (auto output_def : node.OutputDefs()) {
808  if (std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
809  return true;
810  }
811  }
812  return false;
813  }
814 
815  /** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
816  std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
817  int output_idx = 0;
818  std::vector<int> indexes;
819  for (auto output_def : node.OutputDefs()) {
820  if (std::find(GetOutputs().cbegin(), GetOutputs().cend(), output_def) != GetOutputs().cend()) {
821  indexes.push_back(output_idx);
822  }
823 
824  ++output_idx;
825  }
826 
827  return indexes;
828  }
829 
830  /** Gets the NodeArgs that represent value_info instances in the Graph.
831  These are the values that are neither Graph inputs nor outputs.
832  @remarks Contains no nullptr values. */
833  const std::unordered_set<const NodeArg*>& GetValueInfo() const noexcept { return value_info_; }
834 
835 #if !defined(ORT_MINIMAL_BUILD)
836  void AddValueInfo(const NodeArg* new_value_info);
837 #endif
838 
839  /** Gets the Node with the specified node index.
840  @returns Node instance if found. nullptr if node_index is invalid or node has been freed.
841  */
842  const Node* GetNode(NodeIndex node_index) const { return NodeAtIndexImpl(node_index); }
843 
844  /** Gets the mutable Node with the specified node index.
845  @returns Mutable Node instance if found. nullptr if node_index is invalid or node has been freed.
846  */
847  Node* GetNode(NodeIndex node_index) { return NodeAtIndexImpl(node_index); }
848 
849  /** Get a GraphNodes instance that provides mutable access to all valid Nodes in the Graph. */
850  GraphNodes& Nodes() noexcept { return iterable_nodes_; }
851 
852  /** Get a GraphNodes instance that provides const access to all valid Nodes in the Graph. */
853  const GraphNodes& Nodes() const noexcept { return iterable_nodes_; }
854 
855  /** Get a ConstGraphNodes instance that provides access to a filtered set of valid Nodes in the Graph.
856  @remarks We can't use GraphNodes as that would provide mutable access to the nodes by default, and we can't prevent
857  that by returning a const instance of GraphNodes as we're creating a new instance here due to the filter
858  being something we don't control (i.e. we have to return a new instance so it can't be const).
859  */
861  return ConstGraphNodes(nodes_, std::move(filter_func));
862  }
863 
864  /** Gets the maximum NodeIndex value used in the Graph.
865  WARNING: This actually returns the max index value used + 1.
866  */
867  int MaxNodeIndex() const noexcept { return static_cast<int>(nodes_.size()); } // assume the casting won't overflow
868 
869  /** Gets the number of valid Nodes in the Graph.
870  @remarks This may be smaller than MaxNodeIndex(), as Nodes may be removed during optimization.
871  */
872  int NumberOfNodes() const noexcept { return num_of_nodes_; }
873 
874  /** Gets the mutable NodeArg with the provided name.
875  @returns Pointer to NodeArg if found, nullptr if not. */
877  auto iter = node_args_.find(name);
878  if (iter != node_args_.end()) {
879  return iter->second.get();
880  }
881  return nullptr;
882  }
883 
884  /** Gets the const NodeArg with the provided name.
885  @returns Pointer to const NodeArg if found, nullptr if not. */
886  const NodeArg* GetNodeArg(const std::string& name) const {
887  return const_cast<Graph*>(this)->GetNodeArg(name);
888  }
889 
890  // search this and up through any parent_graph_ instance for a NodeArg
891  NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name);
892 
893  /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found.
894  @param name The NodeArg name.
895  @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created.
896  @returns NodeArg reference.
897  */
898  NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) {
899  auto insert_result = node_args_.emplace(name, nullptr);
900  if (insert_result.second) {
901  insert_result.first->second = std::make_unique<NodeArg>(name, p_arg_type);
902  }
903  return *(insert_result.first->second);
904  }
905 
906 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
907  /** Generate a unique name in this Graph for a NodeArg */
908  std::string GenerateNodeArgName(const std::string& base_name);
909 
910  /** Generate a unique name in this Graph for a Node */
911  std::string GenerateNodeName(const std::string& base_name);
912 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
913 
914 #if !defined(ORT_MINIMAL_BUILD)
915  /** Copy a Node and add it to this Graph.
916  @param other Node to copy
917  @returns Reference to the Node that was created and added to this Graph.
918  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
919  */
920  Node& AddNode(const Node& other);
921 #endif
922 
923 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
924  /** Add a Node to this Graph.
925  @param name The Node name. Must be unique in this Graph.
926  @param op_type The operator type. e.g. ONNX operator name.
927  @param description Arbitrary description of the Node.
928  @param input_args The explicit inputs to this Node.
929  @param output_args The outputs from this Node.
930  @param attributes Optional NodeAttributes to add.
931  @param domain The domain for the op_type.
932  @returns Reference to the new Node.
933  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
934  */
935  Node& AddNode(const std::string& name,
936  const std::string& op_type,
937  const std::string& description,
938  gsl::span<NodeArg* const> input_args,
939  gsl::span<NodeArg* const> output_args,
940  const NodeAttributes* attributes = nullptr,
941  const std::string& domain = kOnnxDomain);
942 
944  const std::string& op_type,
945  const std::string& description,
946  std::initializer_list<NodeArg*> input_args,
947  std::initializer_list<NodeArg*> output_args,
948  const NodeAttributes* attributes = nullptr,
949  const std::string& domain = kOnnxDomain) {
950  return AddNode(name, op_type, description,
951  AsSpan(input_args),
952  AsSpan(output_args),
953  attributes, domain);
954  }
955 
957  const std::string& op_type,
958  const std::string& description,
959  gsl::span<NodeArg* const> input_args,
960  std::initializer_list<NodeArg*> output_args,
961  const NodeAttributes* attributes = nullptr,
962  const std::string& domain = kOnnxDomain) {
963  return AddNode(name, op_type, description,
964  input_args,
965  AsSpan(output_args),
966  attributes, domain);
967  }
968 
970  const std::string& op_type,
971  const std::string& description,
972  std::initializer_list<NodeArg*> input_args,
973  gsl::span<NodeArg* const> output_args,
974  const NodeAttributes* attributes = nullptr,
975  const std::string& domain = kOnnxDomain) {
976  return AddNode(name, op_type, description,
977  AsSpan(input_args),
978  output_args,
979  attributes, domain);
980  }
981 
982  /** Remove a Node from this Graph and free it.
983  The output edges of this specified node MUST have been removed before removing the node.
984  The input edges of this specified node is removed while removing the node. The process of
985  removing a node from a graph should be,
986  1. Remove out edges of this specified node.
987  2. Remove this specified node.
988  3. Add new input edges connected with all out nodes.
989  @returns true if the node_index was valid
990  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
991  */
992  bool RemoveNode(NodeIndex node_index);
993 
994  /** Add an edge between two Nodes.
995  @param src_node_index NodeIndex of source Node that is providing output to the destination Node.
996  @param dst_node_index NodeIndex of destination Node that is receiving input from the source Node.
997  @param src_arg_index node arg index of source node.
998  @param dst_arg_index node arg index of destination node.
999  */
1000  void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
1001 
1002  /** Remove an edge between two Nodes.
1003  @param src_node_index NodeIndex of source Node to remove an output edge from.
1004  @param dst_node_index NodeIndex of destination Node to remove an input edge from.
1005  @param src_arg_index node arg index of source node.
1006  @param dst_arg_index node arg index of destination node.
1007  */
1008  void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
1009 #endif
1010 
1011 #if !defined(ORT_MINIMAL_BUILD)
1012  /**
1013  Add a control edge between two Nodes in this Graph.
1014  The source Node does not produce output that is directly consumed by the destination Node, however the
1015  destination Node must execute after the source node. The control edge allows this ordering to occur.
1016  */
1017  bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index);
1018 #endif // !defined(ORT_MINIMAL_BUILD)
1019 
1020  /** Mark the Graph as needing Resolve() to be called.
1021  This should be done after modifying any aspect of the Graph that changes the Nodes or relationships between them. */
1023  graph_resolve_needed_ = true;
1024  return *this;
1025  }
1026 
1027  /** Gets flag indicating whether Graph::Resolve needs to be called before using the Graph. */
1028  bool GraphResolveNeeded() const noexcept {
1029  return graph_resolve_needed_;
1030  }
1031 
1032  /** Sets flag that Graph::graph_proto_ needs to be updated to reflect changes in the Graph. */
1034  graph_proto_sync_needed_ = true;
1035  return *this;
1036  }
1037 
1038  /** Gets flag indicating whether Graph::graph_proto_ needs to be synchronized with this Graph instance. */
1039  bool GraphProtoSyncNeeded() const noexcept {
1040  return graph_proto_sync_needed_;
1041  }
1042 
1043  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1044  up to their source node/s.
1045  @param from NodeIndex values for a set of Nodes to traverse from.
1046  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1047  @param leave Visit function invoked on the node after its parents have all been visited.
1048  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1049  */
1050  void ReverseDFSFrom(gsl::span<NodeIndex const> from,
1051  const std::function<void(const Node*)>& enter,
1052  const std::function<void(const Node*)>& leave,
1053  const std::function<bool(const Node*, const Node*)>& comp = {}) const;
1054 
1055  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1056  up to their source node/s.
1057  @param from Set of Nodes to traverse from.
1058  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1059  @param leave Visit function invoked on the node after its parents have all been visited.
1060  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1061  */
1062  void ReverseDFSFrom(gsl::span<const Node* const> from,
1063  const std::function<void(const Node*)>& enter,
1064  const std::function<void(const Node*)>& leave,
1065  const std::function<bool(const Node*, const Node*)>& comp = {}) const;
1066 
1067  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1068  up to their source node/s.
1069  @param from Set of Nodes to traverse from.
1070  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1071  @param leave Visit function invoked on the node after its parents have all been visited.
1072  @param stop Stop traversal from node n to input node p if stop(n, p) is true.
1073  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1074  */
1075  void ReverseDFSFrom(gsl::span<const Node* const> from,
1076  const std::function<void(const Node*)>& enter,
1077  const std::function<void(const Node*)>& leave,
1078  const std::function<bool(const Node*, const Node*)>& comp,
1079  const std::function<bool(const Node*, const Node*)>& stop) const;
1080 
1081 #if !defined(ORT_MINIMAL_BUILD)
1082  /** Performs topological sort with Kahn's algorithm on the graph/s.
1083  @param enter Visit function that will be invoked on a node when it is visited.
1084  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1085  */
1086  void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
1087  const std::function<bool(const Node*, const Node*)>& comp) const;
1088 
1089 #endif
1090 
1091  /** Gets the map of operator domains to their opset versions. */
1092  const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
1093  return domain_to_version_;
1094  }
1095 
1096 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1097  /**
1098  Create a single Node that will be the result of the a fusion of multiple nodes in this Graph.
1099  @param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
1100  @param fused_node_name The name for the new Node.
1101  @returns Node with fused subgraph.
1102  @remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
1103  IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
1104  while this is in use.
1105  Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
1106  */
1107  Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1108 
1109  void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
1110 #endif
1111 
1112 #if !defined(ORT_MINIMAL_BUILD)
1113  /** Gets the GraphProto representation of this Graph. */
1114  const ONNX_NAMESPACE::GraphProto& ToGraphProto();
1115  ONNX_NAMESPACE::GraphProto ToGraphProto() const;
1116 
1117  /** Gets the GraphProto representation of this Graph
1118  @params external_file_name name of the binary file to use for initializers
1119  @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved
1120  in the external file. Initializer smaller than this threshold are included in the onnx file.
1121  @returns GraphProto serialization of the graph.
1122  */
1123  ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name,
1124  const PathString& file_path,
1125  size_t initializer_size_threshold) const;
1126 
1127  /** Gets the ISchemaRegistry instances being used with this Graph. */
1129 
1130  /**
1131  Looks up the op schema in the schema registry and sets it for the given node.
1132  @param node The node to update.
1133  @return Whether the node's op schema was set to a valid value.
1134  */
1136 
1137  /**
1138  Create a single Function based Node that is the result of the a fusion of multiple nodes in this Graph.
1139  A new Graph instance will be created for the fused nodes.
1140  @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. Ownership is transferred to the new Node
1141  @param fused_node_name The name for the new Node.
1142  @returns Function based Node with fused subgraph. The Node body will contain a Function instance.
1143  */
1144  Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1145 
1146  /**
1147  Directly insert one of the If node branches into this Graph.
1148  `If` node condition must be a constant. The function would
1149  rename the nodes of the corresponding subgraph to make sure there is no conflict.
1150 
1151  Explicit and implicit inputs references stay the same.
1152 
1153  All of the outputs of the subgraph being inlined should be renamed
1154  to the outputs of the If node.
1155 
1156  The function will process any subgraphs in each of the nodes being inlined,
1157  and will rename any references to the new names introduced.
1158 
1159  @param condition_value If condition value
1160  @param if_node - the node that contains the graph_to_inline. This node is going
1161  to be deleted and replaced by the corresponding graph (either then or else)
1162  @param logger
1163  */
1164  Status InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger);
1165 
1166  /**
1167  Directly insert the nodes in the function Node provided into this Graph.
1168  The Graph needs to be Resolve()d after this call.
1169  @param node Node with Node::Type of Node::Type::Fused
1170  @returns Status indicating success or providing an error message.
1171  */
1172  Status InlineFunction(Node& node);
1173 
1174  /**
1175  Directly insert the nodes in the function proto provided into the graph.
1176  The function converts Constant nodes into the initializers in the graph.
1177  It then creates a node in the graph for each of the function nodes.
1178  All of the names are expected to be specialized, and, therefore unique.
1179  See function_utils::Specialize().
1180 
1181  The Graph needs to be Resolve()d after this call.
1182  @param func_to_inline
1183  @returns Status indicating success or providing an error message.
1184  */
1185 
1186  Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);
1187 
1188  /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
1189  be used as a GraphProto attribute in another Node.
1190  e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
1191  define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
1192  when the Graph is resolved.
1193  */
1195  ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name));
1196  }
1197 
1198  /** Explicitly set graph inputs.
1199  @param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered.
1200  @remarks Note that the input order matters for subgraphs.
1201  */
1202  void SetInputs(gsl::span<const NodeArg* const> inputs);
1203 
1204  void SetInputs(std::initializer_list<const NodeArg*> inputs) {
1205  SetInputs(AsSpan(inputs));
1206  }
1207 
1208  const Model& GetModel() const {
1209  return owning_model_;
1210  }
1211 
1212  const logging::Logger& GetLogger() const {
1213  return logger_;
1214  }
1215 
1216  /** Explicitly set graph outputs.
1217  @param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered.
1218  @remarks Note that the output order matters for subgraphs.
1219  */
1220  void SetOutputs(gsl::span<const NodeArg* const> outputs);
1221 
1222  void SetOutputs(std::initializer_list<const NodeArg*> outputs) {
1223  SetOutputs(AsSpan(outputs));
1224  }
1225 
1226 #endif // !defined(ORT_MINIMAL_BUILD)
1227 
1228 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1229  /** Sets the type of a NodeArg, replacing existing type/shape if any */
1230  void SetNodeArgType(NodeArg& arg, const ONNX_NAMESPACE::TypeProto& type_proto);
1231 
1232  const Node* GetProducerNode(const std::string& node_arg_name) const {
1233  return GetProducerNodeImpl(*this, node_arg_name);
1234  }
1235 
1236  Node* GetMutableProducerNode(const std::string& node_arg_name) {
1237  return GetProducerNodeImpl(*this, node_arg_name);
1238  }
1239 
1240  void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) {
1241  auto iter = node_arg_to_producer_node_.find(node_arg_name);
1242 
1243  if (iter != node_arg_to_producer_node_.end()) {
1244  iter->second = node_index;
1245  } else {
1246  node_arg_to_producer_node_[node_arg_name] = node_index;
1247  }
1248  }
1249 
1250  std::vector<const Node*> GetConsumerNodes(const std::string& node_arg_name) const {
1251  return GetConsumerNodesImpl(*this, node_arg_name);
1252  }
1253 
1254  // Without removing the existing consumers, add a consumer to the give node arg name.
1255  void AddConsumerNode(const std::string& node_arg_name, Node* consumer) {
1256  node_arg_to_consumer_nodes_[node_arg_name].insert(consumer->Index());
1257  }
1258 
1259  // Remove a consumer from the set
1260  void RemoveConsumerNode(const std::string& node_arg_name, Node* consumer) {
1261  node_arg_to_consumer_nodes_[node_arg_name].erase(consumer->Index());
1262  }
1263 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1264 
1265 #if !defined(ORT_MINIMAL_BUILD)
1266  std::vector<Node*> GetMutableConsumerNodes(const std::string& node_arg_name) {
1267  return GetConsumerNodesImpl(*this, node_arg_name);
1268  }
1269 
1270  void UpdateConsumerNodes(const std::string& node_arg_name, gsl::span<Node* const> nodes) {
1271  // Replace nodes for the arg
1272  auto& nodes_for_arg = node_arg_to_consumer_nodes_[node_arg_name];
1273  if (!nodes_for_arg.empty()) {
1274  nodes_for_arg.clear();
1275  }
1276 
1277  nodes_for_arg.reserve(nodes.size());
1278  for (Node* node : nodes) {
1279  nodes_for_arg.insert(node->Index());
1280  }
1281  }
1282 
1283  void UpdateConsumerNodes(const std::string& node_arg_name, std::initializer_list<Node*> nodes) {
1284  UpdateConsumerNodes(node_arg_name, AsSpan(nodes));
1285  }
1286 
1287  /** During constant folding it may become possible to infer the shape for a node.
1288  To avoid running a full Resolve allow an individual node to have the shape inferencing re-run.
1289  */
1291 
1292  // Options to control Graph::Resolve.
1294  // Whether to override existing types with inferred types.
1295  bool override_types = false;
1296  // Names of initializers to keep even if unused (optional).
1297  const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr;
1298  // Whether to set that no proto sync is required after resolving.
1299  // Useful for resolving right after loading from a GraphProto.
1301  };
1302 
1303  /**
1304  Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed.
1305  1. Run through all validation rules.
1306  a. Node name and node output's names should be unique.
1307  b. Attribute match between node and op definition.
1308  c. Input/Output match between node and op definition.
1309  d. Graph is acyclic and sort nodes in topological order.
1310  2. Check & Setup inner nodes' dependency.
1311  3. Cleanup function definition lists.
1312  Note: the weights for training can't be cleaned during resolve.
1313  @returns common::Status with success or error information.
1314  */
1315  common::Status Resolve(const ResolveOptions& options);
1316 
1318  ResolveOptions default_options;
1319  return Resolve(default_options);
1320  }
1321 
1322  const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept {
1323  return outer_scope_node_arg_names_;
1324  }
1325 
1326  common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
1327  flatbuffers::Offset<onnxruntime::fbs::Graph>& fbs_graph) const;
1328 
1329 #endif // !defined(ORT_MINIMAL_BUILD)
1330 
1331  /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */
1332  const Node* ParentNode() const { return parent_node_; }
1333 
1334  /** Returns true if the name is for a value that is coming from outer scope */
1335  bool IsOuterScopeValue(const std::string& name) const {
1336  if (!parent_node_) return false;
1337  const auto& implicit_input_defs = parent_node_->ImplicitInputDefs();
1338  return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
1339  [&name](const NodeArg* implicit_input) {
1340  return implicit_input->Name() == name;
1341  });
1342  }
1343 
1344 #if !defined(ORT_MINIMAL_BUILD)
1345  /** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node.
1346  Inherits some properties from the parent graph.
1347  @param parent_graph The Graph containing the Node that has the GraphProto attribute.
1348  @param parent_node The Node that has the GraphProto attribute.
1349  @param subgraph_proto The GraphProto from the Node attribute.
1350  */
1351  Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);
1352 
1353  Graph(const Model& owning_model,
1354  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1355  ONNX_NAMESPACE::GraphProto& subgraph_proto,
1356  const std::unordered_map<std::string, int>& domain_version_map,
1357  const logging::Logger& logger,
1358  bool strict_shape_type_inference);
1359 #endif
1360 
1361  virtual ~Graph();
1362 
1363  static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
1364  const std::unordered_map<std::string, int>& domain_to_version,
1365 #if !defined(ORT_MINIMAL_BUILD)
1366  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1367 #endif
1368  const OrtFormatLoadOptions& load_options,
1369  const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1370 
1371  // deserialize a subgraph
1372  static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
1373  Graph& parent_graph, const Node& parent_node,
1374  const OrtFormatLoadOptions& load_options,
1375  const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1376 
1377 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1378  const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
1379  return runtime_optimizations_;
1380  }
1381 
1382  RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() {
1383  return runtime_optimizations_;
1384  }
1385 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1386 
1387  // This friendship relationship should only be used to call Graph::Graph and
1388  // Graph::LoadGraph All other access should be via the public API.
1389  friend class Model;
1390 
1391  Graph() = delete;
1392 
1393  // Create empty Graph instance to re-create from ORT format serialized data.
1394  Graph(const Model& owning_model,
1395  const std::unordered_map<std::string, int>& domain_to_version,
1396 #if !defined(ORT_MINIMAL_BUILD)
1397  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1398 #endif
1399  Graph* parent_graph, const Node* parent_node,
1400  const logging::Logger& logger,
1401  bool strict_shape_type_inference);
1402 
1403  // Populate Graph instance from ORT format serialized data.
1404  Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
1405  const OrtFormatLoadOptions& load_options);
1406 
1407 #if !defined(ORT_MINIMAL_BUILD)
1408  // Constructor: Given a <GraphProto> loaded from model file, construct
1409  // a <Graph> object. Used by Model to create a Graph instance.
1410  Graph(const Model& owning_model,
1411  ONNX_NAMESPACE::GraphProto* graph_proto,
1412  const std::unordered_map<std::string, int>& domain_to_version,
1413  Version ir_version,
1414  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1415  const logging::Logger& logger,
1416  bool strict_shape_type_inference);
1417 
1418  // internal use by the Graph class only
1419  Graph(const Model& owning_model,
1420  ONNX_NAMESPACE::GraphProto* graph_proto,
1421  const std::unordered_map<std::string, int>& domain_to_version,
1422  Version ir_version,
1423  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1424  Graph* parent_graph,
1425  const Node* parent_node,
1426  const logging::Logger& logger,
1427  bool strict_shape_type_inference);
1428 
1430 
1431  private:
1432  void InitializeStateFromModelFileGraphProto();
1433 
1434  // Add node with specified <node_proto>.
1435  Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
1436  const ArgNameToTypeMap& name_to_type);
1437 
1438  /** Helper that converts and adds constant node proto to an initializer in the graph.
1439  @param constant_node_proto Constant node to convert
1440  @param new_name use the new name for the initializer.
1441  */
1442  Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
1443  std::optional<std::string_view> new_name);
1444 
1445 #endif
1446 
1447  Version IrVersion() const noexcept {
1448  return ir_version_;
1449  }
1450 
1451  Graph& GraphResolveNeeded(bool needed) noexcept {
1452  graph_resolve_needed_ = needed;
1453  return *this;
1454  }
1455 
1456  Graph& GraphProtoSyncNeeded(bool needed) noexcept {
1457  graph_proto_sync_needed_ = needed;
1458  return *this;
1459  }
1460 
1461  // During the Resolve of a Graph it is necessary to recursively descend into subgraphs (created from GraphProto
1462  // Node attributes in the Graph) if present.
1463  // The ResolveContext holds the collection of values for the current Graph instance, be it the main graph
1464  // or a subgraph, so that the various operations that are part of the Resolve can work iteratively or
1465  // recursively as needed.
1466  struct ResolveContext {
1467  ResolveContext(const Graph& owning_graph) : graph{owning_graph} {
1468  }
1469 
1470  std::unordered_map<std::string_view, std::pair<Node*, int>> output_args;
1471  std::unordered_set<std::string_view> inputs_and_initializers;
1472  std::unordered_map<std::string_view, NodeIndex> node_name_to_index;
1473  std::unordered_set<Node*> nodes_with_subgraphs;
1474 
1475  // check if the provided name is an input/initialize/node output of this Graph instance during Graph::Resolve.
1476  // Graph::node_args_ can have stale entries so we can't rely on that.
1477  bool IsLocalValue(const std::string& name) const;
1478 
1479  // check if an ancestor graph has a valid value with the provided name during Graph::Resolve.
1480  // Once Graph::Resolve completes Graph::IsOuterScopeValue can be used and is more efficient.
1481  bool IsOuterScopeValue(const std::string& name) const;
1482 
1483  void Clear() {
1484  output_args.clear();
1485  inputs_and_initializers.clear();
1486  node_name_to_index.clear();
1487  nodes_with_subgraphs.clear();
1488  }
1489 
1490  private:
1491  bool IsInputInitializerOrOutput(const std::string& name, bool check_ancestors) const;
1492 
1493  const Graph& graph;
1494  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ResolveContext);
1495  };
1496 
1497  // Initialize all the graph inputs, initializers and outputs
1498  common::Status InitInputsInitializersOutputs();
1499 
1500  // Initialize overridable initializers container
1501  void ComputeOverridableInitializers();
1502 
1503 #if !defined(ORT_MINIMAL_BUILD)
1504  // Build and verify node connection (edges).
1505  // Verify NodeArg name/type/shape matching correctly.
1506  common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
1507 
1508  common::Status VerifyNoDuplicateName();
1509 
1510  // Check whether <*this> graph is acyclic while performing a topological sort.
1511  // Depth-first going from bottom up through the graph and checking whether there are any back edges.
1512  // NodesInTopologicalOrder is updated with the nodes' indexes in topological
1513  // order if <Status> returned is "OK", otherwise it's undefined.
1514  common::Status PerformTopologicalSortAndCheckIsAcyclic();
1515 
1516  common::Status PerformTypeAndShapeInferencing(const ResolveOptions& options);
1517 
1518  // Recursively find all subgraphs including nested subgraphs
1519  void FindAllSubgraphs(std::vector<Graph*>& subgraphs);
1520 
1521  // Iterate this Graph instance and all subgraphs, calling the provided function for each.
1522  common::Status ForThisAndAllSubgraphs(const std::vector<Graph*>& subgraphs, std::function<Status(Graph&)> func);
1523 
1524  common::Status InferAndVerifyTypeMatch(Node& node, const ONNX_NAMESPACE::OpSchema& op, const ResolveOptions& options);
1525 
1526  // perform type and shape inferencing on the subgraph and Resolve to validate
1527  static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
1528  const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
1529  std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
1530  const Graph::ResolveOptions& options);
1531 
1532  // Apply type-inference and type-checking to all inputs and initializers:
1533  common::Status TypeCheckInputsAndInitializers();
1534 
1535  // Compute set of input and initializer names and checking for duplicate names
1536  common::Status VerifyInputAndInitializerNames();
1537 
1538  // Infer and set type information across <*this> graph if needed, and verify type/attribute
1539  // information matches between node and op.
1540 
1541  common::Status VerifyNodeAndOpMatch(const ResolveOptions& options);
1542 
1543  // Set graph inputs/outputs when resolving a graph..
1544  common::Status SetGraphInputsOutputs();
1545 
1546  // recursively accumulate and set the outer scope node args in the resolve context for all subgraphs
1547  // so they can be used to resolve outer scope dependencies when running BuildConnections for the subgraphs.
1548  common::Status SetOuterScopeNodeArgs(const std::unordered_set<std::string>& outer_scope_node_args);
1549 
1550  // Implementation for initializer replacement
1551  Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external);
1552 
1553  // Clear all unused initializers and NodeArgs
1554  void CleanUnusedInitializersAndNodeArgs(const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr);
1555 
1556  std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
1557  const ArgNameToTypeMap& name_to_type_map);
1558 
1559  void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const;
1560 
1561 #endif // !defined(ORT_MINIMAL_BUILD)
1562 
1563 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1564  Status PopulateNodeArgToProducerConsumerLookupsFromNodes();
1565 
1566  template <typename TInstance>
1567  static auto GetConsumerNodesImpl(
1568  TInstance& instance, const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
1569  std::vector<decltype(instance.GetNode(0))> results;
1570  auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
1571  if (iter != instance.node_arg_to_consumer_nodes_.end()) {
1572  results.reserve(iter->second.size());
1573  for (auto node_index : iter->second) {
1574  results.push_back(instance.GetNode(node_index));
1575  }
1576  }
1577  return results;
1578  }
1579 
1580  template <typename TInstance>
1581  static auto GetProducerNodeImpl(
1582  TInstance& instance, const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
1583  auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
1584  if (iter != instance.node_arg_to_producer_node_.end()) {
1585  auto node_index = iter->second;
1586  return instance.GetNode(node_index);
1587  }
1588  return nullptr;
1589  }
1590 
1591  gsl::not_null<Node*> AllocateNode();
1592 
1593  // Release the node.
1594  // @returns false if node_index was invalid.
1595  bool ReleaseNode(NodeIndex node_index);
1596 
1597  Node& CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1598 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1599 
1600  Node* NodeAtIndexImpl(NodeIndex node_index) const {
1601  // if we are trying to access a node that doesn't exist there's (most
1602  // likely) either a logic issue or a graph consistency/correctness issue.
1603  // use ORT_ENFORCE to prove that or uncover scenarios where we actually
1604  // expect attempts to retrieve a non-existent node.
1605  ORT_ENFORCE(node_index < nodes_.size(), "Validating no unexpected access using an invalid node_index. Got:",
1606  node_index, " Max:", nodes_.size());
1607  return nodes_[node_index].get();
1608  }
1609 
1610  const Model& owning_model_;
1611 
1612  // GraphProto to store name, version, initializer.
1613  // When serializing <*this> Graph to a GraphProto, the nodes and
1614  // functions in <Graph> will also be fed into <graph_proto_> so that
1615  // it's consistent with <*this> graph.
1616  // This pointer is owned by parent model.
1617  ONNX_NAMESPACE::GraphProto* graph_proto_;
1618 
1619  // GraphProto that provides storage for the ONNX proto types deserialized from a flexbuffer/flatbuffer
1620  ONNX_NAMESPACE::GraphProto deserialized_proto_data_;
1621 
1622  InitializedTensorSet name_to_initial_tensor_;
1623 
1624  std::unordered_set<std::reference_wrapper<const std::string>,
1625  std::hash<std::string>, std::equal_to<std::string>>
1626  sparse_tensor_names_;
1627 
1628 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1629  // Runtime optimization storage.
1630  // Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized
1631  std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
1632  RuntimeOptimizationRecordContainer& runtime_optimizations_;
1633 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1634 
1635 #if !defined(ORT_MINIMAL_BUILD)
1636  IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;
1637 
1638  // Currently to make the ORT in-memory graph work, we have to create a temporary op schema
1639  // for the fused kernel. I really don't like it. but for short-term solution, let's host
1640  // those schemas here.
1641  InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
1642  // in some case, a fused sub-graph will happens multiple times in one model, we use a map
1643  // to store reusable-schema in lookup.
1644  InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1645 #endif // !defined(ORT_MINIMAL_BUILD)
1646 
1647  // Graph nodes.
1648  // Element in <nodes_> may be nullptr due to graph optimization.
1649  std::vector<std::unique_ptr<Node>> nodes_;
1650 
1651  // Wrapper of Graph nodes to provide iteration services that hide nullptr entries
1652  GraphNodes iterable_nodes_{nodes_};
1653 
1654  // Number of nodes.
1655  // Normally this is smaller than the size of <m_nodes>, as some
1656  // elements in <m_nodes> may be removed when doing graph optimization,
1657  // or some elements may be merged, etc.
1658  int num_of_nodes_ = 0;
1659 
1660  // A flag indicates whether <*this> graph needs to be resolved.
1661  bool graph_resolve_needed_ = false;
1662 
1663  bool graph_proto_sync_needed_ = false;
1664 
1665  // The topological order of node index used to do node and op match verification temporarily.
1666  std::vector<NodeIndex> nodes_in_topological_order_;
1667 
1668  // Full list of graph inputs. Matches number and order of inputs in the GraphProto.
1669  std::vector<const NodeArg*> graph_inputs_including_initializers_;
1670  bool graph_inputs_manually_set_ = false;
1671 
1672  // Graph inputs excluding initializers.
1673  std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
1674 
1675  // Overridable Initializers. The difference between graph_inputs_including_initializers_
1676  // and graph_inputs_excluding_initializers_
1677  std::vector<const NodeArg*> graph_overridable_initializers_;
1678 
1679  // Graph outputs.
1680  std::vector<const NodeArg*> graph_outputs_;
1681  bool graph_outputs_manually_set_ = false;
1682 
1683  // Graph value_info.
1684  std::unordered_set<const NodeArg*> value_info_;
1685 
1686  // All node args owned by <*this> graph. Key is node arg name.
1687  std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
1688 
1689 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1690  int name_generator_ = 0;
1691 
1692  // Strings which have been used as node names.
1693  // New node name should not conflict with this set.
1694  std::unordered_set<std::string> generated_node_names_;
1695 
1696  // Strings which have been used as node_arg names.
1697  // New node_arg name should not conflict this this set.
1698  std::unordered_set<std::string> generated_node_arg_names_;
1699 
1700  // node arg to its producer node
1701  std::unordered_map<std::string, NodeIndex> node_arg_to_producer_node_;
1702 
1703  // node arg to its consumer nodes
1704  std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1705 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1706 
1707  const std::unordered_map<std::string, int> domain_to_version_;
1708 
1709  // Model IR version.
1710  Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
1711 
1712  ResolveContext resolve_context_{*this};
1713 
1714  // the parent graph if this is a subgraph.
1715  Graph* parent_graph_;
1716  // the node containing the graph if parent_graph_ is not nullptr
1717  const Node* parent_node_;
1718 
1719  // NodeArgs that come from outer scope. Used when building a graph so that
1720  // these don't get recorded as graph inputs in the GraphProto.
1721  std::unordered_set<std::string> outer_scope_node_arg_names_;
1722 
1723  // number of times Resolve has run.
1724  int num_resolves_ = 0;
1725 
1726  const logging::Logger& logger_;
1727 
1728  // If true, all inconsistencies encountered during shape and type inference
1729  // will be exposed to the caller as failures. If false, in some cases
1730  // warnings will be logged but processing will continue and no error will
1731  // be returned.
1732  const bool strict_shape_type_inference_;
1733 
1734  // distinguishes between graph loaded from model file and graph created from scratch
1735  const bool is_loaded_from_model_file_;
1736 };
1737 
1738 #if !defined(ORT_MINIMAL_BUILD)
1739 // Print NodeArg as
1740 // name : type
1741 // For example,
1742 // "110": tensor(float)
1743 std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg);
1744 // Print Node as,
1745 // (operator's name, operator's type, domain, version) : (input0, input1, ...) -> (output0, output1, ...)
1746 // For example,
1747 // ("Add_14", Add, "", 7) : ("110": tensor(float),"109": tensor(float),) -> ("111": tensor(float),)
1748 std::ostream& operator<<(std::ostream& out, const Node& node);
1749 // Print Graph as, for example,
1750 // Inputs:
1751 // "Input": tensor(float)
1752 // Nodes:
1753 // ("add0", Add, "", 7) : ("Input": tensor(float),"Bias": tensor(float),) -> ("add0_out": tensor(float),)
1754 // ("matmul", MatMul, "", 9) : ("add0_out": tensor(float),"matmul_weight": tensor(float),) -> ("matmul_out": tensor(float),)
1755 // ("add1", Add, "", 7) : ("matmul_out": tensor(float),"add_weight": tensor(float),) -> ("add1_out": tensor(float),)
1756 // ("reshape", Reshape, "", 5) : ("add1_out": tensor(float),"concat_out": tensor(int64),) -> ("Result": tensor(float),)
1757 // Outputs:
1758 // "Result": tensor(float)
1759 // Inputs' and outputs' format is described in document of NodeArg's operator<< above.
1760 // Node format is described in Node's operator<< above.
1761 std::ostream& operator<<(std::ostream& out, const Graph& graph);
1762 #endif
1763 
1764 } // namespace onnxruntime
constexpr auto AsSpan(C &c)
Definition: span_utils.h:42
void SetNodeArgType(NodeArg &arg, const ONNX_NAMESPACE::TypeProto &type_proto)
bool IsOuterScopeValue(const std::string &name) const
Definition: graph.h:1335
bool IsInitializedTensor(const std::string &name) const
void UpdateProducerNode(const std::string &node_arg_name, NodeIndex node_index)
Definition: graph.h:1240
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
Definition: basic_types.h:33
The node refers to a primitive operator.
const std::string & ProviderType
Definition: basic_types.h:35
const Node * GetNode(NodeIndex node_index) const
Definition: graph.h:842
void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value)
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
Definition: graph.h:744
void ForEachDef(std::function< void(const onnxruntime::NodeArg &, bool is_input)> func, bool include_missing_optional_defs=false) const
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
Definition: graph.h:216
const std::vector< int > & InputArgCount() const noexcept
Definition: graph.h:212
void SetInputs(gsl::span< const NodeArg *const > inputs)
const ONNX_NAMESPACE::GraphProto & ToGraphProto()
const Function * GetFunctionBody() const noexcept
Definition: graph.h:189
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Definition: basic_types.h:44
Node(NodeIndex index, Graph &graph)
Definition: graph.h:570
int MaxNodeIndex() const noexcept
Definition: graph.h:867
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph)
The node refers to a function.
const std::string & Description() const noexcept
Definition: graph.h:161
bool IsSubgraph() const
Definition: graph.h:683
void SetFunctionTemplate(const FunctionTemplate &func_template)
NodeIndex Index() const noexcept
Definition: graph.h:136
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
int GetDstArgIndex() const
Definition: graph.h:127
Definition: Node.h:52
const Node * operator->() const
void UpdateConsumerNodes(const std::string &node_arg_name, std::initializer_list< Node * > nodes)
Definition: graph.h:1283
Node & FuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
const RuntimeOptimizationRecordContainer & RuntimeOptimizations() const
Definition: graph.h:1378
bool NodeProducesGraphOutput(const Node &node) const
Definition: graph.h:805
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
const Graph * GetGraphAttribute(const std::string &attr_name) const
const NodeAttributes & GetAttributes() const noexcept
Definition: graph.h:395
const Node * ParentNode() const
Definition: graph.h:1332
size_t GetOutputEdgesCount() const noexcept
Definition: graph.h:349
common::Status InjectExternalInitializedTensors(const InlinedHashMap< std::string, OrtValue > &external_initializers)
RuntimeOptimizationRecordContainer & MutableRuntimeOptimizations()
Definition: graph.h:1382
static common::Status ForEachMutableWithIndex(std::vector< NodeArg * > &node_args, std::function< common::Status(NodeArg &arg, size_t index)> func)
Definition: graph.h:242
bool SetOpSchemaFromRegistryForNode(Node &node)
Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Node > &fbs_node) const
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
bool IsSparseInitializer(const std::string &name) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
Definition: graph.h:1092
void FinalizeFuseSubGraph(const IndexedSubGraph &sub_graph, Node &fused_node)
void ToProto(ONNX_NAMESPACE::NodeProto &proto, bool update_subgraphs=false) const
NodeArg * GetNodeArg(const std::string &name)
Definition: graph.h:876
std::vector< NodeArg * > & MutableImplicitInputDefs() noexcept
Definition: graph.h:254
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Definition: core.h:1736
const NodeArg * GetNodeArg(const std::string &name) const
Definition: graph.h:886
NodeConstIterator OutputNodesEnd() const noexcept
Definition: graph.h:326
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph &fbs_graph, const Model &owning_model, const std::unordered_map< std::string, int > &domain_to_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Graph > &graph)
std::vector< Node * > GetMutableConsumerNodes(const std::string &node_arg_name)
Definition: graph.h:1266
std::unordered_map< std::string, gsl::not_null< const Graph * > > GetAttributeNameToSubgraphMap() const
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
Definition: graph.h:229
bool IsInputsIncludingInitializers(const NodeArg *node_arg) const noexcept
Definition: graph.h:781
const ONNX_NAMESPACE::OpSchema * Op() const noexcept
Definition: graph.h:181
const std::string & OpType() const noexcept
Definition: graph.h:142
friend class Model
Definition: graph.h:1389
void SetOutputs(gsl::span< const NodeArg *const > outputs)
void SetExecutionProviderType(ProviderType execution_provider_type)
Definition: graph.h:463
void AddConsumerNode(const std::string &node_arg_name, Node *consumer)
Definition: graph.h:1255
NodeConstIterator InputNodesBegin() const noexcept
Definition: graph.h:316
basic_string_view< char > string_view
Definition: core.h:522
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Definition: graph.h:223
Node * GetNode(NodeIndex node_index)
Definition: graph.h:847
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions)
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
Definition: graph.h:833
const std::string & Description() const noexcept
GraphNodes & Nodes() noexcept
Definition: graph.h:850
Graph & SetGraphResolveNeeded() noexcept
Definition: graph.h:1022
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge &fbs_node_edgs, const Graph &graph)
std::vector< gsl::not_null< const Graph * > > GetSubgraphs() const
NodeConstIterator OutputNodesBegin() const noexcept
Definition: graph.h:321
std::set< EdgeEnd, EdgeEndCompare > EdgeSet
Definition: graph.h:289
NodeAttributes & GetMutableAttributes() noexcept
Definition: graph.h:402
bool RemoveNode(NodeIndex node_index)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:943
ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto)
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto &func_proto) const
int Priority() const noexcept
Definition: graph.h:154
int NumberOfNodes() const noexcept
Definition: graph.h:872
void UpdateConsumerNodes(const std::string &node_arg_name, gsl::span< Node *const > nodes)
Definition: graph.h:1270
static Status OK()
Definition: status.h:160
common::Status Resolve()
Definition: graph.h:1317
void AddAttribute(std::string attr_name, int64_t value)
std::string GenerateNodeName(const std::string &base_name)
void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
std::vector< int > & MutableInputArgsCount()
Definition: graph.h:263
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName)
Definition: common.h:219
bool ClearAttribute(const std::string &attr_name)
std::vector< int > GetNodeOutputsInGraphOutputs(const Node &node) const
Definition: graph.h:816
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index)
static common::Status ForEachWithIndex(const ConstPointerContainer< std::vector< NodeArg * >> &node_args, std::function< common::Status(const NodeArg &arg, size_t index)> func)
Definition: graph.h:200
const ONNX_NAMESPACE::TensorProto * GetInitializer(const std::string &name, bool check_outer_scope) const
Graph * MutableParentGraph()
Definition: graph.h:689
void SetSinceVersion(int since_version) noexcept
Definition: graph.h:176
const std::string & Name() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
Definition: basic_types.h:42
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
Definition: graph.h:442
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const
const Path & ModelPath() const noexcept
void RemoveConsumerNode(const std::string &node_arg_name, Node *consumer)
Definition: graph.h:1260
size_t GetInputEdgesCount() const noexcept
Definition: graph.h:346
constexpr const char * kOnnxDomain
Definition: constants.h:14
GLuint const GLchar * name
Definition: glcorearb.h:786
std::set< std::string > control_inputs
Definition: graph.h:561
Status InlineFunction(Node &node)
flatbuffers::Offset< onnxruntime::fbs::NodeEdge > SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder &builder) const
bool GraphProtoSyncNeeded() const noexcept
Definition: graph.h:1039
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto &func_to_inline)
Graph & SetGraphProtoSyncNeeded() noexcept
Definition: graph.h:1033
const std::vector< const NodeArg * > & GetOutputs() const noexcept
Definition: graph.h:796
void RemoveInitializedTensor(const std::string &tensor_name)
std::vector< int > input_arg_count
Definition: graph.h:523
bool IsOutput(const NodeArg *node_arg) const noexcept
Definition: graph.h:798
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Graph > &fbs_graph) const
NodeArg * GetNodeArgIncludingParentGraphs(const std::string &node_arg_name)
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope) const
void KahnsTopologicalSort(const std::function< void(const Node *)> &enter, const std::function< bool(const Node *, const Node *)> &comp) const
const Model & GetModel() const
Definition: graph.h:1208
void AddOuterScopeNodeArg(const std::string &name)
Definition: graph.h:1194
const std::unordered_set< std::string > * initializer_names_to_preserve
Definition: graph.h:1297
common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer)
const Node * GetProducerNode(const std::string &node_arg_name) const
Definition: graph.h:1232
void SetOutputs(std::initializer_list< const NodeArg * > outputs)
Definition: graph.h:1222
Status UpdateShapeInference(Node &node)
bool GraphResolveNeeded() const noexcept
Definition: graph.h:1028
const std::string & Name() const noexcept
Definition: graph.h:139
GLenum func
Definition: glcorearb.h:783
std::string GenerateNodeArgName(const std::string &base_name)
std::vector< NodeArg * > implicit_input_defs
Definition: graph.h:533
Graph * GetMutableGraphAttribute(const std::string &attr_name)
ProviderType GetExecutionProviderType() const noexcept
Definition: graph.h:460
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void SetPriority(int priority) noexcept
void SetInputs(std::initializer_list< const NodeArg * > inputs)
Definition: graph.h:1204
bool operator==(const NodeConstIterator &p_other) const
Node * GetMutableProducerNode(const std::string &node_arg_name)
Definition: graph.h:1236
const Path & ModelPath() const
Node & AddNode(const Node &other)
const std::string & Domain() const noexcept
Definition: graph.h:147
NodeArg & GetOrCreateNodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
Definition: graph.h:898
void SetName(const std::string &name)
GLuint index
Definition: glcorearb.h:786
const Node & GetNode() const noexcept
Definition: graph.h:119
bool StrictShapeTypeInference() const
Definition: graph.h:692
#define ORT_RETURN_IF_ERROR(expr)
Definition: common.h:233
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
Definition: graph.h:1250
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
Definition: graph.h:1322
GA_API const UT_StringHolder N
bool operator()(const EdgeEnd &lhs, const EdgeEnd &rhs) const
Definition: graph.h:278
ADD_ATTR_INTERFACES(float)
NodeConstIterator InputNodesEnd() const noexcept
Definition: graph.h:318
const std::set< std::string > & ControlInputs() const noexcept
Definition: graph.h:343
const logging::Logger & GetLogger() const
Definition: graph.h:1212
std::vector< NodeArg * > & MutableOutputDefs() noexcept
Definition: graph.h:271
EdgeEnd(const Node &node, int src_arg_index, int dst_arg_index) noexcept
int64_t Version
Definition: basic_types.h:31
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node &fbs_node, Graph &graph, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Node > &node)
const std::vector< const NodeArg * > & GetOverridableInitializers() const
Definition: graph.h:790
EdgeConstIterator InputEdgesBegin() const noexcept
Definition: graph.h:330
const std::vector< const NodeArg * > & GetInputs() const noexcept
Definition: graph.h:771
NodeConstIterator(EdgeConstIterator p_iter)
int SinceVersion() const noexcept
Definition: graph.h:170
Definition: core.h:1131
#define ORT_IGNORE_RETURN_VALUE(fn)
Definition: common.h:78
Node & BeginFuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
std::vector< NodeArg * > output_defs
Definition: graph.h:526
std::vector< NodeArg * > input_defs
Definition: graph.h:515
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto &tensor_proto)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, gsl::span< NodeArg *const > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:956
const GraphNodes & Nodes() const noexcept
Definition: graph.h:853
friend class Graph
Definition: graph.h:569
void ReplaceDefs(const std::map< const onnxruntime::NodeArg *, onnxruntime::NodeArg * > &replacements)
EdgeConstIterator InputEdgesEnd() const noexcept
Definition: graph.h:333
EdgeConstIterator OutputEdgesBegin() const noexcept
Definition: graph.h:337
std::unordered_map< std::string, ONNX_NAMESPACE::TypeProto > ArgNameToTypeMap
Definition: basic_types.h:34
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string &external_file_name, const PathString &file_path, size_t initializer_size_threshold) const
void CleanAllInitializedTensors() noexcept
void SetDescription(const std::string &description)
bool ContainsSubgraph() const
Definition: graph.h:430
ConstGraphNodes FilteredNodes(GraphNodes::NodeFilterFunc &&filter_func) const noexcept
Definition: graph.h:860
std::vector< NodeArg * > & MutableInputDefs() noexcept
Definition: graph.h:266
bool CanOverrideInitializer() const noexcept
Definition: graph.h:750
Node::Type NodeType() const noexcept
Definition: graph.h:164
Node(std::string_view name, std::string_view op_type, std::string_view description, gsl::span< NodeArg *const > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes, std::string_view domain)
Definition: graph.h:80
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
Definition: graph.h:776
void AddValueInfo(const NodeArg *new_value_info)
bool operator!=(const NodeConstIterator &p_other) const
Status InlineIfSubgraph(bool condition_value, Node &if_node, const logging::Logger &logger)
EdgeConstIterator OutputEdgesEnd() const noexcept
Definition: graph.h:340
EdgeSet::const_iterator EdgeConstIterator
Definition: graph.h:290
int PruneRemovableAttributes(gsl::span< const std::string > removable_attributes)
size_t NodeIndex
Definition: basic_types.h:30
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:969
int GetSrcArgIndex() const
Definition: graph.h:123
const Graph * ParentGraph() const
Definition: graph.h:686
std::unordered_map< std::string, gsl::not_null< Graph * > > & GetMutableMapOfAttributeNameToSubgraph()
Definition: graph.h:449
FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool
Definition: core.h:2089
void ReverseDFSFrom(gsl::span< NodeIndex const > from, const std::function< void(const Node *)> &enter, const std::function< void(const Node *)> &leave, const std::function< bool(const Node *, const Node *)> &comp={}) const
void AddAttribute(std::string attr_name, const char(&value)[N])
Definition: graph.h:390
bool CanBeInlined() const
bool Exists() const noexcept