Skip to content

Commit

Permalink
added stream support
Browse files Browse the repository at this point in the history
  • Loading branch information
shrshi committed Nov 26, 2024
1 parent d93e9c2 commit 953a5d5
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cpp/include/nvtext/byte_pair_encoding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ std::unique_ptr<bpe_merge_pairs> load_merge_pairs(
* @param merges_pairs Created by a call to @ref nvtext::load_merge_pairs.
* @param separator String used to build the output after encoding.
* Default is a space.
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Memory resource to allocate any returned objects.
* @return An encoded column of strings.
*/
std::unique_ptr<cudf::column> byte_pair_encoding(
cudf::strings_column_view const& input,
bpe_merge_pairs const& merges_pairs,
cudf::string_scalar const& separator = cudf::string_scalar(" "),
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/** @} */ // end of group
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/text/bpe/byte_pair_encoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,11 @@ std::unique_ptr<cudf::column> byte_pair_encoding(cudf::strings_column_view const
std::unique_ptr<cudf::column> byte_pair_encoding(cudf::strings_column_view const& input,
bpe_merge_pairs const& merges_table,
cudf::string_scalar const& separator,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
return detail::byte_pair_encoding(input, merges_table, separator, cudf::get_default_stream(), mr);
return detail::byte_pair_encoding(input, merges_table, separator, stream, mr);
}

} // namespace nvtext
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ ConfigureTest(
)
ConfigureTest(
STREAM_TEXT_TEST
streams/text/bpe_test.cpp
streams/text/edit_distance_test.cpp
streams/text/ngrams_test.cpp
streams/text/replace_test.cpp
Expand Down
59 changes: 59 additions & 0 deletions cpp/tests/streams/text/bpe_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/default_stream.hpp>
#include <cudf_test/iterator_utilities.hpp>

#include <cudf/strings/strings_column_view.hpp>

#include <nvtext/byte_pair_encoding.hpp>

struct TextBytePairEncoding : public cudf::test::BaseFixture {};

TEST_F(TextBytePairEncoding, BytePairEncoding)
{
// partial table based on values from https://huggingface.co/gpt2/raw/main/merges.txt
auto mpt = cudf::test::strings_column_wrapper({
"e n", // 14
"i t", // 16
"i s", // 17
"e s", // 20
"en t", // 44
"c e", // 90
"es t", // 141
"en ce", // 340
"t h", // 146
"h i", // 5049
"th is", // 5407
"t est", // 9034
"s i", // 13142
"s ent" // 33832
});

auto merge_pairs =
nvtext::load_merge_pairs(cudf::strings_column_view(mpt), cudf::test::get_default_stream());

auto validity = cudf::test::iterators::null_at(4);
cudf::test::strings_column_wrapper input(
{"thisisit", "thisis test-sentence-1", "thisistestsentence-2", "this-istestsentence 3", "", ""},
validity);
auto sv = cudf::strings_column_view(input);

auto results = nvtext::byte_pair_encoding(
sv, *merge_pairs, cudf::string_scalar(" "), cudf::test::get_default_stream());
}

0 comments on commit 953a5d5

Please sign in to comment.