Skip to content

Commit

Permalink
implementing ScanGroup class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Oct 17, 2024
1 parent 7027701 commit 82321cc
Show file tree
Hide file tree
Showing 13 changed files with 242 additions and 100 deletions.
10 changes: 7 additions & 3 deletions src/tavi/data/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def make_labels(
self,
x_str: str,
y_str: str,
norm_channel: Optional[str],
norm_val: float,
norm_to: Optional[tuple[float, str]],
scan_info,
):
"""Create axes labels, plot title and curve label"""
if norm_channel is not None:
if norm_to is not None:
norm_val, norm_channel = norm_to
if norm_channel == "time":
norm_channel_str = "seconds"
else:
Expand Down Expand Up @@ -84,3 +84,7 @@ def plot_curve(self, ax):
ax.set_ylabel(self.ylabel)
ax.grid(alpha=0.6)
ax.legend()


class Plot2D(object):
pass
29 changes: 16 additions & 13 deletions src/tavi/data/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class ScanInfo:
"""Metadata containing scan information"""

exp_id: str
scan_num: int
scan_title: str = ""
def_y: str = "detector"
Expand Down Expand Up @@ -105,6 +106,7 @@ def from_nexus(cls, path_to_nexus: str, scan_num: Optional[int] = None):
@property
def scan_info(self):
scan_info = ScanInfo(
exp_id=self._nexus_dict["attrs"].get("dataset_name"),
scan_num=int(self.name[-4:]),
scan_title=self._nexus_dict.get("title"),
def_y=self._nexus_dict.get("data", ATTRS=True)["signal"],
Expand Down Expand Up @@ -187,37 +189,36 @@ def validate_rebin_params(rebin_params: float | tuple) -> tuple:

def get_plot_data(
self,
x_str: Optional[str] = None,
y_str: Optional[str] = None,
norm_channel: Literal["time", "monitor", "mcu", None] = None,
norm_val: float = 1.0,
axes: tuple[Optional[str], Optional[str]] = (None, None),
norm_to: Optional[tuple[float, Literal["time", "monitor", "mcu"]]] = None,
rebin_type: Literal["tol", "grid", None] = None,
rebin_params: Union[float, tuple] = 0.0,
) -> Plot1D:
"""Generate a curve from a single scan to plot, with the options
to normalize the y-axis and rebin x-axis.
Args:
x_str (str): x-axis variable
y_str (str): y-axis variable
norm_channel (str | None): choose from "time", "monitor", "mcu"
norm_val (float): value to normalized to
axes (x_str, y_stre): x-axis and y-axis variables
norm_to (norm_val (float), norm_channel(str)): value and channel for normalization
norm_channel should be "time", "monitor" or"mcu".
rebin_type (str | None): "tol" or "grid"
rebin_params (float | tuple(flot, float, float)): take as step size if a numer is given,
take as (min, max, step) if a tuple of size 3 is given
"""

x_str, y_str = axes
x_str = self.scan_info.def_x if x_str is None else x_str
y_str = self.scan_info.def_y if y_str is None else y_str

scan_data_1d = ScanData1D(x=self.data[x_str], y=self.data[y_str])

if rebin_type is None: # no rebin
if norm_channel is not None: # normalize y-axis without rebining along x-axis
if norm_to is not None: # normalize y-axis without rebining along x-axis
norm_val, norm_channel = norm_to
scan_data_1d.renorm(norm_col=self.data[norm_channel] / norm_val)

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_channel, norm_val, self.scan_info)
plot1d.make_labels(x_str, y_str, norm_to, self.scan_info)

return plot1d

