HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
rewrite_rule.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include "core/common/common.h"
8 
9 namespace onnxruntime {
10 
11 /**
12 @class RewriteRule
13 
14 The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a
15 computation graph. It can be used to represent, for example, the elimination of operators that serve as
16 no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation
17 of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer,
18 a rewrite rule is a more local transformation that is triggered on a particular node of the graph.
19 
20 Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule
21 to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented:
22 - SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first,
23  because those will lead to discarding fast rules that cannot be applied on a node.
24 - Apply is the actual body of the rule that will be executed if SatisfyCondition returns true for a particular
25  node. Note that additional, more complex checks can be included in the Apply if putting them in the
26  SatisfyCondition would lead to duplicate work (e.g., when we make a check on a Node attribute but we need
27  that attribute to execute the rule too).
28 In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added
29 in the Apply.
30 
31 In order to avoid evaluating the SatisfyCondition for each rule and each node of the graph, each rewrite rule
32 should specify the target op types for which a rule will be evaluated, by overriding the TargetOpTypes() function.
33 If the op type of a node is not included in the target op types of a rule, that rule would not be considered at all.
34 If the list of op types is left empty, that rule will be triggered for every op type.
35 */
36 class RewriteRule {
37  public:
38  /**
39  @class RewriteRuleEffect
40 
41  Class used to indicate the effect of rule application on a graph's node.
42  */
43  enum class RewriteRuleEffect : uint8_t {
44  kNone, // The rewrite rule has not modified the graph.
45  kUpdatedCurrentNode, // The rewrite rule updated (but did not remove) the node on which it was triggered.
46  kRemovedCurrentNode, // The rewrite rule removed the node on which it was triggered.
47  kModifiedRestOfGraph // The rewrite rule modified nodes other than the one it was triggered on.
48  };
49 
50  RewriteRule(const std::string& name) : name_(name) {}
51 
52  virtual ~RewriteRule() = default;
53 
54  /** Gets the name of this rewrite rule. */
55  const std::string& Name() const noexcept {
56  return name_;
57  }
58 
59  /** Returns the node op types for which this rule will be triggered. If the op type of a node is not included in the
60  target op types of a rule, that rule would not be considered at all. Returning an empty list indicates that we
61  will attempt to trigger the rule for every op type. */
62  virtual std::vector<std::string> TargetOpTypes() const noexcept = 0;
63 
64  /** Checks if the condition of the rule is satisfied, and if so applies the body of the rule.
65  @param[in] graph The Graph.
66  @param[in] node The Node to apply the rewrite to.
67  @param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
68  @returns Status indicating success or providing error information */
69  common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
70  return SatisfyCondition(graph, node, logger) ? Apply(graph, node, rule_effect, logger) : Status::OK();
71  }
72 
73  private:
74  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RewriteRule);
75 
76  const std::string name_;
77 
78  /** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be
79  evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
80  on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
81  of the nodes. */
82  virtual bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const = 0;
83 
84  /** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
85  The return-value of node may be different from the input-value due to rewriting.
86  The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
87  virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const = 0;
88 };
89 } // namespace onnxruntime
virtual ~RewriteRule()=default
RewriteRule(const std::string &name)
Definition: rewrite_rule.h:50
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
common::Status CheckConditionAndApply(Graph &graph, Node &node, RewriteRuleEffect &rule_effect, const logging::Logger &logger) const
Definition: rewrite_rule.h:69
GLuint const GLchar * name
Definition: glcorearb.h:786
virtual std::vector< std::string > TargetOpTypes() const noexcept=0
const std::string & Name() const noexcept
Definition: rewrite_rule.h:55