HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ML_Model.h
Go to the documentation of this file.
1 /*
2  * PROPRIETARY INFORMATION. This software is proprietary to
3  * Side Effects Software Inc., and is not to be reproduced,
4  * transmitted, or disclosed in any way without written permission.
5  *
6  * COMMENTS: Wrapper for the ONNX inference engine
7  *
8  */
9 
10 #pragma once
11 
12 #include "ML_API.h"
13 #include <UT/UT_Array.h>
14 #include <UT/UT_NonCopyable.h>
15 #include <UT/UT_SharedPtr.h>
16 #include <UT/UT_StringHolder.h>
17 
18 class UT_WorkBuffer;
19 
21 {
22 public:
23  ML_Model();
24  ~ML_Model();
26 
27  class SessionInfo; // Our environment
29 
30  /// Initializer for the ML_Model class
31  /// \param model_filepath contains the path to the ONNX model
32  /// \param usecuda will trigger if we add the cuda provider or not.
33  /// \param errors is meant to hold any error strings that may be generated from
34  /// an error occurring during initialization
35  bool initializeModel(const UT_StringRef &model_filepath,
36  bool usecuda,
37  UT_WorkBuffer &errors,
38  UT_WorkBuffer &warnings);
39 
40  bool run(const UT_Array<UT_Array<float>> &inputs,
41  const UT_Array<Shape> &input_shapes,
42  UT_Array<UT_Array<float>> &outputs,
43  const UT_Array<Shape> &output_shapes,
44  UT_WorkBuffer &error_message);
45 
46  void getNames(UT_StringArray &input_names,
47  UT_StringArray &output_names) const;
48 
49  void getShapes(UT_Array<Shape> &input_shapes,
50  UT_Array<Shape> &output_shapes) const;
51 
52  /// Gets the product of all non-dynamic axes of a tensor shape.
53  /// Places a bool in a variable to determine if dynamic axes were found
54  /// Any dimensions being zero will return 0
55  /// \returns 1 if all axes are dynamic
56  static exint tensorElementsSize(const UT_Array<exint> &tensor_dimensions,
57  bool &has_dynamic_axes);
58 
59  /// This is a function for acquiring the shape of a tensor from the parameters.
60  /// \param tensor_shape the array to fill with the shape.
61  static bool mat3ToShape(Shape &tensor_shape, const UT_Matrix3D &shape_vector);
62 
63  /// Parses the output data for nodes
64  /// \param maxtuplesize -1 for unlimited
65  static bool parseOutputData(const UT_StringHolder &srcpattern, int maxtuplesize,
67 
68 private:
69  /// Places the information about the model into a UT_WorkBuffer
70  void info(UT_WorkBuffer &model_info) const;
71 
72  /// Places the string representing the shape for input "input_index" into a workbuffer
73  void inputShapeString(int input_index, UT_WorkBuffer &the_string) const;
74 
75  /// Places the string representing the shape for output "output_index" into a workbuffer
76  void outputShapeString(int output_index, UT_WorkBuffer &the_string) const;
77 
78  /// Checks to see if the size and shape of the inputs and outputs
79  /// are compatible with the model and should be able to run
80  bool sizeAndShapeErrorChecking(const UT_Array<UT_Array<float>> &inputs,
81  const UT_Array<Shape> &specified_input_shapes,
82  UT_Array<UT_Array<float>> &outputs,
83  const UT_Array<Shape> &specified_output_shapes,
84  UT_WorkBuffer &error_message);
85 
87 
88 };
GLuint GLsizei const GLuint const GLintptr const GLsizeiptr * sizes
Definition: glcorearb.h:2621
int64 exint
Definition: SYS_Types.h:125
< returns > If no error
Definition: snippets.dox:2
std::shared_ptr< T > UT_SharedPtr
Wrapper around std::shared_ptr.
Definition: UT_SharedPtr.h:36
#define ML_API
Definition: ML_API.h:10
#define UT_NON_COPYABLE(CLASS)
Define deleted copy constructor and assignment operator inside a class.