Skip to content

Rerank graph cli #3337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
716cd2b
OV rerank calculator
michalkulakowski Jun 2, 2025
a666e2d
improvements
michalkulakowski Jun 4, 2025
d21112a
fix
michalkulakowski Jun 4, 2025
f49f1ca
update script
michalkulakowski Jun 4, 2025
88d561f
fix
michalkulakowski Jun 4, 2025
9cb0975
Export and cli for rerank
rasapala Jun 4, 2025
6459300
Merge branch 'mkulakow/ov_rerank_calculator' into rerank_graph_cli
rasapala Jun 4, 2025
dc6dad7
change required library
mzegla Jun 4, 2025
fa5e0f7
Remove param
rasapala Jun 4, 2025
73fca1c
change required library
mzegla Jun 4, 2025
d29b86f
Style
rasapala Jun 4, 2025
0ed4f8f
fix
michalkulakowski Jun 4, 2025
025105f
Merge branch 'mkulakow/ov_rerank_calculator' into rerank_graph_cli
rasapala Jun 5, 2025
b8c3d08
OV rerank calculator
michalkulakowski Jun 2, 2025
2bc89d9
improvements
michalkulakowski Jun 4, 2025
b801a69
fix
michalkulakowski Jun 4, 2025
02e2503
update script
michalkulakowski Jun 4, 2025
f35efeb
fix
michalkulakowski Jun 4, 2025
980abad
fix
michalkulakowski Jun 4, 2025
6f5ea53
fix
michalkulakowski Jun 5, 2025
12a7734
style
michalkulakowski Jun 5, 2025
9a7d9a7
fix
michalkulakowski Jun 5, 2025
7d6f071
fix
michalkulakowski Jun 5, 2025
c05d129
fix
michalkulakowski Jun 5, 2025
3b9bd2e
tmp
michalkulakowski Jun 5, 2025
6953c82
tmp
michalkulakowski Jun 5, 2025
aa6fc4d
Merge branch 'mkulakow/ov_rerank_calculator' into rerank_graph_cli
rasapala Jun 5, 2025
8eb13f6
OV rerank calculator
michalkulakowski Jun 2, 2025
1e12c27
improvements
michalkulakowski Jun 4, 2025
1476c7c
fix
michalkulakowski Jun 4, 2025
529c4c9
update script
michalkulakowski Jun 4, 2025
e51997f
fix
michalkulakowski Jun 4, 2025
ef400f6
fix
michalkulakowski Jun 4, 2025
1e00dc3
fix
michalkulakowski Jun 5, 2025
9b5809a
style
michalkulakowski Jun 5, 2025
b7839b7
fix
michalkulakowski Jun 5, 2025
74def21
fix
michalkulakowski Jun 5, 2025
0f56ab0
fix
michalkulakowski Jun 5, 2025
335e506
fix
michalkulakowski Jun 5, 2025
925b965
fix
michalkulakowski Jun 5, 2025
cede219
style
michalkulakowski Jun 5, 2025
d99eea5
fix
michalkulakowski Jun 5, 2025
8f791ba
Merge branch 'mkulakow/ov_rerank_calculator' into rerank_graph_cli
rasapala Jun 5, 2025
d77d65d
Merge branch 'main' into rerank_graph_cli
rasapala Jun 6, 2025
6eacd21
Merge fix
rasapala Jun 6, 2025
beb4b51
Fix compile
rasapala Jun 6, 2025
9813469
Merge branch 'main' into rerank_graph_cli
rasapala Jun 6, 2025
4dcef99
Fix compile
rasapala Jun 6, 2025
0b1ee9a
Remove log
rasapala Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/capi_frontend/server_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ struct EmbeddingsGraphSettingsImpl {
};

struct RerankGraphSettingsImpl {
std::string modelPath = "./";
std::string targetDevice = "CPU";
std::string modelName = "";
uint32_t numStreams = 1;
uint32_t maxDocLength = 16000; // FIXME: export_rerank_tokenizer python method - not supported currently?
uint32_t version = 1; // FIXME: export_rerank_tokenizer python method - not supported currently?
uint64_t maxAllowedChunks = 10000;
};

struct ImageGenerationGraphSettingsImpl {
Expand Down
7 changes: 0 additions & 7 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,6 @@ bool Config::validate() {
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl;
return false;
}

if (settings.targetDevice != "NPU") {
if (settings.pluginConfig.maxPromptLength.has_value()) {
std::cerr << "max_prompt_len is only supported for NPU target device";
return false;
}
}
}

if (this->serverSettings.hfSettings.task == EMBEDDINGS_GRAPH) {
Expand Down
95 changes: 21 additions & 74 deletions src/graph_export/graph_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,84 +129,35 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
return FileSystem::createFileOverwrite(fullPath, oss.str());
}

static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) {
rapidjson::Document subconfigJson;
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str());
if (parseResult.Code()) {
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code()));
return StatusCode::JSON_INVALID;
}
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) {
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type);
return StatusCode::JSON_INVALID;
}
return StatusCode::OK;
}

static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
std::ostringstream oss;
// clang-format off
oss << R"(
{
"model_config_list": [
{ "config":
{
"name": ")" << graphSettings.modelName << R"(_tokenizer_model",
"base_path": "tokenizer"
}
},
{ "config":
{
"name": ")" << graphSettings.modelName << R"(_rerank_model",
"base_path": "rerank",
"target_device": ")" << graphSettings.targetDevice << R"(",
"plugin_config": { "NUM_STREAMS": ")" << graphSettings.numStreams << R"(" }
}
}
]
})";
auto status = validateSubconfigSchema(oss.str(), "rerank");
if (!status.ok()){
return status;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
return FileSystem::createFileOverwrite(fullPath, oss.str());
}

static Status createRerankGraphTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
std::ostringstream oss;
// Windows path creation - graph parser needs forward slashes in paths
std::string graphOkPath = graphSettings.modelPath;
if (FileSystem::getOsSeparator() != "/") {
std::replace(graphOkPath.begin(), graphOkPath.end(), '\\', '/');
}
// clang-format off
oss << R"(
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node {
name: ")"
<< graphSettings.modelName << R"(",
calculator: "RerankCalculatorOV"
input_side_packet: "RERANK_NODE_RESOURCES:rerank_servable"
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node {
calculator: "OpenVINOModelServerSessionCalculator"
output_side_packet: "SESSION:tokenizer"
node_options: {
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
servable_name: ")"
<< graphSettings.modelName << R"(_tokenizer_model"
}
}
}
node {
calculator: "OpenVINOModelServerSessionCalculator"
output_side_packet: "SESSION:rerank"
node_options: {
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
servable_name: ")"
<< graphSettings.modelName << R"(_rerank_model"
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
models_path: ")"
<< graphOkPath << R"(",
max_allowed_chunks: )"
<< graphSettings.maxAllowedChunks << R"(,
target_device: ")" << graphSettings.targetDevice << R"(",
plugin_config: '{ "NUM_STREAMS": ")" << graphSettings.numStreams << R"("}',
}
}
}
node {
input_side_packet: "TOKENIZER_SESSION:tokenizer"
input_side_packet: "RERANK_SESSION:rerank"
calculator: "RerankCalculator"
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
})";
})";

#if (MEDIAPIPE_DISABLE == 0)
::mediapipe::CalculatorGraphConfig config;
Expand All @@ -218,11 +169,7 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
#endif
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
auto status = FileSystem::createFileOverwrite(fullPath, oss.str());
if (!status.ok())
return status;

return createRerankSubconfigTemplate(directoryPath, graphSettings);
return FileSystem::createFileOverwrite(fullPath, oss.str());
}

static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, const EmbeddingsGraphSettingsImpl& graphSettings) {
Expand Down
15 changes: 5 additions & 10 deletions src/graph_export/rerank_graph_cli_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,10 @@ void RerankGraphCLIParser::createOptions() {
"The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.",
cxxopts::value<uint32_t>()->default_value("1"),
"NUM_STREAMS")
("max_doc_length",
"Maximum length of input documents in tokens.",
cxxopts::value<uint32_t>()->default_value("16000"),
"MAX_DOC_LENGTH")
("model_version",
"Version of the model.",
cxxopts::value<uint32_t>()->default_value("1"),
"MODEL_VERSION");
("max_allowed_chunks",
"Maximum allowed chunks.",
cxxopts::value<uint64_t>()->default_value("10000"),
"MAX_ALLOWED_CHUNKS");
}

void RerankGraphCLIParser::printHelp() {
Expand Down Expand Up @@ -91,8 +87,7 @@ void RerankGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl& hf
}
} else {
rerankGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
rerankGraphSettings.maxDocLength = result->operator[]("max_doc_length").as<uint32_t>();
rerankGraphSettings.version = result->operator[]("model_version").as<std::uint32_t>();
rerankGraphSettings.maxAllowedChunks = result->operator[]("max_allowed_chunks").as<uint64_t>();
}

hfSettings.graphSettings = std::move(rerankGraphSettings);
Expand Down
118 changes: 40 additions & 78 deletions src/test/graph_export_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,76 +139,44 @@ const std::string expectedDefaultGraphContents = R"(
}
)";

const std::string expectedRerankJsonContents = R"(
{
"model_config_list": [
{ "config":
{
"name": "myModel_tokenizer_model",
"base_path": "tokenizer"
}
},
{ "config":
{
"name": "myModel_rerank_model",
"base_path": "rerank",
"target_device": "GPU",
"plugin_config": { "NUM_STREAMS": "2" }
}
}
]
}
)";

const std::string expectedEmbeddingsJsonContents = R"(
{
"model_config_list": [
{ "config":
{
"name": "myModel_tokenizer_model",
"base_path": "tokenizer"
}
},
{ "config":
{
"name": "myModel_embeddings_model",
"base_path": "embeddings",
"target_device": "GPU",
"plugin_config": { "NUM_STREAMS": "2" }
}
}
]
}
)";

const std::string expectedRerankGraphContents = R"(
const std::string expectedRerankGraphContentsNonDefault = R"(
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node {
name: "myModel",
calculator: "RerankCalculatorOV"
input_side_packet: "RERANK_NODE_RESOURCES:rerank_servable"
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node {
calculator: "OpenVINOModelServerSessionCalculator"
output_side_packet: "SESSION:tokenizer"
node_options: {
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
servable_name: "myModel_tokenizer_model"
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
models_path: "/some/path",
max_allowed_chunks: 18,
target_device: "GPU",
plugin_config: '{ "NUM_STREAMS": "2"}',
}
}
}
node {
calculator: "OpenVINOModelServerSessionCalculator"
output_side_packet: "SESSION:rerank"
}
)";

const std::string expectedRerankGraphContentsDefault = R"(
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node {
name: "",
calculator: "RerankCalculatorOV"
input_side_packet: "RERANK_NODE_RESOURCES:rerank_servable"
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: {
servable_name: "myModel_rerank_model"
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
models_path: "./",
max_allowed_chunks: 10000,
target_device: "CPU",
plugin_config: '{ "NUM_STREAMS": "1"}',
}
}
}
node {
input_side_packet: "TOKENIZER_SESSION:tokenizer"
input_side_packet: "RERANK_SESSION:rerank"
calculator: "RerankCalculator"
input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
}
}
)";

const std::string expectedEmbeddingsGraphContents = R"(
Expand Down Expand Up @@ -317,15 +285,15 @@ TEST_F(GraphCreationTest, positiveDefault) {
ASSERT_EQ(expectedDefaultGraphContents, graphContents) << graphContents;
}

TEST_F(GraphCreationTest, rerankPositiveDefault) {
TEST_F(GraphCreationTest, rerankPositiveNonDefault) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::RERANK_GRAPH;
ovms::RerankGraphSettingsImpl rerankGraphSettings;
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel";
rerankGraphSettings.modelPath = "/some/path";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
rerankGraphSettings.maxAllowedChunks = 18;
hfSettings.graphSettings = std::move(rerankGraphSettings);

std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
Expand All @@ -335,27 +303,23 @@ TEST_F(GraphCreationTest, rerankPositiveDefault) {
ASSERT_EQ(status, ovms::StatusCode::OK);

std::string graphContents = GetFileContents(graphPath);
ASSERT_EQ(expectedRerankGraphContents, graphContents) << graphContents;

std::string jsonContents = GetFileContents(subconfigPath);
ASSERT_EQ(expectedRerankJsonContents, jsonContents) << jsonContents;
ASSERT_EQ(expectedRerankGraphContentsNonDefault, graphContents) << graphContents;
}

TEST_F(GraphCreationTest, rerankCreatedJsonInvalid) {
TEST_F(GraphCreationTest, rerankPositiveDefault) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::RERANK_GRAPH;
ovms::RerankGraphSettingsImpl rerankGraphSettings;
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel\t";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
hfSettings.graphSettings = std::move(rerankGraphSettings);

std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
ASSERT_EQ(status, ovms::StatusCode::OK);

std::string graphContents = GetFileContents(graphPath);
ASSERT_EQ(expectedRerankGraphContentsDefault, graphContents) << graphContents;
}

TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
Expand All @@ -365,8 +329,6 @@ TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel\"";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
hfSettings.graphSettings = std::move(rerankGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
Expand Down
Loading