HDK
|
#include <rule_based_graph_transformer.h>
Public Member Functions | |
RuleBasedGraphTransformer (const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={}) | |
Status | Register (std::unique_ptr< RewriteRule > rule) |
const InlinedVector < std::reference_wrapper < const RewriteRule > > * | GetRewriteRulesForOpType (const std::string &op_type) const |
const InlinedVector < std::reference_wrapper < const RewriteRule > > * | GetAnyOpRewriteRules () const |
size_t | RulesCount () const |
Public Member Functions inherited from onnxruntime::GraphTransformer | |
GraphTransformer (const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={}) noexcept | |
virtual | ~GraphTransformer ()=default |
const std::string & | Name () const noexcept |
const InlinedHashSet < std::string_view > & | GetCompatibleExecutionProviders () const noexcept |
Status | Apply (Graph &graph, bool &modified, const logging::Logger &logger) const |
virtual bool | ShouldOnlyApplyOnce () const |
Protected Member Functions | |
common::Status | ApplyRulesOnNode (Graph &graph, Node &node, gsl::span< const std::reference_wrapper< const RewriteRule >> rules, RewriteRule::RewriteRuleEffect &rule_effect, const logging::Logger &logger) const |
Protected Member Functions inherited from onnxruntime::GraphTransformer | |
Status | Recurse (Node &node, bool &modified, int graph_level, const logging::Logger &logger) const |
Rule-based graph transformer that provides an API to register rewrite rules and an API to apply all applicable rules to a Graph.
Represents an IGraphTransformer determined by a set of rewrite rules. The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy. Several rewriting-strategies are possible when traversing the graph and applying rewrite rules, each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
: Is a bottom-up traversal more efficient? : Is it worth adding the max number of passes a rule should be applied for? : We need to define a contract about whether a rewrite rule is allowed to leave the graph in an inconsistent state (this will determine when and where we will be calling Graph::resolve().
Definition at line 30 of file rule_based_graph_transformer.h.
|
inline |
Definition at line 32 of file rule_based_graph_transformer.h.
|
protected |
Applies the given set of rewrite rules on the Node of this Graph.
[in] | graph | The Graph. |
[in] | node | The Node to apply the rules to. |
[in] | rules | The vector of RewriteRules that will be applied to the Node. |
[out] | rule_effect | Enum that indicates whether and how the graph was modified as a result of applying rules on this node. |
|
inline |
Gets the rewrite rules that are evaluated on all nodes irrespective of their op type.
Definition at line 49 of file rule_based_graph_transformer.h.
|
inline |
Gets the list of registered rewrite rules that will be triggered on nodes with the given op type by this rule-based transformer.
Definition at line 42 of file rule_based_graph_transformer.h.
Status onnxruntime::RuleBasedGraphTransformer::Register | ( | std::unique_ptr< RewriteRule > | rule | ) |
Registers a rewrite rule in this transformer.
size_t onnxruntime::RuleBasedGraphTransformer::RulesCount | ( | ) | const |
Returns the total number of rules that are registered in this transformer.