Skip to content

Commit

Permalink
Merge pull request #506 from beomki-yeo/rk-tolerance
Browse files Browse the repository at this point in the history
Add rk tolerance option
  • Loading branch information
beomki-yeo authored Dec 12, 2023
2 parents 103d9f9 + e36b765 commit e91fb9b
Show file tree
Hide file tree
Showing 17 changed files with 43 additions and 7 deletions.
3 changes: 2 additions & 1 deletion core/include/traccc/finding/finding_algorithm.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
m_cfg.overstep_tolerance);
propagation._stepping.template set_constraint<
detray::step::constraint::e_accuracy>(
m_cfg.constrained_step_size);
m_cfg.step_constraint);
propagation.set_mask_tolerance(m_cfg.mask_tolerance);
propagation._stepping.set_tolerance(m_cfg.rk_tolerance);

typename detray::pathlimit_aborter::state s0;
typename detray::parameter_transporter<
Expand Down
3 changes: 2 additions & 1 deletion core/include/traccc/finding/finding_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ struct finding_config {

/// Constrained step size for propagation
/// @TODO: Make a separate file for propagation config?
scalar_t constrained_step_size = std::numeric_limits<scalar_t>::max();
scalar_t step_constraint = std::numeric_limits<scalar_t>::max();

scalar_t overstep_tolerance = -100 * detray::unit<scalar_t>::um;
scalar_t mask_tolerance = 15.f * detray::unit<scalar_t>::um;
scalar_t rk_tolerance = 1e-4;

/// GPU-specific parameter for the number of measurements to be
/// iterated per thread
Expand Down
1 change: 1 addition & 0 deletions core/include/traccc/fitting/fitting_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct fitting_config {
scalar_t overstep_tolerance = -100 * detray::unit<scalar_t>::um;
scalar_t step_constraint = std::numeric_limits<scalar_t>::max();
scalar_t mask_tolerance = 15.f * detray::unit<scalar_t>::um;
scalar_t rk_tolerance = 1e-4;
};

} // namespace traccc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class kalman_fitter {
.template set_constraint<detray::step::constraint::e_accuracy>(
m_cfg.step_constraint);
propagation.set_mask_tolerance(m_cfg.mask_tolerance);
propagation._stepping.set_tolerance(m_cfg.rk_tolerance);

// Run forward filtering
propagator.propagate(propagation, fitter_state());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
propagation._stepping().set_overstep_tolerance(cfg.overstep_tolerance);
propagation._stepping
.template set_constraint<detray::step::constraint::e_accuracy>(
cfg.constrained_step_size);
cfg.step_constraint);
propagation.set_mask_tolerance(cfg.mask_tolerance);
propagation._stepping.set_tolerance(cfg.rk_tolerance);

// Actor state
// @TODO: simplify the syntax here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct propagation_options {
scalar_t step_constraint{std::numeric_limits<scalar_t>::max()};
scalar_t overstep_tolerance{-100.f * detray::unit<scalar_t>::um};
scalar_t mask_tolerance{15.f * detray::unit<scalar_t>::um};
scalar_t rk_tolerance{1e-4};

