HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime::RuleBasedGraphTransformer Class Reference

#include <rule_based_graph_transformer.h>

+ Inheritance diagram for onnxruntime::RuleBasedGraphTransformer:

Public Member Functions

 RuleBasedGraphTransformer (const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={})
 
Status Register (std::unique_ptr< RewriteRule > rule)
 
const InlinedVector
< std::reference_wrapper
< const RewriteRule > > * 
GetRewriteRulesForOpType (const std::string &op_type) const
 
const InlinedVector
< std::reference_wrapper
< const RewriteRule > > * 
GetAnyOpRewriteRules () const
 
size_t RulesCount () const
 
- Public Member Functions inherited from onnxruntime::GraphTransformer
 GraphTransformer (const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={}) noexcept
 
virtual ~GraphTransformer ()=default
 
const std::stringName () const noexcept
 
const InlinedHashSet
< std::string_view > & 
GetCompatibleExecutionProviders () const noexcept
 
Status Apply (Graph &graph, bool &modified, const logging::Logger &logger) const
 
virtual bool ShouldOnlyApplyOnce () const
 

Protected Member Functions

common::Status ApplyRulesOnNode (Graph &graph, Node &node, gsl::span< const std::reference_wrapper< const RewriteRule >> rules, RewriteRule::RewriteRuleEffect &rule_effect, const logging::Logger &logger) const
 
- Protected Member Functions inherited from onnxruntime::GraphTransformer
Status Recurse (Node &node, bool &modified, int graph_level, const logging::Logger &logger) const
 

Detailed Description

Rule-based graph transformer that provides an API to register rewrite rules and an API to apply all applicable rules to a Graph.

Represents an IGraphTransformer determined by a set of rewrite rules. The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy. Several rewriting-strategies are possible when traversing the graph and applying rewrite rules, each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.

: Is a bottom-up traversal more efficient? : Is it worth adding the max number of passes a rule should be applied for? : We need to define a contract about whether a rewrite rule is allowed to leave the graph in an inconsistent state (this will determine when and where we will be calling Graph::resolve().

Definition at line 30 of file rule_based_graph_transformer.h.

Constructor & Destructor Documentation

onnxruntime::RuleBasedGraphTransformer::RuleBasedGraphTransformer ( const std::string name,
const InlinedHashSet< std::string_view > &  compatible_execution_providers = {} 
)
inline

Definition at line 32 of file rule_based_graph_transformer.h.

Member Function Documentation

common::Status onnxruntime::RuleBasedGraphTransformer::ApplyRulesOnNode ( Graph graph,
Node node,
gsl::span< const std::reference_wrapper< const RewriteRule >>  rules,
RewriteRule::RewriteRuleEffect rule_effect,
const logging::Logger logger 
) const
protected

Applies the given set of rewrite rules on the Node of this Graph.

Parameters
[in]graphThe Graph.
[in]nodeThe Node to apply the rules to.
[in]rulesThe vector of RewriteRules that will be applied to the Node.
[out]rule_effectEnum that indicates whether and how the graph was modified as a result of applying rules on this node.
Returns
Status indicating success or providing error information.
const InlinedVector<std::reference_wrapper<const RewriteRule> >* onnxruntime::RuleBasedGraphTransformer::GetAnyOpRewriteRules ( ) const
inline

Gets the rewrite rules that are evaluated on all nodes irrespective of their op type.

Returns
a pointer to the vector containing all such rewrite rules or nullptr if no such rule.

Definition at line 49 of file rule_based_graph_transformer.h.

const InlinedVector<std::reference_wrapper<const RewriteRule> >* onnxruntime::RuleBasedGraphTransformer::GetRewriteRulesForOpType ( const std::string op_type) const
inline

Gets the list of registered rewrite rules that will be triggered on nodes with the given op type by this rule-based transformer.

Returns
a pointer to the vector containing all the registered rewrite rules.

Definition at line 42 of file rule_based_graph_transformer.h.

Status onnxruntime::RuleBasedGraphTransformer::Register ( std::unique_ptr< RewriteRule rule)

Registers a rewrite rule in this transformer.

size_t onnxruntime::RuleBasedGraphTransformer::RulesCount ( ) const

Returns the total number of rules that are registered in this transformer.


The documentation for this class was generated from the following file: