Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

De-Template (Host) Track Fitting, main branch (2024.10.28.) #756

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/common/benchmarks/toy_detector_benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// Traccc include(s).
#include "traccc/definitions/common.hpp"
#include "traccc/finding/finding_config.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/fitting/fitting_config.hpp"
#include "traccc/io/utils.hpp"
#include "traccc/seeding/seeding_algorithm.hpp"
#include "traccc/seeding/track_params_estimation.hpp"
Expand Down
14 changes: 4 additions & 10 deletions benchmarks/cpu/toy_detector_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

// Traccc algorithm include(s).
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
#include "traccc/seeding/seeding_algorithm.hpp"
#include "traccc/seeding/track_params_estimation.hpp"

Expand Down Expand Up @@ -40,14 +40,7 @@
BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {

// Type declarations
using rk_stepper_type =
detray::rk_stepper<b_field_t::view_t,
typename detector_type::algebra_type,
detray::constrained_step<>>;
using host_detector_type = traccc::default_detector::host;
using host_navigator_type = detray::navigator<const host_detector_type>;
using host_fitter_type =
traccc::kalman_fitter<rk_stepper_type, host_navigator_type>;

// Read back detector file
host_detector_type det{host_mr};
Expand All @@ -64,7 +57,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
traccc::track_params_estimation tp(host_mr);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(
finding_cfg);
traccc::fitting_algorithm<host_fitter_type> host_fitting(fitting_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg, host_mr);

for (auto _ : state) {

Expand All @@ -87,7 +80,8 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
vecmem::get_data(params));

// Track fitting with KF
auto track_states = host_fitting(det, field, track_candidates);
auto track_states =
host_fitting(det, field, traccc::get_data(track_candidates));
}
}

Expand Down
1 change: 1 addition & 0 deletions benchmarks/cuda/toy_detector_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "traccc/cuda/seeding/seeding_algorithm.hpp"
#include "traccc/cuda/seeding/track_params_estimation.hpp"
#include "traccc/device/container_d2h_copy_alg.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/geometry/detector.hpp"
#include "traccc/io/read_detector.hpp"
#include "traccc/io/read_geometry.hpp"
Expand Down
6 changes: 5 additions & 1 deletion core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/fitting/kalman_filter/kalman_fitter.hpp"
"include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
"include/traccc/fitting/kalman_filter/statistics_updater.hpp"
"include/traccc/fitting/fitting_algorithm.hpp"
"include/traccc/fitting/details/fit_tracks.hpp"
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
"src/fitting/kalman_fitting_algorithm.cpp"
"src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp"
"src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp"
# Seed finding algorithmic code.
"include/traccc/seeding/detail/lin_circle.hpp"
"include/traccc/seeding/detail/doublet.hpp"
Expand Down
76 changes: 76 additions & 0 deletions core/include/traccc/fitting/details/fit_tracks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Project include(s).
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_state.hpp"

// VecMem include(s).
#include <vecmem/memory/memory_resource.hpp>

namespace traccc::host::details {

/// Templated implementation of the track fitting algorithm.
///
/// Concrete track fitting algorithms can use this function with the appropriate
/// specializations, to fit tracks on top of a specific detector type, magnetic
/// field type, and track fitting configuration.
///
/// @note The memory resource received by this function is not used thoroughly
/// for the setup of the output container. Inner vectors in the output's
/// jagged vector are created using the default memory resource.
///
/// @tparam fitter_t The fitter type used for the track fitting
///
/// @param[in] fitter The fitter object to use on the track candidates
/// @param[in] track_candidates All track candidates to fit
/// @param[in] mr Memory resource to use for the output container
///
/// @return A container of the fitted track states
///
template <typename fitter_t>
krasznaa marked this conversation as resolved.
Show resolved Hide resolved
track_state_container_types::host fit_tracks(
fitter_t& fitter,
const track_candidate_container_types::const_view& track_candidates_view,
vecmem::memory_resource& mr) {

// Create the output container.
track_state_container_types::host result{&mr};

// Iterate over the tracks,
const track_candidate_container_types::const_device track_candidates{
track_candidates_view};
for (track_candidate_container_types::const_device::size_type i = 0;
i < track_candidates.size(); ++i) {

// Make a vector of track states for this track.
vecmem::vector<track_state<typename fitter_t::algebra_type> >
input_states;
input_states.reserve(track_candidates.get_items()[i].size());
for (auto& measurement : track_candidates.get_items()[i]) {
input_states.emplace_back(measurement);
}

// Make a fitter state
typename fitter_t::state fitter_state(std::move(input_states));

// Run the fitter.
fitter.fit(track_candidates.get_headers()[i], fitter_state);

// Save the results into the output container.
result.push_back(
std::move(fitter_state.m_fit_res),
std::move(fitter_state.m_fit_actor_state.m_track_states));
}

// Return the fitted track states.
return result;
}

} // namespace traccc::host::details
88 changes: 0 additions & 88 deletions core/include/traccc/fitting/fitting_algorithm.hpp

