HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
execution_provider.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #ifndef SHARED_PROVIDER
7 #include <memory>
8 #include <unordered_map>
9 #include <unordered_set>
10 
12 #include "core/common/status.h"
13 #include "core/framework/data_transfer.h"
14 #include "core/framework/tensor.h"
15 
16 namespace onnxruntime {
17 class GraphViewer;
18 struct ComputeCapability;
19 class KernelRegistry;
20 struct KernelCreateInfo;
21 class Node;
22 } // namespace onnxruntime
23 #else
24 #include <memory>
25 #endif
26 
29 #include "core/framework/allocator_utils.h"
34 #include "core/framework/tuning_context.h"
35 
36 namespace onnxruntime {
37 
38 /**
39  Logical device representation.
40 */
41 
42 // if we are export the fused function to dll, the function will still in the same binary as onnxruntime
43 // use std function to give execution provider some chance to capture some state.
44 using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>;
45 using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
46 using DestroyFunctionStateFunc = std::function<void(FunctionState)>;
47 
52 };
53 
54 enum class DataLayout {
55  NCHW,
56  NHWC,
57  NCHWC,
58 };
59 
61  protected:
62  IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
63  : IExecutionProvider(type, OrtDevice(), use_metadef_id_creator) {}
64 
65  IExecutionProvider(const std::string& type, OrtDevice device, bool use_metadef_id_creator = false)
66  : default_device_(device), type_{type} {
67  if (use_metadef_id_creator) {
68  metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
69  }
70  }
71 
72  /*
73  default device for this ExecutionProvider
74  */
76 
77  public:
78  virtual ~IExecutionProvider() = default;
79 
80  /**
81  * Returns a data transfer object that implements methods to copy to and
82  * from this device.
83  * If no copy is required for the successful operation of this provider,
84  * return a nullptr.
85  */
86  virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const {
87  return nullptr;
88  }
89 
90  /**
91  * Interface for performing kernel lookup within kernel registries.
92  * Abstracts away lower-level details about kernel registries and kernel matching.
93  */
94  class IKernelLookup {
95  public:
96  /**
97  * Given `node`, try to find a matching kernel for this EP.
98  * The return value is non-null if and only if a matching kernel was found.
99  */
100  virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
101 
102  protected:
103  ~IKernelLookup() = default;
104  };
105 
106  /**
107  Get execution provider's capability for the specified <graph>.
108  Return a bunch of IndexedSubGraphs <*this> execution provider can run if
109  the sub-graph contains only one node or can fuse to run if the sub-graph
110  contains more than one node. The node indexes contained in sub-graphs may
111  have overlap, and it's ONNXRuntime's responsibility to do the partition
112  and decide whether a node will be assigned to <*this> execution provider.
113  For kernels registered in a kernel registry, `kernel_lookup` must be used
114  to find a matching kernel for this EP.
115  */
116  virtual std::vector<std::unique_ptr<ComputeCapability>>
117  GetCapability(const onnxruntime::GraphViewer& graph_viewer,
118  const IKernelLookup& kernel_lookup) const;
119 
120  /**
121  Get kernel registry per execution provider type.
122  The KernelRegistry share pointer returned is shared across sessions.
123 
124  NOTE: this approach was taken to achieve the following goals,
125  1. The execution provider type based kernel registry should be shared
126  across sessions.
127  Only one copy of this kind of kernel registry exists in ONNXRuntime
128  with multiple sessions/models.
129  2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
130  framework/session code.
131  3. onnxruntime (framework/session) does not depend on any specific
132  execution provider lib.
133  */
134  virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; }
135 
136  /**
137  Get the device id of current execution provider
138  */
139  virtual int GetDeviceId() const { return 0; };
140 
141  /**
142  Get execution provider's configuration options.
143  */
144  virtual ProviderOptions GetProviderOptions() const { return {}; }
145 
146  /**
147  Get provider specific custom op domain list.
148  Provider has the responsibility to release OrtCustomOpDomain instances it creates.
149 
150  NOTE: In the case of ONNX model having EP specific custom nodes and don't want to ask user to register those nodes,
151  EP might need to a way to register those custom nodes. This API is added for the purpose where EP can use it to
152  leverage ORT custom op to register those custom nodes with one or more custom op domains.
153 
154  For example, TensorRT EP uses this API to support TRT plugins where each custom op is mapped to TRT plugin and no
155  kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as
156  a role to help pass ONNX model validation.
157  */
158  virtual void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/) const {};
159 
160  /**
161  Returns an opaque handle whose exact type varies based on the provider
162  and is interpreted accordingly by the corresponding kernel implementation.
163  For Direct3D operator kernels, this may return an IUnknown supporting
164  QueryInterface to ID3D12GraphicsCommandList1.
165  */
166  virtual const void* GetExecutionHandle() const noexcept {
167  return nullptr;
168  }
169 
170  /**
171  @return type of the execution provider; should match that set in the node
172  through the SetExecutionProvider API. Example valid return values are:
173  kCpuExecutionProvider, kCudaExecutionProvider
174  */
175  const std::string& Type() const { return type_; }
176 
177  /**
178  Blocks until the device has completed all preceding requested tasks.
179  Currently this is primarily used by the IOBinding object to ensure that all
180  inputs have been copied to the device before execution begins.
181  */
182  virtual common::Status Sync() const { return Status::OK(); }
183 
184  /**
185  Called when InferenceSession::Run started
186  NOTE that due to async execution in provider, the actual work of previous
187  Run may not be finished on device This function should be regarded as the
188  point after which a new Run would start to submit commands from CPU
189  */
190  virtual common::Status OnRunStart() { return Status::OK(); }
191 
192  /**
193  Called when InferenceSession::Run ended
194  NOTE that due to async execution in provider, the actual work of this Run
195  may not be finished on device This function should be regarded as the point
196  that all commands of current Run has been submmited by CPU
197  */
198  virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
199 
200  /**
201  Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
202  the provider. Currently only CUDA execution provider supports it.
203  */
204  virtual bool IsGraphCaptureEnabled() const { return false; }
205 
206  /**
207  Indicate whether the graph has been captured and instantiated. Currently
208  only CUDA execution provider supports it.
209  */
210  virtual bool IsGraphCaptured() const { return false; }
211 
212  /**
213  Run the instantiated graph. Currently only CUDA execution provider supports
214  it.
215  */
216  virtual common::Status ReplayGraph() { return Status::OK(); }
217 
218  /**
219  Called when session creation is complete
220  This provides an opportunity for execution providers to optionally synchronize and
221  clean up its temporary resources to reduce memory and ensure the first run is fast.
222  */
224 
226  const std::reference_wrapper<onnxruntime::Node> fused_node;
227  // GraphViewer that filters the full graph to the nodes that are covered by 'node'
228  const std::reference_wrapper<GraphViewer> filtered_graph;
229  };
230 
231  // Fusion approach that is suppported
232  // !!! The "Function" FusionStyle is deprecated.
233  // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
234  enum class FusionStyle {
235  // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
236  // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
237  // A GraphProto can be produced from the Node body.
238  Function,
239 
240  // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
241  // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
242  // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
243  // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
244  // and can be supported in a minimal build.
246  };
247 
248  virtual FusionStyle GetFusionStyle() const {
249  // All the ORT build in EP has migrate to FilteredGraphViewer style.
250  // For newer EPs, please avoid use Function style as it is deprecated.
252  }
253 
254 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
255  /**
256  Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
257  return create_state/compute/release_state func for each node.
258  @remarks This is now the default interface when execution provider wants to compile nodes
259  for both minimal build and complete ort build.
260 
261  Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
262  as it is only valid for the duration of the call to Compile.
263  */
264  virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
265  std::vector<NodeComputeInfo>& node_compute_funcs);
266 
267 #endif
268 
269  void SetLogger(const logging::Logger* logger) {
270  logger_ = logger;
271  }
272 
273  const logging::Logger* GetLogger() const {
274  return logger_;
275  }
276 
277  /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
278  The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
279  @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
280  @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
281  This is created using the model path if available,
282  or the model input names and the output names from all nodes in the main graph.
283  @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches
284  compiled kernels, so the name must be unique and deterministic across models and sessions.
285  NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and
286  virtual, and ModelMetadefIdGenerator but be defined in the header as well.
287  */
288  virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;
289 
290  virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
291  return {};
292  }
293 
294  virtual DataLayout GetPreferredLayout() const {
295  // NCHW is the default ONNX standard data layout. So default to it.
296  // EPs which prefer a different layout should override to return their preferred layout.
297  return DataLayout::NCHW;
298  }
299 
300  virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}
301 
302  /** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
303  */
304  virtual bool ConcurrentRunSupported() const { return true; }
305 
306  /**
307  * Return the tuning context which holds all TunableOp state.
308  */
309  virtual ITuningContext* GetTuningContext() const {
310  return nullptr;
311  }
312 
313  /**
314  * Return the appropriate OrtDevice object given OrtMemType.
315  */
316  virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const {
317  if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
318  return OrtDevice(); // default return CPU device.
319  }
320  return default_device_;
321  };
322 
323  /**
324  * Create Preferred allocators for the current Execution Provider
325  * This function is a stateless function which creates new instances of Allocator, without storing them in EP.
326  */
327  virtual std::vector<AllocatorPtr> CreatePreferredAllocators() { return std::vector<AllocatorPtr>(); };
328 
329  /**
330  * Get the array of pointers for EPContext nodes
331  * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it.
332  * Default return an empty vector if not provided by the Execution Provider
333  */
336  }
337 
338  private:
339  const std::string type_;
340 
341  // It will be set when this object is registered to a session
342  const logging::Logger* logger_ = nullptr;
343 
344  // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
345  // multiple sessions.
346  class ModelMetadefIdGenerator {
347  public:
348  int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash);
349 
350  private:
351  std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash
352  std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model
353  };
354 
355  std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
356 };
357 } // namespace onnxruntime
virtual FusionStyle GetFusionStyle() const
virtual common::Status Sync() const
virtual std::vector< std::unique_ptr< ComputeCapability > > GetCapability(const onnxruntime::GraphViewer &graph_viewer, const IKernelLookup &kernel_lookup) const
virtual bool IsGraphCaptured() const
virtual int GenerateMetaDefId(const onnxruntime::GraphViewer &graph_viewer, HashValue &model_hash) const
virtual ~IExecutionProvider()=default
Definition: Node.h:52
virtual const KernelCreateInfo * LookUpKernel(const Node &node) const =0
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
std::function< int(ComputeContext *, FunctionState *)> CreateFunctionStateFunc
virtual DataLayout GetPreferredLayout() const
virtual common::Status Compile(const std::vector< FusedNodeAndGraph > &fused_nodes_and_graphs, std::vector< NodeComputeInfo > &node_compute_funcs)
virtual std::shared_ptr< KernelRegistry > GetKernelRegistry() const
IExecutionProvider(const std::string &type, bool use_metadef_id_creator=false)
virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const
IExecutionProvider(const std::string &type, OrtDevice device, bool use_metadef_id_creator=false)
const std::string & Type() const
virtual std::unique_ptr< profiling::EpProfiler > GetProfiler()
const std::reference_wrapper< GraphViewer > filtered_graph
virtual common::Status OnRunStart()
const logging::Logger * GetLogger() const
DestroyFunctionStateFunc release_state_func
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry &, AllocatorMap &) const
absl::InlinedVector< T, N, Allocator > InlinedVector
virtual const void * GetExecutionHandle() const noexcept
std::map< OrtDevice, AllocatorPtr > AllocatorMap
Definition: allocator.h:262
virtual std::unique_ptr< onnxruntime::IDataTransfer > GetDataTransfer() const
virtual ProviderOptions GetProviderOptions() const
std::function< void(FunctionState)> DestroyFunctionStateFunc
virtual bool IsGraphCaptureEnabled() const
std::unordered_map< std::string, std::string > ProviderOptions
virtual common::Status ReplayGraph()
CreateFunctionStateFunc create_state_func
virtual ITuningContext * GetTuningContext() const
std::function< Status(FunctionState, const OrtApi *, OrtKernelContext *)> ComputeFunc
void SetLogger(const logging::Logger *logger)
virtual const InlinedVector< const Node * > GetEpContextNodes() const
const std::reference_wrapper< onnxruntime::Node > fused_node
virtual void GetCustomOpDomainList(std::vector< OrtCustomOpDomain * > &) const
virtual bool ConcurrentRunSupported() const
uint64_t HashValue
Definition: basic_types.h:11
type
Definition: core.h:1059
virtual std::vector< AllocatorPtr > CreatePreferredAllocators()
virtual common::Status OnSessionInitializationEnd()
virtual common::Status OnRunEnd(bool)