Skip to content
135 changes: 129 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
#define LLAMA_CURL_MAX_HEADER_LENGTH 256
#define LLAMA_PROGRESS_UPDATE_INTERVAL 1
#define LLAMA_PROGRESS_PERCENTAGE_WIDTH 10
#define LLAMA_DEFAULT_CONSOLE_WIDTH 80
#endif // LLAMA_USE_CURL

using json = nlohmann::ordered_json;
Expand Down Expand Up @@ -1866,7 +1869,87 @@ void llama_batch_add(

#ifdef LLAMA_USE_CURL

static bool llama_download_file(CURL * curl, const char * url, const char * path) {
struct shard_file_progress {
std::string filename;
double total_bytes;
double received_bytes;
};

std::map<std::string, shard_file_progress> progress_table;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be done without global variables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it something easily possible you think @TevinWang ?

std::mutex progress_mutex;
std::stringstream download_done_buffer;

static int shard_progress_callback(void* clientp, double dltotal, double dlnow, double ultotal, double ulnow) {
// upload not needed for downloading
(void) ultotal;
(void) ulnow;
char* url = static_cast<char*>(clientp);

std::lock_guard<std::mutex> lock(progress_mutex);

shard_file_progress& progress = progress_table[url];
progress.total_bytes = static_cast<double>(dltotal);
progress.received_bytes = static_cast<double>(dlnow);

std::string url_string = static_cast<std::string>(url);
progress.filename = url_string.substr(url_string.find_last_of('/') + 1);

return 0;
}

// function to get the console width
static int get_console_width() {
#ifdef _WIN32
CONSOLE_SCREEN_BUFFER_INFO csbi;
GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
return csbi.dwSize.X;
#elif defined(__linux__) || defined(__APPLE__)
struct winsize ws;
ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws);
return ws.ws_col;
#else
return LLAMA_DEFAULT_CONSOLE_WIDTH; // Default value
#endif
}

static void print_shard_progress_table(bool first_progress) {
if (first_progress) {
fprintf(stderr, "=========================\n");
} else {
// use updating output
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (unsigned int i = 0; i < progress_table.size(); i++) {
fprintf(stderr, "\033[1A\033[K\033[1A\033[K");
}
fprintf(stderr, "\r");
}
}


int progress_bar_width = get_console_width() - LLAMA_PROGRESS_PERCENTAGE_WIDTH;

// Print the progress information for each downloading file
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (const auto& entry : progress_table) {
shard_file_progress progress = entry.second;
int progress_width = static_cast<int>((progress.received_bytes / progress.total_bytes) * progress_bar_width);

fprintf(stderr, "%s\n", progress.filename.c_str());
fprintf(stderr, "[");
for (int i = 0; i < progress_width; ++i) {
fprintf(stderr, "=");
}
for (int i = progress_width; i < progress_bar_width; ++i) {
fprintf(stderr, " ");
}
fprintf(stderr, "] %d%%\n", static_cast<int>((progress.received_bytes / progress.total_bytes) * 100));
}
}
}

static bool llama_download_file(CURL * curl, const char * url, const char * path, bool is_shard) {
bool force_download = false;

// Set the URL, allow to follow http redirection
Expand Down Expand Up @@ -2001,6 +2084,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path

// display download progress
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);

// custom progress callback on sharded download
if (is_shard) {
curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, shard_progress_callback);
curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, url);
}

// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
Expand Down Expand Up @@ -2046,7 +2135,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
if (etag_file) {
fputs(headers.etag, etag_file);
fclose(etag_file);
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
if (is_shard) {
download_done_buffer << __func__ << ": file etag saved " << etag_path << ": " << headers.etag << "\n";
} else {
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
}
}
}

Expand All @@ -2056,8 +2149,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
if (last_modified_file) {
fputs(headers.last_modified, last_modified_file);
fclose(last_modified_file);
fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path,
headers.last_modified);
if (is_shard) {
download_done_buffer << __func__ << ": unable to rename file: " << path_temporary << " to " << path << "\n";
} else {
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
}
}
}

Expand Down Expand Up @@ -2089,7 +2185,7 @@ struct llama_model * llama_load_model_from_url(
return NULL;
}

if (!llama_download_file(curl, model_url, path_model)) {
if (!llama_download_file(curl, model_url, path_model, false)) {
return NULL;
}

Expand Down Expand Up @@ -2148,12 +2244,39 @@ struct llama_model * llama_load_model_from_url(
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);

auto * curl = curl_easy_init();
bool res = llama_download_file(curl, split_url, split_path);
bool res = llama_download_file(curl, split_url, split_path, true);
curl_easy_cleanup(curl);

return res;
}, idx));
}

bool first_progress = true;
while (true) {
// Print the progress table periodically
std::this_thread::sleep_for(std::chrono::seconds(LLAMA_PROGRESS_UPDATE_INTERVAL));
// Print the progress table header
print_shard_progress_table(first_progress);
first_progress = false;

// Check if all downloads are complete
bool all_complete = true;
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (const auto& entry : progress_table) {
const shard_file_progress& progress = entry.second;
if (progress.received_bytes < progress.total_bytes) {
all_complete = false;
break;
}
}
}

if (all_complete) {
fprintf(stderr, "%s", download_done_buffer.str().c_str());
break;
}
}

// Wait for all downloads to complete
for (auto & f : futures_download) {
Expand Down