HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ort_kernel_invoker.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string>
7 #include <vector>
8 
9 #include "core/common/common.h"
11 #include "core/framework/tensor.h"
13 #include "core/graph/constants.h"
15 #include "core/graph/basic_types.h"
16 #include "core/graph/model.h"
17 
18 namespace onnxruntime {
19 #ifdef __GNUC__
20 #pragma GCC diagnostic push
21 #endif
22 
23 class ORTInvoker {
24  public:
25  ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
26  const logging::Logger& logger,
27  const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) : execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
28  if (!execution_provider_) {
29  ORT_THROW("Execution provider is nullptr");
30  }
31  }
32 
34  return *execution_provider_;
35  }
36 
37  common::Status Invoke(const std::string& op_name,
38  // optional inputs / outputs?
39  const std::vector<OrtValue>& inputs,
40  std::vector<OrtValue>& outputs,
41  const NodeAttributes* attributes,
42  const std::string& domain = kOnnxDomain,
43  const int version = -1);
44 
45  private:
46  std::shared_ptr<IExecutionProvider> execution_provider_;
47  const logging::Logger& logger_;
48  // custom ops for current execution provider
49  // we need the op schema to resolve the output type during invoke
50  const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries_;
51 };
52 
53 #ifdef __GNUC__
54 #pragma GCC diagnostic pop
55 #endif
56 } // namespace onnxruntime
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
ORTInvoker(std::shared_ptr< IExecutionProvider > execution_provider, const logging::Logger &logger, const IOnnxRuntimeOpSchemaRegistryList &custom_op_registries)
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
Definition: basic_types.h:42
common::Status Invoke(const std::string &op_name, const std::vector< OrtValue > &inputs, std::vector< OrtValue > &outputs, const NodeAttributes *attributes, const std::string &domain=kOnnxDomain, const int version=-1)
constexpr const char * kOnnxDomain
Definition: constants.h:14
GT_API const UT_StringHolder version
#define ORT_THROW(...)
Definition: common.h:162
IExecutionProvider & GetCurrentExecutionProvider()