HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_kernel_info.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
11 #include "core/common/gsl.h"
12 
13 namespace onnxruntime {
14 
15 class DataTransferManager;
16 class FuncManager;
17 class OrtValueNameIdxMap;
18 struct AllocPlanPerValue;
19 
20 // A very light-weight class, which works as an aggregated
21 // view of all data needed for constructing a Kernel instance.
22 // NOTE: it does not own/hold any objects.
23 class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
24  public:
25  explicit OpKernelInfo(const onnxruntime::Node& node,
26  const KernelDef& kernel_def,
27  const IExecutionProvider& execution_provider,
28  const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
29  const OrtValueNameIdxMap& mlvalue_name_idx_map,
30  const DataTransferManager& data_transfer_mgr,
31  const AllocatorMap& allocators,
32  const ConfigOptions& config_options);
33 
34  OpKernelInfo(const OpKernelInfo& other);
35 
36  const OrtDevice GetDevice(OrtMemType mem_type) const;
37 
38  AllocatorPtr GetAllocator(OrtMemType mem_type) const;
39 
40  const KernelDef& GetKernelDef() const;
41 
42  const IExecutionProvider* GetExecutionProvider() const noexcept;
43 
44  const DataTransferManager& GetDataTransferManager() const noexcept;
45 
46  const onnxruntime::Node& node() const noexcept;
47 
48  bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
49 
50  bool TryGetConstantInput(int input_index, const OrtValue** constant_input_value) const;
51 
52  const AllocatorMap& GetAllocators() const { return allocators_; }
53 
54  const ConfigOptions& GetConfigOptions() const { return config_options_; }
55 
56  private:
57  ORT_DISALLOW_MOVE(OpKernelInfo);
58  ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
59 
60  const onnxruntime::Node& node_;
61  const KernelDef& kernel_def_;
62  // For non cpu/cuda case, this pointer should be set so that function kernel
63  // will delegate kernel compute call to <execution_provider> compute call.
64  gsl::not_null<const ::onnxruntime::IExecutionProvider*> execution_provider_;
65  const std::unordered_map<int, OrtValue>& constant_initialized_tensors_;
66  const OrtValueNameIdxMap& ort_value_name_idx_map_;
67  const DataTransferManager& data_transfer_mgr_;
68  ProtoHelperNodeContext proto_helper_context_;
69  const AllocatorMap& allocators_;
70  const ConfigOptions& config_options_;
71 };
72 
73 } // namespace onnxruntime
const IExecutionProvider * GetExecutionProvider() const noexcept
const DataTransferManager & GetDataTransferManager() const noexcept
const ConfigOptions & GetConfigOptions() const
OpKernelInfo(const onnxruntime::Node &node, const KernelDef &kernel_def, const IExecutionProvider &execution_provider, const std::unordered_map< int, OrtValue > &constant_initialized_tensors, const OrtValueNameIdxMap &mlvalue_name_idx_map, const DataTransferManager &data_transfer_mgr, const AllocatorMap &allocators, const ConfigOptions &config_options)
AllocatorPtr GetAllocator(OrtMemType mem_type) const
const onnxruntime::Node & node() const noexcept
const AllocatorMap & GetAllocators() const
std::map< OrtDevice, AllocatorPtr > AllocatorMap
Definition: allocator.h:262
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:261
const OrtDevice GetDevice(OrtMemType mem_type) const
bool TryGetConstantInput(int input_index, const Tensor **constant_input_value) const
const KernelDef & GetKernelDef() const