propagation_options(po::options_description& desc) {
desc.add_options()("constraint-step-size-mm",
Expand All @@ -40,6 +41,9 @@ struct propagation_options {
desc.add_options()("mask-tolerance-um",
po::value<scalar_t>()->default_value(15.f),
"The mask tolerance [um]");
desc.add_options()("rk-tolerance",
po::value<scalar_t>()->default_value(1e-4),
"The Runge-Kutta stepper tolerance");
}

void read(const po::variables_map& vm) {
Expand All @@ -49,6 +53,7 @@ struct propagation_options {
detray::unit<scalar_t>::um;
mask_tolerance =
vm["mask-tolerance-um"].as<scalar_t>() * detray::unit<scalar_t>::um;
rk_tolerance = vm["rk-tolerance"].as<scalar_t>();
}
};

Expand Down
5 changes: 4 additions & 1 deletion examples/run/cpu/seeding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ int seq_run(const traccc::seeding_input_config& /*i_cfg*/,
cfg.min_track_candidates_per_track = finding_cfg.track_candidates_range[0];
cfg.max_track_candidates_per_track = finding_cfg.track_candidates_range[1];
cfg.chi2_max = finding_cfg.chi2_max;
cfg.constrained_step_size = propagation_opts.step_constraint;
cfg.step_constraint = propagation_opts.step_constraint;
cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
cfg.mask_tolerance = propagation_opts.mask_tolerance;
cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
host_finding(cfg);
Expand All @@ -141,6 +142,8 @@ int seq_run(const traccc::seeding_input_config& /*i_cfg*/,
fit_cfg.step_constraint = propagation_opts.step_constraint;
fit_cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
fit_cfg.mask_tolerance = propagation_opts.mask_tolerance;
fit_cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

// Loop over events
Expand Down
5 changes: 4 additions & 1 deletion examples/run/cpu/truth_finding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ int seq_run(const traccc::finding_input_config<traccc::scalar>& i_cfg,
cfg.min_track_candidates_per_track = i_cfg.track_candidates_range[0];
cfg.max_track_candidates_per_track = i_cfg.track_candidates_range[1];
cfg.chi2_max = i_cfg.chi2_max;
cfg.constrained_step_size = propagation_opts.step_constraint;
cfg.step_constraint = propagation_opts.step_constraint;
cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
cfg.mask_tolerance = propagation_opts.mask_tolerance;
cfg.rk_tolerance = propagation_opts.rk_tolerance;

// Finding algorithm object
traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
Expand All @@ -126,6 +127,8 @@ int seq_run(const traccc::finding_input_config<traccc::scalar>& i_cfg,
fit_cfg.step_constraint = propagation_opts.step_constraint;
fit_cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
fit_cfg.mask_tolerance = propagation_opts.mask_tolerance;
fit_cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

// Seed generator
Expand Down
2 changes: 2 additions & 0 deletions examples/run/cpu/truth_fitting_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ int main(int argc, char* argv[]) {
fit_cfg.step_constraint = propagation_opts.step_constraint;
fit_cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
fit_cfg.mask_tolerance = propagation_opts.mask_tolerance;
fit_cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);

// Seed generator
Expand Down
8 changes: 7 additions & 1 deletion examples/run/cuda/seeding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ int seq_run(const traccc::seeding_input_config& /*i_cfg*/,
cfg.min_track_candidates_per_track = finding_cfg.track_candidates_range[0];
cfg.max_track_candidates_per_track = finding_cfg.track_candidates_range[1];
cfg.chi2_max = finding_cfg.chi2_max;
cfg.constrained_step_size = propagation_opts.step_constraint;
cfg.step_constraint = propagation_opts.step_constraint;
cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
cfg.mask_tolerance = propagation_opts.mask_tolerance;
cfg.rk_tolerance = propagation_opts.rk_tolerance;

// Finding algorithm object
traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
Expand All @@ -191,6 +194,9 @@ int seq_run(const traccc::seeding_input_config& /*i_cfg*/,
// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.step_constraint = propagation_opts.step_constraint;
fit_cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
fit_cfg.mask_tolerance = propagation_opts.mask_tolerance;
fit_cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);
traccc::cuda::fitting_algorithm<device_fitter_type> device_fitting(
Expand Down
4 changes: 3 additions & 1 deletion examples/run/cuda/truth_finding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ int seq_run(const traccc::finding_input_config<traccc::scalar>& i_cfg,
cfg.min_track_candidates_per_track = i_cfg.track_candidates_range[0];
cfg.max_track_candidates_per_track = i_cfg.track_candidates_range[1];
cfg.chi2_max = i_cfg.chi2_max;
cfg.constrained_step_size = propagation_opts.step_constraint;
cfg.step_constraint = propagation_opts.step_constraint;
cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
cfg.mask_tolerance = propagation_opts.mask_tolerance;
cfg.rk_tolerance = propagation_opts.rk_tolerance;

// Finding algorithm object
traccc::finding_algorithm<rk_stepper_type, host_navigator_type>
Expand All @@ -171,6 +172,7 @@ int seq_run(const traccc::finding_input_config<traccc::scalar>& i_cfg,
fit_cfg.step_constraint = propagation_opts.step_constraint;
fit_cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
fit_cfg.mask_tolerance = propagation_opts.mask_tolerance;
fit_cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);
traccc::cuda::fitting_algorithm<device_fitter_type> device_fitting(
Expand Down
4 changes: 4 additions & 0 deletions examples/run/cuda/truth_fitting_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ int main(int argc, char* argv[]) {
// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
fit_cfg.step_constraint = propagation_opts.step_constraint;
fit_cfg.overstep_tolerance = propagation_opts.overstep_tolerance;
fit_cfg.mask_tolerance = propagation_opts.mask_tolerance;
fit_cfg.rk_tolerance = propagation_opts.rk_tolerance;

traccc::fitting_algorithm<host_fitter_type> host_fitting(fit_cfg);
traccc::cuda::fitting_algorithm<device_fitter_type> device_fitting(
fit_cfg, mr, async_copy, stream);
Expand Down
1 change: 1 addition & 0 deletions examples/simulation/simulate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ int main(int argc, char* argv[]) {
sim.get_config().step_constraint = propagation_opts.step_constraint;
sim.get_config().overstep_tolerance = propagation_opts.overstep_tolerance;
sim.get_config().mask_tolerance = propagation_opts.mask_tolerance;
sim.get_config().rk_tolerance = propagation_opts.rk_tolerance;

sim.run();

Expand Down
1 change: 1 addition & 0 deletions examples/simulation/simulate_telescope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ int simulate(std::string output_directory, unsigned int events,
sim.get_config().step_constraint = propagation_opts.step_constraint;
sim.get_config().overstep_tolerance = propagation_opts.overstep_tolerance;
sim.get_config().mask_tolerance = propagation_opts.mask_tolerance;
sim.get_config().rk_tolerance = propagation_opts.rk_tolerance;

sim.run();

Expand Down
1 change: 1 addition & 0 deletions examples/simulation/simulate_toy_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ int simulate(std::string output_directory, unsigned int events,
sim.get_config().step_constraint = propagation_opts.step_constraint;
sim.get_config().overstep_tolerance = propagation_opts.overstep_tolerance;
sim.get_config().mask_tolerance = propagation_opts.mask_tolerance;
sim.get_config().rk_tolerance = propagation_opts.rk_tolerance;

sim.run();

Expand Down
1 change: 1 addition & 0 deletions examples/simulation/simulate_wire_chamber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ int simulate(std::string output_directory, unsigned int events,
sim.get_config().step_constraint = propagation_opts.step_constraint;
sim.get_config().overstep_tolerance = propagation_opts.overstep_tolerance;
sim.get_config().mask_tolerance = propagation_opts.mask_tolerance;
sim.get_config().rk_tolerance = propagation_opts.rk_tolerance;

sim.run();

Expand Down
2 changes: 2 additions & 0 deletions simulation/include/traccc/simulation/simulator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct simulator {
scalar_type overstep_tolerance{-100.f * detray::unit<scalar_type>::um};
scalar_type step_constraint{std::numeric_limits<scalar_type>::max()};
scalar_type mask_tolerance = 15.f * detray::unit<scalar_type>::um;
scalar_type rk_tolerance = 1e-4;
};

using transform3 = typename detector_t::transform3;
Expand Down Expand Up @@ -97,6 +98,7 @@ struct simulator {
detray::step::constraint::e_accuracy>(
m_cfg.step_constraint);
propagation.set_mask_tolerance(m_cfg.mask_tolerance);
propagation._stepping.set_tolerance(m_cfg.rk_tolerance);

p.propagate(propagation, actor_states);

Expand Down

0 comments on commit e91fb9b

Please sign in to comment.