This file was deleted.

86 changes: 86 additions & 0 deletions core/include/traccc/fitting/kalman_fitting_algorithm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Project include(s).
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/fitting_config.hpp"
#include "traccc/geometry/detector.hpp"
#include "traccc/utils/algorithm.hpp"

// Detray include(s).
#include <detray/detectors/bfield.hpp>

// VecMem include(s).
#include <vecmem/memory/memory_resource.hpp>

// System include(s).
#include <functional>

namespace traccc::host {

/// Kalman filter based track fitting algorithm
class kalman_fitting_algorithm
: public algorithm<track_state_container_types::host(
const default_detector::host&,
const detray::bfield::const_field_t::view_t&,
const track_candidate_container_types::const_view&)>,
public algorithm<track_state_container_types::host(
const telescope_detector::host&,
const detray::bfield::const_field_t::view_t&,
const track_candidate_container_types::const_view&)> {

public:
/// Configuration type
using config_type = fitting_config;
/// Output type
using output_type = track_state_container_types::host;

/// Constructor with the algorithm's configuration
///
/// @param config The configuration object
///
explicit kalman_fitting_algorithm(const config_type& config,
vecmem::memory_resource& mr);

/// Execute the algorithm
///
/// @param det The (default) detector object
/// @param field The (constant) magnetic field object
/// @param track_candidates All track candidates to fit
///
/// @return A container of the fitted track states
///
output_type operator()(const default_detector::host& det,
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view&
track_candidates) const override;

/// Execute the algorithm
///
/// @param det The (telescope) detector object
/// @param field The (constant) magnetic field object
/// @param track_candidates All track candidates to fit
///
/// @return A container of the fitted track states
///
output_type operator()(const telescope_detector::host& det,
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view&
track_candidates) const override;

private:
/// Algorithm configuration
config_type m_config;
/// Memory resource to use in the algorithm
std::reference_wrapper<vecmem::memory_resource> m_mr;

}; // class kalman_fitting_algorithm

} // namespace traccc::host
17 changes: 17 additions & 0 deletions core/src/fitting/kalman_fitting_algorithm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// Project include(s).
#include "traccc/fitting/kalman_fitting_algorithm.hpp"

namespace traccc::host {

kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config,
vecmem::memory_resource& mr)
: m_config{config}, m_mr{mr} {}

} // namespace traccc::host
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// Project include(s).
#include "traccc/fitting/details/fit_tracks.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/fitting/kalman_fitting_algorithm.hpp"

// Detray include(s).
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/rk_stepper.hpp>

namespace traccc::host {

kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
const default_detector::host& det,
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view& track_candidates) const {

// Create the fitter object.
kalman_fitter<
detray::rk_stepper<detray::bfield::const_field_t::view_t,
traccc::default_detector::host::algebra_type,
detray::constrained_step<>>,
detray::navigator<const traccc::default_detector::host>>
fitter{det, field, m_config};

// Perform the track fitting using a common, templated function.
return details::fit_tracks(fitter, track_candidates, m_mr.get());
}

} // namespace traccc::host
Loading
Loading