Skip to content

Commit 207b13a

Browse files
committed
Merge branch 'main' into tie
2 parents 38b2c1b + 8eed730 commit 207b13a

File tree

13 files changed

+348
-348
lines changed

13 files changed

+348
-348
lines changed

.pipelines/android-publishing.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ extends:
2828
# For productions pipelines, use "Official".
2929
template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines
3030
parameters:
31+
sdl:
32+
policheck:
33+
enabled: true
34+
break: true # always break the build on policheck issues. You can disable it by setting to 'false'
35+
exclusionsFile: '$(Build.SourcesDirectory)\.pipelines\policheck_exclusions.xml'
36+
tsa:
37+
enabled: true
38+
configFile: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
3139
# Update the pool with your team's 1ES hosted pool.
3240
pool:
3341
name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool

.pipelines/macos-ios-cocoapods-publishing.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ extends:
2828
# For productions pipelines, use "Official".
2929
template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines
3030
parameters:
31+
sdl:
32+
policheck:
33+
enabled: true
34+
break: true # always break the build on policheck issues. You can disable it by setting to 'false'
35+
exclusionsFile: '$(Build.SourcesDirectory)\.pipelines\policheck_exclusions.xml'
36+
tsa:
37+
enabled: true
38+
configFile: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
3139
# Update the pool with your team's 1ES hosted pool.
3240
pool:
3341
name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool

.pipelines/nuget-publishing.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ extends:
125125
sourceRepositoriesToScan:
126126
include:
127127
- repository: manylinux
128+
policheck:
129+
enabled: true
130+
break: true # always break the build on policheck issues. You can disable it by setting to 'false'
131+
exclusionsFile: '$(Build.SourcesDirectory)\.pipelines\policheck_exclusions.xml'
132+
tsa:
133+
enabled: true
134+
configFile: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
128135
stages:
129136
- template: stages/capi-packaging-stage.yml
130137
parameters:

.pipelines/policheck_exclusions.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<PoliCheckExclusions>
3+
<Exclusion Type="FileName">THIRDPARTYNOTICES.TXT|BUILDER.PY|PROMPTS.JSON</Exclusion>
4+
</PoliCheckExclusions>

.pipelines/pypl-publishing.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ extends:
8585
sourceRepositoriesToScan:
8686
include:
8787
- repository: manylinux
88+
policheck:
89+
enabled: true
90+
break: true # always break the build on policheck issues. You can disable it by setting to 'false'
91+
exclusionsFile: '$(Build.SourcesDirectory)\.pipelines\policheck_exclusions.xml'
92+
tsa:
93+
enabled: true
94+
configFile: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
8895
stages:
8996
- template: stages/py-packaging-stage.yml
9097
parameters:

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
1414
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
1515
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
1616
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
17-
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;fc004859e82241e99d458a90d2a39d400050cc59
17+
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;cb00b43f05409d6f70cc558f52fcff0c7e386a97

src/config.cpp

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ struct ProviderOptionsObject_Element : JSON::Element {
5454
};
5555

5656
struct ProviderOptionsArray_Element : JSON::Element {
57-
explicit ProviderOptionsArray_Element(std::vector<Config::ProviderOptions>& v) : v_{v} {}
57+
explicit ProviderOptionsArray_Element(std::vector<Config::ProviderOptions>& v, std::vector<std::string>& providers)
58+
: v_{v}, providers_{providers} {}
5859

5960
JSON::Element& OnObject(std::string_view name) override { return object_; }
6061

@@ -68,11 +69,18 @@ struct ProviderOptionsArray_Element : JSON::Element {
6869
} else if (v.name == "dml") {
6970
v.name = "DML";
7071
}
72+
73+
if (std::find(providers_.begin(), providers_.end(), v.name) == providers_.end()) {
74+
// The providers array determines the which execution provider is picked for the session..
75+
// It also determines the order of the providers.
76+
providers_.push_back(v.name);
77+
}
7178
}
7279
}
7380

7481
private:
7582
std::vector<Config::ProviderOptions>& v_;
83+
std::vector<std::string>& providers_;
7684
ProviderOptionsObject_Element object_{v_};
7785
};
7886

@@ -143,7 +151,7 @@ struct SessionOptions_Element : JSON::Element {
143151

144152
private:
145153
Config::SessionOptions& v_;
146-
ProviderOptionsArray_Element provider_options_{v_.provider_options};
154+
ProviderOptionsArray_Element provider_options_{v_.provider_options, v_.providers};
147155
NamedStrings_Element config_entries_{v_.config_entries};
148156
};
149157

@@ -589,38 +597,6 @@ struct Embedding_Element : JSON::Element {
589597
EmbeddingOutputs_Element outputs_{v_.outputs};
590598
};
591599

