HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph_transformer.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 #include <string>
6 
7 #include "core/common/common.h"
12 
13 namespace onnxruntime {
14 
15 /**
16 @class GraphTransformer
17 
18 The interface for in-place transformation of a Graph.
19 */
21  public:
23  const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
24  : name_(name), compatible_provider_types_(compatible_execution_providers) {
25  }
26 
27  virtual ~GraphTransformer() = default;
28 
29  /** Gets the name of this graph transformer. */
30  const std::string& Name() const noexcept {
31  return name_;
32  }
33 
35  return compatible_provider_types_;
36  }
37 
38  /** Apply the in-place transformation defined by this transformer to the provided Graph instance.
39  @param[out] modified Set to true if the Graph was modified.
40  @returns Status with success or error information.
41  */
42  Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
43 
44  virtual bool ShouldOnlyApplyOnce() const { return false; }
45 
46  protected:
47  /** Helper method to call ApplyImpl on any subgraphs in the Node. */
48  Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {
49  int subgraph_level = ++graph_level;
50  for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
51  auto& subgraph = *entry.second;
52  ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level, logger));
53  }
54 
55  return Status::OK();
56  }
57 
58  private:
59  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer);
60 
61  // Apply the transform to the graph.
62  // graph_level is 0 for the main graph, and is incremented when descending into the subgraph of a node.
63  // You MUST call Recurse for all valid Nodes in the graph to ensure any subgraphs in control flow nodes
64  // (Scan/If/Loop) are processed as well.
65  // You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases
66  // the call to Graph::Resolve in GraphTransformer::Apply after the call to ApplyImpl (if 'modified' is true)
67  // should suffice.
68  virtual Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const = 0;
69 
70  const std::string name_;
71  const InlinedHashSet<std::string_view> compatible_provider_types_;
72 };
73 
74 /**
75  * @brief Immutable object to identify a kernel registration.
76  *
77  * This data structure is used by the graph transformers to check whether
78  * a kernel is registered with the execution provider (i.e. has an
79  * implementation). If not, the transformer can not generate a node with
80  * such kernel.
81  */
85  const int version_;
87 
89  const std::basic_string_view<char>& op,
90  const std::basic_string_view<char>& domain,
91  const int version,
92  const std::initializer_list<std::pair<const std::string, MLDataType>>& init_list)
93  : op_type_(op), domain_(domain), version_(version), type_constraints_(init_list) {}
94 };
95 
96 } // namespace onnxruntime
Status Recurse(Node &node, bool &modified, int graph_level, const logging::Logger &logger) const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
Immutable object to identify a kernel registration.
virtual bool ShouldOnlyApplyOnce() const
virtual ~GraphTransformer()=default
const InlinedHashSet< std::string_view > & GetCompatibleExecutionProviders() const noexcept
OpKernelRegistryId(const std::basic_string_view< char > &op, const std::basic_string_view< char > &domain, const int version, const std::initializer_list< std::pair< const std::string, MLDataType >> &init_list)
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
Definition: graph.h:442
GLuint const GLchar * name
Definition: glcorearb.h:786
Status Apply(Graph &graph, bool &modified, const logging::Logger &logger) const
GT_API const UT_StringHolder version
const InlinedHashMap< std::string, MLDataType > type_constraints_
#define ORT_RETURN_IF_ERROR(expr)
Definition: common.h:233
const std::string & Name() const noexcept
GraphTransformer(const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={}) noexcept