diff --git a/src/detection/src/detection_node.cpp b/src/detection/src/detection_node.cpp index 11dbb98..bd7d626 100644 --- a/src/detection/src/detection_node.cpp +++ b/src/detection/src/detection_node.cpp @@ -22,16 +22,6 @@ double angle(Point2f a, Point2f b); vector> getContours(Mat& frame, const string& color); vector calculateCenters(Mat& frame, const string& color, bool draw = false, bool debug = false); -Mat applyCanny(Mat& frame) { - Mat edges; - GaussianBlur(frame, frame, Size(5, 5), 1.5); - // Canny with thresholds 100 and 200 - Canny(frame, edges, 100, 200); - Mat edgesColor; - cvtColor(edges, edgesColor, COLOR_GRAY2BGR); - return edgesColor; -} - class CVNode : public rclcpp::Node { public: CVNode() : Node("CVNode"), writer_initialized(false), last_frame_time(this->now()) { @@ -59,15 +49,17 @@ class CVNode : public rclcpp::Node { bool flag_write_video_; std::string output_video_path_; + int frame_number = 0; + void topic_callback(const sensor_msgs::msg::CompressedImage &msg) { try { + frame_number ++; + cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8); Mat frame = cv_ptr->image; - Mat edges = applyCanny(frame); - - vector centers = calculateCenters(frame, "blue", flag_write_video_, flag_debug_); + vector centers = calculateCenters(frame, "red", flag_write_video_, flag_debug_); Detection2DArray detections_msg; detections_msg.header = msg.header; @@ -76,12 +68,14 @@ class CVNode : public rclcpp::Node { Detection2D detection; detection.bbox.center.position.x = center.x; detection.bbox.center.position.y = center.y; + RCLCPP_INFO(this->get_logger(), "Frame Number: %d, Detection - x: %.2f, y: %.2f", frame_number, center.x, center.y); detections_msg.detections.push_back(detection); } detections_publisher_->publish(detections_msg); - RCLCPP_INFO(this->get_logger(), "Published %zu detections", detections_msg.detections.size()); + // RCLCPP_INFO(this->get_logger(), "Frame Number: %d, Published %zu detections", frame_number, detections_msg.detections.size()); + // Update the last received frame timestamp last_frame_time = this->now(); diff --git a/src/tracking/CMakeLists.txt b/src/tracking/CMakeLists.txt new file mode 100644 index 0000000..c80d7a0 --- /dev/null +++ b/src/tracking/CMakeLists.txt @@ -0,0 +1,43 @@ +cmake_minimum_required(VERSION 3.8) +project(tracking) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +find_package(rclcpp REQUIRED) +find_package(vision_msgs REQUIRED) +find_package(geometry_msgs REQUIRED) +find_package(cv_bridge REQUIRED) +find_package(OpenCV REQUIRED) +find_package(message_filters REQUIRED) + +# uncomment the following section in order to fill in +# further dependencies manually. +# find_package( REQUIRED) + +add_executable(tracking_node src/tracking_node.cpp) +ament_target_dependencies(tracking_node rclcpp vision_msgs geometry_msgs cv_bridge OpenCV message_filters) +target_include_directories(tracking_node PUBLIC + $ + $) +target_compile_features(tracking_node PUBLIC c_std_99 cxx_std_17) # Require C99 and C++17 + +install(TARGETS tracking_node + DESTINATION lib/${PROJECT_NAME}) + +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + # the following line skips the linter which checks for copyrights + # comment the line when a copyright and license is added to all source files + set(ament_cmake_copyright_FOUND TRUE) + # the following line skips cpplint (only works in a git repo) + # comment the line when this package is in a git repo and when + # a copyright and license is added to all source files + set(ament_cmake_cpplint_FOUND TRUE) + ament_lint_auto_find_test_dependencies() +endif() + +ament_package() diff --git a/src/tracking/package.xml b/src/tracking/package.xml new file mode 100644 index 0000000..e3eca3b --- /dev/null +++ b/src/tracking/package.xml @@ -0,0 +1,24 @@ + + + + tracking + 0.0.0 + TODO: Package description + endian + TODO: License declaration + + ament_cmake + rclcpp + vision_msgs + geometry_msgs + cv_bridge + opencv2 + message_filters + + ament_lint_auto + ament_lint_common + + + ament_cmake + + diff --git a/src/tracking/src/tracking_node.cpp b/src/tracking/src/tracking_node.cpp new file mode 100644 index 0000000..bbed656 --- /dev/null +++ b/src/tracking/src/tracking_node.cpp @@ -0,0 +1,288 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include "cv_bridge/cv_bridge.h" +#include "opencv2/video/tracking.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace cv; + + +using std::placeholders::_1; +using std::placeholders::_2; + +using vision_msgs::msg::Detection2DArray; +using sensor_msgs::msg::CompressedImage; +using vision_msgs::msg::Detection2D; +using vision_msgs::msg::ObjectHypothesisWithPose; + +using vision_msgs::msg::Detection3DArray; +using vision_msgs::msg::Detection3D; + +using namespace std; +using namespace cv; + +// ──────────────── Kalman Filter Class ──────────────── +class Kalman3D { +public: + KalmanFilter kf; + Kalman3D(geometry_msgs::msg::Point32 pt) { + kf.init(6, 3, 0); + int dt = 1; + kf.transitionMatrix = (cv::Mat_(6, 6) << + 1,0,0,dt,0,0, + 0,1,0,0,dt,0, + 0,0,1,0,0,dt, + 0,0,0,1,0,0, + 0,0,0,0,1,0, + 0,0,0,0,0,1); + kf.measurementMatrix = Mat::eye(3, 6, CV_32F); + setIdentity(kf.processNoiseCov, Scalar::all(1e-2)); // Q + setIdentity(kf.measurementNoiseCov, Scalar::all(1e-1)); // R + setIdentity(kf.errorCovPost, Scalar::all(1)); // P + + kf.statePost.at(0) = pt.x; + kf.statePost.at(1) = pt.y; + kf.statePost.at(2) = pt.z; + kf.statePost.at(3) = 0; + kf.statePost.at(4) = 0; + kf.statePost.at(5) = 0; + } + + Mat predict() { + return kf.predict(); + } + + Mat correct(const geometry_msgs::msg::Point32& pt) { + Mat meas(3, 1, CV_32F); + meas.at(0) = pt.x; + meas.at(1) = pt.y; + meas.at(2) = pt.z; + Mat estimate = kf.correct(meas); + return estimate; + } + + Mat getState() const { + return kf.statePost; + } +}; + +// ──────────────── Track Management ──────────────── +struct Track { + Kalman3D filter; + int id; + int unseen; + Track(geometry_msgs::msg::Point32 pt, int id_) : filter(pt), id(id_), unseen(0) {} +}; + +// ──────────────── Assignment Algorithm ──────────────── +class AssignmentSolver { +public: + void Solve(const vector>& costMatrix, vector& assignment, float maxCost = FLT_MAX) { + size_t n = costMatrix.size(); + size_t m = costMatrix[0].size(); + + vector usedRows(n, false), usedCols(m, false); + + assignment.assign(n, -1); + + for (size_t i = 0; i < n; ++i) { + float minCost = std::numeric_limits::max(); + int bestJ = -1; + for (size_t j = 0; j < m; ++j) { + if (!usedCols[j] && costMatrix[i][j] < minCost) { + minCost = costMatrix[i][j]; + bestJ = j; + } + } + if (bestJ != -1 && minCost < maxCost) { + assignment[i] = bestJ; + usedCols[bestJ] = true; + } + } + } +}; + +class DetectionListener : public rclcpp::Node { +public: + DetectionListener() : Node("detection_listener"), tracks_(), next_id_(0) { + publisher_ = this->create_publisher("predicted_points", 10); + + detection_sub_.subscribe(this, "detections"); + image_sub_.subscribe(this, "/robot/rs2/color/image_raw/compressed"); + + sync_ = std::make_shared(SyncPolicy(10), detection_sub_, image_sub_); + sync_->registerCallback(std::bind(&DetectionListener::callback, this, _1, _2)); + + RCLCPP_INFO(this->get_logger(), "DetectionListener node has been started."); + } + +private: + bool writer_initialized = false; + cv::VideoWriter writer_; + rclcpp::Publisher::SharedPtr publisher_; + + std::vector tracks_; + int next_id_; + + geometry_msgs::msg::Point32 toPoint32(const Point3f point) const { + geometry_msgs::msg::Point32 pt; + pt.x = point.x; + pt.y = point.y; + pt.z = point.z; + return pt; + } + + void callback(const Detection3DArray::ConstSharedPtr msg, const CompressedImage::ConstSharedPtr image_msg) { + vector detections; + + for (const auto &detection : msg->detections) { + float x = detection.bbox.center.position.x; + float y = detection.bbox.center.position.y; + float z = detection.bbox.center.position.z; + + Point3f point; + + point.x = x; + point.y = y; + point.z = z; + + detections.push_back(point); + } + + vector predictions; + for (auto& t : tracks_) { + Mat val = t.filter.predict(); + predictions.push_back(Point3f(val.at(0), val.at(1), val.at(2))); + } + + vector> cost_matrix(tracks_.size(), vector(detections.size(), FLT_MAX)); + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < detections.size(); ++j) { + float dx = predictions[i].x - detections[j].x; + float dy = predictions[i].y - detections[j].y; + float dz = predictions[i].z - detections[j].z; + cost_matrix[i][j] = dx * dx + dy * dy + dz * dz; + } + } + + vector assignment; + AssignmentSolver().Solve(cost_matrix, assignment, 2500.0f); + + RCLCPP_INFO(this->get_logger(), "Tracks: %zu, Detections: %zu", tracks_.size(), detections.size()); + + vector matched(detections.size(), false); + for (size_t i = 0; i < tracks_.size(); ++i) { + if (assignment[i] != -1) { + tracks_[i].filter.correct(toPoint32(detections[assignment[i]])); + tracks_[i].unseen = 0; + matched[assignment[i]] = true; + } else { + tracks_[i].unseen++; + } + } + + for (size_t j = 0; j < detections.size(); ++j) { + if (!matched[j]) { + tracks_.emplace_back(toPoint32(detections[j]), next_id_++); + } + } + + tracks_.erase(remove_if(tracks_.begin(), tracks_.end(), [](const Track& t) { + return t.unseen > 2; + }), tracks_.end()); + + Detection3DArray out; + out.header = msg->header; + + for (auto& t : tracks_) { + Mat m = t.filter.getState(); + Point3f pt(m.at(0), m.at(1), m.at(2)); + + Detection3D detection; + detection.bbox.center.position.x = pt.x; + detection.bbox.center.position.y = pt.y; + detection.bbox.center.position.z = pt.z; + out.detections.push_back(detection); + } + + publisher_->publish(out); + + // try { + // cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(image_msg, sensor_msgs::image_encodings::BGR8); + // Mat raw_frame = cv_ptr->image; + // // Downsample frame for efficiency + // Mat frame; + // cv::resize(raw_frame, frame, Size(), 0.5, 0.5, INTER_LINEAR); + + // if (!writer_initialized) { + // writer_.open("oneFile_predicted_video.mp4", cv::VideoWriter::fourcc('m','p','4','v'), 30, frame.size(), true); + // if (!writer_.isOpened()) { + // RCLCPP_ERROR(this->get_logger(), "Could not open the output video for write"); + // } + // writer_initialized = true; + // RCLCPP_INFO(this->get_logger(), "Video writer initialized."); + // } + + // for (auto& det : detections) { + // auto detectedPoint = cv::Point(); + // detectedPoint.x = det.x; + // detectedPoint.y = det.y; + // circle(frame, detectedPoint, 10, Scalar(255, 0, 255), -1); + // RCLCPP_INFO(this->get_logger(), "\tDetection - x: %.2f, y: %.2f", det.x, det.y); + // } + + // for (auto& pred : out.detections) { + // auto predictedPoint = cv::Point(); + // predictedPoint.x = (int) pred.bbox.center.position.x; + // predictedPoint.y = (int) pred.bbox.center.position.y; + // circle(frame, predictedPoint, 10, Scalar(0, 255, 0), -1); + // RCLCPP_INFO(this->get_logger(), "\tPrediction - x: %.2f, y: %.2f", pred.bbox.center.position.x, pred.bbox.center.position.y); + // } + + // writer_.write(frame); + // } catch (cv_bridge::Exception &e) { + // RCLCPP_ERROR(this->get_logger(), "cv_bridge exception: %s", e.what()); + // } + } + + message_filters::Subscriber detection_sub_; + message_filters::Subscriber image_sub_; + + using SyncPolicy = message_filters::sync_policies::ApproximateTime; + using Sync = message_filters::Synchronizer; + std::shared_ptr sync_; +}; + +int main(int argc, char **argv) { + rclcpp::init(argc, argv); + rclcpp::spin(std::make_shared()); + rclcpp::shutdown(); + return 0; +} \ No newline at end of file