Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 162 additions & 8 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
#include "ggml-common.h"

#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>

#if defined(GGML_USE_HIP)
Expand Down Expand Up @@ -972,6 +974,154 @@ struct ggml_cuda_graph {
#endif
};

struct ggml_cuda_concurrent_event {
std::vector<cudaEvent_t> join_events;
cudaEvent_t fork_event = nullptr;

int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;

const ggml_tensor * join_node;

ggml_cuda_concurrent_event() = default;

ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;

explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
join_events.resize(n_streams);

for (size_t i = 0; i < join_events.size(); ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
}

CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
}

ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
: join_events(std::move(other.join_events))
, fork_event(other.fork_event)
, n_streams(other.n_streams)
, stream_mapping(std::move(other.stream_mapping))
, join_node(other.join_node) {
other.fork_event = nullptr;
}

// 1. check if any branches write to overlapping memory ranges (except the join node)
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
// we assume all nodes have the same buffer
bool is_valid() const {
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
write_ranges.resize(n_streams);

// get join_node's memory range to exclude from overlap checking.
// multiple nodes can use join_node's buffer; we synchronize on the join node.
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
const int64_t join_start = (int64_t) join_t->data;
const int64_t join_end = join_start + ggml_nbytes(join_t);

for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);

// skip tensors that overlap with join_node's buffer.
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}

// concurrent streams begin from 1
write_ranges[stream - 1].emplace_back(t_start, t_end);
}

for (int i = 0; i < n_streams; ++i) {
// sorts first by start then by end of write range
std::sort(write_ranges[i].begin(), write_ranges[i].end());
}

bool writes_overlap = false;
bool dependent_srcs = false;
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);

// skip tensors that overlap with join_node's buffer
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}

// check if this buffer's write data overlaps with another stream's
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
for (int i = 0; i < n_streams; ++i) {
if (i == stream - 1) {
continue;
}
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);

if (it != write_ranges[i].end()) {
const std::pair<int64_t, int64_t> & other = *it;

// std::lower_bound returns the first element where other >= data_range (lexicographically).
// This guarantees other.first >= data_range.first.
// Therefore, overlap occurs iff other.first < data_range.second
// (i.e., the other range starts before this range ends).
if (other.first < data_range.second) {
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
writes_overlap = true;
break;
}
}
}

//check if all srcs are either in branch or don't have a branch
for (int i = 0; i < GGML_MAX_SRC; ++i) {
if (!tensor->src[i]) {
continue;
}

auto it = stream_mapping.find(tensor->src[i]);

if (it == stream_mapping.end()) {
continue;
}

if (it->second != stream) {
dependent_srcs = true;
break;
}
}

if (dependent_srcs || writes_overlap) {
break;
}
}

return !writes_overlap && !dependent_srcs;
}

~ggml_cuda_concurrent_event() {
if (fork_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(fork_event));
}
for (cudaEvent_t e : join_events) {
if (e != nullptr) {
CUDA_CHECK(cudaEventDestroy(e));
}
}
}
};

struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_nodes;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;

void reset() {
original_nodes.clear();
concurrent_events.clear();
}
};

struct ggml_backend_cuda_context {
int device;
std::string name;
Expand All @@ -982,11 +1132,15 @@ struct ggml_backend_cuda_context {

std::unique_ptr<ggml_cuda_graph> cuda_graph;

int curr_stream_no = 0;

explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}

ggml_cuda_stream_context concurrent_stream_context;

~ggml_backend_cuda_context();

cudaStream_t stream(int device, int stream) {
Expand All @@ -997,9 +1151,9 @@ struct ggml_backend_cuda_context {
return streams[device][stream];
}

cudaStream_t stream() {
return stream(device, 0);
}
cudaStream_t stream() { return stream(device, curr_stream_no); }

ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }

cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
Expand All @@ -1015,15 +1169,15 @@ struct ggml_backend_cuda_context {
}

// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];

static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);

ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
if (pools[device][curr_stream_no] == nullptr) {
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
return *pools[device];
return *pools[device][curr_stream_no];
}

ggml_cuda_pool & pool() {
Expand Down
Loading
Loading