48
48
// buffer type.
49
49
litert::Expected<std::vector<litert::TensorBuffer>> CreateGlInputBuffers (
50
50
LiteRtEnvironment env, litert::CompiledModel& compiled_model,
51
- litert::Signature& signature) {
52
- LiteRtSubgraph subgraph_handle = signature.Subgraph ();
53
- litert::Subgraph subgraph = litert::Subgraph (subgraph_handle);
54
-
51
+ litert::Model& model, int signature_index) {
52
+ auto signature = model.GetSignature (signature_index);
55
53
std::vector<litert::TensorBuffer> input_buffers;
56
- input_buffers.reserve (subgraph.Inputs ().size ());
57
- for (litert::Tensor& input_tensor : subgraph.Inputs ()) {
54
+ for (int i = 0 ; i < signature->InputNames ().size (); ++i) {
58
55
LITERT_ASSIGN_OR_RETURN (
59
56
litert::TensorBufferRequirements input_buffer_requirements,
60
- compiled_model.GetInputBufferRequirements (signature.Key (),
61
- input_tensor.Name ()));
57
+ compiled_model.GetInputBufferRequirements (signature_index, i));
62
58
LITERT_ASSIGN_OR_RETURN (litert::RankedTensorType ranked_tensor_type,
63
- input_tensor. RankedTensorType ( ));
59
+ model. GetInputTensorType (signature_index, i ));
64
60
LITERT_ASSIGN_OR_RETURN (size_t buffer_size,
65
61
input_buffer_requirements.BufferSize ());
66
62
LITERT_ASSIGN_OR_RETURN (auto input_buffer,
@@ -76,21 +72,19 @@ litert::Expected<std::vector<litert::TensorBuffer>> CreateGlInputBuffers(
76
72
// buffer type.
77
73
litert::Expected<std::vector<litert::TensorBuffer>> CreateGlOutputBuffers (
78
74
LiteRtEnvironment env, litert::CompiledModel& compiled_model,
79
- litert::Signature& signature) {
80
- LiteRtSubgraph subgraph_handle = signature.Subgraph ();
81
- litert::Subgraph subgraph = litert::Subgraph (subgraph_handle);
75
+ litert::Model& model, int signature_index) {
76
+ auto signature = model.GetSignature (signature_index);
82
77
83
78
std::vector<litert::TensorBuffer> output_buffers;
84
- output_buffers.reserve (subgraph. Outputs ().size ());
85
- for (litert::Tensor& output_tensor : subgraph. Outputs () ) {
79
+ output_buffers.reserve (signature-> OutputNames ().size ());
80
+ for (int i = 0 ; i < signature-> OutputNames (). size (); ++i ) {
86
81
LITERT_ASSIGN_OR_RETURN (
87
- litert::TensorBufferRequirements input_buffer_requirements,
88
- compiled_model.GetOutputBufferRequirements (signature.Key (),
89
- output_tensor.Name ()));
82
+ litert::TensorBufferRequirements output_buffer_requirements,
83
+ compiled_model.GetOutputBufferRequirements (signature_index, i));
90
84
LITERT_ASSIGN_OR_RETURN (litert::RankedTensorType ranked_tensor_type,
91
- output_tensor. RankedTensorType ( ));
85
+ model. GetOutputTensorType (signature_index, i ));
92
86
LITERT_ASSIGN_OR_RETURN (size_t buffer_size,
93
- input_buffer_requirements .BufferSize ());
87
+ output_buffer_requirements .BufferSize ());
94
88
LITERT_ASSIGN_OR_RETURN (auto output_buffer,
95
89
litert::TensorBuffer::CreateManaged (
96
90
env, kLiteRtTensorBufferTypeGlBuffer ,
@@ -157,20 +151,18 @@ bool SegmentationModel::InitializeModel(const std::string& model_path,
157
151
}
158
152
}
159
153
160
- env_ = std::make_unique<litert::Environment>(std::move (env));
161
-
162
154
LITERT_ASSIGN_OR_ABORT (auto signatures, model_.GetSignatures ());
163
155
164
156
size_t signature_index = 0 ;
165
157
166
158
if (use_gl_buffers_) {
167
- LITERT_ASSIGN_OR_ABORT (input_buffers_,
168
- CreateGlInputBuffers (env.Get (), compiled_model_,
169
- signatures[ signature_index] ));
159
+ LITERT_ASSIGN_OR_ABORT (
160
+ input_buffers_, CreateGlInputBuffers (env.Get (), compiled_model_, model_ ,
161
+ signature_index));
170
162
171
163
LITERT_ASSIGN_OR_ABORT (output_buffers_,
172
164
CreateGlOutputBuffers (env.Get (), compiled_model_,
173
- signatures[ signature_index] ));
165
+ model_, signature_index));
174
166
175
167
} else {
176
168
LITERT_ASSIGN_OR_ABORT (input_buffers_,
@@ -179,6 +171,9 @@ bool SegmentationModel::InitializeModel(const std::string& model_path,
179
171
LITERT_ASSIGN_OR_ABORT (
180
172
output_buffers_, compiled_model_.CreateOutputBuffers (signature_index));
181
173
}
174
+
175
+ env_ = std::make_unique<litert::Environment>(std::move (env));
176
+
182
177
std::cout << " SegmentationModel: Model initialized." << std::endl;
183
178
std::cout << " SegmentationModel: warming up model..." << std::endl;
184
179
auto start_time = absl::Now ();
0 commit comments