Skip to content

Commit 94a99c6

Browse files
authored
Rerank graph cli (#3337)
* OV rerank calculator
1 parent 09cc27a commit 94a99c6

File tree

6 files changed

+83
-208
lines changed

6 files changed

+83
-208
lines changed

src/capi_frontend/server_settings.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ struct EmbeddingsGraphSettingsImpl {
6464
};
6565

6666
struct RerankGraphSettingsImpl {
67+
std::string modelPath = "./";
6768
std::string targetDevice = "CPU";
6869
std::string modelName = "";
6970
uint32_t numStreams = 1;
70-
uint32_t maxDocLength = 16000; // FIXME: export_rerank_tokenizer python method - not supported currently?
71-
uint32_t version = 1; // FIXME: export_rerank_tokenizer python method - not supported currently?
71+
uint64_t maxAllowedChunks = 10000;
7272
};
7373

7474
struct ImageGenerationGraphSettingsImpl {

src/config.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,6 @@ bool Config::validate() {
159159
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl;
160160
return false;
161161
}
162-
163-
if (settings.targetDevice != "NPU") {
164-
if (settings.pluginConfig.maxPromptLength.has_value()) {
165-
std::cerr << "max_prompt_len is only supported for NPU target device";
166-
return false;
167-
}
168-
}
169162
}
170163

171164
if (this->serverSettings.hfSettings.task == EMBEDDINGS_GRAPH) {

src/graph_export/graph_export.cpp

Lines changed: 21 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -129,84 +129,35 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
129129
return FileSystem::createFileOverwrite(fullPath, oss.str());
130130
}
131131

132-
static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) {
133-
rapidjson::Document subconfigJson;
134-
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str());
135-
if (parseResult.Code()) {
136-
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code()));
137-
return StatusCode::JSON_INVALID;
138-
}
139-
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) {
140-
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type);
141-
return StatusCode::JSON_INVALID;
142-
}
143-
return StatusCode::OK;
144-
}
145-
146-
static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
147-
std::ostringstream oss;
148-
// clang-format off
149-
oss << R"(
150-
{
151-
"model_config_list": [
152-
{ "config":
153-
{
154-
"name": ")" << graphSettings.modelName << R"(_tokenizer_model",
155-
"base_path": "tokenizer"
156-
}
157-
},
158-
{ "config":
159-
{
160-
"name": ")" << graphSettings.modelName << R"(_rerank_model",
161-
"base_path": "rerank",
162-
"target_device": ")" << graphSettings.targetDevice << R"(",
163-
"plugin_config": { "NUM_STREAMS": ")" << graphSettings.numStreams << R"(" }
164-
}
165-
}
166-
]
167-
})";
168-
auto status = validateSubconfigSchema(oss.str(), "rerank");
169-
if (!status.ok()){
170-
return status;
171-
}
172-
// clang-format on
173-
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
174-
return FileSystem::createFileOverwrite(fullPath, oss.str());
175-
}
176-
177132
static Status createRerankGraphTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
178133
std::ostringstream oss;
134+
// Windows path creation - graph parser needs forward slashes in paths
135+
std::string graphOkPath = graphSettings.modelPath;
136+
if (FileSystem::getOsSeparator() != "/") {
137+
std::replace(graphOkPath.begin(), graphOkPath.end(), '\\', '/');
138+
}
179139
// clang-format off
180140
oss << R"(
141+
input_stream: "REQUEST_PAYLOAD:input"
142+
output_stream: "RESPONSE_PAYLOAD:output"
143+
node {
144+
name: ")"
145+
<< graphSettings.modelName << R"(",
146+
calculator: "RerankCalculatorOV"
147+
input_side_packet: "RERANK_NODE_RESOURCES:rerank_servable"
181148
input_stream: "REQUEST_PAYLOAD:input"
182149
output_stream: "RESPONSE_PAYLOAD:output"
183-
node {
184-
calculator: "OpenVINOModelServerSessionCalculator"
185-
output_side_packet: "SESSION:tokenizer"
186150
node_options: {
187-
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
188-
servable_name: ")"
189-
<< graphSettings.modelName << R"(_tokenizer_model"
190-
}
191-
}
192-
}
193-
node {
194-
calculator: "OpenVINOModelServerSessionCalculator"
195-
output_side_packet: "SESSION:rerank"
196-
node_options: {
197-
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
198-
servable_name: ")"
199-
<< graphSettings.modelName << R"(_rerank_model"
151+
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
152+
models_path: ")"
153+
<< graphOkPath << R"(",
154+
max_allowed_chunks: )"
155+
<< graphSettings.maxAllowedChunks << R"(,
156+
target_device: ")" << graphSettings.targetDevice << R"(",
157+
plugin_config: '{ "NUM_STREAMS": ")" << graphSettings.numStreams << R"("}',
200158
}
201159
}
202-
}
203-
node {
204-
input_side_packet: "TOKENIZER_SESSION:tokenizer"
205-
input_side_packet: "RERANK_SESSION:rerank"
206-
calculator: "RerankCalculator"
207-
input_stream: "REQUEST_PAYLOAD:input"
208-
output_stream: "RESPONSE_PAYLOAD:output"
209-
})";
160+
})";
210161

