22 #include <unordered_set>
40 using ArgPtr = std::unique_ptr<Custom::ArgBase>;
47 bool is_input) :
ArgBase(ctx, indice, is_input) {}
49 operator bool()
const {
53 const std::vector<int64_t>&
Shape()
const {
55 ORT_CXX_API_THROW(
"tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
60 ONNXTensorElementDataType
Type()
const {
75 for (
const auto& dim : *
shape_) {
77 shape_str.append(
", ");
89 virtual const void*
DataRaw()
const = 0;
93 std::optional<std::vector<int64_t>>
shape_;
94 ONNXTensorElementDataType
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
108 return data_[indice];
113 template <
typename T>
117 Tensor(OrtKernelContext* ctx,
size_t indice,
bool is_input) :
TensorBase(ctx, indice, is_input) {
120 ORT_CXX_API_THROW(
"invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
124 shape_ = type_shape_info.GetShape();
141 ORT_CXX_API_THROW(
"invalid shape while trying to get a span out of Ort::Custom::Tensor",
142 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
144 span_.Assign(
Data(), static_cast<size_t>((*
shape_)[0]));
148 if (!
shape_.has_value() ||
shape_->size() != 1 || (*shape_)[0] != 1) {
149 ORT_CXX_API_THROW(
"invalid shape while trying to get a scalar from Ort::Custom::Tensor",
150 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
155 return reinterpret_cast<const void*
>(
Data());
173 Tensor(OrtKernelContext* ctx,
size_t indice,
bool is_input) :
TensorBase(ctx, indice, is_input) {
176 ORT_CXX_API_THROW(
"invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
180 shape_ = type_shape_info.GetShape();
181 auto num_chars = const_value.GetStringTensorDataLength();
185 std::vector<char> chars(num_chars + 1,
'\0');
186 std::vector<size_t>
offsets(num_strings);
187 const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
188 auto upper_bound = num_strings - 1;
189 input_strings_.resize(num_strings);
190 for (
size_t i = upper_bound;; --i) {
191 if (i < upper_bound) {
192 chars[offsets[i + 1]] =
'\0';
194 input_strings_[i] = chars.data() + offsets[i];
203 return input_strings_;
206 if (input_strings_.size() != 1) {
207 ORT_CXX_API_THROW(
"DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
209 return reinterpret_cast<const void*
>(input_strings_[0].c_str());
212 if (input_strings_.size() != 1) {
213 ORT_CXX_API_THROW(
"SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
215 return input_strings_[0].size();
219 std::vector<const char*> raw;
220 for (
const auto&
s : ss) {
221 raw.push_back(
s.data());
228 ORT_CXX_API_THROW(
"span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
231 if (input_strings_.size() != 1) {
232 ORT_CXX_API_THROW(
"invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
233 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
235 return input_strings_[0];
239 std::vector<std::string> input_strings_;
248 Tensor(OrtKernelContext* ctx,
size_t indice,
bool is_input) :
TensorBase(ctx, indice, is_input) {
251 ORT_CXX_API_THROW(
"invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
255 shape_ = type_shape_info.GetShape();
256 auto num_chars = const_value.GetStringTensorDataLength();
257 chars_.resize(num_chars + 1,
'\0');
260 std::vector<size_t>
offsets(num_strings);
261 const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
262 offsets.push_back(num_chars);
263 for (
size_t i = 0; i < num_strings; ++i) {
264 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
270 return input_string_views_;
273 if (input_string_views_.size() != 1) {
274 ORT_CXX_API_THROW(
"DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
276 return reinterpret_cast<const void*
>(input_string_views_[0].data());
279 if (input_string_views_.size() != 1) {
280 ORT_CXX_API_THROW(
"SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
282 return input_string_views_[0].size();
286 std::vector<const char*> raw;
287 for (
const auto&
s : ss) {
288 raw.push_back(
s.data());
295 ORT_CXX_API_THROW(
"span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
298 if (input_string_views_.size() != 1) {
299 ORT_CXX_API_THROW(
"invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
300 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
302 return input_string_views_[0];
306 std::vector<char> chars_;
307 std::vector<std::string_view> input_string_views_;
321 for (
size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
327 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
328 tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input,
true);
330 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
331 tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input,
true);
333 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
334 tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input,
true);
336 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
337 tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input,
true);
339 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
340 tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input,
true);
342 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
343 tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input,
true);
345 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
346 tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input,
true);
348 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
349 tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input,
true);
351 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
352 tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input,
true);
354 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
355 tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input,
true);
357 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
358 tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input,
true);
360 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
361 tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input,
true);
367 tensors_.emplace_back(tensor.release());
371 template <
typename T>
376 auto raw_output = tensor.get()->Allocate(shape);
377 tensors_.emplace_back(tensor.release());
385 tensors_.emplace_back(tensor.release());
389 return tensors_.size();
393 return tensors_.at(ith_input);
416 template <
size_t ith_input,
size_t ith_output,
typename... Ts>
417 static typename std::enable_if<
sizeof...(Ts) == 0, std::tuple<>>::
type
419 return std::make_tuple();
422 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
425 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
426 auto next =
CreateTuple<ith_input, ith_output, Ts...>(context,
args, num_input, num_output, ep);
427 return std::tuple_cat(current, next);
430 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
433 std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
434 auto next =
CreateTuple<ith_input, ith_output, Ts...>(context,
args, num_input, num_output, ep);
435 return std::tuple_cat(current, next);
439 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
443 cuda_context.
Init(*context);
444 std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
445 auto next =
CreateTuple<ith_input, ith_output, Ts...>(context,
args, num_input, num_output, ep);
446 return std::tuple_cat(current, next);
451 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
454 thread_local RocmContext rocm_context;
455 rocm_context.Init(*context);
456 std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
457 auto next =
CreateTuple<ith_input, ith_output, Ts...>(context,
args, num_input, num_output, ep);
458 return std::tuple_cat(current, next);
462 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
465 args.push_back(std::make_unique<TensorArray>(context, ith_input,
true));
466 std::tuple<T> current = std::tuple<T>{
reinterpret_cast<T>(args.back().get())};
467 auto next =
CreateTuple<ith_input + 1, ith_output, Ts...>(context,
args, num_input, num_output, ep);
468 return std::tuple_cat(current, next);
471 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
474 args.push_back(std::make_unique<TensorArray>(context, ith_input,
true));
475 std::tuple<T> current = std::tuple<T>{
reinterpret_cast<T
>(*args.back().get())};
476 auto next =
CreateTuple<ith_input + 1, ith_output, Ts...>(context,
args, num_input, num_output, ep);
477 return std::tuple_cat(current, next);
480 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
483 args.push_back(std::make_unique<TensorArray>(context, ith_output,
false));
484 std::tuple<T> current = std::tuple<T>{
reinterpret_cast<T
>(args.back().get())};
485 auto next =
CreateTuple<ith_input, ith_output + 1, Ts...>(context,
args, num_input, num_output, ep);
486 return std::tuple_cat(current, next);
489 template <
size_t ith_input,
size_t ith_output,
typename T,
typename... Ts>
492 args.push_back(std::make_unique<TensorArray>(context, ith_output,
false));
493 std::tuple<T> current = std::tuple<T>{
reinterpret_cast<T
>(*args.back().get())};
494 auto next =
CreateTuple<ith_input, ith_output + 1, Ts...>(context,
args, num_input, num_output, ep);
495 return std::tuple_cat(current, next);
498 #define CREATE_TUPLE_INPUT(data_type) \
499 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
500 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
501 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
502 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
503 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
504 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
505 return std::tuple_cat(current, next); \
507 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
508 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
509 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
510 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
511 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
512 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
513 return std::tuple_cat(current, next); \
515 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
516 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
517 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
518 if (ith_input < num_input) { \
519 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
520 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
521 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
522 return std::tuple_cat(current, next); \
524 std::tuple<T> current = std::tuple<T>{}; \
525 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
526 return std::tuple_cat(current, next); \
529 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
530 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
531 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
532 if ("CPUExecutionProvider" != ep) { \
533 ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
535 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
536 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
537 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
538 return std::tuple_cat(current, next); \
540 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
541 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
542 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
543 if ("CPUExecutionProvider" != ep) { \
544 ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
546 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
547 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
548 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
549 return std::tuple_cat(current, next); \
551 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
552 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
553 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
554 if (ith_input < num_input) { \
555 if ("CPUExecutionProvider" != ep) { \
556 ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
558 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
559 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
560 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
561 return std::tuple_cat(current, next); \
563 std::tuple<T> current = std::tuple<T>{}; \
564 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
565 return std::tuple_cat(current, next); \
568 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
569 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
570 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
571 if ("CPUExecutionProvider" != ep) { \
572 ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
574 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
575 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
576 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
577 return std::tuple_cat(current, next); \
579 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
580 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
581 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
582 if (ith_input < num_input) { \
583 if ("CPUExecutionProvider" != ep) { \
584 ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
586 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
587 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
588 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
589 return std::tuple_cat(current, next); \
591 std::tuple<T> current = std::tuple<T>{}; \
592 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
593 return std::tuple_cat(current, next); \
596 #define CREATE_TUPLE_OUTPUT(data_type) \
597 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
598 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
599 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
600 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
601 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
602 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
603 return std::tuple_cat(current, next); \
605 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
606 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
607 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
608 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
609 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
610 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
611 return std::tuple_cat(current, next); \
613 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
614 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
615 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
616 if (ith_output < num_output) { \
617 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
618 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
619 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
620 return std::tuple_cat(current, next); \
622 std::tuple<T> current = std::tuple<T>{}; \
623 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
624 return std::tuple_cat(current, next); \
627 #define CREATE_TUPLE(data_type) \
628 CREATE_TUPLE_INPUT(data_type) \
629 CREATE_TUPLE_OUTPUT(data_type)
652 template <typename... Ts>
653 static typename std::enable_if<0 == sizeof...(Ts)>::
type
654 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
657 template <
typename T,
typename... Ts>
659 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
660 ParseArgs<Ts...>(input_types, output_types);
663 template <
typename T,
typename... Ts>
665 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
666 ParseArgs<Ts...>(input_types, output_types);
670 template <
typename T,
typename... Ts>
672 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
673 ParseArgs<Ts...>(input_types, output_types);
678 template <
typename T,
typename... Ts>
680 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
681 ParseArgs<Ts...>(input_types, output_types);
685 template <
typename T,
typename... Ts>
687 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
688 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
689 ParseArgs<Ts...>(input_types, output_types);
692 template <
typename T,
typename... Ts>
694 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
695 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
696 ParseArgs<Ts...>(input_types, output_types);
699 template <
typename T,
typename... Ts>
701 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
702 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
703 ParseArgs<Ts...>(input_types, output_types);
706 template <
typename T,
typename... Ts>
708 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
709 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
710 ParseArgs<Ts...>(input_types, output_types);
713 #define PARSE_INPUT_BASE(pack_type, onnx_type) \
714 template <typename T, typename... Ts> \
715 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
716 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
717 input_types.push_back(onnx_type); \
718 ParseArgs<Ts...>(input_types, output_types); \
720 template <typename T, typename... Ts> \
721 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
722 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
723 input_types.push_back(onnx_type); \
724 ParseArgs<Ts...>(input_types, output_types); \
726 template <typename T, typename... Ts> \
727 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
728 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
729 input_types.push_back(onnx_type); \
730 ParseArgs<Ts...>(input_types, output_types); \
733 #define PARSE_INPUT(data_type, onnx_type) \
734 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
735 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
736 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
737 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
738 PARSE_INPUT_BASE(data_type, onnx_type)
740 #define PARSE_OUTPUT(data_type, onnx_type) \
741 template <typename T, typename... Ts> \
742 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
743 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
744 output_types.push_back(onnx_type); \
745 ParseArgs<Ts...>(input_types, output_types); \
747 template <typename T, typename... Ts> \
748 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
749 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
750 output_types.push_back(onnx_type); \
751 ParseArgs<Ts...>(input_types, output_types); \
753 template <typename T, typename... Ts> \
754 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
755 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
756 output_types.push_back(onnx_type); \
757 ParseArgs<Ts...>(input_types, output_types); \
760 #define PARSE_ARGS(data_type, onnx_type) \
761 PARSE_INPUT(data_type, onnx_type) \
762 PARSE_OUTPUT(data_type, onnx_type)
764 PARSE_ARGS(
bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
765 PARSE_ARGS(
float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
766 PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
767 PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
768 PARSE_ARGS(
double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
769 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
770 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
771 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
772 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
773 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
774 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
775 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
776 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
777 PARSE_ARGS(std::
string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
779 PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
780 PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
781 PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
782 PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
784 OrtLiteCustomOp(const
char* op_name,
785 const
char* execution_provider,
795 OrtCustomOp::GetName = [](
const OrtCustomOp* op) {
return static_cast<const OrtLiteCustomOp*
>(op)->
op_name_.c_str(); };
796 OrtCustomOp::GetExecutionProviderType = [](
const OrtCustomOp* op) {
return ((OrtLiteCustomOp*)op)->
execution_provider_.c_str(); };
797 OrtCustomOp::GetInputMemoryType = [](
const OrtCustomOp*, size_t) {
return OrtMemTypeDefault; };
799 OrtCustomOp::GetInputTypeCount = [](
const OrtCustomOp* op) {
800 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
801 return self->input_types_.size();
804 OrtCustomOp::GetInputType = [](
const OrtCustomOp* op,
size_t indice) {
805 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
806 return self->input_types_[indice];
809 OrtCustomOp::GetOutputTypeCount = [](
const OrtCustomOp* op) {
810 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
811 return self->output_types_.size();
814 OrtCustomOp::GetOutputType = [](
const OrtCustomOp* op,
size_t indice) {
815 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
816 return self->output_types_[indice];
819 OrtCustomOp::GetInputCharacteristic = [](
const OrtCustomOp* op,
size_t indice) {
821 return self->
input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
824 OrtCustomOp::GetOutputCharacteristic = [](
const OrtCustomOp* op,
size_t indice) {
825 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
826 return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
829 OrtCustomOp::GetVariadicInputMinArity = [](
const OrtCustomOp*) {
833 OrtCustomOp::GetVariadicInputHomogeneity = [](
const OrtCustomOp*) {
837 OrtCustomOp::GetVariadicOutputMinArity = [](
const OrtCustomOp*) {
841 OrtCustomOp::GetVariadicOutputHomogeneity = [](
const OrtCustomOp*) {
845 OrtCustomOp::GetVariadicInputMinArity = [](
const OrtCustomOp*) {
return 0; };
846 OrtCustomOp::GetVariadicInputHomogeneity = [](
const OrtCustomOp*) {
return 0; };
847 OrtCustomOp::GetVariadicOutputMinArity = [](
const OrtCustomOp*) {
return 0; };
848 OrtCustomOp::GetVariadicOutputHomogeneity = [](
const OrtCustomOp*) {
return 0; };
850 OrtCustomOp::CreateKernelV2 = {};
851 OrtCustomOp::KernelComputeV2 = {};
852 OrtCustomOp::KernelCompute = {};
854 OrtCustomOp::InferOutputShapeFn = {};
856 OrtCustomOp::GetStartVersion = [](
const OrtCustomOp* op) {
857 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
858 return self->start_ver_;
861 OrtCustomOp::GetEndVersion = [](
const OrtCustomOp* op) {
862 auto self =
reinterpret_cast<const OrtLiteCustomOp*
>(op);
863 return self->end_ver_;
892 template <
typename... Args>
900 size_t num_output_{};
907 const char* execution_provider,
915 OrtCustomOp::KernelCompute = [](
void* op_kernel, OrtKernelContext* context) {
916 auto kernel =
reinterpret_cast<Kernel*
>(op_kernel);
917 std::vector<ArgPtr>
args;
918 auto t =
CreateTuple<0, 0, Args...>(context,
args, kernel->num_input_, kernel->num_output_, kernel->ep_);
919 std::apply([kernel](Args
const&... t_args) { kernel->compute_fn_(t_args...); },
t);
922 OrtCustomOp::CreateKernel = [](
const OrtCustomOp* this_,
const OrtApi* ort_api,
const OrtKernelInfo* info) {
923 auto kernel = std::make_unique<Kernel>();
924 auto me =
static_cast<const MyType*
>(this_);
925 kernel->compute_fn_ =
reinterpret_cast<ComputeFn
>(me->compute_fn_);
927 Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
928 auto self =
static_cast<const OrtLiteCustomFunc*
>(this_);
929 kernel->ep_ =
self->execution_provider_;
930 return reinterpret_cast<void*
>(kernel.release());
933 OrtCustomOp::KernelDestroy = [](
void* op_kernel) {
934 delete reinterpret_cast<Kernel*
>(op_kernel);
938 OrtCustomOp::InferOutputShapeFn = [](
const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
940 ShapeInferContext ctx(&
GetApi(), ort_ctx);
941 return shape_info_fn(ctx);
947 const char* execution_provider,
955 OrtCustomOp::KernelComputeV2 = [](
void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
956 auto kernel =
reinterpret_cast<Kernel*
>(op_kernel);
957 std::vector<ArgPtr>
args;
958 auto t =
CreateTuple<0, 0, Args...>(context,
args, kernel->num_input_, kernel->num_output_, kernel->ep_);
959 return std::apply([kernel](Args
const&... t_args) {
Status status = kernel->compute_fn_return_status_(t_args...);
return status.
release(); },
t);
962 OrtCustomOp::CreateKernel = [](
const OrtCustomOp* this_,
const OrtApi* ort_api,
const OrtKernelInfo* info) {
963 auto kernel = std::make_unique<Kernel>();
964 auto me =
static_cast<const MyType*
>(this_);
965 kernel->compute_fn_return_status_ =
reinterpret_cast<ComputeFnReturnStatus
>(me->compute_fn_return_status_);
967 Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
968 auto self =
static_cast<const OrtLiteCustomFunc*
>(this_);
969 kernel->ep_ =
self->execution_provider_;
970 return reinterpret_cast<void*
>(kernel.release());
973 OrtCustomOp::KernelDestroy = [](
void* op_kernel) {
974 delete reinterpret_cast<Kernel*
>(op_kernel);
978 OrtCustomOp::InferOutputShapeFn = [](
const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
980 ShapeInferContext ctx(&
GetApi(), ort_ctx);
981 return shape_info_fn(ctx);
1003 template <
typename CustomOp>
1005 template <
typename... Args>
1008 template <
typename... Args>
1014 size_t num_input_{};
1015 size_t num_output_{};
1021 const char* execution_provider,
1024 SetCompute(&CustomOp::Compute);
1026 OrtCustomOp::CreateKernel = [](
const OrtCustomOp* this_,
const OrtApi* ort_api,
const OrtKernelInfo* info) {
1027 auto kernel = std::make_unique<Kernel>();
1028 Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
1029 Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
1030 kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
1033 return reinterpret_cast<void*
>(kernel.release());
1036 OrtCustomOp::KernelDestroy = [](
void* op_kernel) {
1037 delete reinterpret_cast<Kernel*
>(op_kernel);
1040 SetShapeInfer<CustomOp>(0);
1043 template <
typename... Args>
1046 OrtCustomOp::KernelCompute = [](
void* op_kernel, OrtKernelContext* context) {
1047 auto kernel =
reinterpret_cast<Kernel*
>(op_kernel);
1049 auto t =
CreateTuple<0, 0, Args...>(context,
args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1050 std::apply([kernel](Args
const&... t_args) { kernel->custom_op_->Compute(t_args...); },
t);
1054 template <
typename... Args>
1057 OrtCustomOp::KernelComputeV2 = [](
void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
1058 auto kernel =
reinterpret_cast<Kernel*
>(op_kernel);
1060 auto t =
CreateTuple<0, 0, Args...>(context,
args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1061 return std::apply([kernel](Args
const&... t_args) {
Status status = kernel->custom_op_->Compute(t_args...);
return status.
release(); },
t);
1065 template <
typename C>
1066 decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
1067 OrtCustomOp::InferOutputShapeFn = [](
const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
1069 return C::InferOutputShape(ctx);
1074 template <
typename C>
1076 OrtCustomOp::InferOutputShapeFn = {};
1082 template <
typename... Args>
1084 const char* execution_provider,
1085 void (*custom_compute_fn)(Args...),
1089 using LiteOp = OrtLiteCustomFunc<Args...>;
1090 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
1093 template <
typename... Args>
1095 const char* execution_provider,
1096 Status (*custom_compute_fn_v2)(Args...),
1100 using LiteOp = OrtLiteCustomFunc<Args...>;
1101 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
1104 template <
typename CustomOp>
1106 const char* execution_provider,
1110 return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
OrtLiteCustomFunc(const char *op_name, const char *execution_provider, ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn={}, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
ShapeInferFn shape_infer_fn_
void SetStringOutput(const strings &ss, const std::vector< int64_t > &dims)
std::unique_ptr< Custom::TensorBase > TensorPtr
std::unique_ptr< CustomOp > custom_op_
const void * DataRaw() const override
TensorBase(OrtKernelContext *ctx, size_t indice, bool is_input)
const string_views & Data() const
#define CREATE_TUPLE(data_type)
TensorArray(OrtKernelContext *ctx, size_t start_indice, bool is_input)
Status(CustomOp::*)(Args...) CustomComputeFnReturnStatus
T * AllocateOutput(size_t ith_output, const std::vector< int64_t > &shape)
const std::string execution_provider_
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
static std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, OrtKernelContext * >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, OrtKernelContext & >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, const TensorArray & >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, const TensorArray * >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, TensorArray & >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}template< typename T, typename...Ts > static typename std::enable_if< 0<=sizeof...(Ts)&&std::is_same< T, TensorArray * >::value >::type ParseArgs(std::vector< ONNXTensorElementDataType > &input_types, std::vector< ONNXTensorElementDataType > &output_types){output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);ParseArgs< Ts...>input_types, output_types);}#define PARSE_INPUT_BASE(pack_type, onnx_type)#define PARSE_INPUT(data_type, onnx_type)#define PARSE_OUTPUT(data_type, onnx_type)#define PARSE_ARGS(data_type, onnx_type) PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) OrtLiteCustomOp(const char *op_name, const char *execution_provider, ShapeInferFn shape_infer_fn, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER):op_name_(op_name), execution_provider_(execution_provider), shape_infer_fn_(shape_infer_fn), start_ver_(start_ver), end_ver_(end_ver){OrtCustomOp::version=ORT_API_VERSION;OrtCustomOp::GetName=[](const OrtCustomOp *op){return static_cast< const OrtLiteCustomOp * >op) ->op_name_.c_str();};OrtCustomOp::GetExecutionProviderType=[](const OrtCustomOp *op){return((OrtLiteCustomOp *) op) ->execution_provider_.c_str();};OrtCustomOp::GetInputMemoryType=[](const OrtCustomOp *, size_t){return OrtMemTypeDefault;};OrtCustomOp::GetInputTypeCount=[](const OrtCustomOp *op){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self->input_types_.size();};OrtCustomOp::GetInputType=[](const OrtCustomOp *op, size_t indice){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self->input_types_[indice];};OrtCustomOp::GetOutputTypeCount=[](const OrtCustomOp *op){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self->output_types_.size();};OrtCustomOp::GetOutputType=[](const OrtCustomOp *op, size_t indice){auto self=reinterpret_cast< const OrtLiteCustomOp * >op);return self-> output_types_[indice]
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
GLsizei const GLchar *const * string
void ThrowOnError(OrtStatus *ort_status)
GLsizei const GLfloat * value
ConstValue GetInput(size_t index) const
OrtLiteCustomStruct(const char *op_name, const char *execution_provider, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
Status(*)(Args...) ComputeFnReturnStatus
std::vector< std::string_view > string_views
void SetCompute(CustomComputeFnReturnStatus< Args...>)
static std::enable_if< 0==sizeof...(Ts)>::type ParseArgs(std::vector< ONNXTensorElementDataType > &, std::vector< ONNXTensorElementDataType > &)
virtual const void * DataRaw() const =0
void(CustomOp::*)(Args...) CustomComputeFn
T operator[](size_t indice) const
const Span< std::string > & AsSpan()
int64_t NumberOfElement() const
Provide access to per-node attributes and input shapes, so one could compute and set output shapes...
void SetStringOutput(const strings &ss, const std::vector< int64_t > &dims)
ONNXTensorElementDataType Type() const
const Span< T > & AsSpan()
OrtLiteCustomFunc(const char *op_name, const char *execution_provider, ComputeFn compute_fn, ShapeInferFn shape_infer_fn={}, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
std::string Shape2Str() const
Tensor(OrtKernelContext *ctx, size_t indice, bool is_input)
std::string_view AsScalar()
basic_string_view< char > string_view
GLuint GLsizei const GLuint const GLintptr * offsets
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
const strings & Data() const
std::unique_ptr< Custom::ArgBase > ArgPtr
std::vector< ArgPtr > ArgPtrs
void Init(const OrtKernelContext &kernel_ctx)
struct KernelContext ctx_
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
OrtKernelContext * GetOrtKernelContext() const
static std::enable_if< std::is_same< T, const TensorArray & >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
std::vector< TensorPtr > TensorPtrs
static std::enable_if< std::is_same< T, OrtKernelContext * >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
size_t SizeInBytes() const override
size_t GetInputCount() const
const std::string & AsScalar()
virtual size_t SizeInBytes() const =0
TT * Allocate(const std::vector< int64_t > &shape)
static std::enable_if< sizeof...(Ts)==0, std::tuple<> >::type CreateTuple(OrtKernelContext *, ArgPtrs &, size_t, size_t, const std::string &)
static std::enable_if< std::is_same< T, TensorArray * >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
static std::enable_if< std::is_same< T, const TensorArray * >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
GT_API const UT_StringHolder version
typename std::remove_reference< T >::type TT
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
ONNXTensorElementDataType type_
const std::vector< int64_t > & Shape() const
std::vector< ONNXTensorElementDataType > input_types_
size_t SizeInBytes() const override
void * compute_fn_return_status_
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
Tensor(OrtKernelContext *ctx, size_t indice, bool is_input)
ArgBase(OrtKernelContext *ctx, size_t indice, bool is_input)
#define CREATE_TUPLE_INPUT(data_type)
const void * DataRaw() const override
std::vector< std::string > strings
std::vector< std::string > strings
const std::string op_name_
std::optional< Custom::Tensor< float >> OptionalFloatTensor
std::vector< ONNXTensorElementDataType > output_types_
**If you just want to fire and args
void(*)(Args...) ComputeFn
static std::enable_if< std::is_same< T, OrtKernelContext & >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)
const TensorPtr & operator[](size_t ith_input) const
size_t SizeInBytes() const override
bool accumulate(const PointDataTreeT &points, const std::string &attribute, typename PromoteType< ValueT >::Highest &total, const FilterT &filter, typename PointDataTreeT::template ValueConverter< ResultTreeT >::Type *totalTree)
Evaluates the total value of a point attribute and returns whether the value is valid. Optionally constructs localised total value trees.
#define MAX_CUSTOM_OP_END_VER
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
OrtLiteCustomOp * CreateLiteCustomOp(const char *op_name, const char *execution_provider, void(*custom_compute_fn)(Args...), Status(*shape_infer_fn)(ShapeInferContext &)={}, int start_ver=1, int end_ver=MAX_CUSTOM_OP_END_VER)
std::optional< std::vector< int64_t > > shape_
void SetCompute(CustomComputeFn< Args...>)
const void * DataRaw() const override
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed...
Tensor(OrtKernelContext *ctx, size_t indice, bool is_input)
const Span< std::string_view > & AsSpan()
#define PARSE_ARGS(data_type, onnx_type)
void Assign(const T *data, size_t size)
#define ORT_CXX_API_THROW(string, code)
std::optional< const Custom::Tensor< float > & > ConstOptionalFloatTensor
Tensor< std::string > & AllocateStringTensor(size_t ith_output)
static std::enable_if< std::is_same< T, TensorArray & >::value, std::tuple< T, Ts...> >::type CreateTuple(OrtKernelContext *context, ArgPtrs &args, size_t num_input, size_t num_output, const std::string &ep)