Skip to content

Commit ce40d8c

Browse files
authored
Implement faster annulus check (#1100)
Implement annulus check from global coordinates
1 parent 820b09d commit ce40d8c

File tree

7 files changed

+103
-51
lines changed

7 files changed

+103
-51
lines changed

core/include/detray/geometry/shapes/annulus2D.hpp

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "detray/definitions/indexing.hpp"
1515
#include "detray/definitions/math.hpp"
1616
#include "detray/definitions/units.hpp"
17+
#include "detray/geometry/coordinates/cartesian3D.hpp"
1718
#include "detray/geometry/coordinates/polar2D.hpp"
1819
#include "detray/geometry/detail/shape_utils.hpp"
1920
#include "detray/geometry/detail/vertexer.hpp"
@@ -160,11 +161,60 @@ class annulus2D {
160161
std::numeric_limits<dscalar<algebra_t>>::epsilon(),
161162
const dscalar<algebra_t> edge_tol = 0.f) const {
162163

163-
// Get the full local position
164-
const dpoint2D<algebra_t> loc_p =
165-
local_frame_type<algebra_t>::global_to_local(trf, glob_p, {});
164+
using scalar_t = dscalar<algebra_t>;
165+
166+
// Move point to local plane: Focal frame in cartesian coordinates
167+
const auto loc_3D{
168+
cartesian3D<algebra_t>::global_to_local(trf, glob_p, {})};
169+
170+
// Shift local 3D position into beam frame to check the radius
171+
const scalar_t new_x{loc_3D[0] + bounds[e_shift_x]};
172+
const scalar_t new_y{loc_3D[1] + bounds[e_shift_y]};
173+
174+
const scalar_t r_beam{
175+
math::sqrt(math::fma(new_x, new_x, new_y * new_y))};
176+
177+
auto inside_mask = ((r_beam + tol) >= bounds[e_min_r]) &&
178+
(r_beam <= (bounds[e_max_r] + tol));
179+
180+
// Try to avoid the costly phi calculation
181+
auto phi_focal{detail::invalid_value<scalar_t>()};
182+
if (detail::any_of(inside_mask)) {
183+
// Get phi for phi-bounds check and rotate by average phi
184+
phi_focal = vector::phi(loc_3D) - bounds[e_average_phi];
185+
// Estimate angular tolerance along r
186+
const scalar_t phi_tol{detail::phi_tolerance(tol, r_beam)};
187+
188+
inside_mask = (phi_focal >= (bounds[e_min_phi_rel] - phi_tol)) &&
189+
(phi_focal <= (bounds[e_max_phi_rel] + phi_tol)) &&
190+
inside_mask;
191+
}
192+
193+
decltype(inside_mask) inside_edge{false};
194+
if (detail::any_of(edge_tol > 0.f)) {
195+
// Edge tolerance
196+
const scalar_t full_tol{tol + edge_tol};
197+
198+
inside_edge = ((r_beam + full_tol) >= bounds[e_min_r]) &&
199+
(r_beam <= (bounds[e_max_r] + full_tol));
200+
201+
if (detail::any_of(inside_edge)) {
202+
// If phi had not been calculated before, do it now
203+
if (detail::is_invalid_value(phi_focal)) {
204+
phi_focal = vector::phi(loc_3D) - bounds[e_average_phi];
205+
}
206+
207+
const scalar_t phi_tol_full{
208+
detail::phi_tolerance(full_tol, r_beam)};
209+
210+
inside_edge =
211+
(phi_focal >= (bounds[e_min_phi_rel] - phi_tol_full)) &&
212+
(phi_focal <= (bounds[e_max_phi_rel] + phi_tol_full)) &&
213+
inside_edge;
214+
}
215+
}
166216

167-
return check_boundaries(bounds, loc_p, tol, edge_tol);
217+
return result_type<decltype(inside_mask)>{inside_mask, inside_edge};
168218
}
169219

170220
/// @note the point is expected to be given in local coordinates by the
@@ -188,36 +238,54 @@ class annulus2D {
188238

189239
// Check phi boundaries, which are well def. in focal frame
190240
const scalar_t phi_tol = detail::phi_tolerance(tol, loc_p[0]);
191-
const auto phi_check =
192-
!((phi_focal < (bounds[e_min_phi_rel] - phi_tol)) ||
193-
(phi_focal > (bounds[e_max_phi_rel] + phi_tol)));
241+
auto inside_mask = !((phi_focal < (bounds[e_min_phi_rel] - phi_tol)) ||
242+
(phi_focal > (bounds[e_max_phi_rel] + phi_tol)));
243+
244+
// Try to avoid the costly r_beam calculation
245+
auto r_beam2{detail::invalid_value<scalar_t>()};
246+
if (detail::any_of(inside_mask)) {
194247

195-
const auto r_beam2 = get_r2_beam_frame(bounds, loc_p);
248+
r_beam2 = get_r2_beam_frame(bounds, loc_p);
196249

197-
// Apply tolerances as squares: 0 <= a, 0 <= b: a^2 <= b^2 <=> a <= b
198-
const scalar_t minR_tol =
199-
math::max(bounds[e_min_r] - tol, scalar_t(0.f));
200-
const scalar_t maxR_tol = bounds[e_max_r] + tol;
250+
// Apply tolerances as squares: 0 <= a, 0 <= b: a^2 <= b^2 <=> a <=
251+
// b
252+
const scalar_t minR_tol =
253+
math::max(bounds[e_min_r] - tol, scalar_t(0.f));
254+
const scalar_t maxR_tol = bounds[e_max_r] + tol;
201255

202-
assert(detail::all_of(minR_tol >= scalar_t(0.f)));
256+
assert(detail::all_of(minR_tol >= scalar_t(0.f)));
203257

204-
auto inside_mask{((r_beam2 >= (minR_tol * minR_tol)) &&
205-
(r_beam2 <= (maxR_tol * maxR_tol))) &&
206-
phi_check};
258+
inside_mask = (r_beam2 >= (minR_tol * minR_tol)) &&
259+
(r_beam2 <= (maxR_tol * maxR_tol)) && inside_mask;
260+
}
207261

208262
decltype(inside_mask) inside_edge{false};
209263
if (detail::any_of(edge_tol > 0.f)) {
210264
// Edge tolerance
211265
const scalar_t full_tol{tol + edge_tol};
212-
const scalar_t minR_tol_edge =
213-
math::max(bounds[e_min_r] - full_tol, scalar_t(0.f));
214-
const scalar_t maxR_tol_edge = bounds[e_max_r] + full_tol;
266+
const scalar_t phi_tol_full =
267+
detail::phi_tolerance(full_tol, loc_p[0]);
268+
269+
const auto phi_check_edge =
270+
(phi_focal >= (bounds[e_min_phi_rel] - phi_tol_full)) &&
271+
(phi_focal <= (bounds[e_max_phi_rel] + phi_tol_full));
272+
273+
if (detail::any_of(inside_edge)) {
274+
// If phi had not been calculated before, do it now
275+
if (detail::is_invalid_value(r_beam2)) {
276+
r_beam2 = get_r2_beam_frame(bounds, loc_p);
277+
}
278+
279+
const scalar_t minR_tol_edge =
280+
math::max(bounds[e_min_r] - full_tol, scalar_t(0.f));
281+
const scalar_t maxR_tol_edge = bounds[e_max_r] + full_tol;
215282

216-
assert(detail::all_of(minR_tol_edge >= scalar_t(0.f)));
283+
assert(detail::all_of(minR_tol_edge >= scalar_t(0.f)));
217284

218-
inside_edge = ((r_beam2 >= (minR_tol_edge * minR_tol_edge)) &&
219-
(r_beam2 <= (maxR_tol_edge * maxR_tol_edge))) &&
220-
phi_check;
285+
inside_edge = (r_beam2 >= (minR_tol_edge * minR_tol_edge)) &&
286+
(r_beam2 <= (maxR_tol_edge * maxR_tol_edge)) &&
287+
phi_check_edge;
288+
}
221289
}
222290

223291
return result_type<decltype(inside_mask)>{inside_mask, inside_edge};

tests/unit_tests/cpu/geometry/masks/annulus2D.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ GTEST_TEST(detray_masks, annulus2D_ratio_test) {
133133
bool operator()(const point3 &p,
134134
const mask<annulus2D, test_algebra> &ann,
135135
const test::transform3 &trf, const scalar t) {
136-
137-
const point3 loc_p{ann.to_local_frame3D(trf, p)};
138-
return ann.is_inside(loc_p, t);
136+
return ann.is_inside(trf, p, t);
139137
}
140138
};
141139

tests/unit_tests/cpu/geometry/masks/cylinder.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ GTEST_TEST(detray_masks, cylinder2D_ratio_test) {
7676
bool operator()(const point3 &p,
7777
const mask<cylinder2D, test_algebra> &cyl,
7878
const test::transform3 &trf, const scalar t) {
79-
80-
const point3 loc_p{cyl.to_local_frame3D(trf, p)};
81-
return cyl.is_inside(loc_p, t);
79+
return cyl.is_inside(trf, p, t);
8280
}
8381
};
8482

@@ -156,9 +154,7 @@ GTEST_TEST(detray_masks, cylinder3D_ratio_test) {
156154
bool operator()(const point3 &p,
157155
const mask<cylinder3D, test_algebra> &cyl,
158156
const test::transform3 &trf, const scalar t) {
159-
160-
const point3 loc_p{cyl.to_local_frame3D(trf, p)};
161-
return cyl.is_inside(loc_p, t);
157+
return cyl.is_inside(trf, p, t);
162158
}
163159
};
164160

tests/unit_tests/cpu/geometry/masks/line.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,9 @@ GTEST_TEST(detray_masks, line_circular_ratio_test) {
7979
struct mask_check {
8080
bool operator()(const point3 &p,
8181
const mask<line_circular, test_algebra> &st,
82-
const test::transform3 &trf, const test::vector3 &dir,
83-
const scalar t) {
84-
85-
const point3 loc_p{st.to_local_frame3D(trf, p, dir)};
86-
return st.is_inside(loc_p, t);
82+
const test::transform3 &trf,
83+
const test::vector3 & /*dir*/, const scalar t) {
84+
return st.is_inside(trf, p, t);
8785
}
8886
};
8987

@@ -155,11 +153,9 @@ GTEST_TEST(detray_masks, line_square_ratio_test) {
155153
struct mask_check {
156154
bool operator()(const point3 &p,
157155
const mask<line_square, test_algebra> &dcl,
158-
const test::transform3 &trf, const test::vector3 &dir,
159-
const scalar t) {
160-
161-
const point3 loc_p{dcl.to_local_frame3D(trf, p, dir)};
162-
return dcl.is_inside(loc_p, t);
156+
const test::transform3 &trf,
157+
const test::vector3 & /*dir*/, const scalar t) {
158+
return dcl.is_inside(trf, p, t);
163159
}
164160
};
165161

tests/unit_tests/cpu/geometry/masks/rectangle2D.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ GTEST_TEST(detray_masks, rectangle2D_ratio_test) {
7676
bool operator()(const point3 &p,
7777
const mask<rectangle2D, test_algebra> &r,
7878
const test::transform3 &trf, const scalar t) {
79-
80-
const point3 loc_p{r.to_local_frame3D(trf, p)};
81-
return r.is_inside(loc_p, t);
79+
return r.is_inside(trf, p, t);
8280
}
8381
};
8482

tests/unit_tests/cpu/geometry/masks/ring2D.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ GTEST_TEST(detray_masks, ring2D_ratio_test) {
7474
struct mask_check {
7575
bool operator()(const point3 &p, const mask<ring2D, test_algebra> &r,
7676
const test::transform3 &trf, const scalar t) {
77-
78-
const point3 loc_p{r.to_local_frame3D(trf, p)};
79-
return r.is_inside(loc_p, t);
77+
return r.is_inside(trf, p, t);
8078
}
8179
};
8280

tests/unit_tests/cpu/geometry/masks/trapezoid2D.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ GTEST_TEST(detray_masks, trapezoid2D_ratio_test) {
7878
bool operator()(const point3 &p,
7979
const mask<trapezoid2D, test_algebra> &tp,
8080
const test::transform3 &trf, const scalar t) {
81-
82-
const point3 loc_p{tp.to_local_frame3D(trf, p)};
83-
return tp.is_inside(loc_p, t);
81+
return tp.is_inside(trf, p, t);
8482
}
8583
};
8684

0 commit comments

Comments
 (0)