Skip to content

Commit

Permalink
updated scan class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Sep 16, 2024
1 parent 9b08e8d commit 357af9c
Show file tree
Hide file tree
Showing 207 changed files with 477 additions and 352 deletions.
323 changes: 167 additions & 156 deletions src/tavi/data/scan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import Optional
from typing import Literal, Optional

import h5py
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -33,92 +33,184 @@ class Scan(object):
"""

def __init__(self, scan_info, sample_ub_info, instrument_info, data) -> None:
self.scan_info: Optional[ScanInfo] = scan_info
self.sample_ub_info: Optional[SampleUBInfo] = sample_ub_info
self.instrument_info: Optional[InstrumentInfo] = instrument_info
self.data: Optional[ScanData] = data
self.scan_info: ScanInfo = scan_info
self.sample_ub_info: SampleUBInfo = sample_ub_info
self.instrument_info: InstrumentInfo = instrument_info
self.data: ScanData = data

@classmethod
def from_nexus(cls, nexus_path):
with h5py.File(nexus_path, "r") as nexus_entry:
scan_info, sample_ub_info, instrument_info, data = nexus_entry_to_scan(nexus_entry)
return cls(scan_info, sample_ub_info, instrument_info, data)
(dataset_name, *scan_data) = nexus_entry_to_scan(nexus_entry)
(scan_info, sample_ub_info, instrument_info, data) = scan_data
return dataset_name, cls(scan_info, sample_ub_info, instrument_info, data)

# TODO
# TODO not yet implemented
@classmethod
def from_spice(cls, spice_path):
spice_entry = spice_path
scan_info, sample_ub_info, instrument_info, data = spice_entry_to_scan(spice_entry)
return cls(scan_info, sample_ub_info, instrument_info, data)
(dataset_name, *scan_data) = spice_entry_to_scan(spice_entry)
scan_info, sample_ub_info, instrument_info, data = scan_data
return dataset_name, cls(scan_info, sample_ub_info, instrument_info, data)

def get_scan_info(self):
"""Return scan_info in metadata.
Returns:
dict: dictionay of scan_info metadata.
"""
"""Return scan_info in metadata."""
return self.scan_info

def get_sample_ub_info(self):
"""Return sample_UB_info in metadata.
Returns:
dict: dictionay of sample_UB_info metadata.
"""
"""Return sample_UB_info in metadata."""
return self.sample_ub_info

def get_instrument_info(self):
"""Return instrument_info in metadata.
Returns:
dict: dictionay of instrument_info metadata.
"""
"""Return instrument_info in metadata."""
return self.instrument_info

# def save_metadata(self, metadata_entry):
# """Save metadata_entry into file

# Args:
# metadata_entry (dict): {key: value}
# """
# pass

def get_data_entry(self, entry_name):
"""Return data entry based on entry_name
"""Return data entry based on entry_name"""
return self.data[entry_name]

Args:
entry_name (str): a key of the dictionay in data
def _set_labels(
self,
x_str: str,
y_str: str,
norm_channel: Literal["time", "monitor", "mcu", None],
norm_val: float,
):
"""generate labels and title"""
if norm_channel is not None:
if norm_channel == "time":
norm_channel_str = "seconds"
else:
norm_channel_str = norm_channel
if norm_val == 1:
ylabel = y_str + "/ " + norm_channel_str
else:
ylabel = y_str + f" / {norm_val} " + norm_channel_str
else:
preset_val = self.scan_info.preset_value
ylabel = y_str + f" / {preset_val} " + self.scan_info.preset_channel

Returns:
tuple: data entry
"""
return self.data[entry_name]
xlabel = x_str
label = "scan " + str(self.scan_info.scan_num)
title = label + ": " + self.scan_info.scan_title

return (xlabel, ylabel, title, label)

def _rebin_tol(
self,
x_raw: np.ndarray,
y_raw: np.ndarray,
y_str: str,
rebin_step: float,
norm_channel: Literal["time", "monitor", "mcu", None],
norm_val: float,
):
x_grid = np.arange(
np.min(x_raw) + rebin_step / 2,
np.max(x_raw) + rebin_step / 2,
rebin_step,
)
x = np.zeros_like(x_grid)
y = np.zeros_like(x_grid)
counts = np.zeros_like(x_grid)
weights = np.zeros_like(x_grid)
yerr = None

if norm_channel is None: # rebin, no renorm
weight_channel = self.scan_info.preset_channel
weight = getattr(self.data, weight_channel)
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x_grid + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
x[idx] += x_raw[i] * weight[i]
weights[idx] += weight[i]
counts[idx] += 1

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / counts
y = y / counts
x = x / weights
return (x, y, yerr)

# rebin and renorm
norm = getattr(self.data, norm_channel)
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x_grid + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
x[idx] += x_raw[i] * norm[i]
counts[idx] += norm[i]

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / counts * norm_val
y = y / counts * norm_val
x = x / counts
return (x, y, yerr)

def _rebin_grid(
self,
x_raw: np.ndarray,
y_raw: np.ndarray,
y_str: str,
rebin_step: float,
norm_channel: Literal["time", "monitor", "mcu", None],
norm_val: float,
):
x = np.arange(
np.min(x_raw) + rebin_step / 2,
np.max(x_raw) + rebin_step / 2,
rebin_step,
)
y = np.zeros_like(x)
cts = np.zeros_like(x)
yerr = None
# rebin, no renorm
if norm_channel is None:
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
cts[idx] += 1

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / cts
y = y / cts
return (x, y, yerr)

# rebin and renorm
norm = getattr(self.data, norm_channel)
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
cts[idx] += norm[i]

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / cts * norm_val
y = y / cts * norm_val
return (x, y, yerr)

