-
Notifications
You must be signed in to change notification settings - Fork 43
[REVIEW] Replace function #106
base: master
Are you sure you want to change the base?
Changes from 14 commits
350dafc
8f805d9
5a8b93a
adf537d
3214b08
3935048
a4a7383
f5afc31
ef3f382
003bd9d
5553c67
0f3891d
26c522a
2a7a023
200da31
a2fa767
cc3beca
ef4de00
0b62cd9
5f2c338
35c765d
d3e50d3
7c12b1a
c01252b
321656e
7f80017
207fe0e
8767314
622abbd
d364fd8
95a8a21
99f6ebf
c5f27c1
b8e7ddd
fdb7afc
41b058c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/* | ||
* Copyright 2018 BlazingDB, Inc. | ||
* Copyright 2018 Cristhian Alberto Gonzales Castillo <[email protected]> | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cmath> | ||
|
||
#include <thrust/device_ptr.h> | ||
#include <thrust/device_vector.h> | ||
#include <thrust/execution_policy.h> | ||
#include <thrust/for_each.h> | ||
#include <thrust/replace.h> | ||
|
||
#include <gdf/gdf.h> | ||
|
||
namespace { | ||
|
||
template <gdf_dtype DTYPE> | ||
struct gdf_dtype_traits {}; | ||
|
||
#define DTYPE_FACTORY(DTYPE, T) \ | ||
template <> \ | ||
struct gdf_dtype_traits<GDF_##DTYPE> { \ | ||
typedef T value_type; \ | ||
} | ||
|
||
DTYPE_FACTORY(INT8, std::int8_t); | ||
DTYPE_FACTORY(INT16, std::int16_t); | ||
DTYPE_FACTORY(INT32, std::int32_t); | ||
DTYPE_FACTORY(INT64, std::int64_t); | ||
DTYPE_FACTORY(FLOAT32, float); | ||
DTYPE_FACTORY(FLOAT64, double); | ||
DTYPE_FACTORY(DATE32, std::int32_t); | ||
DTYPE_FACTORY(DATE64, std::int64_t); | ||
DTYPE_FACTORY(TIMESTAMP, std::int64_t); | ||
|
||
#undef DTYPE_FACTORY | ||
|
||
template <class T> | ||
__global__ void | ||
replace_kernel(T *const data, | ||
const std::size_t data_size, | ||
const T *const values, | ||
const thrust::device_ptr<const T> to_replace_begin, | ||
const thrust::device_ptr<const T> to_replace_end) { | ||
for (std::size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < data_size; | ||
i += blockDim.x * gridDim.x) { | ||
// TODO: find by map kernel | ||
const thrust::device_ptr<const T> found_ptr = thrust::find( | ||
thrust::device, to_replace_begin, to_replace_end, data[i]); | ||
|
||
if (found_ptr != to_replace_end) { | ||
typename thrust::iterator_traits< | ||
const thrust::device_ptr<const T>>::difference_type | ||
value_found_index = thrust::distance(to_replace_begin, found_ptr); | ||
|
||
data[i] = values[value_found_index]; | ||
} | ||
} | ||
} | ||
|
||
template <class T> | ||
static inline gdf_error | ||
Replace(T *const data, | ||
const std::size_t data_size, | ||
const T *const to_replace, | ||
const T *const values, | ||
const std::ptrdiff_t replacement_ptrdiff) { | ||
const std::size_t blocks = std::ceil(data_size / 256.); | ||
|
||
const thrust::device_ptr<const T> to_replace_begin(to_replace); | ||
const thrust::device_ptr<const T> to_replace_end(to_replace_begin | ||
+ replacement_ptrdiff); | ||
|
||
replace_kernel<T> | ||
<<<blocks, 256>>>( // TODO: calc blocks and threads | ||
data, | ||
data_size, | ||
values, | ||
to_replace_begin, | ||
to_replace_end); | ||
|
||
return GDF_SUCCESS; | ||
} | ||
|
||
static inline bool | ||
NotEqualReplacementSize(const gdf_column *to_replace, | ||
const gdf_column *values) { | ||
return to_replace->size != values->size; | ||
} | ||
|
||
static inline bool | ||
NotSameDType(const gdf_column *column, | ||
const gdf_column *to_replace, | ||
const gdf_column *values) { | ||
return column->dtype != to_replace->dtype | ||
|| to_replace->dtype != values->dtype; | ||
} | ||
|
||
} // namespace | ||
|
||
gdf_error | ||
gdf_replace(gdf_column * column, | ||
const gdf_column *to_replace, | ||
const gdf_column *values) { | ||
if (NotEqualReplacementSize(to_replace, values)) { | ||
return GDF_COLUMN_SIZE_MISMATCH; | ||
} | ||
|
||
if (NotSameDType(column, to_replace, values)) { return GDF_CUDA_ERROR; } | ||
|
||
switch (column->dtype) { | ||
#define WHEN(DTYPE) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really like calling this macro "WHEN" -- it's not descriptive. Something like
|
||
case GDF_##DTYPE: { \ | ||
using value_type = gdf_dtype_traits<GDF_##DTYPE>::value_type; \ | ||
return Replace(static_cast<value_type *>(column->data), \ | ||
static_cast<std::size_t>(column->size), \ | ||
static_cast<value_type *>(to_replace->data), \ | ||
static_cast<value_type *>(values->data), \ | ||
static_cast<std::ptrdiff_t>(values->size)); \ | ||
} | ||
|
||
WHEN(INT8); | ||
WHEN(INT16); | ||
WHEN(INT32); | ||
WHEN(INT64); | ||
WHEN(FLOAT32); | ||
WHEN(FLOAT64); | ||
WHEN(DATE32); | ||
WHEN(DATE64); | ||
WHEN(TIMESTAMP); | ||
|
||
#undef WHEN | ||
|
||
case GDF_invalid: | ||
default: return GDF_UNSUPPORTED_DTYPE; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#============================================================================= | ||
# Copyright 2018 BlazingDB, Inc. | ||
# Copyright 2018 Cristhian Alberto Gonzales Castillo <[email protected]> | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#============================================================================= | ||
|
||
configure_test(replace-test replace-test.cu) | ||
|
||
if (GDF_BENCHMARK) | ||
include(ExternalProject) | ||
|
||
ExternalProject_Add(benchmark_ep | ||
CMAKE_ARGS | ||
-DCMAKE_BUILD_TYPE=RELEASE | ||
-DCMAKE_INSTALL_PREFIX=build | ||
-DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON | ||
GIT_REPOSITORY https://github.com/google/benchmark.git | ||
GIT_TAG v1.4.1 | ||
UPDATE_COMMAND "" | ||
) | ||
ExternalProject_Get_property(benchmark_ep BINARY_DIR) | ||
set(BENCHMARK_ROOT ${BINARY_DIR}/build) | ||
|
||
file(MAKE_DIRECTORY ${BENCHMARK_ROOT}/include) | ||
file(MAKE_DIRECTORY ${BENCHMARK_ROOT}/lib) | ||
|
||
add_library(Google::Benchmark INTERFACE IMPORTED) | ||
add_dependencies(Google::Benchmark benchmark_ep) | ||
set_target_properties(Google::Benchmark | ||
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${BENCHMARK_ROOT}/include) | ||
set_target_properties(Google::Benchmark | ||
PROPERTIES INTERFACE_LINK_LIBRARIES ${BENCHMARK_ROOT}/lib/libbenchmark.a) | ||
|
||
add_library(Google::Benchmark::Main INTERFACE IMPORTED) | ||
set_target_properties(Google::Benchmark::Main | ||
PROPERTIES INTERFACE_LINK_LIBRARIES ${BENCHMARK_ROOT}/lib/libbenchmark_main.a) | ||
|
||
function(GDF_ADD_BENCHMARK TARGET) | ||
list(REMOVE_AT ARGV 0) | ||
cuda_add_executable(${TARGET} ${ARGV}) | ||
target_link_libraries(${TARGET} | ||
Google::Benchmark Google::Benchmark::Main gdf) | ||
endfunction() | ||
|
||
GDF_ADD_BENCHMARK(replace-benchmark replace-benchmark.cu) | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be moved to a separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And separated into benchmarks and unit tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
* Copyright 2018 BlazingDB, Inc. | ||
* Copyright 2018 Cristhian Alberto Gonzales Castillo <[email protected]> | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <benchmark/benchmark.h> | ||
|
||
#include <unordered_map> | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/sequence.h> | ||
|
||
#include <gdf/gdf.h> | ||
|
||
#include "utils.h" | ||
|
||
using T = std::int64_t; | ||
|
||
static void | ||
BM_CPU_LoopReplace(benchmark::State &state) { | ||
const std::size_t length = state.range(0); | ||
|
||
std::vector<T> vector(length); | ||
thrust::sequence(vector.begin(), vector.end(), 1); | ||
|
||
std::vector<T> to_replace_vector(10); | ||
thrust::sequence(to_replace_vector.begin(), to_replace_vector.end(), 1); | ||
|
||
std::vector<T> values_vector(10); | ||
thrust::sequence(values_vector.begin(), values_vector.end(), 1); | ||
|
||
for (auto _ : state) { | ||
for (std::size_t i = 0; i < vector.size(); i++) { | ||
auto current = std::find( | ||
to_replace_vector.begin(), to_replace_vector.end(), vector[i]); | ||
if (current != to_replace_vector.end()) { | ||
std::size_t j = | ||
std::distance(to_replace_vector.begin(), current); | ||
vector[i] = values_vector[j]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
static void | ||
BM_CPU_MapReplace(benchmark::State &state) { | ||
const std::size_t length = state.range(0); | ||
|
||
std::vector<T> vector(length); | ||
thrust::sequence(vector.begin(), vector.end(), 1); | ||
|
||
std::vector<T> to_replace_vector(10); | ||
thrust::sequence(to_replace_vector.begin(), to_replace_vector.end(), 1); | ||
|
||
std::vector<T> values_vector(10); | ||
thrust::sequence(values_vector.begin(), values_vector.end(), 1); | ||
|
||
for (auto _ : state) { | ||
std::unordered_map<T, T> map; | ||
for (std::size_t i = 0; i < values_vector.size(); i++) { | ||
map.insert({to_replace_vector[i], values_vector[i]}); | ||
} | ||
|
||
for (std::size_t i = 0; i < vector.size(); i++) { | ||
try { | ||
vector[i] = map[vector[i]]; | ||
} catch (...) { continue; } | ||
} | ||
} | ||
} | ||
|
||
static void | ||
BM_GPU_LoopReplace(benchmark::State &state) { | ||
const std::size_t length = state.range(0); | ||
|
||
thrust::device_vector<T> device_vector(length); | ||
thrust::sequence(device_vector.begin(), device_vector.end(), 1); | ||
gdf_column column = MakeGdfColumn(device_vector); | ||
|
||
thrust::device_vector<T> to_replace_vector(10); | ||
thrust::sequence(to_replace_vector.begin(), to_replace_vector.end(), 1); | ||
gdf_column to_replace = MakeGdfColumn(to_replace_vector); | ||
|
||
thrust::device_vector<T> values_vector(10); | ||
thrust::sequence(values_vector.begin(), values_vector.end(), 1); | ||
gdf_column values = MakeGdfColumn(values_vector); | ||
|
||
for (auto _ : state) { | ||
const gdf_error status = gdf_replace(&column, &to_replace, &values); | ||
state.PauseTiming(); | ||
if (status != GDF_SUCCESS) { state.SkipWithError("Failed replace"); } | ||
state.ResumeTiming(); | ||
} | ||
} | ||
|
||
BENCHMARK(BM_CPU_LoopReplace)->Ranges({{8, 8 << 16}, {8, 512}}); | ||
BENCHMARK(BM_CPU_MapReplace)->Ranges({{8, 8 << 16}, {8, 512}}); | ||
BENCHMARK(BM_GPU_LoopReplace)->Ranges({{8, 8 << 16}, {8, 512}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface is confusing to me. Why does it have
column
andto_replace
parameters? What happens tocolumn
? What happens toto_replace
? Why not just make agdf_copy(out_column, in_column)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see. The meaning of "replace" is not clear. It's not a copy. The semantics are actually those of "find_and_replace_all": For each value in
to_replace
, find all instances of that value incolumn
and replace it with the corresponding value invalues
. This should be made clear in the header documentation for the function. Consider changing the name to gdf_find_and_replace_all() or something like that...