diff --git a/src/ctapipe/tools/process.py b/src/ctapipe/tools/process.py index 484a42f0fe4..a39e49d8e9c 100644 --- a/src/ctapipe/tools/process.py +++ b/src/ctapipe/tools/process.py @@ -8,7 +8,7 @@ from ..calib import CameraCalibrator, GainSelector from ..core import QualityQuery, Tool -from ..core.traits import Bool, classes_with_traits, flag +from ..core.traits import Bool, ComponentName, classes_with_traits, flag from ..image import ImageCleaner, ImageModifier, ImageProcessor from ..image.extractor import ImageExtractor from ..image.muon import MuonProcessor @@ -24,6 +24,7 @@ from ..io.datawriter import DATA_MODEL_VERSION from ..reco import Reconstructor, ShowerProcessor from ..utils import EventTypeFilter +from ..visualization import EventViewer COMPATIBLE_DATALEVELS = [ DataLevel.R1, @@ -76,6 +77,13 @@ class ProcessorTool(Tool): default_value=False, ).tag(config=True) + event_viewer_name = ComponentName( + EventViewer, + default_value="QtEventViewer", + ).tag(config=True) + + open_viewer = Bool(False, help="Open EventViewer").tag(config=True) + aliases = { ("i", "input"): "EventSource.input_url", ("o", "output"): "DataWriter.output_path", @@ -137,6 +145,12 @@ class ProcessorTool(Tool): "store DL1/Event/Telescope muon parameters in output", "don't store DL1/Event/Telescope muon parameters in output", ), + **flag( + "viewer", + "ProcessorTool.open_viewer", + "Open EventViewer", + "Do not open EventViewer", + ), "camera-frame": ( {"ImageProcessor": {"use_telescope_frame": False}}, "Use camera frame for image parameters instead of telescope frame", @@ -162,6 +176,7 @@ class ProcessorTool(Tool): + classes_with_traits(ImageModifier) + classes_with_traits(EventTypeFilter) + classes_with_traits(Reconstructor) + + classes_with_traits(EventViewer) ) def setup(self): @@ -207,6 +222,11 @@ def setup(self): "shower distributions read from the input Simulation file are invalid)." ) + if self.open_viewer: + self.event_viewer = EventViewer.from_name(self.event_viewer_name, subarray) + else: + self.event_viewer = None + @property def should_compute_dl2(self): """returns true if we should compute DL2 info""" @@ -323,6 +343,9 @@ def start(self): if self.should_compute_dl2: self.process_shower(event) + if self.event_viewer is not None: + self.event_viewer(event) + self.write(event) def finish(self): diff --git a/src/ctapipe/visualization/__init__.py b/src/ctapipe/visualization/__init__.py index a1b300e2410..7ce4b8ac601 100644 --- a/src/ctapipe/visualization/__init__.py +++ b/src/ctapipe/visualization/__init__.py @@ -1,13 +1,14 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst """ Visualization: Methods for displaying data """ +from .mpl_array import ArrayDisplay +from .mpl_camera import CameraDisplay +from .qt_eventviewer import QtEventViewer +from .viewer import EventViewer -try: - from .mpl_array import ArrayDisplay - from .mpl_camera import CameraDisplay -except ImportError: - pass - - -__all__ = ["CameraDisplay", "ArrayDisplay"] +__all__ = [ + "CameraDisplay", + "ArrayDisplay", + "EventViewer", + "QtEventViewer", +] diff --git a/src/ctapipe/visualization/_qt_viewer_impl.py b/src/ctapipe/visualization/_qt_viewer_impl.py new file mode 100644 index 00000000000..0fc7c92e925 --- /dev/null +++ b/src/ctapipe/visualization/_qt_viewer_impl.py @@ -0,0 +1,274 @@ +from queue import Empty + +import astropy.units as u +import numpy as np +from PySide6 import QtGui +from PySide6.QtCore import Qt, QThread, Signal +from PySide6.QtWidgets import ( + QApplication, + QComboBox, + QHBoxLayout, + QLabel, + QMainWindow, + QPushButton, + QStackedLayout, + QTabWidget, + QVBoxLayout, + QWidget, +) + +# import matplotlib after qt so it can detect which bindings are in use +from matplotlib.backends import backend_qtagg # isort: skip +from matplotlib.figure import Figure # isort: skip + +from ..containers import ArrayEventContainer +from ..coordinates import EastingNorthingFrame, GroundFrame +from .mpl_array import ArrayDisplay +from .mpl_camera import CameraDisplay + + +class CameraDisplayWidget(QWidget): + def __init__(self, geometry, **kwargs): + super().__init__(**kwargs) + + self.geometry = geometry + + self.fig = Figure(layout="constrained") + self.canvas = backend_qtagg.FigureCanvasQTAgg(self.fig) + + self.ax = self.fig.add_subplot(1, 1, 1) + self.display = CameraDisplay(geometry, ax=self.ax) + self.display.add_colorbar() + + layout = QVBoxLayout() + layout.addWidget(self.canvas) + self.setLayout(layout) + + +class TelescopeDataWidget(QWidget): + def __init__(self, subarray, **kwargs): + super().__init__(**kwargs) + self.subarray = subarray + self.current_event = None + + layout = QVBoxLayout() + + top = QHBoxLayout() + label = QLabel(text="tel_id: ") + label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignCenter) + top.addWidget(label) + self.tel_selector = QComboBox(self) + self.tel_selector.currentTextChanged.connect(self.update_tel_image) + top.addWidget(self.tel_selector) + layout.addLayout(top) + + self.camera_displays = [] + self.widget_index = {} + self.camera_display_stack = QStackedLayout() + + for i, tel in enumerate(self.subarray.telescope_types): + widget = CameraDisplayWidget(tel.camera.geometry) + self.camera_displays.append(widget) + self.camera_display_stack.addWidget(widget) + + for tel_id in subarray.get_tel_ids_for_type(tel): + self.widget_index[tel_id] = i + + layout.addLayout(self.camera_display_stack) + self.setLayout(layout) + + def update_tel_image(self, tel_id): + # tel_selector.clear also calls this, but with an empty tel_id + if tel_id == "": + return + + tel_id = int(tel_id) + index = self.widget_index[tel_id] + widget = self.camera_displays[index] + + self.camera_display_stack.setCurrentIndex(index) + widget.display.image = self.current_event.dl1.tel[tel_id].image + widget.display.axes.figure.canvas.draw() + + def update_event(self, event): + self.current_event = event + + if event.dl1 is not None: + tels_with_image = [ + str(tel_id) + for tel_id, dl1 in event.dl1.tel.items() + if dl1.image is not None + ] + self.tel_selector.clear() + self.tel_selector.addItems(tels_with_image) + self.tel_selector.setCurrentIndex(0) + + +class SubarrayDataWidget(QWidget): + def __init__(self, subarray, **kwargs): + super().__init__(**kwargs) + self.subarray = subarray + + self.fig = Figure(layout="constrained") + self.canvas = backend_qtagg.FigureCanvasQTAgg(self.fig) + + self.ax = self.fig.add_subplot(1, 1, 1) + self.display = ArrayDisplay( + subarray, + axes=self.ax, + frame=EastingNorthingFrame(), + ) + self.display.add_labels() + self.display.telescopes.set_linewidth(0) + + layout = QVBoxLayout() + layout.addWidget(self.canvas) + self.setLayout(layout) + self.tel_types = self.display.telescopes.get_array().astype(float) + + (self.true_impact,) = self.ax.plot( + [], + [], + marker="x", + linestyle="", + ms=15, + color="k", + label="true impact", + ) + self.display.legend_elements.append(self.true_impact) + self.ax.legend(handles=self.display.legend_elements) + self.reco_impacts = {} + + def update_event(self, event): + trigger_pattern = self.tel_types.copy() + mask = self.subarray.tel_ids_to_mask(event.trigger.tels_with_trigger) + trigger_pattern[~mask] = np.nan + self.display.values = trigger_pattern + + if (sim := event.simulation) is not None and (shower := sim.shower) is not None: + impact = GroundFrame(shower.core_x, shower.core_y, 0 * u.m) + impact = impact.transform_to(self.display.frame) + x = impact.easting.to_value(u.m) + y = impact.northing.to_value(u.m) + self.true_impact.set_data(x, y) + + for key, reco in event.dl2.stereo.geometry.items(): + if key not in self.reco_impacts: + (marker,) = self.ax.plot( + [], + [], + marker="x", + linestyle="", + ms=15, + label=key, + ) + self.reco_impacts[key] = marker + self.display.legend_elements.append(marker) + self.ax.legend(handles=self.display.legend_elements) + + impact = GroundFrame(reco.core_x, reco.core_y, 0 * u.m) + impact = impact.transform_to(self.display.frame) + x = impact.easting.to_value(u.m) + y = impact.northing.to_value(u.m) + self.reco_impacts[key].set_data(x, y) + + self.canvas.draw() + + +class ViewerMainWindow(QMainWindow): + new_event_signal = Signal(ArrayEventContainer) + + def __init__(self, subarray, queue, **kwargs): + super().__init__(**kwargs) + self.subarray = subarray + self.queue = queue + self.current_event = None + self.setWindowTitle("ctapipe event display") + + layout = QVBoxLayout() + + top = QHBoxLayout() + self.label = QLabel(self) + self.label.setAlignment( + Qt.AlignmentFlag.AlignHCenter | Qt.AlignmentFlag.AlignCenter + ) + top.addWidget(self.label) + layout.addLayout(top) + + tabs = QTabWidget() + self.subarray_data = SubarrayDataWidget(subarray) + tabs.addTab(self.subarray_data, "Subarray Data") + self.tel_data = TelescopeDataWidget(subarray) + tabs.addTab(self.tel_data, "Telescope Data") + layout.addWidget(tabs) + + self.next_button = QPushButton("Next Event", parent=self) + self.next_button.pressed.connect(self.next) + layout.addWidget(self.next_button) + + widget = QWidget(self) + widget.setLayout(layout) + self.setCentralWidget(widget) + + self.event_thread = EventLoop(self) + self.event_thread.start() + + self.new_event_signal.connect(self.update_event) + + # set window size slightly smaller than available desktop space + size = QtGui.QGuiApplication.primaryScreen().availableGeometry().size() + self.resize(0.9 * size) + + def update_event(self, event): + if event is None: + return + + self.current_event = event + + label = f"obs_id: {event.index.obs_id}" f", event_id: {event.index.event_id}" + if event.simulation is not None and event.simulation.shower is not None: + label += f", E={event.simulation.shower.energy:.3f}" + + self.label.setText(label) + self.subarray_data.update_event(event) + self.tel_data.update_event(event) + + def next(self): + if self.current_event is not None: + self.queue.task_done() + + def closeEvent(self, event): + self.event_thread.stop_signal.emit() + self.event_thread.wait() + self.next() + super().closeEvent(event) + + +class EventLoop(QThread): + stop_signal = Signal() + + def __init__(self, display): + super().__init__() + self.display = display + self.closed = False + self.stop_signal.connect(self.close) + + def close(self): + self.closed = True + + def run(self): + while not self.closed: + try: + event = self.display.queue.get(timeout=0.1) + self.display.new_event_signal.emit(event) + except Empty: + continue + except ValueError: + break + + +def viewer_main(subarray, queue): + app = QApplication() + window = ViewerMainWindow(subarray, queue) + window.show() + app.exec_() diff --git a/src/ctapipe/visualization/mpl_array.py b/src/ctapipe/visualization/mpl_array.py index 3bc0064f1b6..45438a98c1c 100644 --- a/src/ctapipe/visualization/mpl_array.py +++ b/src/ctapipe/visualization/mpl_array.py @@ -5,6 +5,7 @@ from astropy.coordinates import Angle, SkyCoord from matplotlib import pyplot as plt from matplotlib.collections import PatchCollection +from matplotlib.colors import ListedColormap, Normalize from matplotlib.lines import Line2D from matplotlib.patches import Circle @@ -70,7 +71,6 @@ def __init__( self.frame = frame # set up colors per telescope type - tel_types = [str(tel) for tel in subarray.tels.values()] if radius is None: # set radius to the mirror radius (so big tels appear big) radius = [ @@ -80,55 +80,57 @@ def __init__( self.radii = radius else: - self.radii = np.ones(len(tel_types)) * radius + self.radii = np.ones(len(subarray)) * radius if title is None: title = subarray.name + tel_types = list({str(tel) for tel in subarray.telescope_types}) + self.tel_type_idx = np.array( + [tel_types.index(str(tel)) for tel in self.subarray.tel.values()] + ) + # get default matplotlib color cycle (depends on the current style) color_cycle = cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) - - # map a color to each telescope type: - tel_type_to_color = {} - for tel_type in list(set(tel_types)): - tel_type_to_color[tel_type] = next(color_cycle) - - tel_color = [tel_type_to_color[ttype] for ttype in tel_types] + cmap = ListedColormap([next(color_cycle) for _ in tel_types]) + cmap.set_bad("gray") + norm = Normalize(vmin=-0.5, vmax=len(tel_types) - 0.5) patches = [] - for x, y, r, c in zip( + for x, y, r in zip( list(self.tel_coords.x.to_value("m")), list(self.tel_coords.y.to_value("m")), list(radius), - tel_color, ): - patches.append(Circle(xy=(x, y), radius=r, fill=True, color=c, alpha=alpha)) + patches.append(Circle(xy=(x, y), radius=r, fill=True, alpha=alpha)) # build the legend: - legend_elements = [] - for ttype in list(set(tel_types)): - color = tel_type_to_color[ttype] - legend_elements.append( + self.legend_elements = [] + for i, tel_type in enumerate(tel_types): + self.legend_elements.append( Line2D( [0], [0], marker="o", - color=color, - label=ttype, + color=cmap(norm(i)), + label=tel_type, markersize=10, alpha=alpha, linewidth=0, ) ) - self.axes.legend(handles=legend_elements) + self.legend = self.axes.legend(handles=self.legend_elements) self.add_radial_grid() # create the plot - self.tel_colors = tel_color self.autoupdate = autoupdate self.telescopes = PatchCollection(patches, match_original=True) + self.telescopes.set_edgecolor(cmap(norm(self.tel_type_idx))) self.telescopes.set_linewidth(2.0) + self.telescopes.set_cmap(cmap) + self.telescopes.set_norm(norm) + self.telescopes.set_array(self.tel_type_idx) self.axes.add_collection(self.telescopes) self.axes.set_aspect(1.0) @@ -136,8 +138,8 @@ def __init__( xunit = self.tel_coords.x.unit.to_string("latex") yunit = self.tel_coords.y.unit.to_string("latex") xname, yname, _ = frame.get_representation_component_names().keys() - self.axes.set_xlabel(f"{xname} [{xunit}] $\\rightarrow$") - self.axes.set_ylabel(f"{yname} [{yunit}] $\\rightarrow$") + self.axes.set_xlabel(f"{xname} / {xunit} $\\rightarrow$") + self.axes.set_ylabel(f"{yname} / {yunit} $\\rightarrow$") self._labels = [] self._quiver = None self.axes.autoscale_view() diff --git a/src/ctapipe/visualization/qt_eventviewer.py b/src/ctapipe/visualization/qt_eventviewer.py new file mode 100644 index 00000000000..4354f00d5b3 --- /dev/null +++ b/src/ctapipe/visualization/qt_eventviewer.py @@ -0,0 +1,58 @@ +from multiprocessing import JoinableQueue, Process + +from ..containers import ArrayEventContainer +from .viewer import EventViewer + + +class QtEventViewer(EventViewer): + """ + EventViewer implementation using QT. + + Requires the ctapipe optional dependency ``pyside6``. + On Linux using wayland, make sure to have qt6 with wayland support, + e.g. when using conda-forge, also install ``qt6-wayland``. + + Qt requires to have the GUI thread be the main thread, so it is started + as a subprocess and communication happens through a ``JoinableQueue``. + + Actual GUI implementation is in ``_qt_viewer_impl`` to make this class + available always but make the qt dependency optional and error with a + nice message in case the optional dependencies are not installed. + """ + + def __init__(self, subarray, **kwargs): + try: + from ._qt_viewer_impl import viewer_main + except ModuleNotFoundError: + raise ModuleNotFoundError( + "PySide6 is needed for this EventViewer" + ) from None + + super().__init__(subarray=subarray, **kwargs) + + self.queue = JoinableQueue() + # don't wait for the GUI process to consume the queue at exit + self.queue.cancel_join_thread() + + # qt GUIs need to run as main thread -> subprocess + self.gui_process = Process( + target=viewer_main, + args=( + self.subarray, + self.queue, + ), + ) + self.gui_process.daemon = True + self.gui_process.start() + + def __call__(self, event: ArrayEventContainer): + # if the user closed the viewer window, we just continue processing events + # and this becomes a no-op + if self.gui_process.is_alive(): + self.queue.join() + self.queue.put(event) + + def close(self): + self.queue.close() + if self.gui_process.is_alive(): + self.gui_process.terminate() diff --git a/src/ctapipe/visualization/viewer.py b/src/ctapipe/visualization/viewer.py new file mode 100644 index 00000000000..e9df1671b5b --- /dev/null +++ b/src/ctapipe/visualization/viewer.py @@ -0,0 +1,18 @@ +from abc import ABCMeta, abstractmethod + +from ..containers import ArrayEventContainer +from ..core import TelescopeComponent + + +class EventViewer(TelescopeComponent, metaclass=ABCMeta): + """ + A component that can display events in some form, e.g. a GUI or Web UI. + """ + + @abstractmethod + def __call__(self, event: ArrayEventContainer): + pass + + @abstractmethod + def close(self): + pass