Skip to content

Commit 4193e1c

Browse files
committed
Reduce thread divergence in covariance transport
Currently, the covariance transport uses the Jacobian engine which is templated on the frame type. While this makes the code easier to read and write, it requires the compiler to duplicate a lot of code for each of the frame types, and this duplicated code counts as multiple branches for the sake of GPU execution. Thus, this templating increases the amount of thread divergence. This commit refactors the Jacobian engine into smaller parts, some of which are templated on the frame type and some of which are not. Client code and then take a more fine-grained approach to branching and improve divergence.
1 parent 861e70e commit 4193e1c

File tree

10 files changed

+215
-117
lines changed

10 files changed

+215
-117
lines changed

core/include/detray/geometry/detail/tracking_surface_kernels.hpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
// Project include(s)
1111
#include "detray/definitions/detail/qualifiers.hpp"
12+
#include "detray/definitions/track_parametrization.hpp"
1213
#include "detray/geometry/detail/surface_kernels.hpp"
1314
#include "detray/propagator/detail/jacobian_engine.hpp"
1415
#include "detray/tracks/detail/transform_track_parameters.hpp"
@@ -28,6 +29,7 @@ struct tracking_surface_kernels : public surface_kernels<algebra_t> {
2829
using bound_param_vector_type = bound_parameters_vector<algebra_t>;
2930
using free_param_vector_type = free_parameters_vector<algebra_t>;
3031
using free_matrix_type = free_matrix<algebra_t>;
32+
using free_to_path_matrix_type = free_to_path_matrix<algebra_t>;
3133

3234
/// A functor to get from a free to a bound vector
3335
struct free_to_bound_vector {
@@ -71,8 +73,9 @@ struct tracking_surface_kernels : public surface_kernels<algebra_t> {
7173

7274
using frame_t = typename mask_group_t::value_type::local_frame;
7375

74-
return detail::jacobian_engine<frame_t>::free_to_bound_jacobian(
75-
trf3, free_vec);
76+
return detail::jacobian_engine<
77+
algebra_t>::template free_to_bound_jacobian<frame_t>(trf3,
78+
free_vec);
7679
}
7780
};
7881

@@ -87,8 +90,9 @@ struct tracking_surface_kernels : public surface_kernels<algebra_t> {
8790

8891
using frame_t = typename mask_group_t::value_type::local_frame;
8992

90-
return detail::jacobian_engine<frame_t>::bound_to_free_jacobian(
91-
trf3, mask_group[index], bound_vec);
93+
return detail::jacobian_engine<algebra_t>::
94+
template bound_to_free_jacobian<frame_t>(
95+
trf3, mask_group[index], bound_vec);
9296
}
9397
};
9498

@@ -105,10 +109,27 @@ struct tracking_surface_kernels : public surface_kernels<algebra_t> {
105109

106110
using frame_t = typename mask_group_t::value_type::local_frame;
107111

108-
return detail::jacobian_engine<frame_t>::path_correction(
109-
pos, dir, dtds, dqopds, trf3);
112+
return detail::jacobian_engine<algebra_t>::template path_correction<
113+
frame_t>(pos, dir, dtds, dqopds, trf3);
110114
}
111115
};
112-
};
113116

117+
/// A function object to get the free to path derivative
118+
struct free_to_path_derivative {
119+
120+
template <typename mask_group_t, typename index_t>
121+
DETRAY_HOST_DEVICE inline free_to_path_matrix_type operator()(
122+
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
123+
const transform3_type& trf3, const vector3_type& pos,
124+
const vector3_type& dir, const vector3_type& dtds) const {
125+
126+
using frame_t = typename mask_group_t::value_type::local_frame;
127+
128+
return detail::jacobian_engine<
129+
algebra_t>::template free_to_path_derivative<frame_t>(pos, dir,
130+
dtds,
131+
trf3);
132+
}
133+
};
134+
};
114135
} // namespace detray::detail

core/include/detray/geometry/tracking_surface.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ class tracking_surface : public geometry::surface<detector_t> {
100100
return this->template visit_mask<typename kernels::path_correction>(
101101
this->transform(ctx), pos, dir, dtds, dqopds);
102102
}
103+
104+
/// @returns the free to path derivative
105+
DETRAY_HOST_DEVICE
106+
constexpr auto free_to_path_derivative(const context &ctx,
107+
const vector3_type &pos,
108+
const vector3_type &dir,
109+
const vector3_type &dtds) const {
110+
return this
111+
->template visit_mask<typename kernels::free_to_path_derivative>(
112+
this->transform(ctx), pos, dir, dtds);
113+
}
103114
};
104115

