6 #ifndef SHARED_PROVIDER
8 #include <unordered_map>
9 #include <unordered_set>
13 #include "core/framework/data_transfer.h"
16 namespace onnxruntime {
18 struct ComputeCapability;
20 struct KernelCreateInfo;
29 #include "core/framework/allocator_utils.h"
34 #include "core/framework/tuning_context.h"
36 namespace onnxruntime {
45 using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
67 if (use_metadef_id_creator) {
68 metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
116 virtual std::vector<std::unique_ptr<ComputeCapability>>
254 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
265 std::vector<NodeComputeInfo>& node_compute_funcs);
317 if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
346 class ModelMetadefIdGenerator {
351 std::unordered_map<HashValue, HashValue> main_graph_hash_;
352 std::unordered_map<HashValue, int> model_metadef_id_;
355 std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
virtual FusionStyle GetFusionStyle() const
virtual common::Status Sync() const
virtual std::vector< std::unique_ptr< ComputeCapability > > GetCapability(const onnxruntime::GraphViewer &graph_viewer, const IKernelLookup &kernel_lookup) const
virtual bool IsGraphCaptured() const
virtual int GenerateMetaDefId(const onnxruntime::GraphViewer &graph_viewer, HashValue &model_hash) const
virtual ~IExecutionProvider()=default
virtual const KernelCreateInfo * LookUpKernel(const Node &node) const =0
GLsizei const GLchar *const * string
std::function< int(ComputeContext *, FunctionState *)> CreateFunctionStateFunc
virtual DataLayout GetPreferredLayout() const
const OrtDevice default_device_
virtual common::Status Compile(const std::vector< FusedNodeAndGraph > &fused_nodes_and_graphs, std::vector< NodeComputeInfo > &node_compute_funcs)
virtual std::shared_ptr< KernelRegistry > GetKernelRegistry() const
IExecutionProvider(const std::string &type, bool use_metadef_id_creator=false)
virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const
IExecutionProvider(const std::string &type, OrtDevice device, bool use_metadef_id_creator=false)
const std::string & Type() const
virtual std::unique_ptr< profiling::EpProfiler > GetProfiler()
const std::reference_wrapper< GraphViewer > filtered_graph
virtual common::Status OnRunStart()
const logging::Logger * GetLogger() const
DestroyFunctionStateFunc release_state_func
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry &, AllocatorMap &) const
absl::InlinedVector< T, N, Allocator > InlinedVector
virtual const void * GetExecutionHandle() const noexcept
std::map< OrtDevice, AllocatorPtr > AllocatorMap
virtual std::unique_ptr< onnxruntime::IDataTransfer > GetDataTransfer() const
virtual ProviderOptions GetProviderOptions() const
std::function< void(FunctionState)> DestroyFunctionStateFunc
virtual bool IsGraphCaptureEnabled() const
std::unordered_map< std::string, std::string > ProviderOptions
virtual int GetDeviceId() const
virtual common::Status ReplayGraph()
CreateFunctionStateFunc create_state_func
virtual ITuningContext * GetTuningContext() const
std::function< Status(FunctionState, const OrtApi *, OrtKernelContext *)> ComputeFunc
void SetLogger(const logging::Logger *logger)
virtual const InlinedVector< const Node * > GetEpContextNodes() const
const std::reference_wrapper< onnxruntime::Node > fused_node
virtual void GetCustomOpDomainList(std::vector< OrtCustomOpDomain * > &) const
virtual bool ConcurrentRunSupported() const
virtual std::vector< AllocatorPtr > CreatePreferredAllocators()
virtual common::Status OnSessionInitializationEnd()
virtual common::Status OnRunEnd(bool)