Skip to content

Commit 4149c07

Browse files
ai-edge-botcopybara-github
authored andcommitted
Added helper methods to retrieve input/output tensor types.
LiteRT-PiperOrigin-RevId: 760817334
1 parent f9675c2 commit 4149c07

File tree

2 files changed

+62
-25
lines changed

2 files changed

+62
-25
lines changed

litert/cc/litert_model.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,48 @@ class Model : public internal::Handle<LiteRtModel, LiteRtDestroyModel> {
469469
return key;
470470
}
471471

472+
// Returns the tensor type for the given n-th input tensor.
473+
Expected<RankedTensorType> GetInputTensorType(size_t signature_index,
474+
size_t input_index) const {
475+
auto subgraph = Subgraph(signature_index);
476+
return subgraph->Inputs()[input_index].RankedTensorType();
477+
}
478+
479+
// Returns the tensor type for the given input tensor name.
480+
Expected<RankedTensorType> GetInputTensorType(
481+
size_t signature_index, absl::string_view input_name) const {
482+
auto subgraph = Subgraph(signature_index);
483+
LITERT_ASSIGN_OR_RETURN(auto tensor, subgraph->Input(input_name));
484+
return tensor.RankedTensorType();
485+
}
486+
487+
// Get input tensor type of the default signature for input name.
488+
Expected<RankedTensorType> GetInputTensorType(
489+
absl::string_view input_name) const {
490+
return GetInputTensorType(/*signature_index=*/0, input_name);
491+
}
492+
493+
// Returns the tensor type for the given n-th output tensor.
494+
Expected<RankedTensorType> GetOutputTensorType(size_t signature_index,
495+
size_t output_index) const {
496+
auto subgraph = Subgraph(signature_index);
497+
return subgraph->Outputs()[output_index].RankedTensorType();
498+
}
499+
500+
// Returns the tensor type for the given output tensor name.
501+
Expected<RankedTensorType> GetOutputTensorType(
502+
size_t signature_index, absl::string_view output_name) const {
503+
auto subgraph = Subgraph(signature_index);
504+
LITERT_ASSIGN_OR_RETURN(auto tensor, subgraph->Output(output_name));
505+
return tensor.RankedTensorType();
506+
}
507+
508+
// Get output tensor type of the default signature for output name.
509+
Expected<RankedTensorType> GetOutputTensorType(
510+
absl::string_view output_name) const {
511+
return GetOutputTensorType(/*signature_index=*/0, output_name);
512+
}
513+
472514
private:
473515
// Parameter `owned` indicates if the created TensorBuffer object should take
474516
// ownership of the provided `tensor_buffer` handle.