105116
template <typename detector_t, typename descr_t>

core/include/detray/propagator/actors/parameter_transporter.hpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,44 +29,29 @@ struct parameter_transporter : actor {
2929
using bound_matrix_t = bound_matrix<algebra_t>;
3030
// Matrix type for bound to free jacobian
3131
using bound_to_free_matrix_t = bound_to_free_matrix<algebra_t>;
32+
// Matrix type for free to bound jacobian
33+
using free_to_bound_matrix_t = free_to_bound_matrix<algebra_t>;
3234
/// @}
3335

34-
struct get_full_jacobian_kernel {
35-
36+
struct get_free_to_bound_jacobian_kernel {
3637
template <typename mask_group_t, typename index_t,
3738
typename stepper_state_t>
38-
DETRAY_HOST_DEVICE inline bound_matrix_t operator()(
39+
DETRAY_HOST_DEVICE inline free_to_bound_matrix_t operator()(
3940
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
4041
const transform3_type& trf3,
41-
const bound_to_free_matrix_t& bound_to_free_jacobian,
42-
const material<scalar_type>* vol_mat_ptr,
4342
const stepper_state_t& stepping) const {
44-
4543
using frame_t = typename mask_group_t::value_type::shape::
4644
template local_frame_type<algebra_t>;
4745

48-
using jacobian_engine_t = detail::jacobian_engine<frame_t>;
49-
50-
using free_matrix_t = free_matrix<algebra_t>;
51-
using free_to_bound_matrix_t =
52-
typename jacobian_engine_t::free_to_bound_matrix_type;
46+
// Declare jacobian for bound to free coordinate transform
47+
free_to_bound_matrix_t jac_to_local =
48+
matrix::zero<free_to_bound_matrix_t>();
5349

54-
// Free to bound jacobian at the destination surface
55-
const free_to_bound_matrix_t free_to_bound_jacobian =
56-
jacobian_engine_t::free_to_bound_jacobian(trf3, stepping());
50+
detail::jacobian_engine<algebra_t>::
51+
template free_to_bound_jacobian_step_1<frame_t>(
52+
jac_to_local, trf3, stepping().pos(), stepping().dir());
5753

58-
// Path correction factor
59-
const free_matrix_t path_correction =
60-
jacobian_engine_t::path_correction(
61-
stepping().pos(), stepping().dir(), stepping.dtds(),
62-
stepping.dqopds(vol_mat_ptr), trf3);
63-
64-
const free_matrix_t correction_term =
65-
matrix::identity<free_matrix_t>() + path_correction;
66-
67-
return free_to_bound_jacobian *
68-
(correction_term *
69-
(stepping.transport_jacobian() * bound_to_free_jacobian));
54+
return jac_to_local;
7055
}
7156
};
7257

@@ -143,9 +128,30 @@ struct parameter_transporter : actor {
143128
? vol.material_parameters(stepping().pos())
144129
: nullptr;
145130

146-
return sf.template visit_mask<get_full_jacobian_kernel>(
147-
sf.transform(gctx), bound_to_free_jacobian, vol_mat_ptr,
148-
propagation._stepping);
131+
auto free_to_bound_jacobian =
132+
sf.template visit_mask<get_free_to_bound_jacobian_kernel>(
133+
sf.transform(gctx), propagation._stepping);
134+
135+
detail::jacobian_engine<algebra_t>::free_to_bound_jacobian_step_2(
136+
free_to_bound_jacobian, stepping().dir());
137+
138+
const auto path_to_free_derivative =
139+
detail::jacobian_engine<algebra_t>::path_to_free_derivative(
140+
stepping().dir(), stepping.dtds(),
141+
stepping.dqopds(vol_mat_ptr));
142+
143+
const auto free_to_path_derivative = sf.free_to_path_derivative(
144+
gctx, stepping().pos(), stepping().dir(), stepping.dtds());
145+
146+
const auto path_correction =
147+
path_to_free_derivative * free_to_path_derivative;
148+
149+
const auto correction_term =
150+
matrix::identity<free_matrix<algebra_t>>() + path_correction;
151+
152+
return free_to_bound_jacobian *
153+
(correction_term *
154+
(stepping.transport_jacobian() * bound_to_free_jacobian));
149155
}
150156

151157
}; // namespace detray

0 commit comments

Comments
 (0)