10 #include <type_traits>
11 #include <unordered_map>
12 #include <unordered_set>
17 #pragma warning(disable : 4244)
24 #include "flatbuffers/flatbuffers.h"
30 #if !defined(ORT_MINIMAL_BUILD)
34 #include "core/common/path.h"
38 #include "core/graph/onnx_protobuf.h"
42 #if !defined(ORT_MINIMAL_BUILD)
43 #include "core/graph/function_template.h"
47 #include "core/graph/ort_format_load_options.h"
49 namespace onnxruntime {
51 struct IndexedSubGraph;
55 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
56 class RuntimeOptimizationRecordContainer;
77 explicit Node() =
default;
79 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
83 gsl::span<NodeArg* const> input_args,
84 gsl::span<NodeArg* const> output_args,
87 Init(name, op_type, description,
111 EdgeEnd(
const Node& node,
int src_arg_index,
int dst_arg_index) noexcept;
131 const int src_arg_index_;
132 const int dst_arg_index_;
154 int Priority() const noexcept {
return priority_; };
178 #if !defined(ORT_MINIMAL_BUILD)
181 const ONNX_NAMESPACE::OpSchema*
Op() const noexcept {
return op_; }
202 for (
size_t index = 0; index < node_args.size(); ++index) {
203 auto arg = node_args[index];
233 #if !defined(ORT_MINIMAL_BUILD)
244 for (
size_t index = 0; index < node_args.size(); ++index) {
245 auto arg = node_args[index];
257 #endif // !defined(ORT_MINIMAL_BUILD)
259 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
274 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
289 using EdgeSet = std::set<EdgeEnd, EdgeEndCompare>;
363 #define ADD_ATTR_SINGLE_INTERFACE(Type) \
364 void AddAttribute(std::string attr_name, Type value)
366 #define ADD_ATTR_LIST_INTERFACE(Type) \
367 void AddAttribute(std::string attr_name, gsl::span<const Type> values)
369 #define ADD_ATTR_INTERFACES(Type) \
370 ADD_ATTR_SINGLE_INTERFACE(Type); \
371 ADD_ATTR_LIST_INTERFACE(Type)
376 #if !defined(DISABLE_SPARSE_TENSORS)
383 #undef ADD_ATTR_SINGLE_INTERFACE
384 #undef ADD_ATTR_LIST_INTERFACE
385 #undef ADD_ATTR_INTERFACES
397 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
404 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
412 #if !defined(ORT_MINIMAL_BUILD)
425 #endif // !defined(ORT_MINIMAL_BUILD)
431 return !attr_to_subgraph_map_.empty();
436 std::vector<gsl::not_null<const Graph*>>
GetSubgraphs()
const;
443 return attr_to_subgraph_map_;
450 return attr_to_subgraph_map_;
464 execution_provider_type_ = execution_provider_type;
473 bool include_missing_optional_defs =
false)
const;
475 #if !defined(ORT_MINIMAL_BUILD)
479 void ReplaceDefs(
const std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*>& replacements);
486 void ToProto(ONNX_NAMESPACE::NodeProto& proto,
bool update_subgraphs =
false)
const;
489 flatbuffers::Offset<onnxruntime::fbs::Node>& fbs_node)
const;
491 flatbuffers::Offset<onnxruntime::fbs::NodeEdge>
498 const OrtFormatLoadOptions& load_options,
502 const OrtFormatLoadOptions& load_options,
575 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
579 gsl::span<NodeArg* const> input_args,
580 gsl::span<NodeArg* const> output_args,
585 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
587 Definitions& MutableDefinitions() noexcept;
590 Relationships& MutableRelationships() noexcept;
592 void SetNodeType(
Node::
Type node_type) noexcept { node_type_ = node_type; }
598 std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept {
return subgraphs_; }
603 const Definitions& GetDefinitions() const noexcept {
return definitions_; }
604 const Relationships& GetRelationships() const noexcept {
return relationships_; }
618 #if !defined(ORT_MINIMAL_BUILD)
620 const ONNX_NAMESPACE::OpSchema* op_ =
nullptr;
623 const FunctionTemplate* func_template_ =
nullptr;
630 int since_version_ = -1;
635 std::unique_ptr<Function> func_body_ =
nullptr;
641 Definitions definitions_;
644 Relationships relationships_;
654 Graph* graph_ =
nullptr;
657 std::unordered_map<std::string, gsl::not_null<Graph*>> attr_to_subgraph_map_;
660 std::vector<std::unique_ptr<Graph>> subgraphs_;
694 #if !defined(ORT_MINIMAL_BUILD)
709 #if !defined(DISABLE_EXTERNAL_INITIALIZERS)
714 #endif // !defined(DISABLE_EXTERNAL_INITIALIZERS)
716 #endif // !defined(ORT_MINIMAL_BUILD)
718 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
729 #if !defined(DISABLE_SPARSE_TENSORS)
771 const std::vector<const NodeArg*>&
GetInputs() const noexcept {
return graph_inputs_excluding_initializers_; }
777 return graph_inputs_including_initializers_;
782 return std::find(graph_inputs_including_initializers_.begin(),
783 graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
791 return graph_overridable_initializers_;
796 const std::vector<const NodeArg*>&
GetOutputs() const noexcept {
return graph_outputs_; }
799 return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
806 auto end_outputs = graph_outputs_.cend();
808 if (
std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
818 std::vector<int> indexes;
821 indexes.push_back(output_idx);
833 const std::unordered_set<const NodeArg*>&
GetValueInfo() const noexcept {
return value_info_; }
835 #if !defined(ORT_MINIMAL_BUILD)
867 int MaxNodeIndex() const noexcept {
return static_cast<int>(nodes_.size()); }
877 auto iter = node_args_.find(name);
878 if (iter != node_args_.end()) {
879 return iter->second.get();
899 auto insert_result = node_args_.emplace(name,
nullptr);
900 if (insert_result.second) {
901 insert_result.first->second = std::make_unique<NodeArg>(name, p_arg_type);
903 return *(insert_result.first->second);
906 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
912 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
914 #if !defined(ORT_MINIMAL_BUILD)
923 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
938 gsl::span<NodeArg* const> input_args,
939 gsl::span<NodeArg* const> output_args,
946 std::initializer_list<NodeArg*> input_args,
947 std::initializer_list<NodeArg*> output_args,
950 return AddNode(name, op_type, description,
959 gsl::span<NodeArg* const> input_args,
960 std::initializer_list<NodeArg*> output_args,
963 return AddNode(name, op_type, description,
972 std::initializer_list<NodeArg*> input_args,
973 gsl::span<NodeArg* const> output_args,
976 return AddNode(name, op_type, description,
1011 #if !defined(ORT_MINIMAL_BUILD)
1018 #endif // !defined(ORT_MINIMAL_BUILD)
1023 graph_resolve_needed_ =
true;
1029 return graph_resolve_needed_;
1034 graph_proto_sync_needed_ =
true;
1040 return graph_proto_sync_needed_;
1051 const std::function<
void(
const Node*)>& enter,
1052 const std::function<
void(
const Node*)>& leave,
1053 const std::function<
bool(
const Node*,
const Node*)>& comp = {})
const;
1063 const std::function<
void(
const Node*)>& enter,
1064 const std::function<
void(
const Node*)>& leave,
1065 const std::function<
bool(
const Node*,
const Node*)>& comp = {})
const;
1076 const std::function<
void(
const Node*)>& enter,
1077 const std::function<
void(
const Node*)>& leave,
1078 const std::function<
bool(
const Node*,
const Node*)>& comp,
1079 const std::function<
bool(
const Node*,
const Node*)>& stop)
const;
1081 #if !defined(ORT_MINIMAL_BUILD)
1087 const std::function<
bool(
const Node*,
const Node*)>& comp)
const;
1093 return domain_to_version_;
1096 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1112 #if !defined(ORT_MINIMAL_BUILD)
1124 const PathString& file_path,
1125 size_t initializer_size_threshold)
const;
1202 void SetInputs(gsl::span<const NodeArg* const> inputs);
1204 void SetInputs(std::initializer_list<const NodeArg*> inputs) {
1209 return owning_model_;
1220 void SetOutputs(gsl::span<const NodeArg* const> outputs);
1226 #endif // !defined(ORT_MINIMAL_BUILD)
1228 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1233 return GetProducerNodeImpl(*
this, node_arg_name);
1237 return GetProducerNodeImpl(*
this, node_arg_name);
1241 auto iter = node_arg_to_producer_node_.find(node_arg_name);
1243 if (iter != node_arg_to_producer_node_.end()) {
1244 iter->second = node_index;
1246 node_arg_to_producer_node_[node_arg_name] = node_index;
1251 return GetConsumerNodesImpl(*
this, node_arg_name);
1256 node_arg_to_consumer_nodes_[node_arg_name].insert(consumer->
Index());
1261 node_arg_to_consumer_nodes_[node_arg_name].erase(consumer->
Index());
1263 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1265 #if !defined(ORT_MINIMAL_BUILD)
1267 return GetConsumerNodesImpl(*
this, node_arg_name);
1272 auto& nodes_for_arg = node_arg_to_consumer_nodes_[node_arg_name];
1273 if (!nodes_for_arg.empty()) {
1274 nodes_for_arg.clear();
1277 nodes_for_arg.reserve(nodes.size());
1278 for (
Node* node : nodes) {
1279 nodes_for_arg.insert(node->Index());
1319 return Resolve(default_options);
1323 return outer_scope_node_arg_names_;
1327 flatbuffers::Offset<onnxruntime::fbs::Graph>& fbs_graph)
const;
1329 #endif // !defined(ORT_MINIMAL_BUILD)
1336 if (!parent_node_)
return false;
1338 return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
1339 [&name](
const NodeArg* implicit_input) {
1340 return implicit_input->Name() == name;
1344 #if !defined(ORT_MINIMAL_BUILD)
1351 Graph(
Graph& parent_graph,
const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);
1355 ONNX_NAMESPACE::GraphProto& subgraph_proto,
1356 const std::unordered_map<std::string, int>& domain_version_map,
1358 bool strict_shape_type_inference);
1364 const std::unordered_map<std::string, int>& domain_to_version,
1365 #
if !defined(ORT_MINIMAL_BUILD)
1368 const OrtFormatLoadOptions& load_options,
1373 Graph& parent_graph,
const Node& parent_node,
1374 const OrtFormatLoadOptions& load_options,
1377 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1379 return runtime_optimizations_;
1383 return runtime_optimizations_;
1385 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1395 const std::unordered_map<std::string, int>& domain_to_version,
1396 #
if !defined(ORT_MINIMAL_BUILD)
1399 Graph* parent_graph,
const Node* parent_node,
1401 bool strict_shape_type_inference);
1405 const OrtFormatLoadOptions& load_options);
1407 #if !defined(ORT_MINIMAL_BUILD)
1411 ONNX_NAMESPACE::GraphProto* graph_proto,
1412 const std::unordered_map<std::string, int>& domain_to_version,
1416 bool strict_shape_type_inference);
1420 ONNX_NAMESPACE::GraphProto* graph_proto,
1421 const std::unordered_map<std::string, int>& domain_to_version,
1424 Graph* parent_graph,
1425 const Node* parent_node,
1427 bool strict_shape_type_inference);
1432 void InitializeStateFromModelFileGraphProto();
1435 Node&
AddNode(
const ONNX_NAMESPACE::NodeProto& node_proto,
1442 Status AddConstantProtoAsInitializer(
const ONNX_NAMESPACE::NodeProto& constant_node_proto,
1443 std::optional<std::string_view> new_name);
1447 Version IrVersion() const noexcept {
1452 graph_resolve_needed_ = needed;
1457 graph_proto_sync_needed_ = needed;
1466 struct ResolveContext {
1467 ResolveContext(
const Graph& owning_graph) : graph{owning_graph} {
1470 std::unordered_map<std::string_view, std::pair<Node*, int>> output_args;
1471 std::unordered_set<std::string_view> inputs_and_initializers;
1472 std::unordered_map<std::string_view, NodeIndex> node_name_to_index;
1473 std::unordered_set<Node*> nodes_with_subgraphs;
1484 output_args.clear();
1485 inputs_and_initializers.clear();
1486 node_name_to_index.clear();
1487 nodes_with_subgraphs.clear();
1491 bool IsInputInitializerOrOutput(
const std::string&
name,
bool check_ancestors)
const;
1501 void ComputeOverridableInitializers();
1503 #if !defined(ORT_MINIMAL_BUILD)
1506 common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
1516 common::Status PerformTypeAndShapeInferencing(
const ResolveOptions& options);
1519 void FindAllSubgraphs(std::vector<Graph*>& subgraphs);
1524 common::Status InferAndVerifyTypeMatch(
Node& node,
const ONNX_NAMESPACE::OpSchema& op,
const ResolveOptions& options);
1528 const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
1529 std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
1530 const Graph::ResolveOptions& options);
1541 common::Status VerifyNodeAndOpMatch(
const ResolveOptions& options);
1548 common::Status SetOuterScopeNodeArgs(
const std::unordered_set<std::string>& outer_scope_node_args);
1551 Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer,
bool is_external);
1554 void CleanUnusedInitializersAndNodeArgs(
const std::unordered_set<std::string>* initializer_names_to_preserve =
nullptr);
1556 std::vector<NodeArg*> CreateNodeArgs(
const google::protobuf::RepeatedPtrField<std::string>& names,
1559 void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto)
const;
1561 #endif // !defined(ORT_MINIMAL_BUILD)
1563 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1564 Status PopulateNodeArgToProducerConsumerLookupsFromNodes();
1566 template <
typename TInstance>
1567 static auto GetConsumerNodesImpl(
1568 TInstance& instance,
const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
1569 std::vector<decltype(instance.GetNode(0))> results;
1570 auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
1571 if (iter != instance.node_arg_to_consumer_nodes_.end()) {
1572 results.reserve(iter->second.size());
1573 for (
auto node_index : iter->second) {
1574 results.push_back(instance.GetNode(node_index));
1580 template <
typename TInstance>
1581 static auto GetProducerNodeImpl(
1582 TInstance& instance,
const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
1583 auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
1584 if (iter != instance.node_arg_to_producer_node_.end()) {
1585 auto node_index = iter->second;
1586 return instance.GetNode(node_index);
1591 gsl::not_null<Node*> AllocateNode();
1597 Node& CreateFusedSubGraphNode(
const IndexedSubGraph& sub_graph,
const std::string& fused_node_name);
1598 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1605 ORT_ENFORCE(node_index < nodes_.size(),
"Validating no unexpected access using an invalid node_index. Got:",
1606 node_index,
" Max:", nodes_.size());
1607 return nodes_[node_index].get();
1610 const Model& owning_model_;
1617 ONNX_NAMESPACE::GraphProto* graph_proto_;
1620 ONNX_NAMESPACE::GraphProto deserialized_proto_data_;
1624 std::unordered_set<std::reference_wrapper<const std::string>,
1625 std::hash<std::string>, std::equal_to<std::string>>
1626 sparse_tensor_names_;
1628 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1631 std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
1632 RuntimeOptimizationRecordContainer& runtime_optimizations_;
1633 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1635 #if !defined(ORT_MINIMAL_BUILD)
1641 InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
1644 InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1645 #endif // !defined(ORT_MINIMAL_BUILD)
1649 std::vector<std::unique_ptr<Node>> nodes_;
1652 GraphNodes iterable_nodes_{nodes_};
1658 int num_of_nodes_ = 0;
1661 bool graph_resolve_needed_ =
false;
1663 bool graph_proto_sync_needed_ =
false;
1666 std::vector<NodeIndex> nodes_in_topological_order_;
1669 std::vector<const NodeArg*> graph_inputs_including_initializers_;
1670 bool graph_inputs_manually_set_ =
false;
1673 std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
1677 std::vector<const NodeArg*> graph_overridable_initializers_;
1680 std::vector<const NodeArg*> graph_outputs_;
1681 bool graph_outputs_manually_set_ =
false;
1684 std::unordered_set<const NodeArg*> value_info_;
1687 std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
1689 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1690 int name_generator_ = 0;
1694 std::unordered_set<std::string> generated_node_names_;
1698 std::unordered_set<std::string> generated_node_arg_names_;
1701 std::unordered_map<std::string, NodeIndex> node_arg_to_producer_node_;
1704 std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1705 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1707 const std::unordered_map<std::string, int> domain_to_version_;
1710 Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
1712 ResolveContext resolve_context_{*
this};
1715 Graph* parent_graph_;
1717 const Node* parent_node_;
1721 std::unordered_set<std::string> outer_scope_node_arg_names_;
1724 int num_resolves_ = 0;
1726 const logging::Logger& logger_;
1732 const bool strict_shape_type_inference_;
1735 const bool is_loaded_from_model_file_;
1738 #if !defined(ORT_MINIMAL_BUILD)
1743 std::ostream&
operator<<(std::ostream& out,
const NodeArg& node_arg);
1761 std::ostream&
operator<<(std::ostream& out,
const Graph& graph);
constexpr auto AsSpan(C &c)
void SetNodeArgType(NodeArg &arg, const ONNX_NAMESPACE::TypeProto &type_proto)
bool IsOuterScopeValue(const std::string &name) const
bool IsInitializedTensor(const std::string &name) const
void UpdateProducerNode(const std::string &node_arg_name, NodeIndex node_index)
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
The node refers to a primitive operator.
const std::string & ProviderType
const Node * GetNode(NodeIndex node_index) const
void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value)
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
void ForEachDef(std::function< void(const onnxruntime::NodeArg &, bool is_input)> func, bool include_missing_optional_defs=false) const
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
const std::vector< int > & InputArgCount() const noexcept
void SetInputs(gsl::span< const NodeArg *const > inputs)
const ONNX_NAMESPACE::GraphProto & ToGraphProto()
const Function * GetFunctionBody() const noexcept
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Node(NodeIndex index, Graph &graph)
int MaxNodeIndex() const noexcept
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph)
The node refers to a function.
const std::string & Description() const noexcept
void SetFunctionTemplate(const FunctionTemplate &func_template)
NodeIndex Index() const noexcept
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
int GetDstArgIndex() const
const Node * operator->() const
void UpdateConsumerNodes(const std::string &node_arg_name, std::initializer_list< Node * > nodes)
Node & FuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
const RuntimeOptimizationRecordContainer & RuntimeOptimizations() const
bool NodeProducesGraphOutput(const Node &node) const
GLsizei const GLchar *const * string
const Graph * GetGraphAttribute(const std::string &attr_name) const
const NodeAttributes & GetAttributes() const noexcept
const Node * ParentNode() const
size_t GetOutputEdgesCount() const noexcept
common::Status InjectExternalInitializedTensors(const InlinedHashMap< std::string, OrtValue > &external_initializers)
RuntimeOptimizationRecordContainer & MutableRuntimeOptimizations()
static common::Status ForEachMutableWithIndex(std::vector< NodeArg * > &node_args, std::function< common::Status(NodeArg &arg, size_t index)> func)
bool SetOpSchemaFromRegistryForNode(Node &node)
Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Node > &fbs_node) const
#define ORT_ENFORCE(condition,...)
bool IsSparseInitializer(const std::string &name) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
void FinalizeFuseSubGraph(const IndexedSubGraph &sub_graph, Node &fused_node)
void ToProto(ONNX_NAMESPACE::NodeProto &proto, bool update_subgraphs=false) const
NodeArg * GetNodeArg(const std::string &name)
std::vector< NodeArg * > & MutableImplicitInputDefs() noexcept
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
const NodeArg * GetNodeArg(const std::string &name) const
NodeConstIterator OutputNodesEnd() const noexcept
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph &fbs_graph, const Model &owning_model, const std::unordered_map< std::string, int > &domain_to_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Graph > &graph)
std::vector< Node * > GetMutableConsumerNodes(const std::string &node_arg_name)
std::unordered_map< std::string, gsl::not_null< const Graph * > > GetAttributeNameToSubgraphMap() const
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
bool IsInputsIncludingInitializers(const NodeArg *node_arg) const noexcept
const ONNX_NAMESPACE::OpSchema * Op() const noexcept
const std::string & OpType() const noexcept
void SetOutputs(gsl::span< const NodeArg *const > outputs)
void SetExecutionProviderType(ProviderType execution_provider_type)
void AddConsumerNode(const std::string &node_arg_name, Node *consumer)
NodeConstIterator InputNodesBegin() const noexcept
basic_string_view< char > string_view
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Node * GetNode(NodeIndex node_index)
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions)
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
const std::string & Description() const noexcept
GraphNodes & Nodes() noexcept
Graph & SetGraphResolveNeeded() noexcept
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge &fbs_node_edgs, const Graph &graph)
std::vector< gsl::not_null< const Graph * > > GetSubgraphs() const
NodeConstIterator OutputNodesBegin() const noexcept
std::set< EdgeEnd, EdgeEndCompare > EdgeSet
NodeAttributes & GetMutableAttributes() noexcept
bool RemoveNode(NodeIndex node_index)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto)
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto &func_proto) const
int Priority() const noexcept
int NumberOfNodes() const noexcept
void UpdateConsumerNodes(const std::string &node_arg_name, gsl::span< Node *const > nodes)
void AddAttribute(std::string attr_name, int64_t value)
std::string GenerateNodeName(const std::string &base_name)
const Node & operator*() const
void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
std::vector< int > & MutableInputArgsCount()
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName)
bool ClearAttribute(const std::string &attr_name)
std::vector< int > GetNodeOutputsInGraphOutputs(const Node &node) const
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index)
static common::Status ForEachWithIndex(const ConstPointerContainer< std::vector< NodeArg * >> &node_args, std::function< common::Status(const NodeArg &arg, size_t index)> func)
const ONNX_NAMESPACE::TensorProto * GetInitializer(const std::string &name, bool check_outer_scope) const
Graph * MutableParentGraph()
void SetSinceVersion(int since_version) noexcept
const std::string & Name() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const
const Path & ModelPath() const noexcept
void RemoveConsumerNode(const std::string &node_arg_name, Node *consumer)
size_t GetInputEdgesCount() const noexcept
constexpr const char * kOnnxDomain
GLuint const GLchar * name
std::set< std::string > control_inputs
bool no_proto_sync_required
std::function< bool(NodeIndex)> NodeFilterFunc
Status InlineFunction(Node &node)
flatbuffers::Offset< onnxruntime::fbs::NodeEdge > SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder &builder) const
bool GraphProtoSyncNeeded() const noexcept
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto &func_to_inline)
Graph & SetGraphProtoSyncNeeded() noexcept
const std::vector< const NodeArg * > & GetOutputs() const noexcept
void RemoveInitializedTensor(const std::string &tensor_name)
std::vector< int > input_arg_count
bool IsOutput(const NodeArg *node_arg) const noexcept
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Graph > &fbs_graph) const
NodeArg * GetNodeArgIncludingParentGraphs(const std::string &node_arg_name)
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope) const
void KahnsTopologicalSort(const std::function< void(const Node *)> &enter, const std::function< bool(const Node *, const Node *)> &comp) const
const Model & GetModel() const
void AddOuterScopeNodeArg(const std::string &name)
const std::unordered_set< std::string > * initializer_names_to_preserve
common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer)
const Node * GetProducerNode(const std::string &node_arg_name) const
void SetOutputs(std::initializer_list< const NodeArg * > outputs)
Status UpdateShapeInference(Node &node)
bool GraphResolveNeeded() const noexcept
const std::string & Name() const noexcept
std::string GenerateNodeArgName(const std::string &base_name)
std::vector< NodeArg * > implicit_input_defs
Graph * GetMutableGraphAttribute(const std::string &attr_name)
ProviderType GetExecutionProviderType() const noexcept
GLenum GLsizei GLsizei GLint * values
void SetPriority(int priority) noexcept
void SetInputs(std::initializer_list< const NodeArg * > inputs)
bool operator==(const NodeConstIterator &p_other) const
Node * GetMutableProducerNode(const std::string &node_arg_name)
const Path & ModelPath() const
Node & AddNode(const Node &other)
const std::string & Domain() const noexcept
NodeArg & GetOrCreateNodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
void SetName(const std::string &name)
const Node & GetNode() const noexcept
bool StrictShapeTypeInference() const
#define ORT_RETURN_IF_ERROR(expr)
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
GA_API const UT_StringHolder N
bool operator()(const EdgeEnd &lhs, const EdgeEnd &rhs) const
ADD_ATTR_INTERFACES(float)
NodeConstIterator InputNodesEnd() const noexcept
const std::set< std::string > & ControlInputs() const noexcept
const logging::Logger & GetLogger() const
std::vector< NodeArg * > & MutableOutputDefs() noexcept
EdgeEnd(const Node &node, int src_arg_index, int dst_arg_index) noexcept
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node &fbs_node, Graph &graph, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Node > &node)
const std::vector< const NodeArg * > & GetOverridableInitializers() const
EdgeConstIterator InputEdgesBegin() const noexcept
const std::vector< const NodeArg * > & GetInputs() const noexcept
NodeConstIterator(EdgeConstIterator p_iter)
int SinceVersion() const noexcept
#define ORT_IGNORE_RETURN_VALUE(fn)
Node & BeginFuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
std::vector< NodeArg * > output_defs
std::vector< NodeArg * > input_defs
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto &tensor_proto)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, gsl::span< NodeArg *const > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
const GraphNodes & Nodes() const noexcept
void ReplaceDefs(const std::map< const onnxruntime::NodeArg *, onnxruntime::NodeArg * > &replacements)
EdgeConstIterator InputEdgesEnd() const noexcept
EdgeConstIterator OutputEdgesBegin() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::TypeProto > ArgNameToTypeMap
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string &external_file_name, const PathString &file_path, size_t initializer_size_threshold) const
void CleanAllInitializedTensors() noexcept
void SetDescription(const std::string &description)
bool ContainsSubgraph() const
ConstGraphNodes FilteredNodes(GraphNodes::NodeFilterFunc &&filter_func) const noexcept
std::vector< NodeArg * > & MutableInputDefs() noexcept
bool CanOverrideInitializer() const noexcept
Node::Type NodeType() const noexcept
Node(std::string_view name, std::string_view op_type, std::string_view description, gsl::span< NodeArg *const > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes, std::string_view domain)
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
void AddValueInfo(const NodeArg *new_value_info)
bool operator!=(const NodeConstIterator &p_other) const
Status InlineIfSubgraph(bool condition_value, Node &if_node, const logging::Logger &logger)
EdgeConstIterator OutputEdgesEnd() const noexcept
EdgeSet::const_iterator EdgeConstIterator
int PruneRemovableAttributes(gsl::span< const std::string > removable_attributes)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
int GetSrcArgIndex() const
const Graph * ParentGraph() const
std::unordered_map< std::string, gsl::not_null< Graph * > > & GetMutableMapOfAttributeNameToSubgraph()
FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool
void ReverseDFSFrom(gsl::span< NodeIndex const > from, const std::function< void(const Node *)> &enter, const std::function< void(const Node *)> &leave, const std::function< bool(const Node *, const Node *)> &comp={}) const
void AddAttribute(std::string attr_name, const char(&value)[N])
bool CanBeInlined() const
bool Exists() const noexcept