HDK
Main Page
Related Pages
Modules
Namespaces
Classes
Files
Examples
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
rocm_context.h
Go to the documentation of this file.
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3
4
#define ORT_ROCM_CTX
5
6
#include "
rocm_resource.h
"
7
#include "
core/providers/custom_op_context.h
"
8
#include <hip/hip_runtime.h>
9
#include <miopen/miopen.h>
10
#include <rocblas/rocblas.h>
11
12
namespace
Ort {
13
14
namespace
Custom {
15
16
struct
RocmContext
:
public
CustomOpContext
{
17
hipStream_t
hip_stream
= {};
18
miopenHandle_t
miopen_handle
= {};
19
rocblas_handle
rblas_handle
= {};
20
21
void
Init
(
const
OrtKernelContext& kernel_ctx) {
22
const
auto
& ort_api =
Ort::GetApi
();
23
void
* resource = {};
24
OrtStatus* status =
nullptr
;
25
26
status = ort_api.KernelContext_GetResource(&kernel_ctx,
ORT_ROCM_RESOUCE_VERSION
,
RocmResource::hip_stream_t
, &resource);
27
if
(status) {
28
ORT_CXX_API_THROW
(
"failed to fetch hip stream"
, OrtErrorCode::ORT_RUNTIME_EXCEPTION);
29
}
30
hip_stream
=
reinterpret_cast<
hipStream_t
>
(resource);
31
32
resource = {};
33
status = ort_api.KernelContext_GetResource(&kernel_ctx,
ORT_ROCM_RESOUCE_VERSION
,
RocmResource::miopen_handle_t
, &resource);
34
if
(status) {
35
ORT_CXX_API_THROW
(
"failed to fetch miopen handle"
, OrtErrorCode::ORT_RUNTIME_EXCEPTION);
36
}
37
miopen_handle
=
reinterpret_cast<
miopenHandle_t
>
(resource);
38
39
resource = {};
40
status = ort_api.KernelContext_GetResource(&kernel_ctx,
ORT_ROCM_RESOUCE_VERSION
,
RocmResource::rocblas_handle_t
, &resource);
41
if
(status) {
42
ORT_CXX_API_THROW
(
"failed to fetch rocblas handle"
, OrtErrorCode::ORT_RUNTIME_EXCEPTION);
43
}
44
rblas_handle
=
reinterpret_cast<
rocblas_handle
>
(resource);
45
}
46
};
47
48
}
// namespace Custom
49
}
// namespace Ort
ORT_ROCM_RESOUCE_VERSION
#define ORT_ROCM_RESOUCE_VERSION
Definition:
rocm_resource.h:6
Ort::Custom::RocmContext
Definition:
rocm_context.h:16
Ort::Custom::RocmContext::Init
void Init(const OrtKernelContext &kernel_ctx)
Definition:
rocm_context.h:21
hip_stream_t
Definition:
rocm_resource.h:9
Ort::Custom::RocmContext::miopen_handle
miopenHandle_t miopen_handle
Definition:
rocm_context.h:18
hipStream_t
struct ihipStream_t * hipStream_t
Definition:
oidn.h:25
Ort::GetApi
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition:
onnxruntime_cxx_api.h:124
miopen_handle_t
Definition:
rocm_resource.h:10
custom_op_context.h
Ort::Custom::RocmContext::rblas_handle
rocblas_handle rblas_handle
Definition:
rocm_context.h:19
rocm_resource.h
CustomOpContext
Definition:
custom_op_context.h:7
rocblas_handle_t
Definition:
rocm_resource.h:11
ORT_CXX_API_THROW
#define ORT_CXX_API_THROW(string, code)
Definition:
onnxruntime_cxx_api.h:77
Ort::Custom::RocmContext::hip_stream
hipStream_t hip_stream
Definition:
rocm_context.h:17
onnxruntime
core
providers
rocm
rocm_context.h
Generated on Sun Nov 17 2024 03:01:33 for HDK by
1.8.6