Skip to content

Commit d582781

Browse files
authored
Do not store links which will be rejected (acts-project#933)
* Cut by max_num_branches_per_seed as soon as possible * Cut by max_num_skipping_per_cand as soon as possible * Cut by ..._track_candidates_per_track as soon as possible
1 parent b8be856 commit d582781

20 files changed

+239
-541
lines changed

device/alpaka/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ traccc_add_alpaka_library( traccc_alpaka alpaka TYPE SHARED
5757
"src/finding/kernels/make_barcode_sequence.hpp"
5858
"src/finding/kernels/apply_interaction.hpp"
5959
"src/finding/kernels/fill_sort_keys.hpp"
60-
"src/finding/kernels/prune_tracks.hpp"
6160
"src/finding/kernels/build_tracks.hpp"
6261
"src/finding/kernels/find_tracks.hpp"
6362
"src/finding/kernels/propagate_to_next_surface.hpp"

device/alpaka/src/finding/finding_algorithm.cpp

Lines changed: 38 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "./kernels/find_tracks.hpp"
1818
#include "./kernels/make_barcode_sequence.hpp"
1919
#include "./kernels/propagate_to_next_surface.hpp"
20-
#include "./kernels/prune_tracks.hpp"
2120
#include "traccc/definitions/primitives.hpp"
2221
#include "traccc/definitions/qualifiers.hpp"
2322
#include "traccc/edm/device/sort_key.hpp"
@@ -219,6 +218,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
219218
n_in_params * m_cfg.max_num_branches_per_surface, m_mr.main);
220219
m_copy.setup(updated_liveness_buffer)->ignore();
221220

221+
// Reset the number of tracks per seed
222+
m_copy.memset(n_tracks_per_seed_buffer, 0)->ignore();
223+
222224
const unsigned int links_size = m_copy.get_size(links_buffer);
223225

224226
if (links_size + n_max_candidates > link_buffer_capacity) {
@@ -259,20 +261,22 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
259261
::alpaka::allocBuf<PayloadType, Idx>(devHost, 1u);
260262
PayloadType* payload = ::alpaka::getPtrNative(bufHost_payload);
261263

262-
new (payload) PayloadType{
263-
.det_data = det_view,
264-
.measurements_view = measurements,
265-
.in_params_view = in_params_buffer,
266-
.in_params_liveness_view = param_liveness_buffer,
267-
.n_in_params = n_in_params,
268-
.barcodes_view = barcodes_buffer,
269-
.upper_bounds_view = upper_bounds_buffer,
270-
.links_view = links_buffer,
271-
.prev_links_idx = prev_link_idx,
272-
.curr_links_idx = step_to_link_idx_map[step],
273-
.step = step,
274-
.out_params_view = updated_params_buffer,
275-
.out_params_liveness_view = updated_liveness_buffer};
264+
new (payload)
265+
PayloadType{.det_data = det_view,
266+
.measurements_view = measurements,
267+
.in_params_view = in_params_buffer,
268+
.in_params_liveness_view = param_liveness_buffer,
269+
.n_in_params = n_in_params,
270+
.barcodes_view = barcodes_buffer,
271+
.upper_bounds_view = upper_bounds_buffer,
272+
.links_view = links_buffer,
273+
.prev_links_idx = prev_link_idx,
274+
.curr_links_idx = step_to_link_idx_map[step],
275+
.step = step,
276+
.out_params_view = updated_params_buffer,
277+
.out_params_liveness_view = updated_liveness_buffer,
278+
.tips_view = tips_buffer,
279+
.n_tracks_per_seed_view = n_tracks_per_seed_buffer};
276280

277281
auto bufAcc_payload =
278282
::alpaka::allocBuf<PayloadType, Idx>(devAcc, 1u);
@@ -294,6 +298,10 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
294298
::alpaka::wait(queue);
295299
}
296300

301+
if (step == m_cfg.max_track_candidates_per_track - 1) {
302+
break;
303+
}
304+
297305
if (n_candidates > 0) {
298306
/*****************************************************************
299307
* Kernel4: Get key and value for parameter sorting
@@ -333,9 +341,6 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
333341
*****************************************************************/
334342

335343
{
336-
// Reset the number of tracks per seed
337-
m_copy.memset(n_tracks_per_seed_buffer, 0)->ignore();
338-
339344
Idx blocksPerGrid =
340345
(n_candidates + threadsPerBlock - 1) / threadsPerBlock;
341346
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
@@ -348,18 +353,17 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
348353
::alpaka::allocBuf<PayloadType, Idx>(devHost, 1u);
349354
PayloadType* payload = ::alpaka::getPtrNative(bufHost_payload);
350355

351-
new (payload) PayloadType{
352-
.det_data = det_view,
353-
.field_data = field_view,
354-
.params_view = in_params_buffer,
355-
.params_liveness_view = param_liveness_buffer,
356-
.param_ids_view = param_ids_buffer,
357-
.links_view = links_buffer,
358-
.prev_links_idx = step_to_link_idx_map[step],
359-
.step = step,
360-
.n_in_params = n_candidates,
361-
.tips_view = tips_buffer,
362-
.n_tracks_per_seed_view = n_tracks_per_seed_buffer};
356+
new (payload)
357+
PayloadType{.det_data = det_view,
358+
.field_data = field_view,
359+
.params_view = in_params_buffer,
360+
.params_liveness_view = param_liveness_buffer,
361+
.param_ids_view = param_ids_buffer,
362+
.links_view = links_buffer,
363+
.prev_links_idx = step_to_link_idx_map[step],
364+
.step = step,
365+
.n_in_params = n_candidates,
366+
.tips_view = tips_buffer};
363367

364368
auto bufAcc_payload =
365369
::alpaka::allocBuf<PayloadType, Idx>(devAcc, 1u);
@@ -403,65 +407,21 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
403407
m_copy.setup(track_candidates_buffer.headers)->ignore();
404408
m_copy.setup(track_candidates_buffer.items)->ignore();
405409

406-
// Create buffer for valid indices
407-
vecmem::data::vector_buffer<unsigned int> valid_indices_buffer(n_tips_total,
408-
m_mr.main);
409-
410-
// Count the number of valid tracks
411-
auto bufHost_n_valid_tracks =
412-
::alpaka::allocBuf<unsigned int, Idx>(devHost, 1u);
413-
unsigned int* n_valid_tracks =
414-
::alpaka::getPtrNative(bufHost_n_valid_tracks);
415-
::alpaka::memset(queue, bufHost_n_valid_tracks, 0);
416-
::alpaka::wait(queue);
417-
418410
// @Note: nBlocks can be zero in case there is no tip. This happens when
419411
// chi2_max config is set tightly and no tips are found
420412
if (n_tips_total > 0) {
421-
auto n_valid_tracks_device =
422-
::alpaka::allocBuf<unsigned int, Idx>(devAcc, 1u);
423-
::alpaka::memset(queue, n_valid_tracks_device, 0);
424-
425413
Idx blocksPerGrid =
426414
(n_tips_total + threadsPerBlock - 1) / threadsPerBlock;
427415
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
428416

429417
::alpaka::exec<Acc>(
430-
queue, workDiv, BuildTracksKernel{}, m_cfg,
431-
device::build_tracks_payload{
432-
measurements, seeds_view, links_buffer, tips_buffer,
433-
track_candidates_buffer, valid_indices_buffer,
434-
::alpaka::getPtrNative(n_valid_tracks_device)});
435-
::alpaka::wait(queue);
436-
437-
// Global counter object: Device -> Host
438-
::alpaka::memcpy(queue, bufHost_n_valid_tracks, n_valid_tracks_device);
439-
::alpaka::wait(queue);
440-
}
441-
442-
// Create pruned candidate buffer
443-
track_candidate_container_types::buffer prune_candidates_buffer{
444-
{*n_valid_tracks, m_mr.main},
445-
{std::vector<std::size_t>(*n_valid_tracks,
446-
m_cfg.max_track_candidates_per_track),
447-
m_mr.main, m_mr.host, vecmem::data::buffer_type::resizable}};
448-
449-
m_copy.setup(prune_candidates_buffer.headers)->ignore();
450-
m_copy.setup(prune_candidates_buffer.items)->ignore();
451-
452-
if (*n_valid_tracks > 0) {
453-
Idx blocksPerGrid =
454-
(*n_valid_tracks + threadsPerBlock - 1) / threadsPerBlock;
455-
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
456-
457-
::alpaka::exec<Acc>(queue, workDiv, PruneTracksKernel{},
458-
device::prune_tracks_payload{
459-
track_candidates_buffer, valid_indices_buffer,
460-
prune_candidates_buffer});
418+
queue, workDiv, BuildTracksKernel{},
419+
device::build_tracks_payload{measurements, seeds_view, links_buffer,
420+
tips_buffer, track_candidates_buffer});
461421
::alpaka::wait(queue);
462422
}
463423

464-
return prune_candidates_buffer;
424+
return track_candidates_buffer;
465425
}
466426

467427
// Explicit template instantiation

device/alpaka/src/finding/kernels/build_tracks.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ namespace traccc::alpaka {
1919

2020
struct BuildTracksKernel {
2121
template <typename TAcc>
22-
ALPAKA_FN_ACC void operator()(TAcc const& acc, const finding_config cfg,
22+
ALPAKA_FN_ACC void operator()(TAcc const& acc,
2323
device::build_tracks_payload payload) const {
2424

2525
device::global_index_t globalThreadIdx =
2626
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
2727

28-
device::build_tracks(globalThreadIdx, cfg, payload);
28+
device::build_tracks(globalThreadIdx, payload);
2929
}
3030
};
3131

device/alpaka/src/finding/kernels/prune_tracks.hpp

Lines changed: 0 additions & 28 deletions
This file was deleted.

device/common/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,12 @@ traccc_add_library( traccc_device_common device_common TYPE INTERFACE
6464
"include/traccc/finding/device/fill_sort_keys.hpp"
6565
"include/traccc/finding/device/make_barcode_sequence.hpp"
6666
"include/traccc/finding/device/propagate_to_next_surface.hpp"
67-
"include/traccc/finding/device/prune_tracks.hpp"
6867
"include/traccc/finding/device/impl/apply_interaction.ipp"
6968
"include/traccc/finding/device/impl/build_tracks.ipp"
7069
"include/traccc/finding/device/impl/find_tracks.ipp"
7170
"include/traccc/finding/device/impl/fill_sort_keys.ipp"
7271
"include/traccc/finding/device/impl/make_barcode_sequence.ipp"
7372
"include/traccc/finding/device/impl/propagate_to_next_surface.ipp"
74-
"include/traccc/finding/device/impl/prune_tracks.ipp"
7573
# Track fitting funtions(s).
7674
"include/traccc/fitting/device/fit.hpp"
7775
"include/traccc/fitting/device/impl/fit.ipp"

device/common/include/traccc/finding/device/build_tracks.hpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,6 @@ struct build_tracks_payload {
5252
* @brief View object to the vector of track candidates
5353
*/
5454
track_candidate_container_types::view track_candidates_view;
55-
56-
/**
57-
* @brief View object to the vector of indices meeting the selection
58-
* criteria
59-
*/
60-
vecmem::data::vector_view<unsigned int> valid_indices_view;
61-
62-
/**
63-
* @brief The number of valid tracks meeting criteria
64-
*/
65-
unsigned int* n_valid_tracks;
6655
};
6756

6857
/// Function for building full tracks from the link container:
@@ -75,8 +64,7 @@ struct build_tracks_payload {
7564
/// @param[inout] payload The function call payload
7665
///
7766
TRACCC_HOST_DEVICE inline void build_tracks(
78-
global_index_t globalIndex, const finding_config& cfg,
79-
const build_tracks_payload& payload);
67+
global_index_t globalIndex, const build_tracks_payload& payload);
8068

8169
} // namespace traccc::device
8270

device/common/include/traccc/finding/device/find_tracks.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ struct find_tracks_payload {
9898
* @brief View object to the output track parameter liveness vector
9999
*/
100100
vecmem::data::vector_view<unsigned int> out_params_liveness_view;
101+
102+
/**
103+
* @brief View object to the vector of tips
104+
*/
105+
vecmem::data::vector_view<unsigned int> tips_view;
106+
107+
/**
108+
* @brief View object to the vector of the number of tracks per initial
109+
* input seed
110+
*/
111+
vecmem::data::vector_view<unsigned int> n_tracks_per_seed_view;
101112
};
102113

103114
/// (Shared Event Data) Payload for the @c traccc::device::find_tracks function

device/common/include/traccc/finding/device/impl/build_tracks.ipp

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
namespace traccc::device {
1414

1515
TRACCC_HOST_DEVICE inline void build_tracks(
16-
const global_index_t globalIndex, const finding_config& cfg,
17-
const build_tracks_payload& payload) {
16+
const global_index_t globalIndex, const build_tracks_payload& payload) {
1817

1918
const measurement_collection_types::const_device measurements(
2019
payload.measurements_view);
@@ -29,9 +28,6 @@ TRACCC_HOST_DEVICE inline void build_tracks(
2928
track_candidate_container_types::device track_candidates(
3029
payload.track_candidates_view);
3130

32-
vecmem::device_vector<unsigned int> valid_indices(
33-
payload.valid_indices_view);
34-
3531
if (globalIndex >= tips.size()) {
3632
return;
3733
}
@@ -50,8 +46,6 @@ TRACCC_HOST_DEVICE inline void build_tracks(
5046
// Resize the candidates with the exact size
5147
cands_per_track.resize(n_cands);
5248

53-
bool success = true;
54-
5549
// Track summary variables
5650
scalar ndf_sum = 0.f;
5751
scalar chi2_sum = 0.f;
@@ -67,11 +61,7 @@ TRACCC_HOST_DEVICE inline void build_tracks(
6761
L = links.at(L.previous_candidate_idx);
6862
}
6963

70-
// Break if the measurement is still invalid
71-
if (L.meas_idx >= measurements.size()) {
72-
success = false;
73-
break;
74-
}
64+
assert(L.meas_idx < n_meas);
7565

7666
*it = {measurements.at(L.meas_idx)};
7767
num_inserted++;
@@ -97,36 +87,21 @@ TRACCC_HOST_DEVICE inline void build_tracks(
9787
}
9888

9989
#ifndef NDEBUG
100-
if (success) {
101-
// Assert that we inserted exactly as many elements as we reserved
102-
// space for.
103-
assert(num_inserted == cands_per_track.size());
104-
105-
// Assert that we did not make any duplicate track states.
106-
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
107-
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
108-
if (i != j) {
109-
// TODO: Re-enable me!
110-
// assert(cands_per_track.at(i).measurement_id !=
111-
// cands_per_track.at(j).measurement_id);
112-
}
90+
// Assert that we inserted exactly as many elements as we reserved
91+
// space for.
92+
assert(num_inserted == cands_per_track.size());
93+
94+
// Assert that we did not make any duplicate track states.
95+
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
96+
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
97+
if (i != j) {
98+
// TODO: Re-enable me!
99+
// assert(cands_per_track.at(i).measurement_id !=
100+
// cands_per_track.at(j).measurement_id);
113101
}
114102
}
115103
}
116104
#endif
117-
118-
// NOTE: We may at some point want to assert that `success` is true
119-
120-
// Criteria for valid tracks
121-
if (n_cands >= cfg.min_track_candidates_per_track &&
122-
n_cands <= cfg.max_track_candidates_per_track && success) {
123-
124-
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
125-
*payload.n_valid_tracks);
126-
127-
const unsigned int pos = num_valid_tracks.fetch_add(1);
128-
valid_indices[pos] = globalIndex;
129-
}
130105
}
131106

132107
} // namespace traccc::device

0 commit comments

Comments
 (0)