Skip to content

Commit 50e4dbe

Browse files
committed
Batch mode. this refs #33
1 parent 283e303 commit 50e4dbe

11 files changed

+281
-154
lines changed

__code/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,35 @@ class SvmbirParameters(RecParameters):
8989
weight_type = 'weight type'
9090
verbose = 'verbose'
9191
temp_disk = 'temp disk'
92+
93+
94+
class BatchJsonKeys:
95+
96+
list_raw_files = 'list_raw_files'
97+
list_ob_files = 'list_ob_files'
98+
list_dc_files = 'list_dc_files'
99+
crop_region = 'crop_region'
100+
gamma_filtering_flag = 'gamma_filtering_flag'
101+
beam_fluctuation_flag = 'beam_fluctuation_flag'
102+
beam_fluctuation_region = 'beam_fluctuation_region'
103+
tilt_value = 'tilt_value'
104+
remove_negative_values_flag = 'remove_negative_values_flag'
105+
bm3d_flag = 'bm3d_flag'
106+
tomopy_v0_flag = 'tomopy_v0_flag'
107+
ketcham_flag = 'ketcham_flag'
108+
range_slices_to_reconstruct = 'range_slices_to_reconstruct'
109+
laminography_dict = 'laminography_dict'
110+
111+
angle = 'angle'
112+
list_gpus = 'list_gpus'
113+
num_iterations = 'num_iterations'
114+
mrf_p = 'mrf_p'
115+
mrf_sigma = 'mrf_sigma'
116+
stop_threshold = 'stop_threshold'
117+
verbose = 'verbose'
118+
debug = 'debug'
119+
log_file_name = 'log_file_name'
120+
121+
filt_cutoff = 'filt_cutoff'
122+
filt_type = 'filt_type'
123+

__code/batch_handler.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import os
2+
13
from __code.parent import Parent
2-
from __code import DataType
4+
from __code import DataType, BatchJsonKeys
35
from __code.laminography_event_handler import LaminographyEventHandler
6+
from __code.utilities.time import get_current_time_in_special_file_name_format
7+
from __code.utilities.json import save_json
48

59

610
class BatchHandler(Parent):
@@ -12,19 +16,19 @@ def create_config_file(self):
1216
list_ob_files = self.parent.input_files[DataType.ob]
1317
list_dc_files = self.parent.input_files[DataType.dc]
1418

15-
# crop region
16-
crop_left, crop_right, crop_top, crop_bottom = list(self.parent.cropping.result)
19+
# crop region (left, right, top, bottom)
20+
crop_region = list(self.parent.cropping.result)
1721

1822
## filter #1
1923
# gamma filtering flag
2024
gamma_filtering_flag = self.parent.gamma_filtering_ui.value
2125

22-
# beam fluctuation correction flag and region
26+
# beam fluctuation correction flag and region (left, right, top, bottom)
2327
beam_fluctuation_flag = self.parent.beam_fluctuation_ui.value
24-
bf_left, bf_right, bf_top, bf_bottom = list(self.parent.beam_fluctuation_roi.value)
28+
beam_fluctuation_region = list(self.parent.beam_fluctuation_roi.result)
2529

2630
# tilt value
27-
tilt_value = self.parent.tilt_options_ui.value
31+
tilt_value = self.parent.tilt_option_ui.value
2832

2933
## filter #2
3034
# remove negative values
@@ -36,16 +40,53 @@ def create_config_file(self):
3640
ketcham_flag = self.parent.ring_removal_ui.children[2].value
3741

3842
# range of slices to reconstruct
39-
top_slice, bottom_slice = list(self.parent.z_range_selection.result)
43+
range_slices_to_reconstruct = list(self.parent.z_range_selection.result)
4044

