Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli : auto activate conversation mode if chat template is available #11214

Merged
merged 6 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,15 +777,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-cnv", "--conversation"},
string_format(
"run in conversation mode:\n"
"- does not print special tokens and suffix/prefix\n"
"- interactive mode is also enabled\n"
"(default: %s)",
params.conversation ? "true" : "false"
),
"run in conversation mode:\n"
"- does not print special tokens and suffix/prefix\n"
"- interactive mode is also enabled\n"
"(default: auto enabled if chat template is available)",
[](common_params & params) {
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-no-cnv", "--no-conversation"},
"force disable conversation mode (default: false)",
[](common_params & params) {
params.conversation = true;
params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
Expand Down
9 changes: 8 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};

enum common_conversation_mode {
COMMON_CONVERSATION_MODE_DISABLED = 0,
COMMON_CONVERSATION_MODE_ENABLED = 1,
COMMON_CONVERSATION_MODE_AUTO = 2,
};

// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
Expand Down Expand Up @@ -275,7 +281,6 @@ struct common_params {
bool special = false; // enable special token output
bool interactive = false; // interactive mode
bool interactive_first = false; // wait for user input immediately
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

Expand All @@ -301,6 +306,8 @@ struct common_params {
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V

common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;

// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector // NOLINT
std::vector<std::string> image; // path to image file(s)
Expand Down
40 changes: 30 additions & 10 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant";

static llama_context ** g_ctx;
static llama_model ** g_model;
static common_sampler ** g_smpl;
Expand Down Expand Up @@ -204,8 +206,20 @@ int main(int argc, char ** argv) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
}

// auto enable conversation mode if chat template is available
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO && has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
}

// in case user force-activate conversation mode (via -cnv) without proper chat template, we show a warning
if (params.conversation_mode && !has_chat_template) {
LOG_WRN("%s: chat template is not available or is not supported. This may cause the model to output suboptimal responses\n", __func__);
}

// print chat template example in conversation mode
if (params.conversation) {
if (params.conversation_mode) {
if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
} else {
Expand Down Expand Up @@ -252,8 +266,10 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;

{
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
auto prompt = (params.conversation_mode && params.enable_chat_template)
// format the system prompt in conversation mode (fallback to default if empty)
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
// otherwise use the prompt as is
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG_DBG("tokenize the prompt\n");
Expand Down Expand Up @@ -327,7 +343,7 @@ int main(int argc, char ** argv) {
params.n_keep += add_bos; // always keep the BOS token
}

if (params.conversation) {
if (params.conversation_mode) {
params.interactive_first = true;
}

Expand Down Expand Up @@ -451,7 +467,11 @@ int main(int argc, char ** argv) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
#endif
LOG_INF( "%s\n", control_message);
LOG_INF( "%s", control_message);
if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) {
LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n");
}
LOG_INF("\n");

is_interacting = params.interactive_first;
}
Expand Down Expand Up @@ -763,15 +783,15 @@ int main(int argc, char ** argv) {
}

// if current token is not EOG, we add it to current assistant message
if (params.conversation) {
if (params.conversation_mode) {
const auto id = common_sampler_last(smpl);
assistant_ss << common_token_to_piece(ctx, id, false);
}

if (n_past > 0 && is_interacting) {
LOG_DBG("waiting for user input\n");

if (params.conversation) {
if (params.conversation_mode) {
LOG("\n> ");
}

Expand All @@ -781,7 +801,7 @@ int main(int argc, char ** argv) {
}

std::string buffer;
if (!params.input_prefix.empty() && !params.conversation) {
if (!params.input_prefix.empty() && !params.conversation_mode) {
LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
LOG("%s", params.input_prefix.c_str());
}
Expand All @@ -805,7 +825,7 @@ int main(int argc, char ** argv) {
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
// append input suffix if any
if (!params.input_suffix.empty() && !params.conversation) {
if (!params.input_suffix.empty() && !params.conversation_mode) {
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
LOG("%s", params.input_suffix.c_str());
}
Expand All @@ -818,7 +838,7 @@ int main(int argc, char ** argv) {
string_process_escapes(buffer);
}

bool format_chat = params.conversation && params.enable_chat_template;
bool format_chat = params.conversation_mode && params.enable_chat_template;
std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
: std::move(buffer);
Expand Down
Loading