211162
#if (MEDIAPIPE_DISABLE == 0)
212163
::mediapipe::CalculatorGraphConfig config;
@@ -218,11 +169,7 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
218169
#endif
219170
// clang-format on
220171
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
221-
auto status = FileSystem::createFileOverwrite(fullPath, oss.str());
222-
if (!status.ok())
223-
return status;
224-
225-
return createRerankSubconfigTemplate(directoryPath, graphSettings);
172+
return FileSystem::createFileOverwrite(fullPath, oss.str());
226173
}
227174

228175
static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, const EmbeddingsGraphSettingsImpl& graphSettings) {

src/graph_export/rerank_graph_cli_parser.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,10 @@ void RerankGraphCLIParser::createOptions() {
4444
"The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.",
4545
cxxopts::value<uint32_t>()->default_value("1"),
4646
"NUM_STREAMS")
47-
("max_doc_length",
48-
"Maximum length of input documents in tokens.",
49-
cxxopts::value<uint32_t>()->default_value("16000"),
50-
"MAX_DOC_LENGTH")
51-
("model_version",
52-
"Version of the model.",
53-
cxxopts::value<uint32_t>()->default_value("1"),
54-
"MODEL_VERSION");
47+
("max_allowed_chunks",
48+
"Maximum allowed chunks.",
49+
cxxopts::value<uint64_t>()->default_value("10000"),
50+
"MAX_ALLOWED_CHUNKS");
5551
}
5652

5753
void RerankGraphCLIParser::printHelp() {
@@ -91,8 +87,7 @@ void RerankGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl& hf
9187
}
9288
} else {
9389
rerankGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
94-
rerankGraphSettings.maxDocLength = result->operator[]("max_doc_length").as<uint32_t>();
95-
rerankGraphSettings.version = result->operator[]("model_version").as<std::uint32_t>();
90+
rerankGraphSettings.maxAllowedChunks = result->operator[]("max_allowed_chunks").as<uint64_t>();
9691
}
9792

9893
hfSettings.graphSettings = std::move(rerankGraphSettings);

src/test/graph_export_test.cpp

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -139,76 +139,44 @@ const std::string expectedDefaultGraphContents = R"(
139139
}
140140
)";
141141

142-
const std::string expectedRerankJsonContents = R"(
143-
{
144-
"model_config_list": [
145-
{ "config":
146-
{
147-
"name": "myModel_tokenizer_model",
148-
"base_path": "tokenizer"
149-
}
150-
},
151-
{ "config":
152-
{
153-
"name": "myModel_rerank_model",
154-
"base_path": "rerank",
155-
"target_device": "GPU",
156-
"plugin_config": { "NUM_STREAMS": "2" }
157-
}
158-
}
159-
]
160-
}
161-
)";
162-
163-
const std::string expectedEmbeddingsJsonContents = R"(
164-
{
165-
"model_config_list": [
166-
{ "config":
167-
{
168-
"name": "myModel_tokenizer_model",
169-
"base_path": "tokenizer"
170-
}
171-
},
172-
{ "config":
173-
{
174-
"name": "myModel_embeddings_model",
175-
"base_path": "embeddings",
176-
"target_device": "GPU",
177-
"plugin_config": { "NUM_STREAMS": "2" }
178-
}
179-
}
180-
]
181-
}
182-
)";
183-
184-
const std::string expectedRerankGraphContents = R"(
142+
const std::string expectedRerankGraphContentsNonDefault = R"(
143+
input_stream: "REQUEST_PAYLOAD:input"
144+
output_stream: "RESPONSE_PAYLOAD:output"
145+
node {
146+
name: "myModel",
147+
calculator: "RerankCalculatorOV"
148+
input_side_packet: "RERANK_NODE_RESOURCES:rerank_servable"
185149
input_stream: "REQUEST_PAYLOAD:input"
186150
output_stream: "RESPONSE_PAYLOAD:output"
187-
node {
188-
calculator: "OpenVINOModelServerSessionCalculator"
189-
output_side_packet: "SESSION:tokenizer"
190151
node_options: {
191-
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
192-
servable_name: "myModel_tokenizer_model"
152+
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
153+
models_path: "/some/path",
154+
max_allowed_chunks: 18,
155+
target_device: "GPU",
156+
plugin_config: '{ "NUM_STREAMS": "2"}',
193157
}
194158
}
195-
}
196-
node {
197-
calculator: "OpenVINOModelServerSessionCalculator"
198-
output_side_packet: "SESSION:rerank"
159+
}
160+
)";
161+
162+
const std::string expectedRerankGraphContentsDefault = R"(
163+
input_stream: "REQUEST_PAYLOAD:input"
164+
output_stream: "RESPONSE_PAYLOAD:output"
165+
node {
166+
name: "",
167+
calculator: "RerankCalculatorOV"
168+
input_side_packet: "RERANK_NODE_RESOURCES:rerank_servable"
169+
input_stream: "REQUEST_PAYLOAD:input"
170+
output_stream: "RESPONSE_PAYLOAD:output"
199171
node_options: {
200-
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
201-
servable_name: "myModel_rerank_model"
172+
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
173+
models_path: "./",
174+
max_allowed_chunks: 10000,
175+
target_device: "CPU",
176+
plugin_config: '{ "NUM_STREAMS": "1"}',
202177
}
203178
}
204-
}
205-
node {
206-
input_side_packet: "TOKENIZER_SESSION:tokenizer"
207-
input_side_packet: "RERANK_SESSION:rerank"
208-
calculator: "RerankCalculator"
209-
input_stream: "REQUEST_PAYLOAD:input"
210-
output_stream: "RESPONSE_PAYLOAD:output"
211-
}
179+
}
212180
)";
213181