4145
# laminography parameters
42-
laminography_dict = self.parent.laminography_settings_ui
43-
angle = laminography_dict['angle'].value
44-
list_gpu_index = LaminographyEventHandler(laminography_dict['list_gpus'])
45-
num_iter = laminography_dict['num_iterations'].value
46-
mrf_p = laminography_dict['mrf_p'].value
47-
mrf_sigma = laminography_dict['mrf_sigma'].value
48-
stop_threhsold = laminography_dict['stop_threshold'].value
49-
verbose = laminography_dict['verbose'].value
50-
51-
json_file_name = pass
46+
ui_laminography_dict = self.parent.laminography_settings_ui
47+
angle = ui_laminography_dict[BatchJsonKeys.angle].value
48+
list_gpu_index = LaminographyEventHandler.get_gpu_index(laminography_dict[BatchJsonKeys.list_gpus])
49+
num_iter = ui_laminography_dict[BatchJsonKeys.num_iterations].value
50+
mrf_p = ui_laminography_dict[BatchJsonKeys.mrf_p].value
51+
mrf_sigma = ui_laminography_dict[BatchJsonKeys.mrf_sigma].value
52+
stop_threhsold = ui_laminography_dict[BatchJsonKeys.stop_threshold].value
53+
verbose = ui_laminography_dict[BatchJsonKeys.verbose].value
54+
laminograph_dict = {BatchJsonKeys.angle: angle,
55+
BatchJsonKeys.list_gpus: list_gpu_index,
56+
BatchJsonKeys.num_iterations: num_iter,
57+
BatchJsonKeys.mrf_p: mrf_p,
58+
BatchJsonKeys.mrf_sigma: mrf_sigma,
59+
BatchJsonKeys.stop_threshold: stop_threhsold,
60+
BatchJsonKeys.verbose: verbose}
61+
62+
63+
# create json dictionary
64+
json_dictionary = {BatchJsonKeys.list_raw_files: list_raw_files,
65+
BatchJsonKeys.list_ob_files: list_ob_files,
66+
BatchJsonKeys.list_dc_files: list_dc_files,
67+
BatchJsonKeys.crop_region: crop_region,
68+
BatchJsonKeys.gamma_filtering_flag: gamma_filtering_flag,
69+
BatchJsonKeys.beam_fluctuation_flag:beam_fluctuation_flag,
70+
BatchJsonKeys.beam_fluctuation_region: beam_fluctuation_region,
71+
BatchJsonKeys.tilt_value: tilt_value,
72+
BatchJsonKeys.remove_negative_values_flag: remove_negative_values_flag,
73+
BatchJsonKeys.bm3d_flag: bm3d_flag,
74+
BatchJsonKeys.tomopy_v0_flag: tomopy_v0_flag,
75+
BatchJsonKeys.ketcham_flag: ketcham_flag,
76+
BatchJsonKeys.range_slices_to_reconstruct: range_slices_to_reconstruct,
77+
BatchJsonKeys.laminography_dict: laminography_dict,
78+
}
79+
80+
_current_time = get_current_time_in_special_file_name_format()
81+
base_folder_name = self.parent.raw_folder_name
82+
json_file_name = os.path.join(os.path.expanduser("~"),
83+
f"laminography_{base_folder_name}_{_current_time}.json")
84+
85+
log_file_name = os.path.join(os.path.expanduser("~"),
86+
f"laminography_{base_folder_name}_{_current_time}.cfg")
87+
88+
# print(f"{json_dictionary =}")
89+
90+
save_json(json_file_name=json_file_name,
91+
json_dictionary=json_dictionary)
92+

__code/laminography_event_handler.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __code.system import System
1313
from __code.utilities.files import save_json
1414
from __code.utilities.time import convert_time_s_in_time_hr_mn_s
15+
from __code import BatchJsonKeys
1516

1617

1718
class LaminographyEventHandler:
@@ -88,13 +89,13 @@ def set_settings(self):
8889
display(tab)
8990

9091
# saving widgets for batch mode
91-
self.parent.laminography_settings_ui = {'angle': self.laminography_angle_ui,
92-
'list_gpus': self.children_gpus,
93-
'num_iterations': self.num_iter_ui,
94-
'mrf_p': self.mrf_p_ui,
95-
'mrf_sigma': self.mrf_sigma_ui,
96-
'stop_threshold': self.stop_threshold_ui,
97-
'verbose': self.verbose_ui}
92+
self.parent.laminography_settings_ui = {BatchJsonKeys.angle: self.laminography_angle_ui,
93+
BatchJsonKeys.list_gpus: self.children_gpus,
94+
BatchJsonKeys.num_iterations: self.num_iter_ui,
95+
BatchJsonKeys.mrf_p: self.mrf_p_ui,
96+
BatchJsonKeys.mrf_sigma: self.mrf_sigma_ui,
97+
BatchJsonKeys.stop_threshold: self.stop_threshold_ui,
98+
BatchJsonKeys.verbose: self.verbose_ui}
9899

99100
@staticmethod
100101
def get_gpu_index(children_gpus_ui):
@@ -106,19 +107,19 @@ def get_gpu_index(children_gpus_ui):
106107

