-
Notifications
You must be signed in to change notification settings - Fork 0
/
KMeans_CUDA.h
36 lines (27 loc) · 856 Bytes
/
KMeans_CUDA.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// KMeans in CUDA
#ifndef __KMEANS_CU_H__
#define __KMEANS_CU_H_
class KMeans_CUDA {
public:
KMeans_CUDA(float *data, int n, int d, int k);
~KMeans_CUDA();
void one_epoch();
void print_centroids();
void print_predictions();
float compute_error();
private:
// Data
float *h_data; // Size n*d
float *d_data; // Pointer to data on GPU. Size n*d
// Learned centroids
float *h_centroids; // Pointer to centroids on heap. Size k*d
float *d_centroids; // Pointer to centroids on GPU. Size k*d
// Count and Sum
float *d_sum; // Size k*d
int *d_count; // Size k
// Dataset
int n; // Number of data elements
int d; // Number of dimensions
int k; // Number of clusters
};
#endif // __KMEANS_CU_H__