Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for OpenBLAS #240

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ rungnu:
runompgnu:
$(CC) -Ofast -fopenmp -std=gnu11 run.c -lm -o run

.PHONY: runblas
runblas: run.c
$(CC) -DOPENBLAS -march=native -Ofast -o run run.c -lm -lpthread -lopenblas

.PHONY: clean
clean:
rm -f run
17 changes: 17 additions & 0 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,22 @@ void softmax(float* x, int size) {
}
}


#ifdef OPENBLAS
#include "cblas.h"

void matmul(float* xout, float* x, float* w, int n, int d) {
// Use the cblas_sgemv function from the BLAS library
// cblas_sgemv computes y = alpha*A*x + beta*y
// In our case, A is the matrix w, x is the vector x, and y is the output vector xout

float alpha = 1.0f; // Scalar to multiply with A*x
float beta = 0.0f; // Scalar to multiply with y (we want the initial value of y to have no effect)

cblas_sgemv(CblasRowMajor, CblasNoTrans, d, n, alpha, w, n, x, 1, beta, xout, 1);
}
#else

void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
Expand All @@ -212,6 +228,7 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
xout[i] = val;
}
}
#endif

void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {

Expand Down