214182
const std::string expectedEmbeddingsGraphContents = R"(
@@ -317,15 +285,15 @@ TEST_F(GraphCreationTest, positiveDefault) {
317285
ASSERT_EQ(expectedDefaultGraphContents, graphContents) << graphContents;
318286
}
319287

320-
TEST_F(GraphCreationTest, rerankPositiveDefault) {
288+
TEST_F(GraphCreationTest, rerankPositiveNonDefault) {
321289
ovms::HFSettingsImpl hfSettings;
322290
hfSettings.task = ovms::RERANK_GRAPH;
323291
ovms::RerankGraphSettingsImpl rerankGraphSettings;
324292
rerankGraphSettings.targetDevice = "GPU";
325293
rerankGraphSettings.modelName = "myModel";
294+
rerankGraphSettings.modelPath = "/some/path";
326295
rerankGraphSettings.numStreams = 2;
327-
rerankGraphSettings.maxDocLength = 18;
328-
rerankGraphSettings.version = 2;
296+
rerankGraphSettings.maxAllowedChunks = 18;
329297
hfSettings.graphSettings = std::move(rerankGraphSettings);
330298

331299
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
@@ -335,27 +303,23 @@ TEST_F(GraphCreationTest, rerankPositiveDefault) {
335303
ASSERT_EQ(status, ovms::StatusCode::OK);
336304

337305
std::string graphContents = GetFileContents(graphPath);
338-
ASSERT_EQ(expectedRerankGraphContents, graphContents) << graphContents;
339-
340-
std::string jsonContents = GetFileContents(subconfigPath);
341-
ASSERT_EQ(expectedRerankJsonContents, jsonContents) << jsonContents;
306+
ASSERT_EQ(expectedRerankGraphContentsNonDefault, graphContents) << graphContents;
342307
}
343308

344-
TEST_F(GraphCreationTest, rerankCreatedJsonInvalid) {
309+
TEST_F(GraphCreationTest, rerankPositiveDefault) {
345310
ovms::HFSettingsImpl hfSettings;
346311
hfSettings.task = ovms::RERANK_GRAPH;
347312
ovms::RerankGraphSettingsImpl rerankGraphSettings;
348-
rerankGraphSettings.targetDevice = "GPU";
349-
rerankGraphSettings.modelName = "myModel\t";
350-
rerankGraphSettings.numStreams = 2;
351-
rerankGraphSettings.maxDocLength = 18;
352-
rerankGraphSettings.version = 2;
353313
hfSettings.graphSettings = std::move(rerankGraphSettings);
314+
354315
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
355316
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
356317
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
357318
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
358-
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
319+
ASSERT_EQ(status, ovms::StatusCode::OK);
320+
321+
std::string graphContents = GetFileContents(graphPath);
322+
ASSERT_EQ(expectedRerankGraphContentsDefault, graphContents) << graphContents;
359323
}
360324

361325
TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
@@ -365,8 +329,6 @@ TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
365329
rerankGraphSettings.targetDevice = "GPU";
366330
rerankGraphSettings.modelName = "myModel\"";
367331
rerankGraphSettings.numStreams = 2;
368-
rerankGraphSettings.maxDocLength = 18;
369-
rerankGraphSettings.version = 2;
370332
hfSettings.graphSettings = std::move(rerankGraphSettings);
371333
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
372334
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";

0 commit comments

Comments
 (0)