HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
cuda_context.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // This header is to expose a context for cuda custom ops.
5 // By the context, a custom cuda operator could fetch existing resources,
6 // such as cuda stream and cudnn handle, for reusing.
7 
8 // For concrete usage, pls find page here:
9 // https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm
10 
11 #pragma once
12 
13 #define ORT_CUDA_CTX
14 
15 #include "cuda_resource.h"
17 #include <cuda.h>
18 #include <cuda_runtime.h>
19 #include <cublas_v2.h>
20 #include <cudnn.h>
21 
22 namespace Ort {
23 
24 namespace Custom {
25 
26 struct CudaContext : public CustomOpContext {
28  cudnnHandle_t cudnn_handle = {};
29  cublasHandle_t cublas_handle = {};
30  OrtAllocator* deferred_cpu_allocator = {};
31  // below are cuda ep options
32  int16_t device_id = 0;
33  int32_t arena_extend_strategy = 0;
38  bool prefer_nhwc = false;
39 
40  void Init(const OrtKernelContext& kernel_ctx) {
41  cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
42  cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t);
43  cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t);
44  deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t);
45 
46  device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t);
47  arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t);
48  cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
50 
51  cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
53  prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
54  }
55 
56  template <typename T>
57  T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
58  if (sizeof(T) > sizeof(void*)) {
59  ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
60  }
61  const auto& ort_api = Ort::GetApi();
62  void* resource = {};
63  OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);
64  if (status) {
65  ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
66  }
67  T t = {};
68  memcpy(&t, &resource, sizeof(T));
69  return t;
70  }
71 
72  void* AllocDeferredCpuMem(size_t size) const {
73  if (0 == size) {
74  return {};
75  }
76  const auto& ort_api = Ort::GetApi();
77  void* mem = {};
78  auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
79  if (status) {
80  ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
81  }
82  return mem;
83  }
84 
85  void FreeDeferredCpuMem(void* mem) const {
86  if (mem) {
87  const auto& ort_api = Ort::GetApi();
88  auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
89  if (status) {
90  ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
91  }
92  }
93  }
94 };
95 
96 } // namespace Custom
97 } // namespace Ort
#define ORT_CUDA_RESOUCE_VERSION
Definition: cuda_resource.h:6
auto to_string(const T &value) -> std::string
Definition: format.h:2597
cudaStream_t cuda_stream
Definition: cuda_context.h:27
OrtAllocator * deferred_cpu_allocator
Definition: cuda_context.h:30
void FreeDeferredCpuMem(void *mem) const
Definition: cuda_context.h:85
CudaResource
Definition: cuda_resource.h:8
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
void Init(const OrtKernelContext &kernel_ctx)
Definition: cuda_context.h:40
void * AllocDeferredCpuMem(size_t size) const
Definition: cuda_context.h:72
T FetchResource(const OrtKernelContext &kernel_ctx, CudaResource resource_type)
Definition: cuda_context.h:57
cudnnHandle_t cudnn_handle
Definition: cuda_context.h:28
struct CUstream_st * cudaStream_t
Definition: oidn.h:24
GLdouble t
Definition: glad.h:2397
GLsizeiptr size
Definition: glcorearb.h:664
cublasHandle_t cublas_handle
Definition: cuda_context.h:29
#define ORT_CXX_API_THROW(string, code)