Skip to content

Commit 913f2f3

Browse files
committed
Compute projection instead of dictionary
1 parent 214c770 commit 913f2f3

File tree

1 file changed

+15
-95
lines changed

1 file changed

+15
-95
lines changed

nanshe_ipython.ipynb

Lines changed: 15 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@
241241
"source": [
242242
"from nanshe_workflow.par import halo_block_parallel\n",
243243
"\n",
244-
"from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data\n",
244+
"from nanshe_workflow.imp2 import extract_f0, wavelet_transform, normalize_data\n",
245245
"\n",
246246
"from nanshe_workflow.par import halo_block_generate_dictionary_parallel\n",
247247
"from nanshe_workflow.imp import block_postprocess_data_parallel\n",
@@ -993,11 +993,10 @@
993993
"cell_type": "markdown",
994994
"metadata": {},
995995
"source": [
996-
"### Normalize Data\n",
996+
"### Project\n",
997997
"\n",
998998
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
999-
"* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).\n",
1000-
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
999+
"* `proj_type` (`str`): type of projection to take."
10011000
]
10021001
},
10031002
{
@@ -1007,12 +1006,11 @@
10071006
"outputs": [],
10081007
"source": [
10091008
"block_frames = 40\n",
1010-
"block_space = 300\n",
1011-
"norm_frames = 100\n",
1009+
"proj_type = \"max\"\n",
10121010
"\n",
10131011
"\n",
10141012
"with get_executor(client) as executor:\n",
1015-
" dask_io_remove(data_basename + postfix_norm + zarr_ext, executor)\n",
1013+
" dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n",
10161014
"\n",
10171015
"\n",
10181016
" with open_zarr(data_basename + postfix_wt + zarr_ext, \"r\") as f:\n",
@@ -1027,106 +1025,28 @@
10271025
" da_imgs_flt.dtype.itemsize >= 4):\n",
10281026
" da_imgs_flt = da_imgs_flt.astype(np.float32)\n",
10291027
"\n",
1030-
" da_imgs_flt_mins = da_imgs_flt.min(\n",
1031-
" axis=tuple(irange(1, da_imgs_flt.ndim)),\n",
1032-
" keepdims=True\n",
1033-
" )\n",
1034-
"\n",
1035-
" da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins\n",
1036-
"\n",
1037-
" da_result = renormalized_images(da_imgs_flt_shift)\n",
1028+
" da_result = da_imgs\n",
1029+
" if proj_type == \"max\":\n",
1030+
" da_result = da_result.max(axis=0, keepdims=True)\n",
1031+
" elif proj_type == \"std\":\n",
1032+
" da_result = da_result.std(axis=0, keepdims=True)\n",
10381033
"\n",
10391034
" # Store denoised data\n",
1040-
" dask_store_zarr(data_basename + postfix_norm + zarr_ext, [\"images\"], [da_result], executor)\n",
1041-
"\n",
1042-
"\n",
1043-
" zip_zarr(data_basename + postfix_norm + zarr_ext, executor)\n",
1044-
"\n",
1045-
"\n",
1046-
"if __IPYTHON__:\n",
1047-
" result_image_stack = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n",
1048-
"\n",
1049-
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
1050-
" mplsv.set_images(\n",
1051-
" result_image_stack,\n",
1052-
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
1053-
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
1054-
" )"
1055-
]
1056-
},
1057-
{
1058-
"cell_type": "markdown",
1059-
"metadata": {},
1060-
"source": [
1061-
"### Dictionary Learning\n",
1062-
"\n",
1063-
"* `n_components` (`int`): number of basis images in the dictionary.\n",
1064-
"* `batchsize` (`int`): minibatch size to use.\n",
1065-
"* `iters` (`int`): number of iterations to run before getting dictionary.\n",
1066-
"* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.\n",
1067-
"* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.\n",
1068-
"\n",
1069-
"<br>\n",
1070-
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
1071-
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
1072-
]
1073-
},
1074-
{
1075-
"cell_type": "code",
1076-
"execution_count": null,
1077-
"metadata": {},
1078-
"outputs": [],
1079-
"source": [
1080-
"n_components = 50\n",
1081-
"batchsize = 256\n",
1082-
"iters = 100\n",
1083-
"lambda1 = 0.2\n",
1084-
"lambda2 = 0.0\n",
1085-
"\n",
1086-
"block_frames = 51\n",
1087-
"norm_frames = 100\n",
1035+
" dask_store_zarr(data_basename + postfix_dict + zarr_ext, [\"images\"], [da_result], executor)\n",
10881036
"\n",
10891037
"\n",
1090-
"with get_executor(client) as executor:\n",
1091-
" dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n",
1092-
"\n",
1093-
"\n",
1094-
"result = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n",
1095-
"block_shape = (block_frames,) + result.shape[1:]\n",
1096-
"with open_zarr(data_basename + postfix_dict + zarr_ext, \"w\") as f2:\n",
1097-
" new_result = f2.create_dataset(\"images\", shape=(n_components,) + result.shape[1:], dtype=result.dtype, chunks=True)\n",
1098-
"\n",
1099-
" result = par_generate_dictionary(block_shape)(\n",
1100-
" result,\n",
1101-
" n_components=n_components,\n",
1102-
" out=new_result,\n",
1103-
" **{\"sklearn.decomposition.dict_learning_online\" : {\n",
1104-
" \"n_jobs\" : 1,\n",
1105-
" \"n_iter\" : iters,\n",
1106-
" \"batch_size\" : batchsize,\n",
1107-
" \"alpha\" : lambda1\n",
1108-
" }\n",
1109-
" }\n",
1110-
" )\n",
1111-
"\n",
1112-
" result_j = f2.create_dataset(\"images_j\", shape=new_result.shape, dtype=numpy.uint16, chunks=True)\n",
1113-
" par_norm_layer(num_frames=norm_frames)(result, out=result_j)\n",
1114-
"\n",
1115-
"\n",
1116-
"with get_executor(client) as executor:\n",
11171038
" zip_zarr(data_basename + postfix_dict + zarr_ext, executor)\n",
11181039
"\n",
11191040
"\n",
11201041
"if __IPYTHON__:\n",
1121-
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")\n",
1042+
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")[...][...]\n",
11221043
"\n",
11231044
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
11241045
" mplsv.set_images(\n",
11251046
" result_image_stack,\n",
1126-
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
1127-
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
1128-
" )\n",
1129-
" mplsv.time_nav.stime.label.set_text(\"Basis Image\")"
1047+
" vmin=result_image_stack.min(),\n",
1048+
" vmax=result_image_stack.max()\n",
1049+
" )"
11301050
]
11311051
},
11321052
{

0 commit comments

Comments
 (0)