HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
rocm_context.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #define ORT_ROCM_CTX
5 
6 #include "rocm_resource.h"
8 #include <hip/hip_runtime.h>
9 #include <miopen/miopen.h>
10 #include <rocblas/rocblas.h>
11 
12 namespace Ort {
13 
14 namespace Custom {
15 
16 struct RocmContext : public CustomOpContext {
18  miopenHandle_t miopen_handle = {};
19  rocblas_handle rblas_handle = {};
20 
21  void Init(const OrtKernelContext& kernel_ctx) {
22  const auto& ort_api = Ort::GetApi();
23  void* resource = {};
24  OrtStatus* status = nullptr;
25 
26  status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::hip_stream_t, &resource);
27  if (status) {
28  ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
29  }
30  hip_stream = reinterpret_cast<hipStream_t>(resource);
31 
32  resource = {};
33  status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::miopen_handle_t, &resource);
34  if (status) {
35  ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
36  }
37  miopen_handle = reinterpret_cast<miopenHandle_t>(resource);
38 
39  resource = {};
40  status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::rocblas_handle_t, &resource);
41  if (status) {
42  ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
43  }
44  rblas_handle = reinterpret_cast<rocblas_handle>(resource);
45  }
46 };
47 
48 } // namespace Custom
49 } // namespace Ort
#define ORT_ROCM_RESOUCE_VERSION
Definition: rocm_resource.h:6
void Init(const OrtKernelContext &kernel_ctx)
Definition: rocm_context.h:21
miopenHandle_t miopen_handle
Definition: rocm_context.h:18
struct ihipStream_t * hipStream_t
Definition: oidn.h:25
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
rocblas_handle rblas_handle
Definition: rocm_context.h:19
#define ORT_CXX_API_THROW(string, code)