HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_lite_custom_op.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Summary
5 // The header has APIs to save custom op authors the trouble of defining schemas,
6 // which will be inferred by functions' signature, as long as their argument list has types supported here.
7 // Input could be:
8 // 1. Tensor of onnx data types.
9 // 2. Span of onnx data types.
10 // 3. Scalar of onnx data types.
11 // A input could be optional if indicated as std::optional<...>.
12 // For an output, it must be a tensor of onnx data types.
13 // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
14 // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
15 // Note - all APIs in this header are ABI.
16 
17 #pragma once
18 #include "onnxruntime_cxx_api.h"
19 #include <optional>
20 #include <numeric>
21 #include <functional>
22 #include <unordered_set>
23 
24 namespace Ort {
25 namespace Custom {
26 
27 class ArgBase {
28  public:
29  ArgBase(OrtKernelContext* ctx,
30  size_t indice,
31  bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
32  virtual ~ArgBase(){};
33 
34  protected:
36  size_t indice_;
37  bool is_input_;
38 };
39 
40 using ArgPtr = std::unique_ptr<Custom::ArgBase>;
41 using ArgPtrs = std::vector<ArgPtr>;
42 
43 class TensorBase : public ArgBase {
44  public:
45  TensorBase(OrtKernelContext* ctx,
46  size_t indice,
47  bool is_input) : ArgBase(ctx, indice, is_input) {}
48 
49  operator bool() const {
50  return shape_.has_value();
51  }
52 
53  const std::vector<int64_t>& Shape() const {
54  if (!shape_.has_value()) {
55  ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
56  }
57  return shape_.value();
58  }
59 
60  ONNXTensorElementDataType Type() const {
61  return type_;
62  }
63 
64  int64_t NumberOfElement() const {
65  if (shape_.has_value()) {
66  return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
67  } else {
68  return 0;
69  }
70  }
71 
73  if (shape_.has_value()) {
74  std::string shape_str;
75  for (const auto& dim : *shape_) {
76  shape_str.append(std::to_string(dim));
77  shape_str.append(", ");
78  }
79  return shape_str;
80  } else {
81  return "empty";
82  }
83  }
84 
85  bool IsCpuTensor() const {
86  return strcmp("Cpu", mem_type_) == 0;
87  }
88 
89  virtual const void* DataRaw() const = 0;
90  virtual size_t SizeInBytes() const = 0;
91 
92  protected:
93  std::optional<std::vector<int64_t>> shape_;
94  ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
95  const char* mem_type_ = "Cpu";
96 };
97 
98 template <typename T>
99 struct Span {
100  const T* data_ = {};
101  size_t size_ = {};
102  void Assign(const T* data, size_t size) {
103  data_ = data;
104  size_ = size;
105  }
106  size_t size() const { return size_; }
107  T operator[](size_t indice) const {
108  return data_[indice];
109  }
110  const T* data() const { return data_; }
111 };
112 
113 template <typename T>
114 class Tensor : public TensorBase {
115  public:
117  Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
118  if (is_input_) {
119  if (indice >= ctx_.GetInputCount()) {
120  ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
121  }
122  const_value_ = ctx_.GetInput(indice);
123  auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
124  shape_ = type_shape_info.GetShape();
125  }
126  }
127  const TT* Data() const {
128  return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
129  }
130  TT* Allocate(const std::vector<int64_t>& shape) {
131  shape_ = shape;
132  if (!data_) {
133  shape_ = shape;
134  data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
135  }
136  return data_;
137  }
138  static TT GetT() { return (TT)0; }
139  const Span<T>& AsSpan() {
140  if (!shape_.has_value() || shape_->size() != 1) {
141  ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
142  OrtErrorCode::ORT_RUNTIME_EXCEPTION);
143  }
144  span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
145  return span_;
146  }
147  const T& AsScalar() {
148  if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
149  ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
150  OrtErrorCode::ORT_RUNTIME_EXCEPTION);
151  }
152  return *Data();
153  }
154  const void* DataRaw() const override {
155  return reinterpret_cast<const void*>(Data());
156  }
157 
158  size_t SizeInBytes() const override {
159  return sizeof(TT) * static_cast<size_t>(NumberOfElement());
160  }
161 
162  private:
163  ConstValue const_value_; // for input
164  TT* data_{}; // for output
165  Span<T> span_;
166 };
167 
168 template <>
169 class Tensor<std::string> : public TensorBase {
170  public:
171  using strings = std::vector<std::string>;
172 
173  Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
174  if (is_input_) {
175  if (indice >= ctx_.GetInputCount()) {
176  ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
177  }
178  auto const_value = ctx_.GetInput(indice);
179  auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
180  shape_ = type_shape_info.GetShape();
181  auto num_chars = const_value.GetStringTensorDataLength();
182  // note - there will be copy ...
183  auto num_strings = static_cast<size_t>(NumberOfElement());
184  if (num_strings) {
185  std::vector<char> chars(num_chars + 1, '\0');
186  std::vector<size_t> offsets(num_strings);
187  const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
188  auto upper_bound = num_strings - 1;
189  input_strings_.resize(num_strings);
190  for (size_t i = upper_bound;; --i) {
191  if (i < upper_bound) {
192  chars[offsets[i + 1]] = '\0';
193  }
194  input_strings_[i] = chars.data() + offsets[i];
195  if (0 == i) {
196  break;
197  }
198  }
199  }
200  }
201  }
202  const strings& Data() const {
203  return input_strings_;
204  }
205  const void* DataRaw() const override {
206  if (input_strings_.size() != 1) {
207  ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
208  }
209  return reinterpret_cast<const void*>(input_strings_[0].c_str());
210  }
211  size_t SizeInBytes() const override {
212  if (input_strings_.size() != 1) {
213  ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
214  }
215  return input_strings_[0].size();
216  }
217  void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
218  shape_ = dims;
219  std::vector<const char*> raw;
220  for (const auto& s : ss) {
221  raw.push_back(s.data());
222  }
223  auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
224  // note - there will be copy ...
225  output.FillStringTensor(raw.data(), raw.size());
226  }
228  ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
229  }
231  if (input_strings_.size() != 1) {
232  ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
233  OrtErrorCode::ORT_RUNTIME_EXCEPTION);
234  }
235  return input_strings_[0];
236  }
237 
238  private:
239  std::vector<std::string> input_strings_; // for input
240 };
241 
242 template <>
243 class Tensor<std::string_view> : public TensorBase {
244  public:
245  using strings = std::vector<std::string>;
246  using string_views = std::vector<std::string_view>;
247 
248  Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
249  if (is_input_) {
250  if (indice >= ctx_.GetInputCount()) {
251  ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
252  }
253  auto const_value = ctx_.GetInput(indice);
254  auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
255  shape_ = type_shape_info.GetShape();
256  auto num_chars = const_value.GetStringTensorDataLength();
257  chars_.resize(num_chars + 1, '\0');
258  auto num_strings = static_cast<size_t>(NumberOfElement());
259  if (num_strings) {
260  std::vector<size_t> offsets(num_strings);
261  const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
262  offsets.push_back(num_chars);
263  for (size_t i = 0; i < num_strings; ++i) {
264  input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
265  }
266  }
267  }
268  }
269  const string_views& Data() const {
270  return input_string_views_;
271  }
272  const void* DataRaw() const override {
273  if (input_string_views_.size() != 1) {
274  ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
275  }
276  return reinterpret_cast<const void*>(input_string_views_[0].data());
277  }
278  size_t SizeInBytes() const override {
279  if (input_string_views_.size() != 1) {
280  ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
281  }
282  return input_string_views_[0].size();
283  }
284  void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
285  shape_ = dims;
286  std::vector<const char*> raw;
287  for (const auto& s : ss) {
288  raw.push_back(s.data());
289  }
290  auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
291  // note - there will be copy ...
292  output.FillStringTensor(raw.data(), raw.size());
293  }
295  ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
296  }
298  if (input_string_views_.size() != 1) {
299  ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
300  OrtErrorCode::ORT_RUNTIME_EXCEPTION);
301  }
302  return input_string_views_[0];
303  }
304 
305  private:
306  std::vector<char> chars_; // for input
307  std::vector<std::string_view> input_string_views_; // for input
308 };
309 
310 using TensorPtr = std::unique_ptr<Custom::TensorBase>;
311 using TensorPtrs = std::vector<TensorPtr>;
312 
313 struct TensorArray : public ArgBase {
314  TensorArray(OrtKernelContext* ctx,
315  size_t start_indice,
316  bool is_input) : ArgBase(ctx,
317  start_indice,
318  is_input) {
319  if (is_input) {
320  auto input_count = ctx_.GetInputCount();
321  for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
322  auto const_value = ctx_.GetInput(start_indice);
323  auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
324  auto type = type_shape_info.GetElementType();
325  TensorPtr tensor;
326  switch (type) {
327  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
328  tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
329  break;
330  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
331  tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
332  break;
333  case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
334  tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
335  break;
336  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
337  tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
338  break;
339  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
340  tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
341  break;
342  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
343  tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
344  break;
345  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
346  tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
347  break;
348  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
349  tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
350  break;
351  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
352  tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
353  break;
354  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
355  tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
356  break;
357  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
358  tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
359  break;
360  case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
361  tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
362  break;
363  default:
364  ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
365  break;
366  }
367  tensors_.emplace_back(tensor.release());
368  } // for
369  }
370  }
371  template <typename T>
372  T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
373  // ith_output is the indice of output relative to the tensor array
374  // indice_ + ith_output is the indice relative to context
375  auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
376  auto raw_output = tensor.get()->Allocate(shape);
377  tensors_.emplace_back(tensor.release());
378  return raw_output;
379  }
381  // ith_output is the indice of output relative to the tensor array
382  // indice_ + ith_output is the indice relative to context
383  auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
384  Tensor<std::string>& output = *tensor;
385  tensors_.emplace_back(tensor.release());
386  return output;
387  }
388  size_t Size() const {
389  return tensors_.size();
390  }
391  const TensorPtr& operator[](size_t ith_input) const {
392  // ith_input is the indice of output relative to the tensor array
393  return tensors_.at(ith_input);
394  }
395 
396  private:
397  TensorPtrs tensors_;
398 };
399 
401 
402 /*
403 Note:
404 OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
405 The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
406 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy.
407 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
408  hence memory could still be recycled properly.
409 Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
410 */
411 struct OrtLiteCustomOp : public OrtCustomOp {
412  using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
413  using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
414 
415  // CreateTuple
416  template <size_t ith_input, size_t ith_output, typename... Ts>
417  static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
418  CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
419  return std::make_tuple();
420  }
421 
422  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
423  static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
424  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
425  std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
426  auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
427  return std::tuple_cat(current, next);
428  }
429 
430  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
431  static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
432  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
433  std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
434  auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
435  return std::tuple_cat(current, next);
436  }
437 
438 #ifdef ORT_CUDA_CTX
439  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
440  static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
441  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
442  thread_local CudaContext cuda_context;
443  cuda_context.Init(*context);
444  std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
445  auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
446  return std::tuple_cat(current, next);
447  }
448 #endif
449 
450 #ifdef ORT_ROCM_CTX
451  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
452  static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
453  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
454  thread_local RocmContext rocm_context;
455  rocm_context.Init(*context);
456  std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
457  auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
458  return std::tuple_cat(current, next);
459  }
460 #endif
461 
462  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
463  static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
464  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
465  args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
466  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
467  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
468  return std::tuple_cat(current, next);
469  }
470 
471  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
472  static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
473  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
474  args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
475  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
476  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
477  return std::tuple_cat(current, next);
478  }
479 
480  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
481  static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
482  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
483  args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
484  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
485  auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
486  return std::tuple_cat(current, next);
487  }
488 
489  template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
490  static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
491  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
492  args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
493  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
494  auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
495  return std::tuple_cat(current, next);
496  }
497 
498 #define CREATE_TUPLE_INPUT(data_type) \
499  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
500  static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
501  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
502  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
503  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
504  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
505  return std::tuple_cat(current, next); \
506  } \
507  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
508  static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
509  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
510  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
511  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
512  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
513  return std::tuple_cat(current, next); \
514  } \
515  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
516  static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
517  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
518  if (ith_input < num_input) { \
519  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
520  std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
521  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
522  return std::tuple_cat(current, next); \
523  } else { \
524  std::tuple<T> current = std::tuple<T>{}; \
525  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
526  return std::tuple_cat(current, next); \
527  } \
528  } \
529  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
530  static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
531  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
532  if ("CPUExecutionProvider" != ep) { \
533  ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
534  } \
535  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
536  std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
537  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
538  return std::tuple_cat(current, next); \
539  } \
540  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
541  static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
542  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
543  if ("CPUExecutionProvider" != ep) { \
544  ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
545  } \
546  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
547  std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
548  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
549  return std::tuple_cat(current, next); \
550  } \
551  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
552  static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
553  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
554  if (ith_input < num_input) { \
555  if ("CPUExecutionProvider" != ep) { \
556  ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
557  } \
558  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
559  std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
560  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
561  return std::tuple_cat(current, next); \
562  } else { \
563  std::tuple<T> current = std::tuple<T>{}; \
564  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
565  return std::tuple_cat(current, next); \
566  } \
567  } \
568  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
569  static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
570  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
571  if ("CPUExecutionProvider" != ep) { \
572  ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
573  } \
574  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
575  std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
576  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
577  return std::tuple_cat(current, next); \
578  } \
579  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
580  static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
581  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
582  if (ith_input < num_input) { \
583  if ("CPUExecutionProvider" != ep) { \
584  ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
585  } \
586  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
587  std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
588  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
589  return std::tuple_cat(current, next); \
590  } else { \
591  std::tuple<T> current = std::tuple<T>{}; \
592  auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
593  return std::tuple_cat(current, next); \
594  } \
595  }
596 #define CREATE_TUPLE_OUTPUT(data_type) \
597  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
598  static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
599  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
600  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
601  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
602  auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
603  return std::tuple_cat(current, next); \
604  } \
605  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
606  static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
607  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
608  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
609  std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
610  auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
611  return std::tuple_cat(current, next); \
612  } \
613  template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
614  static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
615  CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
616  if (ith_output < num_output) { \
617  args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
618  std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
619  auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
620  return std::tuple_cat(current, next); \
621  } else { \
622  std::tuple<T> current = std::tuple<T>{}; \
623  auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
624  return std::tuple_cat(current, next); \
625  } \
626  }
627 #define CREATE_TUPLE(data_type) \
628  CREATE_TUPLE_INPUT(data_type) \
629  CREATE_TUPLE_OUTPUT(data_type)
630 
631  CREATE_TUPLE(bool)
632  CREATE_TUPLE(float)
633  CREATE_TUPLE(Ort::Float16_t)
634  CREATE_TUPLE(Ort::BFloat16_t)
635  CREATE_TUPLE(double)
636  CREATE_TUPLE(int8_t)
637  CREATE_TUPLE(int16_t)
638  CREATE_TUPLE(int32_t)
639  CREATE_TUPLE(int64_t)
640  CREATE_TUPLE(uint8_t)
641  CREATE_TUPLE(uint16_t)
642  CREATE_TUPLE(uint32_t)
643  CREATE_TUPLE(uint64_t)
644  CREATE_TUPLE(std::string)
646  CREATE_TUPLE(Ort::Float8E4M3FN_t)
647  CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
648  CREATE_TUPLE(Ort::Float8E5M2_t)
649  CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
650 
651  // ParseArgs ...
652  template <typename... Ts>
653  static typename std::enable_if<0 == sizeof...(Ts)>::type
654  ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
655  }
656 
657  template <typename T, typename... Ts>
658  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
659  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
660  ParseArgs<Ts...>(input_types, output_types);
661  }
662 
663  template <typename T, typename... Ts>
664  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
665  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
666  ParseArgs<Ts...>(input_types, output_types);
667  }
668 
669 #ifdef ORT_CUDA_CTX
670  template <typename T, typename... Ts>
671  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
672  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
673  ParseArgs<Ts...>(input_types, output_types);
674  }
675 #endif
676 
677 #ifdef ORT_ROCM_CTX
678  template <typename T, typename... Ts>
679  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
680  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
681  ParseArgs<Ts...>(input_types, output_types);
682  }
683 #endif
684 
685  template <typename T, typename... Ts>
686  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
687  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
688  input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
689  ParseArgs<Ts...>(input_types, output_types);
690  }
691 
692  template <typename T, typename... Ts>
693  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
694  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
695  input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
696  ParseArgs<Ts...>(input_types, output_types);
697  }
698 
699  template <typename T, typename... Ts>
700  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
701  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
702  output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
703  ParseArgs<Ts...>(input_types, output_types);
704  }
705 
706  template <typename T, typename... Ts>
707  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
708  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
709  output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
710  ParseArgs<Ts...>(input_types, output_types);
711  }
712 
713 #define PARSE_INPUT_BASE(pack_type, onnx_type) \
714  template <typename T, typename... Ts> \
715  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
716  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
717  input_types.push_back(onnx_type); \
718  ParseArgs<Ts...>(input_types, output_types); \
719  } \
720  template <typename T, typename... Ts> \
721  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
722  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
723  input_types.push_back(onnx_type); \
724  ParseArgs<Ts...>(input_types, output_types); \
725  } \
726  template <typename T, typename... Ts> \
727  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
728  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
729  input_types.push_back(onnx_type); \
730  ParseArgs<Ts...>(input_types, output_types); \
731  }
732 
733 #define PARSE_INPUT(data_type, onnx_type) \
734  PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
735  PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
736  PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
737  PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
738  PARSE_INPUT_BASE(data_type, onnx_type)
739 
740 #define PARSE_OUTPUT(data_type, onnx_type) \
741  template <typename T, typename... Ts> \
742  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
743  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
744  output_types.push_back(onnx_type); \
745  ParseArgs<Ts...>(input_types, output_types); \
746  } \
747  template <typename T, typename... Ts> \
748  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
749  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
750  output_types.push_back(onnx_type); \
751  ParseArgs<Ts...>(input_types, output_types); \
752  } \
753  template <typename T, typename... Ts> \
754  static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
755  ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
756  output_types.push_back(onnx_type); \
757  ParseArgs<Ts...>(input_types, output_types); \
758  }
759 
760 #define PARSE_ARGS(data_type, onnx_type) \
761  PARSE_INPUT(data_type, onnx_type) \
762  PARSE_OUTPUT(data_type, onnx_type)
763 
764  PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
765  PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
766  PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
767  PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
768  PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
769  PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
770  PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
771  PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
772  PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
773  PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
774  PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
775  PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
776  PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
777  PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
778  PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
779  PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
780  PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
781  PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
782  PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
783 
784  OrtLiteCustomOp(const char* op_name,
785  const char* execution_provider,
786  ShapeInferFn shape_infer_fn,
787  int start_ver = 1,
788  int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
789  execution_provider_(execution_provider),
790  shape_infer_fn_(shape_infer_fn),
791  start_ver_(start_ver),
792  end_ver_(end_ver) {
793  OrtCustomOp::version = ORT_API_VERSION;
794 
795  OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
796  OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
797  OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
798 
799  OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
800  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
801  return self->input_types_.size();
802  };
803 
804  OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
805  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
806  return self->input_types_[indice];
807  };
808 
809  OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
810  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
811  return self->output_types_.size();
812  };
813 
814  OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
815  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
816  return self->output_types_[indice];
817  };
818 
819  OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
820  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
821  return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
822  };
823 
824  OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
825  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
826  return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
827  };
828 
829  OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
830  return 1;
831  };
832 
833  OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
834  return 0;
835  };
836 
837  OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
838  return 1;
839  };
840 
841  OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
842  return 0;
843  };
844 
845  OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
846  OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
847  OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
848  OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
849 
850  OrtCustomOp::CreateKernelV2 = {};
851  OrtCustomOp::KernelComputeV2 = {};
852  OrtCustomOp::KernelCompute = {};
853 
854  OrtCustomOp::InferOutputShapeFn = {};
855 
856  OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
857  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
858  return self->start_ver_;
859  };
860 
861  OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
862  auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
863  return self->end_ver_;
864  };
865  }
866 
869 
870  std::vector<ONNXTensorElementDataType> input_types_;
871  std::vector<ONNXTensorElementDataType> output_types_;
872 
874 
875  int start_ver_ = 1;
877 
878  void* compute_fn_ = {};
880 };
881 
882 //////////////////////////// OrtLiteCustomFunc ////////////////////////////////
883 // The struct is to implement function-as-op.
884 // E.g. a function might be defined as:
885 // void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
886 // It could be registered this way:
887 // Ort::CustomOpDomain v2_domain{"v2"};
888 // std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
889 // v2_domain.Add(fil_op_ptr.get());
890 // session_options.Add(v2_domain);
891 // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
892 template <typename... Args>
894  using ComputeFn = void (*)(Args...);
895  using ComputeFnReturnStatus = Status (*)(Args...);
896  using MyType = OrtLiteCustomFunc<Args...>;
897 
898  struct Kernel {
899  size_t num_input_{};
900  size_t num_output_{};
903  std::string ep_{};
904  };
905 
906  OrtLiteCustomFunc(const char* op_name,
907  const char* execution_provider,
908  ComputeFn compute_fn,
909  ShapeInferFn shape_infer_fn = {},
910  int start_ver = 1,
911  int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
912  compute_fn_ = reinterpret_cast<void*>(compute_fn);
914 
915  OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
916  auto kernel = reinterpret_cast<Kernel*>(op_kernel);
917  std::vector<ArgPtr> args;
918  auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
919  std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
920  };
921 
922  OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
923  auto kernel = std::make_unique<Kernel>();
924  auto me = static_cast<const MyType*>(this_);
925  kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
926  Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
927  Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
928  auto self = static_cast<const OrtLiteCustomFunc*>(this_);
929  kernel->ep_ = self->execution_provider_;
930  return reinterpret_cast<void*>(kernel.release());
931  };
932 
933  OrtCustomOp::KernelDestroy = [](void* op_kernel) {
934  delete reinterpret_cast<Kernel*>(op_kernel);
935  };
936 
937  if (shape_infer_fn_) {
938  OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
939  auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
940  ShapeInferContext ctx(&GetApi(), ort_ctx);
941  return shape_info_fn(ctx);
942  };
943  }
944  }
945 
946  OrtLiteCustomFunc(const char* op_name,
947  const char* execution_provider,
948  ComputeFnReturnStatus compute_fn_return_status,
949  ShapeInferFn shape_infer_fn = {},
950  int start_ver = 1,
951  int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
952  compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
954 
955  OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
956  auto kernel = reinterpret_cast<Kernel*>(op_kernel);
957  std::vector<ArgPtr> args;
958  auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
959  return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
960  };
961 
962  OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
963  auto kernel = std::make_unique<Kernel>();
964  auto me = static_cast<const MyType*>(this_);
965  kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
966  Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
967  Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
968  auto self = static_cast<const OrtLiteCustomFunc*>(this_);
969  kernel->ep_ = self->execution_provider_;
970  return reinterpret_cast<void*>(kernel.release());
971  };
972 
973  OrtCustomOp::KernelDestroy = [](void* op_kernel) {
974  delete reinterpret_cast<Kernel*>(op_kernel);
975  };
976 
977  if (shape_infer_fn_) {
978  OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
979  auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
980  ShapeInferContext ctx(&GetApi(), ort_ctx);
981  return shape_info_fn(ctx);
982  };
983  }
984  }
985 }; // struct OrtLiteCustomFunc
986 
987 /////////////////////////// OrtLiteCustomStruct ///////////////////////////
988 // The struct is to implement struct-as-op.
989 // E.g. a struct might be defined as:
990 // struct Merge {
991 // Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
992 // void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
993 // std::string_view string_in,
994 // Ort::Custom::Tensor<std::string>* strings_out) {...}
995 // bool reverse_ = false;
996 // };
997 // It could be registered this way:
998 // Ort::CustomOpDomain v2_domain{"v2"};
999 // std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
1000 // v2_domain.Add(mrg_op_ptr.get());
1001 // session_options.Add(v2_domain);
1002 // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
1003 template <typename CustomOp>
1005  template <typename... Args>
1006  using CustomComputeFn = void (CustomOp::*)(Args...);
1007 
1008  template <typename... Args>
1009  using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
1010 
1012 
1013  struct Kernel {
1014  size_t num_input_{};
1015  size_t num_output_{};
1016  std::unique_ptr<CustomOp> custom_op_;
1018  };
1019 
1020  OrtLiteCustomStruct(const char* op_name,
1021  const char* execution_provider,
1022  int start_ver = 1,
1023  int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
1024  SetCompute(&CustomOp::Compute);
1025 
1026  OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1027  auto kernel = std::make_unique<Kernel>();
1028  Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
1029  Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
1030  kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
1031  auto self = static_cast<const OrtLiteCustomStruct*>(this_);
1032  kernel->ep_ = self->execution_provider_;
1033  return reinterpret_cast<void*>(kernel.release());
1034  };
1035 
1036  OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1037  delete reinterpret_cast<Kernel*>(op_kernel);
1038  };
1039 
1040  SetShapeInfer<CustomOp>(0);
1041  }
1042 
1043  template <typename... Args>
1046  OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1047  auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1048  ArgPtrs args;
1049  auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1050  std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1051  };
1052  }
1053 
1054  template <typename... Args>
1057  OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
1058  auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1059  ArgPtrs args;
1060  auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1061  return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
1062  };
1063  }
1064 
1065  template <typename C>
1066  decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
1067  OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
1068  ShapeInferContext ctx(&GetApi(), ort_ctx);
1069  return C::InferOutputShape(ctx);
1070  };
1071  return {};
1072  }
1073 
1074  template <typename C>
1075  void SetShapeInfer(...) {
1076  OrtCustomOp::InferOutputShapeFn = {};
1077  }
1078 }; // struct OrtLiteCustomStruct
1079 
1080 /////////////////////////// CreateLiteCustomOp ////////////////////////////
1081 
1082 template <typename... Args>
1084  const char* execution_provider,
1085  void (*custom_compute_fn)(Args...),
1086  Status (*shape_infer_fn)(ShapeInferContext&) = {},
1087  int start_ver = 1,
1088  int end_ver = MAX_CUSTOM_OP_END_VER) {
1089  using LiteOp = OrtLiteCustomFunc<Args...>;
1090  return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
1091 }
1092 
1093 template <typename... Args>
1095  const char* execution_provider,
1096  Status (*custom_compute_fn_v2)(Args...),
1097  Status (*shape_infer_fn)(ShapeInferContext&) = {},
1098  int start_ver = 1,
1099  int end_ver = MAX_CUSTOM_OP_END_VER) {
1100  using LiteOp = OrtLiteCustomFunc<Args...>;
1101  return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
1102 }
1103 
1104 template <typename CustomOp>
1106  const char* execution_provider,
1107  int start_ver = 1,
1108  int end_ver = MAX_CUSTOM_OP_END_VER) {
1109  using LiteOp = OrtLiteCustomStruct<CustomOp>;
1110  return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
1111 }
1112 
1113 } // namespace Custom
1114 } // namespace Ort
OrtLiteCustomFunc(const char *op_name, const char *execution_provider, ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn={}, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
void SetStringOutput(const strings &ss, const std::vector< int64_t > &dims)
std::unique_ptr< Custom::TensorBase > TensorPtr
auto to_string(const T &value) -> std::string
Definition: format.h:2597
TensorBase(OrtKernelContext *ctx, size_t indice, bool is_input)
#define CREATE_TUPLE(data_type)
TensorArray(OrtKernelContext *ctx, size_t start_indice, bool is_input)
Status(CustomOp::*)(Args...) CustomComputeFnReturnStatus
T * AllocateOutput(size_t ith_output, const std::vector< int64_t > &shape)
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
void
Definition: png.h:1083
static std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, OrtKernelContext * >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, OrtKernelContext & >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, const TensorArray & >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, const TensorArray * >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, TensorArray & >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, TensorArray * >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}#define PARSE_INPUT_BASE(pack_type, onnx_type)#define PARSE_INPUT(data_type, onnx_type)#define PARSE_OUTPUT(data_type, onnx_type)#define PARSE_ARGS(data_type, onnx_type) PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) OrtLiteCustomOp(const char *op_name, const char *execution_provider, ShapeInferFn shape_infer_fn, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER):op_name_(op_name), execution_provider_(execution_provider), shape_infer_fn_(shape_infer_fn), start_ver_(start_ver), end_ver_(end_ver){OrtCustomOp::version=ORT_API_VERSION;OrtCustomOp::GetName=[](const OrtCustomOp *op){return static_cast< const OrtLiteCustomOp * >op) ->op_name_.c_str();};OrtCustomOp::GetExecutionProviderType=[](const OrtCustomOp *op){return((OrtLiteCustomOp *) op) ->execution_provider_.c_str();};OrtCustomOp::GetInputMemoryType=[](const OrtCustomOp *, size_t){return OrtMemTypeDefault;};OrtCustomOp::GetInputTypeCount=[](const OrtCustomOp *op){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self->input_types_.size();};OrtCustomOp::GetInputType=[](const OrtCustomOp *op, size_t indice){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self->input_types_[indice];};OrtCustomOp::GetOutputTypeCount=[](const OrtCustomOp *op){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self->output_types_.size();};OrtCustomOp::GetOutputType=[](const OrtCustomOp *op, size_t indice){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self-> output_types_[indice]
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
void ThrowOnError(OrtStatus *ort_status)
GLsizei const GLfloat * value
Definition: glcorearb.h:824
ConstValue GetInput(size_t index) const
OrtLiteCustomStruct(const char *op_name, const char *execution_provider, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
GLdouble s
Definition: glad.h:3009
void SetCompute(CustomComputeFnReturnStatus< Args...>)
static std::enable_if< 0==sizeof...(Ts)>::type ParseArgs(std::vector< ONNXTensorElementDataType > &, std::vector< ONNXTensorElementDataType > &)
virtual const void * DataRaw() const =0
T operator[](size_t indice) const
Provide access to per-node attributes and input shapes, so one could compute and set output shapes...
void SetStringOutput(const strings &ss, const std::vector< int64_t > &dims)
ONNXTensorElementDataType Type() const
OrtLiteCustomFunc(const char *op_name, const char *execution_provider, ComputeFn compute_fn, ShapeInferFn shape_infer_fn={}, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
Tensor(OrtKernelContext *ctx, size_t indice, bool is_input)
basic_string_view< char > string_view
Definition: core.h:522
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
std::unique_ptr< Custom::ArgBase > ArgPtr
std::vector< ArgPtr > ArgPtrs
void Init(const OrtKernelContext &kernel_ctx)
Definition: cuda_context.h:40
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
OrtKernelContext * GetOrtKernelContext() const
static std::enable_if< std::is_same< T, const TensorArray & >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
std::vector< TensorPtr > TensorPtrs
static std::enable_if< std::is_same< T, OrtKernelContext * >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
virtual size_t SizeInBytes() const =0
TT * Allocate(const std::vector< int64_t > &shape)
static std::enable_if< sizeof...(Ts)==0, std::tuple<> >::type CreateTuple(OrtKernelContext *, ArgPtrs &, size_t, size_t, const std::string &)
static std::enable_if< std::is_same< T, TensorArray * >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
GLdouble t
Definition: glad.h:2397
static std::enable_if< std::is_same< T, const TensorArray * >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
GT_API const UT_StringHolder version
typename std::remove_reference< T >::type TT
GLsizeiptr size
Definition: glcorearb.h:664
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
ONNXTensorElementDataType type_
const std::vector< int64_t > & Shape() const
std::vector< ONNXTensorElementDataType > input_types_
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
Tensor(OrtKernelContext *ctx, size_t indice, bool is_input)
ArgBase(OrtKernelContext *ctx, size_t indice, bool is_input)
#define CREATE_TUPLE_INPUT(data_type)
std::optional< Custom::Tensor< float >> OptionalFloatTensor
std::vector< ONNXTensorElementDataType > output_types_
**If you just want to fire and args
Definition: thread.h:609
static std::enable_if< std::is_same< T, OrtKernelContext & >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
const TensorPtr & operator[](size_t ith_input) const
size_t SizeInBytes() const override
bool accumulate(const PointDataTreeT &points, const std::string &attribute, typename PromoteType< ValueT >::Highest &total, const FilterT &filter, typename PointDataTreeT::template ValueConverter< ResultTreeT >::Type *totalTree)
Evaluates the total value of a point attribute and returns whether the value is valid. Optionally constructs localised total value trees.
#define MAX_CUSTOM_OP_END_VER
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
OrtLiteCustomOp * CreateLiteCustomOp(const char *op_name, const char *execution_provider, void(*custom_compute_fn)(Args...), Status(*shape_infer_fn)(ShapeInferContext &)={}, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
std::optional< std::vector< int64_t > > shape_
void SetCompute(CustomComputeFn< Args...>)
const void * DataRaw() const override
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed...
Tensor(OrtKernelContext *ctx, size_t indice, bool is_input)
type
Definition: core.h:1059
const Span< std::string_view > & AsSpan()
#define PARSE_ARGS(data_type, onnx_type)
void Assign(const T *data, size_t size)
Definition: format.h:895
#define ORT_CXX_API_THROW(string, code)
std::optional< const Custom::Tensor< float > & > ConstOptionalFloatTensor
Tensor< std::string > & AllocateStringTensor(size_t ith_output)
static std::enable_if< std::is_same< T, TensorArray & >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)