@@ -931,14 +931,12 @@ __attribute__((noinline)) void dlacpy(char *uplo_p, int *M_p, int *N_p, double *
931931
932932__attribute__ ((noinline)) cublasStatus_t
933933cublasDlascl(cublasHandle_t *handle, cublasOperation_t type, int KL, int KU,
934- double cfrom, double cto, int M, int N, double *A, int lda, int info) {
935- calls.push_back ((BlasCall){ABIType::CUBLAS,handle,
936- inDerivative, CallType::LASCL,
937- A, UNUSED_POINTER, UNUSED_POINTER,
938- cfrom, cto,
939- CUBLAS_LAYOUT,
940- (char )type, UNUSED_TRANS,
941- M, N, UNUSED_INT, lda, KL, KU});
934+ double *cfrom, double *cto, int M, int N, double *A, int lda,
935+ int info) {
936+ calls.push_back ((BlasCall){ABIType::CUBLAS, handle, inDerivative,
937+ CallType::LASCL, A, UNUSED_POINTER, UNUSED_POINTER,
938+ *cfrom, *cto, CUBLAS_LAYOUT, (char )type,
939+ UNUSED_TRANS, M, N, UNUSED_INT, lda, KL, KU});
942940 return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
943941}
944942__attribute__ ((noinline)) cublasStatus_t cublasDlacpy(cublasHandle_t *handle, char uplo, int M,
@@ -1054,47 +1052,57 @@ __attribute__((noinline)) cublasStatus_t cublasDaxpy(cublasHandle_t *handle,
10541052}
10551053__attribute__ ((noinline)) cublasStatus_t
10561054cublasDgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
1057- double alpha, double *A, int lda, double *X, int incx, double beta,
1058- double *Y, int incy) {
1059- BlasCall call = {ABIType::CUBLAS,handle,
1060- inDerivative, CallType::GEMV, Y, A, X, alpha, beta, CUBLAS_LAYOUT,
1061- (char )trans, UNUSED_TRANS, M, N, UNUSED_INT, lda, incx, incy};
1055+ double *alpha, double *A, int lda, double *X, int incx,
1056+ double *beta, double *Y, int incy) {
1057+ BlasCall call = {ABIType::CUBLAS,
1058+ handle,
1059+ inDerivative,
1060+ CallType::GEMV,
1061+ Y,
1062+ A,
1063+ X,
1064+ *alpha,
1065+ *beta,
1066+ CUBLAS_LAYOUT,
1067+ (char )trans,
1068+ UNUSED_TRANS,
1069+ M,
1070+ N,
1071+ UNUSED_INT,
1072+ lda,
1073+ incx,
1074+ incy};
10621075 calls.push_back (call);
10631076 return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10641077}
10651078__attribute__ ((noinline)) cublasStatus_t
10661079cublasDgemm(cublasHandle_t *handle, cublasOperation_t transA,
1067- cublasOperation_t transB, int M, int N, int K, double alpha,
1068- double *A, int lda, double *B, int ldb, double beta, double *C,
1069- int ldc) {
1070- calls.push_back ((BlasCall){ABIType::CUBLAS,handle,inDerivative, CallType::GEMM, C, A, B, alpha,
1071- beta,
1072- CUBLAS_LAYOUT,
1073- (char )transA, (char )transB, M, N, K, lda,
1074- ldb, ldc});
1080+ cublasOperation_t transB, int M, int N, int K, double *alpha,
1081+ double *A, int lda, double *B, int ldb, double *beta, double *C,
1082+ int ldc) {
1083+ calls.push_back ((BlasCall){ABIType::CUBLAS, handle, inDerivative,
1084+ CallType::GEMM, C, A, B, *alpha, *beta,
1085+ CUBLAS_LAYOUT, (char )transA, (char )transB, M, N, K,
1086+ lda, ldb, ldc});
10751087 return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10761088}
10771089__attribute__ ((noinline)) cublasStatus_t
1078- cublasDscal(cublasHandle_t *handle, int N, double alpha, double *X, int incX) {
1090+ cublasDscal(cublasHandle_t *handle, int N, double * alpha, double *X, int incX) {
10791091 calls.push_back ((BlasCall){
1080- ABIType::CUBLAS,handle,inDerivative, CallType::SCAL, X, UNUSED_POINTER, UNUSED_POINTER, alpha,
1081- UNUSED_DOUBLE,
1082- CUBLAS_LAYOUT,
1083- UNUSED_TRANS, UNUSED_TRANS, N, UNUSED_INT,
1084- UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
1092+ ABIType::CUBLAS, handle, inDerivative, CallType::SCAL, X, UNUSED_POINTER,
1093+ UNUSED_POINTER, *alpha, UNUSED_DOUBLE, CUBLAS_LAYOUT, UNUSED_TRANS,
1094+ UNUSED_TRANS, N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
10851095 return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10861096}
10871097
10881098// A = alpha * X * transpose(Y) + A
10891099__attribute__ ((noinline)) cublasStatus_t
1090- cublasDger(cublasHandle_t *handle, int M, int N, double alpha, double *X,
1091- int incX, double *Y, int incY, double *A, int lda) {
1092- calls.push_back ((BlasCall){ABIType::CUBLAS,handle,inDerivative, CallType::GER, A, X, Y, alpha,
1093- UNUSED_DOUBLE,
1094- CUBLAS_LAYOUT,
1095- UNUSED_TRANS,
1096- UNUSED_TRANS, M, N, UNUSED_INT, incX, incY,
1097- lda});
1100+ cublasDger(cublasHandle_t *handle, int M, int N, double *alpha, double *X,
1101+ int incX, double *Y, int incY, double *A, int lda) {
1102+ calls.push_back ((BlasCall){ABIType::CUBLAS, handle, inDerivative,
1103+ CallType::GER, A, X, Y, *alpha, UNUSED_DOUBLE,
1104+ CUBLAS_LAYOUT, UNUSED_TRANS, UNUSED_TRANS, M, N,
1105+ UNUSED_INT, incX, incY, lda});
10981106 return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
10991107}
11001108
0 commit comments