Skip to content

Commit

Permalink
Batch mode. this refs #33
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanBilheux committed Aug 1, 2024
1 parent 283e303 commit 50e4dbe
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 154 deletions.
32 changes: 32 additions & 0 deletions __code/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,35 @@ class SvmbirParameters(RecParameters):
weight_type = 'weight type'
verbose = 'verbose'
temp_disk = 'temp disk'


class BatchJsonKeys:

list_raw_files = 'list_raw_files'
list_ob_files = 'list_ob_files'
list_dc_files = 'list_dc_files'
crop_region = 'crop_region'
gamma_filtering_flag = 'gamma_filtering_flag'
beam_fluctuation_flag = 'beam_fluctuation_flag'
beam_fluctuation_region = 'beam_fluctuation_region'
tilt_value = 'tilt_value'
remove_negative_values_flag = 'remove_negative_values_flag'
bm3d_flag = 'bm3d_flag'
tomopy_v0_flag = 'tomopy_v0_flag'
ketcham_flag = 'ketcham_flag'
range_slices_to_reconstruct = 'range_slices_to_reconstruct'
laminography_dict = 'laminography_dict'

angle = 'angle'
list_gpus = 'list_gpus'
num_iterations = 'num_iterations'
mrf_p = 'mrf_p'
mrf_sigma = 'mrf_sigma'
stop_threshold = 'stop_threshold'
verbose = 'verbose'
debug = 'debug'
log_file_name = 'log_file_name'

filt_cutoff = 'filt_cutoff'
filt_type = 'filt_type'

75 changes: 58 additions & 17 deletions __code/batch_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os

from __code.parent import Parent
from __code import DataType
from __code import DataType, BatchJsonKeys
from __code.laminography_event_handler import LaminographyEventHandler
from __code.utilities.time import get_current_time_in_special_file_name_format
from __code.utilities.json import save_json


class BatchHandler(Parent):
Expand All @@ -12,19 +16,19 @@ def create_config_file(self):
list_ob_files = self.parent.input_files[DataType.ob]
list_dc_files = self.parent.input_files[DataType.dc]

# crop region
crop_left, crop_right, crop_top, crop_bottom = list(self.parent.cropping.result)
# crop region (left, right, top, bottom)
crop_region = list(self.parent.cropping.result)

## filter #1
# gamma filtering flag
gamma_filtering_flag = self.parent.gamma_filtering_ui.value

# beam fluctuation correction flag and region
# beam fluctuation correction flag and region (left, right, top, bottom)
beam_fluctuation_flag = self.parent.beam_fluctuation_ui.value
bf_left, bf_right, bf_top, bf_bottom = list(self.parent.beam_fluctuation_roi.value)
beam_fluctuation_region = list(self.parent.beam_fluctuation_roi.result)

# tilt value
tilt_value = self.parent.tilt_options_ui.value
tilt_value = self.parent.tilt_option_ui.value

## filter #2
# remove negative values
Expand All @@ -36,16 +40,53 @@ def create_config_file(self):
ketcham_flag = self.parent.ring_removal_ui.children[2].value

# range of slices to reconstruct
top_slice, bottom_slice = list(self.parent.z_range_selection.result)
range_slices_to_reconstruct = list(self.parent.z_range_selection.result)

# laminography parameters
laminography_dict = self.parent.laminography_settings_ui
angle = laminography_dict['angle'].value
list_gpu_index = LaminographyEventHandler(laminography_dict['list_gpus'])
num_iter = laminography_dict['num_iterations'].value
mrf_p = laminography_dict['mrf_p'].value
mrf_sigma = laminography_dict['mrf_sigma'].value
stop_threhsold = laminography_dict['stop_threshold'].value
verbose = laminography_dict['verbose'].value

json_file_name = pass
ui_laminography_dict = self.parent.laminography_settings_ui
angle = ui_laminography_dict[BatchJsonKeys.angle].value
list_gpu_index = LaminographyEventHandler.get_gpu_index(laminography_dict[BatchJsonKeys.list_gpus])
num_iter = ui_laminography_dict[BatchJsonKeys.num_iterations].value
mrf_p = ui_laminography_dict[BatchJsonKeys.mrf_p].value
mrf_sigma = ui_laminography_dict[BatchJsonKeys.mrf_sigma].value
stop_threhsold = ui_laminography_dict[BatchJsonKeys.stop_threshold].value
verbose = ui_laminography_dict[BatchJsonKeys.verbose].value
laminograph_dict = {BatchJsonKeys.angle: angle,
BatchJsonKeys.list_gpus: list_gpu_index,
BatchJsonKeys.num_iterations: num_iter,
BatchJsonKeys.mrf_p: mrf_p,
BatchJsonKeys.mrf_sigma: mrf_sigma,
BatchJsonKeys.stop_threshold: stop_threhsold,
BatchJsonKeys.verbose: verbose}


