From e986aef928031ab7233c45cd0570b02d88f03aa0 Mon Sep 17 00:00:00 2001 From: Attila Krasznahorkay Date: Thu, 24 Oct 2024 16:53:34 +0200 Subject: [PATCH 1/4] Introduced traccc::host::kf_algorithm. It is a replacement for traccc::fitting_algorithm, without a templated API. --- core/CMakeLists.txt | 5 ++ core/include/traccc/fitting/kf_algorithm.hpp | 76 +++++++++++++++++++ core/src/fitting/fit_tracks.hpp | 74 ++++++++++++++++++ core/src/fitting/kf_algorithm.cpp | 15 ++++ .../fitting/kf_algorithm_defdet_cfield.cpp | 38 ++++++++++ .../fitting/kf_algorithm_teldet_cfield.cpp | 38 ++++++++++ 6 files changed, 246 insertions(+) create mode 100644 core/include/traccc/fitting/kf_algorithm.hpp create mode 100644 core/src/fitting/fit_tracks.hpp create mode 100644 core/src/fitting/kf_algorithm.cpp create mode 100644 core/src/fitting/kf_algorithm_defdet_cfield.cpp create mode 100644 core/src/fitting/kf_algorithm_teldet_cfield.cpp diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 2dc99f138..e19def9b9 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -75,6 +75,11 @@ traccc_add_library( traccc_core core TYPE SHARED "include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp" "include/traccc/fitting/kalman_filter/statistics_updater.hpp" "include/traccc/fitting/fitting_algorithm.hpp" + "src/fitting/fit_tracks.hpp" + "include/traccc/fitting/kf_algorithm.hpp" + "src/fitting/kf_algorithm.cpp" + "src/fitting/kf_algorithm_defdet_cfield.cpp" + "src/fitting/kf_algorithm_teldet_cfield.cpp" # Seed finding algorithmic code. "include/traccc/seeding/detail/lin_circle.hpp" "include/traccc/seeding/detail/doublet.hpp" diff --git a/core/include/traccc/fitting/kf_algorithm.hpp b/core/include/traccc/fitting/kf_algorithm.hpp new file mode 100644 index 000000000..3e23f7d57 --- /dev/null +++ b/core/include/traccc/fitting/kf_algorithm.hpp @@ -0,0 +1,76 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Project include(s). +#include "traccc/edm/track_candidate.hpp" +#include "traccc/edm/track_state.hpp" +#include "traccc/fitting/fitting_config.hpp" +#include "traccc/geometry/detector.hpp" +#include "traccc/utils/algorithm.hpp" + +// Detray include(s). +#include + +namespace traccc::host { + +/// Kalman filter based track fitting algorithm +class kf_algorithm : public algorithm, + public algorithm { + + public: + /// Configuration type + using config_type = fitting_config; + /// Output type + using output_type = track_state_container_types::host; + + /// Constructor with the algorithm's configuration + /// + /// @param config The configuration object + /// + kf_algorithm(const config_type& config); + + /// Execute the algorithm + /// + /// @param det The (default) detector object + /// @param field The (constant) magnetic field object + /// @param track_candidates All track candidates to fit + /// + /// @return A container of the fitted track states + /// + output_type operator()(const default_detector::host& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& + track_candidates) const override; + + /// Execute the algorithm + /// + /// @param det The (telescope) detector object + /// @param field The (constant) magnetic field object + /// @param track_candidates All track candidates to fit + /// + /// @return A container of the fitted track states + /// + output_type operator()(const telescope_detector::host& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& + track_candidates) const override; + + private: + /// Algorithm configuration + config_type m_config; + +}; // class kf_algorithm + +} // namespace traccc::host diff --git a/core/src/fitting/fit_tracks.hpp b/core/src/fitting/fit_tracks.hpp new file mode 100644 index 000000000..00cc6bb3d --- /dev/null +++ b/core/src/fitting/fit_tracks.hpp @@ -0,0 +1,74 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Project include(s). +#include "traccc/edm/track_candidate.hpp" +#include "traccc/edm/track_state.hpp" + +namespace traccc::host::details { + +/// Templated implementation of the track fitting algorithm. +/// +/// Concrete track fitting algorithms can use this function with the appropriate +/// specializations, to fit tracks on top of a specific detector type, magnetic +/// field type, and track fitting configuration. +/// +/// @tparam fitter_t The fitter type used for the track fitting +/// +/// @param det The detector object +/// @param field The magnetic field object +/// @param track_candidates All track candidates to fit +/// @param config The track fitting configuration +/// +/// @return A container of the fitted track states +/// +template +track_state_container_types::host fit_tracks( + const typename fitter_t::detector_type& det, + const typename fitter_t::bfield_type& field, + const track_candidate_container_types::const_view& track_candidates_view, + const typename fitter_t::config_type& config) { + + // Create the fitter object. + fitter_t fitter(det, field, config); + + // Output container. + track_state_container_types::host output_states; + + // Iterate over the tracks, + const track_candidate_container_types::const_device track_candidates{ + track_candidates_view}; + for (track_candidate_container_types::const_device::size_type i = 0; + i < track_candidates.size(); ++i) { + + // Make a vector of track states for this track. + vecmem::vector > + input_states; + input_states.reserve(track_candidates.get_items()[i].size()); + for (auto& measurement : track_candidates.get_items()[i]) { + input_states.emplace_back(measurement); + } + + // Make a fitter state + typename fitter_t::state fitter_state(std::move(input_states)); + + // Run the fitter. + fitter.fit(track_candidates.get_headers()[i], fitter_state); + + // Save the results into the output container. + output_states.push_back( + std::move(fitter_state.m_fit_res), + std::move(fitter_state.m_fit_actor_state.m_track_states)); + } + + // Return the fitted track states. + return output_states; +} + +} // namespace traccc::host::details diff --git a/core/src/fitting/kf_algorithm.cpp b/core/src/fitting/kf_algorithm.cpp new file mode 100644 index 000000000..bca1fa0cf --- /dev/null +++ b/core/src/fitting/kf_algorithm.cpp @@ -0,0 +1,15 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +// Project include(s). +#include "traccc/fitting/kf_algorithm.hpp" + +namespace traccc::host { + +kf_algorithm::kf_algorithm(const config_type& config) : m_config(config) {} + +} // namespace traccc::host diff --git a/core/src/fitting/kf_algorithm_defdet_cfield.cpp b/core/src/fitting/kf_algorithm_defdet_cfield.cpp new file mode 100644 index 000000000..1b18de4ff --- /dev/null +++ b/core/src/fitting/kf_algorithm_defdet_cfield.cpp @@ -0,0 +1,38 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +// Project include(s). +#include "fit_tracks.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" + +// Detray include(s). +#include +#include + +namespace traccc::host { + +kf_algorithm::output_type kf_algorithm::operator()( + const default_detector::host& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& track_candidates) const { + + // Set up the fitter type(s). + using stepper_type = + detray::rk_stepper>; + using navigator_type = + detray::navigator; + using fitter_type = kalman_fitter; + + // Perform the track fitting using a common, templated function. + return details::fit_tracks(det, field, track_candidates, + m_config); +} + +} // namespace traccc::host diff --git a/core/src/fitting/kf_algorithm_teldet_cfield.cpp b/core/src/fitting/kf_algorithm_teldet_cfield.cpp new file mode 100644 index 000000000..5ca4f96bc --- /dev/null +++ b/core/src/fitting/kf_algorithm_teldet_cfield.cpp @@ -0,0 +1,38 @@ +/** TRACCC library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +// Project include(s). +#include "fit_tracks.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" + +// Detray include(s). +#include +#include + +namespace traccc::host { + +kf_algorithm::output_type kf_algorithm::operator()( + const telescope_detector::host& det, + const detray::bfield::const_field_t::view_t& field, + const track_candidate_container_types::const_view& track_candidates) const { + + // Set up the fitter type(s). + using stepper_type = + detray::rk_stepper>; + using navigator_type = + detray::navigator; + using fitter_type = kalman_fitter; + + // Perform the track fitting using a common, templated function. + return details::fit_tracks(det, field, track_candidates, + m_config); +} + +} // namespace traccc::host From 5f432dbe6d30f52143278e4483ca87ac9a6f7edc Mon Sep 17 00:00:00 2001 From: Attila Krasznahorkay Date: Mon, 28 Oct 2024 15:32:07 +0100 Subject: [PATCH 2/4] Removed traccc::fitting_algorithm. Updating all users to use traccc::host::kf_algorithm instread. --- .../benchmarks/toy_detector_benchmark.hpp | 2 +- benchmarks/cpu/toy_detector_cpu.cpp | 14 +-- benchmarks/cuda/toy_detector_cuda.cpp | 1 + core/CMakeLists.txt | 1 - .../traccc/fitting/fitting_algorithm.hpp | 88 ------------------- examples/run/cpu/full_chain_algorithm.cpp | 11 ++- examples/run/cpu/full_chain_algorithm.hpp | 15 +--- examples/run/cpu/seeding_example.cpp | 20 ++--- examples/run/cpu/seq_example.cpp | 15 +--- examples/run/cpu/truth_finding_example.cpp | 21 ++--- examples/run/cpu/truth_fitting_example.cpp | 24 ++--- examples/run/cuda/seeding_example_cuda.cpp | 14 ++- examples/run/cuda/seq_example_cuda.cpp | 10 +-- .../run/cuda/truth_finding_example_cuda.cpp | 15 ++-- .../run/cuda/truth_fitting_example_cuda.cpp | 23 ++--- examples/run/sycl/full_chain_algorithm.hpp | 6 +- .../cpu/test_ckf_combinatorics_telescope.cpp | 1 - .../cpu/test_ckf_sparse_tracks_telescope.cpp | 9 +- tests/cpu/test_kalman_fitter_telescope.cpp | 9 +- tests/cpu/test_kalman_fitter_wire_chamber.cpp | 9 +- tests/cuda/test_kalman_fitter_telescope.cpp | 1 - 21 files changed, 79 insertions(+), 230 deletions(-) delete mode 100644 core/include/traccc/fitting/fitting_algorithm.hpp diff --git a/benchmarks/common/benchmarks/toy_detector_benchmark.hpp b/benchmarks/common/benchmarks/toy_detector_benchmark.hpp index f76d677ed..ef1a0b95c 100644 --- a/benchmarks/common/benchmarks/toy_detector_benchmark.hpp +++ b/benchmarks/common/benchmarks/toy_detector_benchmark.hpp @@ -8,7 +8,7 @@ // Traccc include(s). #include "traccc/definitions/common.hpp" #include "traccc/finding/finding_config.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/fitting_config.hpp" #include "traccc/io/utils.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" diff --git a/benchmarks/cpu/toy_detector_cpu.cpp b/benchmarks/cpu/toy_detector_cpu.cpp index 363c6eb7b..825c38316 100644 --- a/benchmarks/cpu/toy_detector_cpu.cpp +++ b/benchmarks/cpu/toy_detector_cpu.cpp @@ -10,7 +10,7 @@ // Traccc algorithm include(s). #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" @@ -40,14 +40,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { // Type declarations - using rk_stepper_type = - detray::rk_stepper>; using host_detector_type = traccc::default_detector::host; - using host_navigator_type = detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; // Read back detector file host_detector_type det{host_mr}; @@ -64,7 +57,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { traccc::track_params_estimation tp(host_mr); traccc::host::combinatorial_kalman_filter_algorithm host_finding( finding_cfg); - traccc::fitting_algorithm host_fitting(fitting_cfg); + traccc::host::kf_algorithm host_fitting(fitting_cfg); for (auto _ : state) { @@ -87,7 +80,8 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { vecmem::get_data(params)); // Track fitting with KF - auto track_states = host_fitting(det, field, track_candidates); + auto track_states = + host_fitting(det, field, traccc::get_data(track_candidates)); } } diff --git a/benchmarks/cuda/toy_detector_cuda.cpp b/benchmarks/cuda/toy_detector_cuda.cpp index ec0985551..01a57adb1 100644 --- a/benchmarks/cuda/toy_detector_cuda.cpp +++ b/benchmarks/cuda/toy_detector_cuda.cpp @@ -11,6 +11,7 @@ #include "traccc/cuda/seeding/seeding_algorithm.hpp" #include "traccc/cuda/seeding/track_params_estimation.hpp" #include "traccc/device/container_d2h_copy_alg.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_geometry.hpp" diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index e19def9b9..c4f5ff070 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -74,7 +74,6 @@ traccc_add_library( traccc_core core TYPE SHARED "include/traccc/fitting/kalman_filter/kalman_fitter.hpp" "include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp" "include/traccc/fitting/kalman_filter/statistics_updater.hpp" - "include/traccc/fitting/fitting_algorithm.hpp" "src/fitting/fit_tracks.hpp" "include/traccc/fitting/kf_algorithm.hpp" "src/fitting/kf_algorithm.cpp" diff --git a/core/include/traccc/fitting/fitting_algorithm.hpp b/core/include/traccc/fitting/fitting_algorithm.hpp deleted file mode 100644 index f6402555b..000000000 --- a/core/include/traccc/fitting/fitting_algorithm.hpp +++ /dev/null @@ -1,88 +0,0 @@ -/** TRACCC library, part of the ACTS project (R&D line) - * - * (c) 2022-2023 CERN for the benefit of the ACTS project - * - * Mozilla Public License Version 2.0 - */ - -#pragma once - -// Project include(s). -#include "traccc/definitions/qualifiers.hpp" -#include "traccc/edm/track_candidate.hpp" -#include "traccc/edm/track_state.hpp" -#include "traccc/fitting/fitting_config.hpp" -#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" -#include "traccc/utils/algorithm.hpp" - -namespace traccc { - -/// Fitting algorithm for a set of tracks -template -class fitting_algorithm - : public algorithm { - - public: - using algebra_type = typename fitter_t::algebra_type; - using bfield_type = typename fitter_t::bfield_type; - /// Configuration type - using config_type = typename fitter_t::config_type; - - /// Constructor for the fitting algorithm - /// - /// @param cfg Configuration object - fitting_algorithm(const config_type& cfg) : m_cfg(cfg) {} - - /// Run the algorithm - /// - /// @param track_candidates the candidate measurements from track finding - /// @return the container of the fitted track parameters - track_state_container_types::host operator()( - const typename fitter_t::detector_type& det, - const typename fitter_t::bfield_type& field, - const typename track_candidate_container_types::host& track_candidates) - const override { - - fitter_t fitter(det, field, m_cfg); - - track_state_container_types::host output_states; - - // The number of tracks - std::size_t n_tracks = track_candidates.size(); - - // Iterate over tracks - for (std::size_t i = 0; i < n_tracks; i++) { - - // Seed parameter - const auto& seed_param = track_candidates[i].header; - - // Make a vector of track state - auto& cands = track_candidates[i].items; - vecmem::vector> input_states; - input_states.reserve(cands.size()); - for (auto& cand : cands) { - input_states.emplace_back(cand); - } - - // Make a fitter state - typename fitter_t::state fitter_state(std::move(input_states)); - - // Run fitter - fitter.fit(seed_param, fitter_state); - - output_states.push_back( - std::move(fitter_state.m_fit_res), - std::move(fitter_state.m_fit_actor_state.m_track_states)); - } - - return output_states; - } - - /// Config object - config_type m_cfg; -}; - -} // namespace traccc diff --git a/examples/run/cpu/full_chain_algorithm.cpp b/examples/run/cpu/full_chain_algorithm.cpp index 9abda5159..349fa09d9 100644 --- a/examples/run/cpu/full_chain_algorithm.cpp +++ b/examples/run/cpu/full_chain_algorithm.cpp @@ -61,11 +61,14 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()( const bound_track_parameters_collection_types::const_view track_params_view = vecmem::get_data(track_params); - // Return the final container, after track finding and fitting. - return m_fitting(*m_detector, m_field, - m_finding(*m_detector, m_field, measurements_view, - track_params_view)); + // Run the track finding. + const finding_algorithm::output_type track_candidates = m_finding( + *m_detector, m_field, measurements_view, track_params_view); + // Run the track fitting, and return its results. + const track_candidate_container_types::const_view + track_candidates_view = get_data(track_candidates); + return m_fitting(*m_detector, m_field, track_candidates_view); } // If not, just return an empty object. else { diff --git a/examples/run/cpu/full_chain_algorithm.hpp b/examples/run/cpu/full_chain_algorithm.hpp index 546dbe79c..35c7811b5 100644 --- a/examples/run/cpu/full_chain_algorithm.hpp +++ b/examples/run/cpu/full_chain_algorithm.hpp @@ -12,8 +12,7 @@ #include "traccc/edm/silicon_cell_collection.hpp" #include "traccc/edm/track_state.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" -#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/geometry/silicon_detector_description.hpp" #include "traccc/seeding/seeding_algorithm.hpp" @@ -50,14 +49,7 @@ class full_chain_algorithm : public algorithm>; - /// Navigator type used by the track finding and fitting algorithms - using navigator_type = detray::navigator; - + /// Clusterization algorithm type using clustering_algorithm = host::clusterization_algorithm; /// Spacepoint formation algorithm type using spacepoint_formation_algorithm = @@ -66,8 +58,7 @@ class full_chain_algorithm : public algorithm>; + using fitting_algorithm = traccc::host::kf_algorithm; /// @} diff --git a/examples/run/cpu/seeding_example.cpp b/examples/run/cpu/seeding_example.cpp index da7c81abd..9502e8031 100644 --- a/examples/run/cpu/seeding_example.cpp +++ b/examples/run/cpu/seeding_example.cpp @@ -19,7 +19,7 @@ // algorithms #include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" @@ -65,17 +65,6 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, const traccc::opts::detector& detector_opts, const traccc::opts::performance& performance_opts) { - /// Type declarations - using b_field_t = covfie::field; - using rk_stepper_type = - detray::rk_stepper>; - using host_navigator_type = - detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; - // Memory resource used by the EDM. vecmem::host_memory_resource host_mr; @@ -132,10 +121,10 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); traccc::greedy_ambiguity_resolution_algorithm host_ambiguity_resolution{}; @@ -192,7 +181,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, Track Fitting with KF ------------------------*/ - track_states = host_fitting(detector, field, track_candidates); + track_states = + host_fitting(detector, field, traccc::get_data(track_candidates)); n_fitted_tracks += track_states.size(); /*----------------------------------------- diff --git a/examples/run/cpu/seq_example.cpp b/examples/run/cpu/seq_example.cpp index 26cc1aa63..33a63416d 100644 --- a/examples/run/cpu/seq_example.cpp +++ b/examples/run/cpu/seq_example.cpp @@ -16,7 +16,7 @@ #include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp" #include "traccc/clusterization/clusterization_algorithm.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" @@ -100,16 +100,9 @@ int seq_run(const traccc::opts::input_data& input_opts, // Type definitions using spacepoint_formation_algorithm = traccc::host::silicon_pixel_spacepoint_formation_algorithm; - using stepper_type = - detray::rk_stepper>; - using navigator_type = - detray::navigator; using finding_algorithm = traccc::host::combinatorial_kalman_filter_algorithm; - using fitting_algorithm = traccc::fitting_algorithm< - traccc::kalman_fitter>; + using fitting_algorithm = traccc::host::kf_algorithm; // Constant B field for the track finding and fitting const traccc::vector3 field_vec = {0.f, 0.f, @@ -257,8 +250,8 @@ int seq_run(const traccc::opts::input_data& input_opts, { traccc::performance::timer timer{"Track fitting", elapsedTimes}; - track_states = - fitting_alg(detector, field, track_candidates); + track_states = fitting_alg( + detector, field, traccc::get_data(track_candidates)); } } diff --git a/examples/run/cpu/truth_finding_example.cpp b/examples/run/cpu/truth_finding_example.cpp index 208b4d864..5258e8119 100644 --- a/examples/run/cpu/truth_finding_example.cpp +++ b/examples/run/cpu/truth_finding_example.cpp @@ -10,8 +10,7 @@ #include "traccc/definitions/primitives.hpp" #include "traccc/efficiency/finding_performance_writer.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" -#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" #include "traccc/io/read_measurements.hpp" @@ -52,17 +51,6 @@ int seq_run(const traccc::opts::track_finding& finding_opts, const traccc::opts::detector& detector_opts, const traccc::opts::performance& performance_opts) { - /// Type declarations - using b_field_t = covfie::field; - using rk_stepper_type = - detray::rk_stepper>; - - using host_navigator_type = - detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; - // Memory resources used by the application. vecmem::host_memory_resource host_mr; @@ -112,10 +100,10 @@ int seq_run(const traccc::opts::track_finding& finding_opts, traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); // Seed generator traccc::seed_generator sg(detector, @@ -157,7 +145,8 @@ int seq_run(const traccc::opts::track_finding& finding_opts, << std::endl; // Run fitting - auto track_states = host_fitting(detector, field, track_candidates); + auto track_states = + host_fitting(detector, field, traccc::get_data(track_candidates)); std::cout << "Number of fitted tracks: " << track_states.size() << std::endl; diff --git a/examples/run/cpu/truth_fitting_example.cpp b/examples/run/cpu/truth_fitting_example.cpp index 2874cb562..792e158a0 100644 --- a/examples/run/cpu/truth_fitting_example.cpp +++ b/examples/run/cpu/truth_fitting_example.cpp @@ -8,8 +8,8 @@ // Project include(s). #include "traccc/definitions/common.hpp" #include "traccc/definitions/primitives.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" -#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/geometry/detector.hpp" #include "traccc/io/read_geometry.hpp" #include "traccc/io/utils.hpp" #include "traccc/options/detector.hpp" @@ -57,17 +57,7 @@ int main(int argc, char* argv[]) { argv}; /// Type declarations - using host_detector_type = detray::detector; - - using b_field_t = covfie::field; - using rk_stepper_type = - detray::rk_stepper>; - - using host_navigator_type = detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; + using host_detector_type = traccc::default_detector::host; // Memory resources used by the application. vecmem::host_memory_resource host_mr; @@ -114,10 +104,10 @@ int main(int argc, char* argv[]) { 1.f * detray::unit::ns}; // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); // Seed generator traccc::seed_generator sg(host_det, stddevs); @@ -135,8 +125,8 @@ int main(int argc, char* argv[]) { evt_data.generate_truth_candidates(sg, host_mr); // Run fitting - auto track_states = - host_fitting(host_det, field, truth_track_candidates); + auto track_states = host_fitting( + host_det, field, traccc::get_data(truth_track_candidates)); std::cout << "Number of fitted tracks: " << track_states.size() << std::endl; diff --git a/examples/run/cuda/seeding_example_cuda.cpp b/examples/run/cuda/seeding_example_cuda.cpp index edf548778..3898dc588 100644 --- a/examples/run/cuda/seeding_example_cuda.cpp +++ b/examples/run/cuda/seeding_example_cuda.cpp @@ -18,7 +18,8 @@ #include "traccc/efficiency/seeding_performance_writer.hpp" #include "traccc/efficiency/track_filter.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" #include "traccc/io/read_measurements.hpp" @@ -76,10 +77,6 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, b_field_t::view_t, typename traccc::default_detector::host::algebra_type, detray::constrained_step<>>; - using host_navigator_type = - detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; using device_navigator_type = detray::navigator; using device_fitter_type = @@ -179,10 +176,10 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, device_finding(cfg, mr, async_copy, stream); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); @@ -333,7 +330,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, if (accelerator_opts.compare_with_cpu) { traccc::performance::timer t("Track fitting with KF (cpu)", elapsedTimes); - track_states = host_fitting(host_det, field, track_candidates); + track_states = host_fitting(host_det, field, + traccc::get_data(track_candidates)); } } // Stop measuring wall time diff --git a/examples/run/cuda/seq_example_cuda.cpp b/examples/run/cuda/seq_example_cuda.cpp index f48821a7b..9086c2c9e 100644 --- a/examples/run/cuda/seq_example_cuda.cpp +++ b/examples/run/cuda/seq_example_cuda.cpp @@ -18,7 +18,7 @@ #include "traccc/device/container_d2h_copy_alg.hpp" #include "traccc/efficiency/seeding_performance_writer.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/read_cells.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" @@ -131,8 +131,6 @@ int seq_run(const traccc::opts::detector& detector_opts, detray::rk_stepper>; - using host_navigator_type = - detray::navigator; using device_navigator_type = detray::navigator; @@ -141,8 +139,7 @@ int seq_run(const traccc::opts::detector& detector_opts, using device_finding_algorithm = traccc::cuda::finding_algorithm; - using host_fitting_algorithm = traccc::fitting_algorithm< - traccc::kalman_fitter>; + using host_fitting_algorithm = traccc::host::kf_algorithm; using device_fitting_algorithm = traccc::cuda::fitting_algorithm< traccc::kalman_fitter>; @@ -343,7 +340,8 @@ int seq_run(const traccc::opts::detector& detector_opts, traccc::performance::timer timer{"Track fitting (cpu)", elapsedTimes}; track_states = - fitting_alg(host_detector, field, track_candidates); + fitting_alg(host_detector, field, + traccc::get_data(track_candidates)); } } diff --git a/examples/run/cuda/truth_finding_example_cuda.cpp b/examples/run/cuda/truth_finding_example_cuda.cpp index 501c6a144..d8bf17bf2 100644 --- a/examples/run/cuda/truth_finding_example_cuda.cpp +++ b/examples/run/cuda/truth_finding_example_cuda.cpp @@ -15,8 +15,8 @@ #include "traccc/device/container_h2d_copy_alg.hpp" #include "traccc/efficiency/finding_performance_writer.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" #include "traccc/io/read_measurements.hpp" @@ -69,10 +69,6 @@ int seq_run(const traccc::opts::track_finding& finding_opts, using rk_stepper_type = detray::rk_stepper>; - using host_navigator_type = - detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; using device_navigator_type = detray::navigator; using device_fitter_type = @@ -156,10 +152,10 @@ int seq_run(const traccc::opts::track_finding& finding_opts, device_finding(cfg, mr, async_copy, stream); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); @@ -245,7 +241,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, // CPU containers traccc::host::combinatorial_kalman_filter_algorithm::output_type track_candidates; - traccc::fitting_algorithm::output_type track_states; + traccc::host::kf_algorithm::output_type track_states; if (accelerator_opts.compare_with_cpu) { @@ -264,7 +260,8 @@ int seq_run(const traccc::opts::track_finding& finding_opts, elapsedTimes); // Run fitting - track_states = host_fitting(detector, field, track_candidates); + track_states = host_fitting(detector, field, + traccc::get_data(track_candidates)); } } diff --git a/examples/run/cuda/truth_fitting_example_cuda.cpp b/examples/run/cuda/truth_fitting_example_cuda.cpp index ca5b2095b..fb5cb55c4 100644 --- a/examples/run/cuda/truth_fitting_example_cuda.cpp +++ b/examples/run/cuda/truth_fitting_example_cuda.cpp @@ -12,8 +12,9 @@ #include "traccc/definitions/primitives.hpp" #include "traccc/device/container_d2h_copy_alg.hpp" #include "traccc/device/container_h2d_copy_alg.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/geometry/detector.hpp" #include "traccc/io/read_geometry.hpp" #include "traccc/io/read_measurements.hpp" #include "traccc/io/utils.hpp" @@ -71,19 +72,13 @@ int main(int argc, char* argv[]) { argv}; /// Type declarations - using host_detector_type = detray::detector; - using device_detector_type = - detray::detector; + using host_detector_type = traccc::default_detector::host; + using device_detector_type = traccc::default_detector::device; using b_field_t = covfie::field; using rk_stepper_type = detray::rk_stepper>; - using host_navigator_type = detray::navigator; - using host_fitter_type = - traccc::kalman_fitter; using device_navigator_type = detray::navigator; using device_fitter_type = traccc::kalman_fitter; @@ -157,10 +152,10 @@ int main(int argc, char* argv[]) { 1.f * detray::unit::ns}; // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); @@ -202,7 +197,7 @@ int main(int argc, char* argv[]) { track_state_d2h(track_states_cuda_buffer); // CPU container(s) - traccc::fitting_algorithm::output_type track_states; + traccc::host::kf_algorithm::output_type track_states; if (accelerator_opts.compare_with_cpu) { @@ -211,8 +206,8 @@ int main(int argc, char* argv[]) { elapsedTimes); // Run fitting - track_states = - host_fitting(host_det, field, truth_track_candidates); + track_states = host_fitting( + host_det, field, traccc::get_data(truth_track_candidates)); } } diff --git a/examples/run/sycl/full_chain_algorithm.hpp b/examples/run/sycl/full_chain_algorithm.hpp index 4a1692dba..e6bb28c1f 100644 --- a/examples/run/sycl/full_chain_algorithm.hpp +++ b/examples/run/sycl/full_chain_algorithm.hpp @@ -10,8 +10,7 @@ // Project include(s). #include "traccc/edm/silicon_cell_collection.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" -#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/geometry/silicon_detector_description.hpp" #include "traccc/sycl/clusterization/clusterization_algorithm.hpp" @@ -76,8 +75,7 @@ class full_chain_algorithm using finding_algorithm = traccc::host::combinatorial_kalman_filter_algorithm; /// Track fitting algorithm type - using fitting_algorithm = traccc::fitting_algorithm< - traccc::kalman_fitter>; + using fitting_algorithm = traccc::host::kf_algorithm; /// @} diff --git a/tests/cpu/test_ckf_combinatorics_telescope.cpp b/tests/cpu/test_ckf_combinatorics_telescope.cpp index 79da9735b..f72684653 100644 --- a/tests/cpu/test_ckf_combinatorics_telescope.cpp +++ b/tests/cpu/test_ckf_combinatorics_telescope.cpp @@ -7,7 +7,6 @@ // Project include(s). #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" #include "traccc/io/read_measurements.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" diff --git a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp index d044c8595..c3cbc86a7 100644 --- a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp +++ b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp @@ -7,7 +7,7 @@ // Project include(s). #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/read_measurements.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" @@ -131,12 +131,12 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) { traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.ptc_hypothesis = ptc; fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::fitting_algorithm host_fitting(fit_cfg); + traccc::host::kf_algorithm host_fitting(fit_cfg); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { @@ -168,7 +168,8 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) { ASSERT_EQ(track_candidates.size(), n_truth_tracks); // Run fitting - auto track_states = host_fitting(host_det, field, track_candidates); + auto track_states = + host_fitting(host_det, field, traccc::get_data(track_candidates)); ASSERT_EQ(track_states.size(), n_truth_tracks); diff --git a/tests/cpu/test_kalman_fitter_telescope.cpp b/tests/cpu/test_kalman_fitter_telescope.cpp index 28c6e1781..af7203dae 100644 --- a/tests/cpu/test_kalman_fitter_telescope.cpp +++ b/tests/cpu/test_kalman_fitter_telescope.cpp @@ -7,7 +7,7 @@ // Project include(s). #include "traccc/edm/track_state.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" #include "traccc/simulation/simulator.hpp" @@ -118,12 +118,12 @@ TEST_P(KalmanFittingTelescopeTests, Run) { seed_generator sg(host_det, stddevs); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.ptc_hypothesis = ptc; fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - fitting_algorithm fitting(fit_cfg); + traccc::host::kf_algorithm fitting(fit_cfg); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { @@ -138,7 +138,8 @@ TEST_P(KalmanFittingTelescopeTests, Run) { ASSERT_EQ(track_candidates.size(), n_truth_tracks); // Run fitting - auto track_states = fitting(host_det, field, track_candidates); + auto track_states = + fitting(host_det, field, traccc::get_data(track_candidates)); // Iterator over tracks const std::size_t n_tracks = track_states.size(); diff --git a/tests/cpu/test_kalman_fitter_wire_chamber.cpp b/tests/cpu/test_kalman_fitter_wire_chamber.cpp index 80dce2471..8bb283e40 100644 --- a/tests/cpu/test_kalman_fitter_wire_chamber.cpp +++ b/tests/cpu/test_kalman_fitter_wire_chamber.cpp @@ -7,7 +7,7 @@ // Project include(s). #include "traccc/edm/track_state.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" +#include "traccc/fitting/kf_algorithm.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" #include "traccc/simulation/measurement_smearer.hpp" @@ -119,12 +119,12 @@ TEST_P(KalmanFittingWireChamberTests, Run) { seed_generator sg(host_det, stddevs); // Fitting algorithm object - typename traccc::fitting_algorithm::config_type fit_cfg; + traccc::fitting_config fit_cfg; fit_cfg.propagation.navigation.min_mask_tolerance = static_cast(mask_tolerance); fit_cfg.propagation.navigation.search_window = search_window; fit_cfg.ptc_hypothesis = ptc; - fitting_algorithm fitting(fit_cfg); + traccc::host::kf_algorithm fitting(fit_cfg); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { @@ -139,7 +139,8 @@ TEST_P(KalmanFittingWireChamberTests, Run) { ASSERT_EQ(track_candidates.size(), n_truth_tracks); // Run fitting - auto track_states = fitting(host_det, field, track_candidates); + auto track_states = + fitting(host_det, field, traccc::get_data(track_candidates)); // Iterator over tracks const std::size_t n_tracks = track_states.size(); diff --git a/tests/cuda/test_kalman_fitter_telescope.cpp b/tests/cuda/test_kalman_fitter_telescope.cpp index 6801ecf98..cff0402d3 100644 --- a/tests/cuda/test_kalman_fitter_telescope.cpp +++ b/tests/cuda/test_kalman_fitter_telescope.cpp @@ -10,7 +10,6 @@ #include "traccc/device/container_d2h_copy_alg.hpp" #include "traccc/device/container_h2d_copy_alg.hpp" #include "traccc/edm/track_state.hpp" -#include "traccc/fitting/fitting_algorithm.hpp" #include "traccc/io/utils.hpp" #include "traccc/performance/details/is_same_object.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" From 5d0e763e6a10a772aeada25b91e8c6d63fe52978 Mon Sep 17 00:00:00 2001 From: Attila Krasznahorkay Date: Tue, 29 Oct 2024 17:34:06 +0100 Subject: [PATCH 3/4] Renamed the new algorithm and its files following Beomki's suggestions. --- benchmarks/cpu/toy_detector_cpu.cpp | 4 ++-- core/CMakeLists.txt | 8 +++---- ...rithm.hpp => kalman_fitting_algorithm.hpp} | 21 ++++++++++--------- ...rithm.cpp => kalman_fitting_algorithm.cpp} | 5 +++-- ...rithm_constant_field_default_detector.cpp} | 4 ++-- ...thm_constant_field_telescope_detector.cpp} | 4 ++-- examples/run/cpu/full_chain_algorithm.hpp | 4 ++-- examples/run/cpu/seeding_example.cpp | 4 ++-- examples/run/cpu/seq_example.cpp | 4 ++-- examples/run/cpu/truth_finding_example.cpp | 4 ++-- examples/run/cpu/truth_fitting_example.cpp | 4 ++-- examples/run/cuda/seeding_example_cuda.cpp | 4 ++-- examples/run/cuda/seq_example_cuda.cpp | 5 +++-- .../run/cuda/truth_finding_example_cuda.cpp | 6 +++--- .../run/cuda/truth_fitting_example_cuda.cpp | 6 +++--- examples/run/sycl/full_chain_algorithm.hpp | 4 ++-- .../cpu/test_ckf_sparse_tracks_telescope.cpp | 4 ++-- tests/cpu/test_kalman_fitter_telescope.cpp | 4 ++-- tests/cpu/test_kalman_fitter_wire_chamber.cpp | 4 ++-- 19 files changed, 53 insertions(+), 50 deletions(-) rename core/include/traccc/fitting/{kf_algorithm.hpp => kalman_fitting_algorithm.hpp} (76%) rename core/src/fitting/{kf_algorithm.cpp => kalman_fitting_algorithm.cpp} (60%) rename core/src/fitting/{kf_algorithm_defdet_cfield.cpp => kalman_fitting_algorithm_constant_field_default_detector.cpp} (90%) rename core/src/fitting/{kf_algorithm_teldet_cfield.cpp => kalman_fitting_algorithm_constant_field_telescope_detector.cpp} (90%) diff --git a/benchmarks/cpu/toy_detector_cpu.cpp b/benchmarks/cpu/toy_detector_cpu.cpp index 825c38316..b9abbf9b7 100644 --- a/benchmarks/cpu/toy_detector_cpu.cpp +++ b/benchmarks/cpu/toy_detector_cpu.cpp @@ -10,7 +10,7 @@ // Traccc algorithm include(s). #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/seeding/seeding_algorithm.hpp" #include "traccc/seeding/track_params_estimation.hpp" @@ -57,7 +57,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { traccc::track_params_estimation tp(host_mr); traccc::host::combinatorial_kalman_filter_algorithm host_finding( finding_cfg); - traccc::host::kf_algorithm host_fitting(fitting_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg); for (auto _ : state) { diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index c4f5ff070..e55bc281c 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -75,10 +75,10 @@ traccc_add_library( traccc_core core TYPE SHARED "include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp" "include/traccc/fitting/kalman_filter/statistics_updater.hpp" "src/fitting/fit_tracks.hpp" - "include/traccc/fitting/kf_algorithm.hpp" - "src/fitting/kf_algorithm.cpp" - "src/fitting/kf_algorithm_defdet_cfield.cpp" - "src/fitting/kf_algorithm_teldet_cfield.cpp" + "include/traccc/fitting/kalman_fitting_algorithm.hpp" + "src/fitting/kalman_fitting_algorithm.cpp" + "src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp" + "src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp" # Seed finding algorithmic code. "include/traccc/seeding/detail/lin_circle.hpp" "include/traccc/seeding/detail/doublet.hpp" diff --git a/core/include/traccc/fitting/kf_algorithm.hpp b/core/include/traccc/fitting/kalman_fitting_algorithm.hpp similarity index 76% rename from core/include/traccc/fitting/kf_algorithm.hpp rename to core/include/traccc/fitting/kalman_fitting_algorithm.hpp index 3e23f7d57..2924ccac0 100644 --- a/core/include/traccc/fitting/kf_algorithm.hpp +++ b/core/include/traccc/fitting/kalman_fitting_algorithm.hpp @@ -20,14 +20,15 @@ namespace traccc::host { /// Kalman filter based track fitting algorithm -class kf_algorithm : public algorithm, - public algorithm { +class kalman_fitting_algorithm + : public algorithm, + public algorithm { public: /// Configuration type @@ -39,7 +40,7 @@ class kf_algorithm : public algorithm @@ -16,7 +16,7 @@ namespace traccc::host { -kf_algorithm::output_type kf_algorithm::operator()( +kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( const default_detector::host& det, const detray::bfield::const_field_t::view_t& field, const track_candidate_container_types::const_view& track_candidates) const { diff --git a/core/src/fitting/kf_algorithm_teldet_cfield.cpp b/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp similarity index 90% rename from core/src/fitting/kf_algorithm_teldet_cfield.cpp rename to core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp index 5ca4f96bc..53404596d 100644 --- a/core/src/fitting/kf_algorithm_teldet_cfield.cpp +++ b/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp @@ -8,7 +8,7 @@ // Project include(s). #include "fit_tracks.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" // Detray include(s). #include @@ -16,7 +16,7 @@ namespace traccc::host { -kf_algorithm::output_type kf_algorithm::operator()( +kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( const telescope_detector::host& det, const detray::bfield::const_field_t::view_t& field, const track_candidate_container_types::const_view& track_candidates) const { diff --git a/examples/run/cpu/full_chain_algorithm.hpp b/examples/run/cpu/full_chain_algorithm.hpp index 35c7811b5..1c4c2b85e 100644 --- a/examples/run/cpu/full_chain_algorithm.hpp +++ b/examples/run/cpu/full_chain_algorithm.hpp @@ -12,7 +12,7 @@ #include "traccc/edm/silicon_cell_collection.hpp" #include "traccc/edm/track_state.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/geometry/silicon_detector_description.hpp" #include "traccc/seeding/seeding_algorithm.hpp" @@ -58,7 +58,7 @@ class full_chain_algorithm : public algorithm sg(detector, diff --git a/examples/run/cpu/truth_fitting_example.cpp b/examples/run/cpu/truth_fitting_example.cpp index 792e158a0..6463d3d26 100644 --- a/examples/run/cpu/truth_fitting_example.cpp +++ b/examples/run/cpu/truth_fitting_example.cpp @@ -8,7 +8,7 @@ // Project include(s). #include "traccc/definitions/common.hpp" #include "traccc/definitions/primitives.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/io/read_geometry.hpp" #include "traccc/io/utils.hpp" @@ -107,7 +107,7 @@ int main(int argc, char* argv[]) { traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::host::kf_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); // Seed generator traccc::seed_generator sg(host_det, stddevs); diff --git a/examples/run/cuda/seeding_example_cuda.cpp b/examples/run/cuda/seeding_example_cuda.cpp index 3898dc588..a45e2a4cf 100644 --- a/examples/run/cuda/seeding_example_cuda.cpp +++ b/examples/run/cuda/seeding_example_cuda.cpp @@ -19,7 +19,7 @@ #include "traccc/efficiency/track_filter.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" #include "traccc/io/read_measurements.hpp" @@ -179,7 +179,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kf_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/examples/run/cuda/seq_example_cuda.cpp b/examples/run/cuda/seq_example_cuda.cpp index 9086c2c9e..a97852f86 100644 --- a/examples/run/cuda/seq_example_cuda.cpp +++ b/examples/run/cuda/seq_example_cuda.cpp @@ -18,7 +18,8 @@ #include "traccc/device/container_d2h_copy_alg.hpp" #include "traccc/efficiency/seeding_performance_writer.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_filter/kalman_fitter.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/io/read_cells.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" @@ -139,7 +140,7 @@ int seq_run(const traccc::opts::detector& detector_opts, using device_finding_algorithm = traccc::cuda::finding_algorithm; - using host_fitting_algorithm = traccc::host::kf_algorithm; + using host_fitting_algorithm = traccc::host::kalman_fitting_algorithm; using device_fitting_algorithm = traccc::cuda::fitting_algorithm< traccc::kalman_fitter>; diff --git a/examples/run/cuda/truth_finding_example_cuda.cpp b/examples/run/cuda/truth_finding_example_cuda.cpp index d8bf17bf2..0e6af2f84 100644 --- a/examples/run/cuda/truth_finding_example_cuda.cpp +++ b/examples/run/cuda/truth_finding_example_cuda.cpp @@ -16,7 +16,7 @@ #include "traccc/efficiency/finding_performance_writer.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/io/read_detector.hpp" #include "traccc/io/read_detector_description.hpp" #include "traccc/io/read_measurements.hpp" @@ -155,7 +155,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kf_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); @@ -241,7 +241,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, // CPU containers traccc::host::combinatorial_kalman_filter_algorithm::output_type track_candidates; - traccc::host::kf_algorithm::output_type track_states; + traccc::host::kalman_fitting_algorithm::output_type track_states; if (accelerator_opts.compare_with_cpu) { diff --git a/examples/run/cuda/truth_fitting_example_cuda.cpp b/examples/run/cuda/truth_fitting_example_cuda.cpp index fb5cb55c4..c93990bb3 100644 --- a/examples/run/cuda/truth_fitting_example_cuda.cpp +++ b/examples/run/cuda/truth_fitting_example_cuda.cpp @@ -13,7 +13,7 @@ #include "traccc/device/container_d2h_copy_alg.hpp" #include "traccc/device/container_h2d_copy_alg.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/io/read_geometry.hpp" #include "traccc/io/read_measurements.hpp" @@ -155,7 +155,7 @@ int main(int argc, char* argv[]) { traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::host::kf_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); @@ -197,7 +197,7 @@ int main(int argc, char* argv[]) { track_state_d2h(track_states_cuda_buffer); // CPU container(s) - traccc::host::kf_algorithm::output_type track_states; + traccc::host::kalman_fitting_algorithm::output_type track_states; if (accelerator_opts.compare_with_cpu) { diff --git a/examples/run/sycl/full_chain_algorithm.hpp b/examples/run/sycl/full_chain_algorithm.hpp index e6bb28c1f..5a280baad 100644 --- a/examples/run/sycl/full_chain_algorithm.hpp +++ b/examples/run/sycl/full_chain_algorithm.hpp @@ -10,7 +10,7 @@ // Project include(s). #include "traccc/edm/silicon_cell_collection.hpp" #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/geometry/detector.hpp" #include "traccc/geometry/silicon_detector_description.hpp" #include "traccc/sycl/clusterization/clusterization_algorithm.hpp" @@ -75,7 +75,7 @@ class full_chain_algorithm using finding_algorithm = traccc::host::combinatorial_kalman_filter_algorithm; /// Track fitting algorithm type - using fitting_algorithm = traccc::host::kf_algorithm; + using fitting_algorithm = traccc::host::kalman_fitting_algorithm; /// @} diff --git a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp index c3cbc86a7..cdb9d06d1 100644 --- a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp +++ b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp @@ -7,7 +7,7 @@ // Project include(s). #include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/io/read_measurements.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" @@ -136,7 +136,7 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) { fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::host::kf_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { diff --git a/tests/cpu/test_kalman_fitter_telescope.cpp b/tests/cpu/test_kalman_fitter_telescope.cpp index af7203dae..19e2b24fc 100644 --- a/tests/cpu/test_kalman_fitter_telescope.cpp +++ b/tests/cpu/test_kalman_fitter_telescope.cpp @@ -7,7 +7,7 @@ // Project include(s). #include "traccc/edm/track_state.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" #include "traccc/simulation/simulator.hpp" @@ -123,7 +123,7 @@ TEST_P(KalmanFittingTelescopeTests, Run) { fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::host::kf_algorithm fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm fitting(fit_cfg); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { diff --git a/tests/cpu/test_kalman_fitter_wire_chamber.cpp b/tests/cpu/test_kalman_fitter_wire_chamber.cpp index 8bb283e40..9a82e01d2 100644 --- a/tests/cpu/test_kalman_fitter_wire_chamber.cpp +++ b/tests/cpu/test_kalman_fitter_wire_chamber.cpp @@ -7,7 +7,7 @@ // Project include(s). #include "traccc/edm/track_state.hpp" -#include "traccc/fitting/kf_algorithm.hpp" +#include "traccc/fitting/kalman_fitting_algorithm.hpp" #include "traccc/io/utils.hpp" #include "traccc/resolution/fitting_performance_writer.hpp" #include "traccc/simulation/measurement_smearer.hpp" @@ -124,7 +124,7 @@ TEST_P(KalmanFittingWireChamberTests, Run) { static_cast(mask_tolerance); fit_cfg.propagation.navigation.search_window = search_window; fit_cfg.ptc_hypothesis = ptc; - traccc::host::kf_algorithm fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm fitting(fit_cfg); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { From b4c53e0cb99c4ddaa48e4239ea614e8b592bc73b Mon Sep 17 00:00:00 2001 From: Attila Krasznahorkay Date: Mon, 4 Nov 2024 13:04:07 +0100 Subject: [PATCH 4/4] Made all templated code of track fitting public. This way clients can have access to the full details of the code if they want to, while also getting algorithms that would perform fitting in one very specific way. --- benchmarks/cpu/toy_detector_cpu.cpp | 2 +- core/CMakeLists.txt | 2 +- .../traccc/fitting/details}/fit_tracks.hpp | 30 ++++++++++--------- .../fitting/kalman_fitting_algorithm.hpp | 11 ++++++- core/src/fitting/kalman_fitting_algorithm.cpp | 5 ++-- ...orithm_constant_field_default_detector.cpp | 16 +++++----- ...ithm_constant_field_telescope_detector.cpp | 16 +++++----- examples/run/cpu/full_chain_algorithm.cpp | 2 +- examples/run/cpu/seeding_example.cpp | 2 +- examples/run/cpu/seq_example.cpp | 2 +- examples/run/cpu/truth_finding_example.cpp | 2 +- examples/run/cpu/truth_fitting_example.cpp | 2 +- examples/run/cuda/seeding_example_cuda.cpp | 2 +- examples/run/cuda/seq_example_cuda.cpp | 2 +- .../run/cuda/truth_finding_example_cuda.cpp | 2 +- .../run/cuda/truth_fitting_example_cuda.cpp | 2 +- .../cpu/test_ckf_sparse_tracks_telescope.cpp | 2 +- tests/cpu/test_kalman_fitter_telescope.cpp | 2 +- tests/cpu/test_kalman_fitter_wire_chamber.cpp | 2 +- 19 files changed, 57 insertions(+), 49 deletions(-) rename core/{src/fitting => include/traccc/fitting/details}/fit_tracks.hpp (73%) diff --git a/benchmarks/cpu/toy_detector_cpu.cpp b/benchmarks/cpu/toy_detector_cpu.cpp index b9abbf9b7..3fd92d468 100644 --- a/benchmarks/cpu/toy_detector_cpu.cpp +++ b/benchmarks/cpu/toy_detector_cpu.cpp @@ -57,7 +57,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { traccc::track_params_estimation tp(host_mr); traccc::host::combinatorial_kalman_filter_algorithm host_finding( finding_cfg); - traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg, host_mr); for (auto _ : state) { diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index e55bc281c..d2d2af375 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -74,7 +74,7 @@ traccc_add_library( traccc_core core TYPE SHARED "include/traccc/fitting/kalman_filter/kalman_fitter.hpp" "include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp" "include/traccc/fitting/kalman_filter/statistics_updater.hpp" - "src/fitting/fit_tracks.hpp" + "include/traccc/fitting/details/fit_tracks.hpp" "include/traccc/fitting/kalman_fitting_algorithm.hpp" "src/fitting/kalman_fitting_algorithm.cpp" "src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp" diff --git a/core/src/fitting/fit_tracks.hpp b/core/include/traccc/fitting/details/fit_tracks.hpp similarity index 73% rename from core/src/fitting/fit_tracks.hpp rename to core/include/traccc/fitting/details/fit_tracks.hpp index 00cc6bb3d..a5407f8ff 100644 --- a/core/src/fitting/fit_tracks.hpp +++ b/core/include/traccc/fitting/details/fit_tracks.hpp @@ -11,6 +11,9 @@ #include "traccc/edm/track_candidate.hpp" #include "traccc/edm/track_state.hpp" +// VecMem include(s). +#include + namespace traccc::host::details { /// Templated implementation of the track fitting algorithm. @@ -19,27 +22,26 @@ namespace traccc::host::details { /// specializations, to fit tracks on top of a specific detector type, magnetic /// field type, and track fitting configuration. /// +/// @note The memory resource received by this function is not used thoroughly +/// for the setup of the output container. Inner vectors in the output's +/// jagged vector are created using the default memory resource. +/// /// @tparam fitter_t The fitter type used for the track fitting /// -/// @param det The detector object -/// @param field The magnetic field object -/// @param track_candidates All track candidates to fit -/// @param config The track fitting configuration +/// @param[in] fitter The fitter object to use on the track candidates +/// @param[in] track_candidates All track candidates to fit +/// @param[in] mr Memory resource to use for the output container /// /// @return A container of the fitted track states /// template track_state_container_types::host fit_tracks( - const typename fitter_t::detector_type& det, - const typename fitter_t::bfield_type& field, + fitter_t& fitter, const track_candidate_container_types::const_view& track_candidates_view, - const typename fitter_t::config_type& config) { - - // Create the fitter object. - fitter_t fitter(det, field, config); + vecmem::memory_resource& mr) { - // Output container. - track_state_container_types::host output_states; + // Create the output container. + track_state_container_types::host result{&mr}; // Iterate over the tracks, const track_candidate_container_types::const_device track_candidates{ @@ -62,13 +64,13 @@ track_state_container_types::host fit_tracks( fitter.fit(track_candidates.get_headers()[i], fitter_state); // Save the results into the output container. - output_states.push_back( + result.push_back( std::move(fitter_state.m_fit_res), std::move(fitter_state.m_fit_actor_state.m_track_states)); } // Return the fitted track states. - return output_states; + return result; } } // namespace traccc::host::details diff --git a/core/include/traccc/fitting/kalman_fitting_algorithm.hpp b/core/include/traccc/fitting/kalman_fitting_algorithm.hpp index 2924ccac0..31fafbc91 100644 --- a/core/include/traccc/fitting/kalman_fitting_algorithm.hpp +++ b/core/include/traccc/fitting/kalman_fitting_algorithm.hpp @@ -17,6 +17,12 @@ // Detray include(s). #include +// VecMem include(s). +#include + +// System include(s). +#include + namespace traccc::host { /// Kalman filter based track fitting algorithm @@ -40,7 +46,8 @@ class kalman_fitting_algorithm /// /// @param config The configuration object /// - kalman_fitting_algorithm(const config_type& config); + explicit kalman_fitting_algorithm(const config_type& config, + vecmem::memory_resource& mr); /// Execute the algorithm /// @@ -71,6 +78,8 @@ class kalman_fitting_algorithm private: /// Algorithm configuration config_type m_config; + /// Memory resource to use in the algorithm + std::reference_wrapper m_mr; }; // class kalman_fitting_algorithm diff --git a/core/src/fitting/kalman_fitting_algorithm.cpp b/core/src/fitting/kalman_fitting_algorithm.cpp index ff85039b2..ef2002837 100644 --- a/core/src/fitting/kalman_fitting_algorithm.cpp +++ b/core/src/fitting/kalman_fitting_algorithm.cpp @@ -10,7 +10,8 @@ namespace traccc::host { -kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config) - : m_config(config) {} +kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config, + vecmem::memory_resource& mr) + : m_config{config}, m_mr{mr} {} } // namespace traccc::host diff --git a/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp b/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp index f944fb36c..f4a3933c6 100644 --- a/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp +++ b/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp @@ -6,7 +6,7 @@ */ // Project include(s). -#include "fit_tracks.hpp" +#include "traccc/fitting/details/fit_tracks.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/fitting/kalman_fitting_algorithm.hpp" @@ -21,18 +21,16 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( const detray::bfield::const_field_t::view_t& field, const track_candidate_container_types::const_view& track_candidates) const { - // Set up the fitter type(s). - using stepper_type = + // Create the fitter object. + kalman_fitter< detray::rk_stepper>; - using navigator_type = - detray::navigator; - using fitter_type = kalman_fitter; + detray::constrained_step<>>, + detray::navigator> + fitter{det, field, m_config}; // Perform the track fitting using a common, templated function. - return details::fit_tracks(det, field, track_candidates, - m_config); + return details::fit_tracks(fitter, track_candidates, m_mr.get()); } } // namespace traccc::host diff --git a/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp b/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp index 53404596d..d28fe814c 100644 --- a/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp +++ b/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp @@ -6,7 +6,7 @@ */ // Project include(s). -#include "fit_tracks.hpp" +#include "traccc/fitting/details/fit_tracks.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/fitting/kalman_fitting_algorithm.hpp" @@ -21,18 +21,16 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( const detray::bfield::const_field_t::view_t& field, const track_candidate_container_types::const_view& track_candidates) const { - // Set up the fitter type(s). - using stepper_type = + // Create the fitter object. + kalman_fitter< detray::rk_stepper>; - using navigator_type = - detray::navigator; - using fitter_type = kalman_fitter; + detray::constrained_step<>>, + detray::navigator> + fitter{det, field, m_config}; // Perform the track fitting using a common, templated function. - return details::fit_tracks(det, field, track_candidates, - m_config); + return details::fit_tracks(fitter, track_candidates, m_mr.get()); } } // namespace traccc::host diff --git a/examples/run/cpu/full_chain_algorithm.cpp b/examples/run/cpu/full_chain_algorithm.cpp index 349fa09d9..8756b945a 100644 --- a/examples/run/cpu/full_chain_algorithm.cpp +++ b/examples/run/cpu/full_chain_algorithm.cpp @@ -28,7 +28,7 @@ full_chain_algorithm::full_chain_algorithm( m_seeding(finder_config, grid_config, filter_config, mr), m_track_parameter_estimation(mr), m_finding(finding_config), - m_fitting(fitting_config), + m_fitting(fitting_config, mr), m_finder_config(finder_config), m_grid_config(grid_config), m_filter_config(filter_config), diff --git a/examples/run/cpu/seeding_example.cpp b/examples/run/cpu/seeding_example.cpp index 4216b8ba4..0b41e8af2 100644 --- a/examples/run/cpu/seeding_example.cpp +++ b/examples/run/cpu/seeding_example.cpp @@ -124,7 +124,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::greedy_ambiguity_resolution_algorithm host_ambiguity_resolution{}; diff --git a/examples/run/cpu/seq_example.cpp b/examples/run/cpu/seq_example.cpp index e83fbc6a4..df80a11cc 100644 --- a/examples/run/cpu/seq_example.cpp +++ b/examples/run/cpu/seq_example.cpp @@ -128,7 +128,7 @@ int seq_run(const traccc::opts::input_data& input_opts, seeding_opts.seedfilter, host_mr); traccc::track_params_estimation tp(host_mr); finding_algorithm finding_alg(finding_cfg); - fitting_algorithm fitting_alg(fitting_cfg); + fitting_algorithm fitting_alg(fitting_cfg, host_mr); traccc::greedy_ambiguity_resolution_algorithm resolution_alg; // performance writer diff --git a/examples/run/cpu/truth_finding_example.cpp b/examples/run/cpu/truth_finding_example.cpp index a02163b39..c6f7b9456 100644 --- a/examples/run/cpu/truth_finding_example.cpp +++ b/examples/run/cpu/truth_finding_example.cpp @@ -103,7 +103,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); // Seed generator traccc::seed_generator sg(detector, diff --git a/examples/run/cpu/truth_fitting_example.cpp b/examples/run/cpu/truth_fitting_example.cpp index 6463d3d26..bebe633b9 100644 --- a/examples/run/cpu/truth_fitting_example.cpp +++ b/examples/run/cpu/truth_fitting_example.cpp @@ -107,7 +107,7 @@ int main(int argc, char* argv[]) { traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); // Seed generator traccc::seed_generator sg(host_det, stddevs); diff --git a/examples/run/cuda/seeding_example_cuda.cpp b/examples/run/cuda/seeding_example_cuda.cpp index a45e2a4cf..d365eb3df 100644 --- a/examples/run/cuda/seeding_example_cuda.cpp +++ b/examples/run/cuda/seeding_example_cuda.cpp @@ -179,7 +179,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/examples/run/cuda/seq_example_cuda.cpp b/examples/run/cuda/seq_example_cuda.cpp index a97852f86..591a2f47f 100644 --- a/examples/run/cuda/seq_example_cuda.cpp +++ b/examples/run/cuda/seq_example_cuda.cpp @@ -166,7 +166,7 @@ int seq_run(const traccc::opts::detector& detector_opts, seeding_opts.seedfilter, host_mr); traccc::track_params_estimation tp(host_mr); host_finding_algorithm finding_alg(finding_cfg); - host_fitting_algorithm fitting_alg(fitting_cfg); + host_fitting_algorithm fitting_alg(fitting_cfg, host_mr); traccc::cuda::clusterization_algorithm ca_cuda(mr, copy, stream, clusterization_opts); diff --git a/examples/run/cuda/truth_finding_example_cuda.cpp b/examples/run/cuda/truth_finding_example_cuda.cpp index 0e6af2f84..90d968a74 100644 --- a/examples/run/cuda/truth_finding_example_cuda.cpp +++ b/examples/run/cuda/truth_finding_example_cuda.cpp @@ -155,7 +155,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/examples/run/cuda/truth_fitting_example_cuda.cpp b/examples/run/cuda/truth_fitting_example_cuda.cpp index c93990bb3..5adba022d 100644 --- a/examples/run/cuda/truth_fitting_example_cuda.cpp +++ b/examples/run/cuda/truth_fitting_example_cuda.cpp @@ -155,7 +155,7 @@ int main(int argc, char* argv[]) { traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp index cdb9d06d1..595f57122 100644 --- a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp +++ b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp @@ -136,7 +136,7 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) { fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { diff --git a/tests/cpu/test_kalman_fitter_telescope.cpp b/tests/cpu/test_kalman_fitter_telescope.cpp index 19e2b24fc..5cc3e4e80 100644 --- a/tests/cpu/test_kalman_fitter_telescope.cpp +++ b/tests/cpu/test_kalman_fitter_telescope.cpp @@ -123,7 +123,7 @@ TEST_P(KalmanFittingTelescopeTests, Run) { fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::host::kalman_fitting_algorithm fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm fitting(fit_cfg, host_mr); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { diff --git a/tests/cpu/test_kalman_fitter_wire_chamber.cpp b/tests/cpu/test_kalman_fitter_wire_chamber.cpp index 9a82e01d2..952a4e389 100644 --- a/tests/cpu/test_kalman_fitter_wire_chamber.cpp +++ b/tests/cpu/test_kalman_fitter_wire_chamber.cpp @@ -124,7 +124,7 @@ TEST_P(KalmanFittingWireChamberTests, Run) { static_cast(mask_tolerance); fit_cfg.propagation.navigation.search_window = search_window; fit_cfg.ptc_hypothesis = ptc; - traccc::host::kalman_fitting_algorithm fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm fitting(fit_cfg, host_mr); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) {