107108
def get_rec_params(self):
108109
rec_params = {}
109-
rec_params['num_iter'] = self.num_iter_ui.value
110-
rec_params['gpu_index'] = LaminographyEventHandler.get_gpu_index(self.children_gpus[1:])
111-
rec_params['MRF_P'] = self.mrf_p_ui.value
112-
rec_params['MRF_SIGMA'] = self.mrf_sigma_ui.value
110+
rec_params[BatchJsonKeys.num_iterations] = self.num_iter_ui.value
111+
rec_params[BatchJsonKeys.list_gpus] = LaminographyEventHandler.get_gpu_index(self.children_gpus[1:])
112+
rec_params[BatchJsonKeys.mrf_p] = self.mrf_p_ui.value
113+
rec_params[BatchJsonKeys.mrf_sigma] = self.mrf_sigma_ui.value
113114
# rec_params['huber_T'] = self.huber_t
114115
# rec_params['huber_delta'] = self.huber_delta
115116
# rec_params['sigma'] = self.sigma
116117
# rec_params['reject_frac'] = self.reject_frac
117-
rec_params['verbose'] = self.verbose_ui.value
118-
rec_params['debug'] = self.debug
119-
rec_params['stop_thresh'] = self.stop_threshold_ui.value
120-
rec_params['filt_cutoff'] = 0.5
121-
rec_params['filt_type'] = 'Ram-Lak'
118+
rec_params[BatchJsonKeys.verbose] = self.verbose_ui.value
119+
rec_params[BatchJsonKeys.debug] = self.debug
120+
rec_params[BatchJsonKeys.stop_threshold] = self.stop_threshold_ui.value
121+
rec_params[BatchJsonKeys.filt_cutoff] = 0.5
122+
rec_params[BatchJsonKeys.filt_type] = 'Ram-Lak'
122123

123124
return rec_params
124125

__code/laminographyui.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ class LaminographyUi:
9797
ob_raw = None
9898
dc_raw = None
9999

100+
# name of the raw folder
101+
raw_folder_name = None
102+
100103
investigate_data_flag = False
101104

102105
o_tilt = None
@@ -116,6 +119,7 @@ def __init__(self, working_dir="./"):
116119
self.working_dir[DataType.raw] = os.path.join(init_path_to_raw, default_input_folder[DataType.raw])
117120
self.working_dir[DataType.ob] = os.path.join(init_path_to_raw, default_input_folder[DataType.ob])
118121
self.working_dir[DataType.dc] = os.path.join(init_path_to_raw, default_input_folder[DataType.dc])
122+
print("version 07-30-2024")
119123

120124
# SELECT INPUT DATA ===============================================================================================
121125
def select_raw(self):

__code/laminographyui_batch_mode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class LaminographyUi:
105105
ob_raw = None
106106
dc_raw = None
107107

108+
# name of the raw folder
109+
raw_folder_name = None
110+
108111
investigate_data_flag = False
109112

110113
o_tilt = None
@@ -168,6 +171,7 @@ def define_parameters(self):
168171
self.display_section_title(name='Beam fluctuation')
169172
o_beam = BeamFluctuationCorrection(parent=self)
170173
o_beam.beam_fluctuation_correction_option()
174+
o_beam.apply_select_beam_fluctuation(batch_mode=True)
171175

172176
self.display_section_title(name='Tilt calculation')
173177
self.tilt_correction_options()

__code/utilities/json.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import json
2+
import os
3+
4+
5+
def load_json(json_file_name):
6+
if not os.path.exists(json_file_name):
7+
return None
8+
9+
with open(json_file_name) as json_file:
10+
data = json.load(json_file)
11+
12+
return data
13+
14+
15+
def save_json(json_file_name, json_dictionary=None):
16+
with open(json_file_name, 'w') as outfile:
17+
json.dump(json_dictionary, outfile)

__code/workflow/load.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,6 @@ def load_percentage_of_data(self, percentage_to_load=5):
7676
max_workers=20) # use 20 workers
7777
)
7878

79-
80-
81-
82-
83-
84-
85-
8679
def load_data(self):
8780

8881
self.parent.proj_raw, self.parent.ob_raw, self.parent.dc_raw, self.parent.rot_angles = (
@@ -100,6 +93,12 @@ def load_data(self):
10093

10194
self.parent.dc_raw = np.array([np.zeros_like(self.parent.proj_raw[0])])
10295

96+
# debugging - use np.float16 instead of default np.float64
97+
print(f"Before conversion: {self.parent.proj_raw.dtype= }")
98+
# self.parent.proj_raw = self.parent.proj_raw.astype(np.float16)
99+
# self.parent.ob_raw = self.parent.ob_raw.astype(np.float16)
100+
# self.parent.dc_raw = self.parent.dc_raw.astype(np.float16)
101+
# print(f"After conversion: {self.parent.proj_raw.dtype= }")
103102
self.parent.untouched_sample_data = copy.deepcopy(self.parent.proj_raw)
104103

105104
def select_dc_options(self):

__code/workflow/tilt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,13 +533,15 @@ def plot_comparisons(algo_selected, color_range, col, row, zoom_x, zoom_y):
533533
display(self.test_tilt)
534534

535535
def display_batch_options(self):
536-
self.parent.tilt_options_ui = widgets.VBox([
536+
tilt_options_ui = widgets.VBox([
537537
widgets.Label("Tilt value (degrees)",
538538
layout=widgets.Layout(width='200px'),
539539
),
540540
widgets.FloatSlider(min=-90,
541541
max=90,
542542
value=0)
543543
])
544-
display(self.parent.tilt_options_ui)
544+
display(tilt_options_ui)
545+
546+
self.parent.tilt_option_ui = tilt_options_ui.children[1]
545547

0 commit comments

Comments
 (0)