# create json dictionary
json_dictionary = {BatchJsonKeys.list_raw_files: list_raw_files,
BatchJsonKeys.list_ob_files: list_ob_files,
BatchJsonKeys.list_dc_files: list_dc_files,
BatchJsonKeys.crop_region: crop_region,
BatchJsonKeys.gamma_filtering_flag: gamma_filtering_flag,
BatchJsonKeys.beam_fluctuation_flag:beam_fluctuation_flag,
BatchJsonKeys.beam_fluctuation_region: beam_fluctuation_region,
BatchJsonKeys.tilt_value: tilt_value,
BatchJsonKeys.remove_negative_values_flag: remove_negative_values_flag,
BatchJsonKeys.bm3d_flag: bm3d_flag,
BatchJsonKeys.tomopy_v0_flag: tomopy_v0_flag,
BatchJsonKeys.ketcham_flag: ketcham_flag,
BatchJsonKeys.range_slices_to_reconstruct: range_slices_to_reconstruct,
BatchJsonKeys.laminography_dict: laminography_dict,
}

_current_time = get_current_time_in_special_file_name_format()
base_folder_name = self.parent.raw_folder_name
json_file_name = os.path.join(os.path.expanduser("~"),
f"laminography_{base_folder_name}_{_current_time}.json")

log_file_name = os.path.join(os.path.expanduser("~"),
f"laminography_{base_folder_name}_{_current_time}.cfg")

# print(f"{json_dictionary =}")

save_json(json_file_name=json_file_name,
json_dictionary=json_dictionary)

33 changes: 17 additions & 16 deletions __code/laminography_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __code.system import System
from __code.utilities.files import save_json
from __code.utilities.time import convert_time_s_in_time_hr_mn_s
from __code import BatchJsonKeys


class LaminographyEventHandler:
Expand Down Expand Up @@ -88,13 +89,13 @@ def set_settings(self):
display(tab)

# saving widgets for batch mode
self.parent.laminography_settings_ui = {'angle': self.laminography_angle_ui,
'list_gpus': self.children_gpus,
'num_iterations': self.num_iter_ui,
'mrf_p': self.mrf_p_ui,
'mrf_sigma': self.mrf_sigma_ui,
'stop_threshold': self.stop_threshold_ui,
'verbose': self.verbose_ui}
self.parent.laminography_settings_ui = {BatchJsonKeys.angle: self.laminography_angle_ui,
BatchJsonKeys.list_gpus: self.children_gpus,
BatchJsonKeys.num_iterations: self.num_iter_ui,
BatchJsonKeys.mrf_p: self.mrf_p_ui,
BatchJsonKeys.mrf_sigma: self.mrf_sigma_ui,
BatchJsonKeys.stop_threshold: self.stop_threshold_ui,
BatchJsonKeys.verbose: self.verbose_ui}

@staticmethod
def get_gpu_index(children_gpus_ui):
Expand All @@ -106,19 +107,19 @@ def get_gpu_index(children_gpus_ui):

def get_rec_params(self):
rec_params = {}
rec_params['num_iter'] = self.num_iter_ui.value
rec_params['gpu_index'] = LaminographyEventHandler.get_gpu_index(self.children_gpus[1:])
rec_params['MRF_P'] = self.mrf_p_ui.value
rec_params['MRF_SIGMA'] = self.mrf_sigma_ui.value
rec_params[BatchJsonKeys.num_iterations] = self.num_iter_ui.value
rec_params[BatchJsonKeys.list_gpus] = LaminographyEventHandler.get_gpu_index(self.children_gpus[1:])
rec_params[BatchJsonKeys.mrf_p] = self.mrf_p_ui.value
rec_params[BatchJsonKeys.mrf_sigma] = self.mrf_sigma_ui.value
# rec_params['huber_T'] = self.huber_t
# rec_params['huber_delta'] = self.huber_delta
# rec_params['sigma'] = self.sigma
# rec_params['reject_frac'] = self.reject_frac
rec_params['verbose'] = self.verbose_ui.value
rec_params['debug'] = self.debug
rec_params['stop_thresh'] = self.stop_threshold_ui.value
rec_params['filt_cutoff'] = 0.5
rec_params['filt_type'] = 'Ram-Lak'
rec_params[BatchJsonKeys.verbose] = self.verbose_ui.value
rec_params[BatchJsonKeys.debug] = self.debug
rec_params[BatchJsonKeys.stop_threshold] = self.stop_threshold_ui.value
rec_params[BatchJsonKeys.filt_cutoff] = 0.5
rec_params[BatchJsonKeys.filt_type] = 'Ram-Lak'