592-
struct PromptTemplates_Element : JSON::Element {
593-
explicit PromptTemplates_Element(std::optional<Config::Model::PromptTemplates>& v) : v_{v} {}
594-
595-
void OnValue(std::string_view name, JSON::Value value) override {
596-
// if one of templates is given in json, then any non-specified template will be default "{Content}"
597-
if (name == "assistant") {
598-
EnsureAvailable();
599-
v_->assistant = JSON::Get<std::string_view>(value);
600-
} else if (name == "prompt") {
601-
EnsureAvailable();
602-
v_->prompt = JSON::Get<std::string_view>(value);
603-
} else if (name == "system") {
604-
EnsureAvailable();
605-
v_->system = JSON::Get<std::string_view>(value);
606-
} else if (name == "user") {
607-
EnsureAvailable();
608-
v_->user = JSON::Get<std::string_view>(value);
609-
} else {
610-
throw JSON::unknown_value_error{};
611-
}
612-
}
613-
614-
private:
615-
std::optional<Config::Model::PromptTemplates>& v_;
616-
617-
void EnsureAvailable() {
618-
if (!v_.has_value()) {
619-
v_.emplace();
620-
}
621-
}
622-
};
623-
624600
struct Model_Element : JSON::Element {
625601
explicit Model_Element(Config::Model& v) : v_{v} {}
626602

@@ -664,9 +640,6 @@ struct Model_Element : JSON::Element {
664640
if (name == "embedding") {
665641
return embedding_;
666642
}
667-
if (name == "prompt_templates") {
668-
return prompt_templates_;
669-
}
670643
if (name == "speech") {
671644
return speech_;
672645
}
@@ -680,7 +653,6 @@ struct Model_Element : JSON::Element {
680653
Eos_Array_Element eos_token_ids_{v_};
681654
Vision_Element vision_{v_.vision};
682655
Embedding_Element embedding_{v_.embedding};
683-
PromptTemplates_Element prompt_templates_{v_.prompt_templates};
684656
Speech_Element speech_{v_.speech};
685657
};
686658

@@ -747,7 +719,7 @@ void SetSearchBool(Config::Search& search, std::string_view name, bool value) {
747719
}
748720

749721
void ClearProviders(Config& config) {
750-
config.model.decoder.session_options.provider_options.clear();
722+
config.model.decoder.session_options.providers.clear();
751723
}
752724

753725
void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value) {
@@ -757,7 +729,7 @@ void SetProviderOption(Config& config, std::string_view provider_name, std::stri
757729
json << R"(")" << option_name << R"(":")" << option_value << R"(")";
758730
}
759731
json << R"(}})";
760-
ProviderOptionsArray_Element element{config.model.decoder.session_options.provider_options};
732+
ProviderOptionsArray_Element element{config.model.decoder.session_options.provider_options, config.model.decoder.session_options.providers};
761733
JSON::Parse(element, json.str());
762734
}
763735

src/config.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ struct Config {
2525
static constexpr std::string_view InputsEmbedsName = "inputs_embeds";
2626
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
2727
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
28-
static constexpr std::string_view promptTemplate = "{Content}";
2928
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
3029
static constexpr std::string_view TokenTypeIdsName = "token_type_ids";
3130

@@ -75,6 +74,7 @@ struct Config {
7574
std::vector<NamedString> config_entries; // Entries go into OrtSessionOptions::AddConfigEntry
7675

7776
std::vector<ProviderOptions> provider_options;
77+
std::vector<std::string> providers; // List of providers to use at runtime, not persisted in the json currently
7878
std::optional<GraphOptimizationLevel> graph_optimization_level;
7979
};
8080

@@ -206,13 +206,6 @@ struct Config {
206206

207207
} decoder;
208208

209-
struct PromptTemplates {
210-
std::string assistant{Defaults::promptTemplate};
211-
std::string prompt{Defaults::promptTemplate};
212-
std::string system{Defaults::promptTemplate};
213-
std::string user{Defaults::promptTemplate};
214-
};
215-
std::optional<PromptTemplates> prompt_templates;
216209
} model;
217210

218211
struct Search {

src/models/model.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,21 @@ int32_t Tokenizer::TokenToTokenId(const char* token) const {
274274
}
275275

276276
DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
277+
const std::vector<std::string>& providers,
277278
const std::vector<Config::ProviderOptions>& provider_options_list,
278279
bool is_primary_session_options,
279280
bool disable_graph_capture) {
280281
DeviceInterface* p_device{};
281282

282-
for (auto& provider_options : provider_options_list) {
283+
for (auto& provider : providers) {
284+
auto provider_options_it = std::find_if(provider_options_list.begin(), provider_options_list.end(),
285+
[&provider](const Config::ProviderOptions& po) { return po.name == provider; });
286+
287+
if (provider_options_it == provider_options_list.end()) {
288+
throw std::runtime_error("Provider options not found for provider: " + provider);
289+
}
290+
const auto& provider_options = *provider_options_it;
291+
283292
if (provider_options.name == "cuda") {
284293
auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();
285294
std::vector<const char*> keys, values;
@@ -299,7 +308,6 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
299308
}
300309

301310
session_options.AppendExecutionProvider_CUDA_V2(*ort_provider_options);
302-
303311
} else if (provider_options.name == "rocm") {
304312
OrtROCMProviderOptions ort_provider_options;
305313

@@ -416,7 +424,8 @@ void EnsureDeviceOrtInit(DeviceInterface& device) {
416424
auto session_options = OrtSessionOptions::Create();
417425
std::vector<Config::ProviderOptions> provider_options_list;
418426
provider_options_list.emplace_back(Config::ProviderOptions{device_type_names[static_cast<int>(type)], {}});
419-
SetProviderSessionOptions(*session_options, provider_options_list, true, false);
427+
const std::vector<std::string> providers{device_type_names[static_cast<int>(type)]};
428+
SetProviderSessionOptions(*session_options, providers, provider_options_list, true, false);
420429
session_options->SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); // Errors only here, as warnings are not useful to the user
421430

422431
allocator.session_ = OrtSession::Create(GetOrtEnv(), g_trivial_model, sizeof(g_trivial_model), session_options.get());
@@ -613,7 +622,9 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
613622
session_options.SetGraphOptimizationLevel(config_session_options.graph_optimization_level.value());
614623
}
615624

616-
p_device_ = SetProviderSessionOptions(session_options, config_session_options.provider_options, is_primary_session_options, disable_graph_capture);
625+
p_device_ = SetProviderSessionOptions(session_options, config_session_options.providers,
626+
config_session_options.provider_options, is_primary_session_options,
627+
disable_graph_capture);
617628

618629
// Fallback to CPU if no provider specific interface was set
619630
if (!p_device_)

0 commit comments

Comments
 (0)