litert/samples/async_segmentation/segmentation_model.cc

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,15 @@
4848
// buffer type.
4949
litert::Expected<std::vector<litert::TensorBuffer>> CreateGlInputBuffers(
5050
LiteRtEnvironment env, litert::CompiledModel& compiled_model,
51-
litert::Signature& signature) {
52-
LiteRtSubgraph subgraph_handle = signature.Subgraph();
53-
litert::Subgraph subgraph = litert::Subgraph(subgraph_handle);
54-
51+
litert::Model& model, int signature_index) {
52+
auto signature = model.GetSignature(signature_index);
5553
std::vector<litert::TensorBuffer> input_buffers;
56-
input_buffers.reserve(subgraph.Inputs().size());
57-
for (litert::Tensor& input_tensor : subgraph.Inputs()) {
54+
for (int i = 0; i < signature->InputNames().size(); ++i) {
5855
LITERT_ASSIGN_OR_RETURN(
5956
litert::TensorBufferRequirements input_buffer_requirements,
60-
compiled_model.GetInputBufferRequirements(signature.Key(),
61-
input_tensor.Name()));
57+
compiled_model.GetInputBufferRequirements(signature_index, i));
6258
LITERT_ASSIGN_OR_RETURN(litert::RankedTensorType ranked_tensor_type,
63-
input_tensor.RankedTensorType());
59+
model.GetInputTensorType(signature_index, i));
6460
LITERT_ASSIGN_OR_RETURN(size_t buffer_size,
6561
input_buffer_requirements.BufferSize());
6662
LITERT_ASSIGN_OR_RETURN(auto input_buffer,
@@ -76,21 +72,19 @@ litert::Expected<std::vector<litert::TensorBuffer>> CreateGlInputBuffers(
7672
// buffer type.
7773
litert::Expected<std::vector<litert::TensorBuffer>> CreateGlOutputBuffers(
7874
LiteRtEnvironment env, litert::CompiledModel& compiled_model,
79-
litert::Signature& signature) {
80-
LiteRtSubgraph subgraph_handle = signature.Subgraph();
81-
litert::Subgraph subgraph = litert::Subgraph(subgraph_handle);
75+
litert::Model& model, int signature_index) {
76+
auto signature = model.GetSignature(signature_index);
8277

8378
std::vector<litert::TensorBuffer> output_buffers;
84-
output_buffers.reserve(subgraph.Outputs().size());
85-
for (litert::Tensor& output_tensor : subgraph.Outputs()) {
79+
output_buffers.reserve(signature->OutputNames().size());
80+
for (int i = 0; i < signature->OutputNames().size(); ++i) {
8681
LITERT_ASSIGN_OR_RETURN(
87-
litert::TensorBufferRequirements input_buffer_requirements,
88-
compiled_model.GetOutputBufferRequirements(signature.Key(),
89-
output_tensor.Name()));
82+
litert::TensorBufferRequirements output_buffer_requirements,
83+
compiled_model.GetOutputBufferRequirements(signature_index, i));
9084
LITERT_ASSIGN_OR_RETURN(litert::RankedTensorType ranked_tensor_type,
91-
output_tensor.RankedTensorType());
85+
model.GetOutputTensorType(signature_index, i));
9286
LITERT_ASSIGN_OR_RETURN(size_t buffer_size,
93-
input_buffer_requirements.BufferSize());
87+
output_buffer_requirements.BufferSize());
9488
LITERT_ASSIGN_OR_RETURN(auto output_buffer,
9589
litert::TensorBuffer::CreateManaged(
9690
env, kLiteRtTensorBufferTypeGlBuffer,
@@ -157,20 +151,18 @@ bool SegmentationModel::InitializeModel(const std::string& model_path,
157151
}
158152
}
159153

160-
env_ = std::make_unique<litert::Environment>(std::move(env));
161-
162154
LITERT_ASSIGN_OR_ABORT(auto signatures, model_.GetSignatures());
163155

164156
size_t signature_index = 0;
165157

166158
if (use_gl_buffers_) {
167-
LITERT_ASSIGN_OR_ABORT(input_buffers_,
168-
CreateGlInputBuffers(env.Get(), compiled_model_,
169-
signatures[signature_index]));
159+
LITERT_ASSIGN_OR_ABORT(
160+
input_buffers_, CreateGlInputBuffers(env.Get(), compiled_model_, model_,
161+
signature_index));
170162

171163
LITERT_ASSIGN_OR_ABORT(output_buffers_,
172164
CreateGlOutputBuffers(env.Get(), compiled_model_,
173-
signatures[signature_index]));
165+
model_, signature_index));
174166

175167
} else {
176168
LITERT_ASSIGN_OR_ABORT(input_buffers_,
@@ -179,6 +171,9 @@ bool SegmentationModel::InitializeModel(const std::string& model_path,
179171
LITERT_ASSIGN_OR_ABORT(
180172
output_buffers_, compiled_model_.CreateOutputBuffers(signature_index));
181173
}
174+
175+
env_ = std::make_unique<litert::Environment>(std::move(env));
176+
182177
std::cout << "SegmentationModel: Model initialized." << std::endl;
183178
std::cout << "SegmentationModel: warming up model..." << std::endl;
184179
auto start_time = absl::Now();

0 commit comments

Comments
 (0)