return rec_params

Expand Down
4 changes: 4 additions & 0 deletions __code/laminographyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class LaminographyUi:
ob_raw = None
dc_raw = None

# name of the raw folder
raw_folder_name = None

investigate_data_flag = False

o_tilt = None
Expand All @@ -116,6 +119,7 @@ def __init__(self, working_dir="./"):
self.working_dir[DataType.raw] = os.path.join(init_path_to_raw, default_input_folder[DataType.raw])
self.working_dir[DataType.ob] = os.path.join(init_path_to_raw, default_input_folder[DataType.ob])
self.working_dir[DataType.dc] = os.path.join(init_path_to_raw, default_input_folder[DataType.dc])
print("version 07-30-2024")

# SELECT INPUT DATA ===============================================================================================
def select_raw(self):
Expand Down
4 changes: 4 additions & 0 deletions __code/laminographyui_batch_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class LaminographyUi:
ob_raw = None
dc_raw = None

# name of the raw folder
raw_folder_name = None

investigate_data_flag = False

o_tilt = None
Expand Down Expand Up @@ -168,6 +171,7 @@ def define_parameters(self):
self.display_section_title(name='Beam fluctuation')
o_beam = BeamFluctuationCorrection(parent=self)
o_beam.beam_fluctuation_correction_option()
o_beam.apply_select_beam_fluctuation(batch_mode=True)

self.display_section_title(name='Tilt calculation')
self.tilt_correction_options()
Expand Down
17 changes: 17 additions & 0 deletions __code/utilities/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
import os


def load_json(json_file_name):
if not os.path.exists(json_file_name):
return None

with open(json_file_name) as json_file:
data = json.load(json_file)

return data


def save_json(json_file_name, json_dictionary=None):
with open(json_file_name, 'w') as outfile:
json.dump(json_dictionary, outfile)
13 changes: 6 additions & 7 deletions __code/workflow/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ def load_percentage_of_data(self, percentage_to_load=5):
max_workers=20) # use 20 workers
)








def load_data(self):

self.parent.proj_raw, self.parent.ob_raw, self.parent.dc_raw, self.parent.rot_angles = (
Expand All @@ -100,6 +93,12 @@ def load_data(self):

self.parent.dc_raw = np.array([np.zeros_like(self.parent.proj_raw[0])])

# debugging - use np.float16 instead of default np.float64
print(f"Before conversion: {self.parent.proj_raw.dtype= }")
# self.parent.proj_raw = self.parent.proj_raw.astype(np.float16)
# self.parent.ob_raw = self.parent.ob_raw.astype(np.float16)
# self.parent.dc_raw = self.parent.dc_raw.astype(np.float16)
# print(f"After conversion: {self.parent.proj_raw.dtype= }")
self.parent.untouched_sample_data = copy.deepcopy(self.parent.proj_raw)

def select_dc_options(self):
Expand Down
6 changes: 4 additions & 2 deletions __code/workflow/tilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,15 @@ def plot_comparisons(algo_selected, color_range, col, row, zoom_x, zoom_y):
display(self.test_tilt)

def display_batch_options(self):
self.parent.tilt_options_ui = widgets.VBox([
tilt_options_ui = widgets.VBox([
widgets.Label("Tilt value (degrees)",
layout=widgets.Layout(width='200px'),
),
widgets.FloatSlider(min=-90,
max=90,
value=0)
])
display(self.parent.tilt_options_ui)
display(tilt_options_ui)

self.parent.tilt_option_ui = tilt_options_ui.children[1]

Loading

0 comments on commit 50e4dbe

Please sign in to comment.