def generate_curve(
self,
x_str=None,
y_str=None,
norm_channel=None,
norm_val=1,
rebin_type=None,
rebin_step=0,
x_str: Optional[str] = None,
y_str: Optional[str] = None,
norm_channel: Literal["time", "monitor", "mcu", None] = None,
norm_val: float = 1.0,
rebin_type: Literal["tol", "grid", None] = None,
rebin_step: float = 0.0,
):
"""Generate a curve from a single scan to plot,
with the options to
with the options to
normalize the y-axis and rebin x-axis.
Args:
x_str (str): string of x axis
y_str (str): string of x axis
norm_channel (str): None, "time", "monitor" or "mcu"
norm_val (float):
rebin_type (str): None, "tol", or "grid"
rebin_size (float):
Returns:
x:
y:
xerr: if rebin
yerr: if "detector"
xlabel:
ylabel:
Expand All @@ -135,8 +227,6 @@ def generate_curve(
x_raw = getattr(self.data, x_str)
y_raw = getattr(self.data, y_str)

# xerr NOT used
xerr = None
yerr = None

if rebin_type is None: # no rebin
Expand All @@ -152,116 +242,37 @@ def generate_curve(
if yerr is not None:
yerr = yerr / norm

else:
if rebin_step > 0:
match rebin_type:
case "tol": # x weighted by normalization channel
x_grid = np.arange(
np.min(x_raw) + rebin_step / 2,
np.max(x_raw) + rebin_step / 2,
rebin_step,
)
x = np.zeros_like(x_grid)
y = np.zeros_like(x_grid)
cts = np.zeros_like(x_grid)
wts = np.zeros_like(x_grid)

if norm_channel is not None: # rebin and renorm
norm = self.data[norm_channel]
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x_grid + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
x[idx] += x_raw[i] * norm[i]
cts[idx] += norm[i]

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / cts * norm_val
y = y / cts * norm_val
x = x / cts

else: # rebin, no renorm
weight_channel = self.scan_info["preset_channel"]
weight = self.data[weight_channel]
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x_grid + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
x[idx] += x_raw[i] * weight[i]
wts[idx] += weight[i]
cts[idx] += 1

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / cts
y = y / cts
x = x / wts
case "grid":
x = np.arange(
np.min(x_raw) + rebin_step / 2,
np.max(x_raw) + rebin_step / 2,
rebin_step,
)
y = np.zeros_like(x)
cts = np.zeros_like(x)

if norm_channel is not None: # rebin and renorm
norm = self.data[norm_channel]
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
cts[idx] += norm[i]

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / cts * norm_val
y = y / cts * norm_val

else: # rebin, no renorm
for i, x0 in enumerate(x_raw):
idx = np.nanargmax(x + rebin_step / 2 >= x0)
y[idx] += y_raw[i]
cts[idx] += 1

# errror bars for detector only
if "detector" in y_str:
yerr = np.sqrt(y) / cts
y = y / cts

case _:
print('Unrecogonized rebin type. Needs to be "tol" or "grid".')
else:
print("Rebin step needs to be greater than zero.")
(xlabel, ylabel, title, label) = self._set_labels(x_str, y_str, norm_channel, norm_val)

# generate labels and title
if norm_channel is not None:
if norm_channel == "time":
norm_channel = "seconds"
if norm_val == 1:
ylabel = y_str + "/ " + norm_channel
else:
ylabel = y_str + f" / {norm_val} " + norm_channel
else:
preset_val = self.scan_info["preset_value"]
ylabel = y_str + f" / {preset_val} " + self.scan_info["preset_channel"]
return (x, y, yerr, xlabel, ylabel, title, label)

xlabel = x_str
label = "scan " + str(self.scan_info.scan_num)
title = label + ": " + self.scan_info.scan_title
if not rebin_step > 0:
raise ValueError("Rebin step needs to be greater than zero.")

match rebin_type:
case "tol": # x weighted by normalization channel
x, y, yerr = self._rebin_tol(x_raw, y_raw, y_str, rebin_step, norm_channel, norm_val)
case "grid":
x, y, yerr = self._rebin_grid(x_raw, y_raw, y_str, rebin_step, norm_channel, norm_val)
case _:
print('Unrecogonized rebin type. Needs to be "tol" or "grid".')

(xlabel, ylabel, title, label) = self._set_labels(x_str, y_str, norm_channel, norm_val)

return (x, y, xerr, yerr, xlabel, ylabel, title, label)
return (x, y, yerr, xlabel, ylabel, title, label)

def plot_curve(
self,
x_str=None,
y_str=None,
norm_channel=None,
norm_val=1,
rebin_type=None,
rebin_step=0,
x_str: Optional[str] = None,
y_str: Optional[str] = None,
norm_channel: Literal["time", "monitor", "mcu", None] = None,
norm_val: float = 1.0,
rebin_type: Literal["tol", "grid", None] = None,
rebin_step: float = 0.0,
):
"""Plot a 1D curve gnerated from a singal scan in a new window"""

x, y, xerr, yerr, xlabel, ylabel, title, _ = self.generate_curve(
x, y, yerr, xlabel, ylabel, title, _ = self.generate_curve(
x_str,
y_str,
norm_channel,
Expand All @@ -271,7 +282,7 @@ def plot_curve(
)

fig, ax = plt.subplots()
ax.errorbar(x, y, xerr=xerr, yerr=yerr, fmt="o")
ax.errorbar(x, y, yerr=yerr, fmt="o")
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
Expand Down
Loading

0 comments on commit 357af9c

Please sign in to comment.