Expand All @@ -226,19 +227,21 @@ def get_plot_data(

match rebin_type:
case "tol":
if norm_channel is None: # x weighted by preset channel
if norm_to is None: # x weighted by preset channel
weight_channel = self.scan_info.preset_channel
scan_data_1d.rebin_tol(rebin_params_tuple, weight_col=self.data[weight_channel])
else: # x weighted by normalization channel
norm_val, norm_channel = norm_to
scan_data_1d.rebin_tol_renorm(
rebin_params_tuple,
norm_col=self.data[norm_channel],
norm_val=norm_val,
)
case "grid":
if norm_channel is None:
if norm_to is None:
scan_data_1d.rebin_grid(rebin_params_tuple)
else:
norm_val, norm_channel = norm_to
scan_data_1d.rebin_grid_renorm(
rebin_params_tuple,
norm_col=self.data[norm_channel],
Expand All @@ -248,7 +251,7 @@ def get_plot_data(
raise ValueError('Unrecogonized rebin type. Needs to be "tol" or "grid".')

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_channel, norm_val, self.scan_info)
plot1d.make_labels(x_str, y_str, norm_to, self.scan_info)
return plot1d

def plot(
Expand Down
3 changes: 1 addition & 2 deletions src/tavi/data/scan_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
self.err = np.sqrt(y)
self._ind = ind

def __add__(self, other): # addition is not really needed
def __add__(self, other):
# check x length, rebin other if do not match
if len(self.x) != len(other.x):
rebin_intervals = np.diff(self.x)
Expand Down Expand Up @@ -173,7 +173,6 @@ def __init__(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None:
self.x = x
self.y = y
self.y = z
self.err = np.sqrt(z)

def __sub__(self, other):
pass
Expand Down
190 changes: 134 additions & 56 deletions src/tavi/data/scan_group.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from tavi.data.plotter import Plot1D, Plot2D
from tavi.data.scan import Scan


@dataclass
class SGInfo:
"""Information needed to generate a ScanGroup"""

scan_num: int
x_axis: Optional[str] = None
y_axis: Optional[str] = None
z_axis: Optional[str] = None
norm_channel: Optional[str] = None
exp_id: Optional[str] = None
from tavi.data.scan_data import ScanData1D, ScanData2D


class ScanGroup(object):
Expand All @@ -35,62 +24,90 @@ class ScanGroup(object):

def __init__(
self,
scan_path_list,
scans: list[Scan],
name: Optional[str] = None,
):
scans = {}
for scan_path in scan_path_list:
if "/" in scan_path:
exp_id, scan_name = scan_path.split("/")
else:
exp_id = next(iter(self.data))
scan_name = scan_path
scan_path = "/".join([exp_id, scan_name])
scans.update({scan_path: Scan(scan_name, self.data[exp_id][scan_name])})

# axes: tuple,
# rebin_params: tuple,
# sg_info_list: list[SGInfo],
# scan_group_name: Optional[str] = None,
# self.axes = axes
# self.dim = len(axes)
# if len(rebin_params) != self.dim:
# raise ValueError(f"Mismatched dimension with axes={axes} and rebin_params={rebin_params}")

# for scan in sg_info_list:
# self.add_scan(scan)

# if self.dim == 2: # 1D data
# ScanData1D()
# elif self.dim == 3: # 2D data
# ScanData2D()

# self.axes = axes

self.name = scan_group_name if scan_group_name is not None else f"ScanGroup{ScanGroup.scan_group_number}"
self.scans = scans
self.name = name if name is not None else f"ScanGroup{ScanGroup.scan_group_number}"
ScanGroup.scan_group_number += 1

# TODO
def add_scan(self, scan_path: str):
def add_scan(self, scan_num: Union[tuple[str, int], int]):
pass

# TODO
def remove_scan(self, scan_path: str):
def remove_scan(self, scan_num: Union[tuple[str, int], int]):
pass

# TODO non-orthogonal axes for constant E contours

# @staticmethod
# def validate_rebin_params(rebin_params: float | tuple) -> tuple:
# return rebin_params

def get_plot_data(
def set_axes(
self,
norm_channel=None,
norm_val=1,
rebin_steps=(None, None),
x: Union[str, tuple[str], None] = None,
y: Union[str, tuple[str], None] = None,
z: Union[str, tuple[str], None] = None,
norm_to: Union[tuple[float, str], tuple[tuple[float, str]], None] = None,
):
"""Generate a 2D contour plot"""
"""Set axes and normalization parameters
Args:
norm_to (norm_val (float), norm_channel(str)): value and channel for normalization
norm_channel should be "time", "monitor" or"mcu".
"""
num = len(self.scans)

if x is None:
x_axes = [scan.scan_info.def_x for scan in self.scans]
elif isinstance(x, str):
x_axes = [x] * num
elif isinstance(x, tuple):
if num != len(x):
raise ValueError(f"length of x-axes={x} does not match number of scans.")
x_axes = list(x)

if y is None:
y_axes = [scan.scan_info.def_y for scan in self.scans]
elif isinstance(y, str):
y_axes = [y] * num
elif isinstance(y, tuple):
if num != len(y):
raise ValueError(f"length of y-axes={y} does not match number of scans.")
y_axes = list(y)

if z is None:
z_axes = [None] * num
elif isinstance(z, str):
z_axes = [z] * num
elif isinstance(z, tuple):
if num != len(z):
raise ValueError(f"length of z-axes={z} does not match number of scans.")
z_axes = list(z)

if norm_to is None:
norms = [None] * num
elif isinstance(norm_to, tuple):
for item in norm_to:
if isinstance(item, tuple):
if num != len(norm_to):
raise ValueError(f"length of normalization channels={norm_to} does not match number of scans.")
norms = list(norm_to)
else:
norms = [norm_to] * num

self.axes = list(zip(x_axes, y_axes, z_axes, norms))

# TODO
def get_plot_data_1d(
self,
rebin_type: Literal["tol", "grid", None] = None,
rebin_params: Union[float, tuple] = 0.0,
) -> Plot1D:
"""
rebin_type (str | None): "tol" or "grid"
rebin_params (float | tuple(flot, float, float)): take as step size if a numer is given,
take as (min, max, step) if a tuple of size 3 is given
"""
ScanData1D()
num_scans = np.size(self.signals)

signal_x, signal_y, signal_z = self.signal_axes
Expand Down Expand Up @@ -166,6 +183,67 @@ def get_plot_data(

return (xv, yv, z, x_step, y_step, xlabel, ylabel, zlabel, title)

@staticmethod
def validate_rebin_params_2d(rebin_params_2d: tuple) -> tuple:

params = []
for rebin_params in rebin_params_2d:
if isinstance(rebin_params, tuple):
if len(rebin_params) != 3:
raise ValueError("Rebin parameters should have the form (min, max, step)")
rebin_min, rebin_max, rebin_step = rebin_params
if (rebin_min >= rebin_max) or (rebin_step < 0):
raise ValueError(f"Nonsensical rebin parameters {rebin_params}")
params.append(rebin_params)

elif isinstance(rebin_params, float | int):
if rebin_params < 0:
raise ValueError("Rebin step needs to be greater than zero.")
params.append((None, None, float(rebin_params)))
else:
raise ValueError(f"Unrecogonized rebin parameters {rebin_params}")
return tuple(params)

def get_plot_data_2d(
self,
axes: tuple[str, str, str],
rebin_params: tuple[Union[float, tuple], Union[float, tuple]],
norm_to: Optional[tuple[float, Literal["monitor", "time", "mcu"]]] = None,
) -> Plot2D:
"""
Args:
rebin_params (float | tuple(flot, float, float)): take as step size if a numer is given,
take as (min, max, step) if a tuple of size 3 is given
"""
x_axis, y_axis, z_axis = axes

x_data = []
y_data = []
z_data = []

for scan in self.scans:
x_data.append(scan.data.get(x_axis))
y_data.append(scan.data.get(y_axis))
z_data.append(scan.data.get(z_axis))

if norm_to is not None:
norm_data = []
norm_val, norm_channel = norm_to
for scan in self.scans:
norm_data.append(scan.data.get(norm_channel))

scan_data_2d = ScanData2D(x=np.concatenate(x_data), y=np.concatenate(y_data), z=np.concatenate(z_data))
# Rebin, first validate rebin params
rebin_params_2d = ScanGroup.validate_rebin_params_2d(rebin_params)
if norm_to is not None:
pass
else:
scan_data_2d.rebin_grid(rebin_params_2d)
plot2d = Plot2D(scan_data_2d.x, scan_data_2d.y, scan_data_2d.z)
# plot2d.make_labels(self.axes)
return plot2d

def plot(self, contour_plot, cmap="turbo", vmax=100, vmin=0, ylim=None, xlim=None):
"""Plot contour"""

Expand Down
Loading

0 comments on commit 82321cc

Please sign in to comment.