diff --git a/config/hyperparams_icarus.yaml b/config/hyperparams_icarus.yaml index 34fda75..f7ea11e 100644 --- a/config/hyperparams_icarus.yaml +++ b/config/hyperparams_icarus.yaml @@ -29,7 +29,7 @@ Model: use_fine_model : True # If set, creates a fine model d_filter_fine : 512 # Dimensions of linear layer filters of fine network n_layers_fine : 8 # Number of layers in fine network bottleneck - d_output : 1 # electron density + d_output : 4 # electron density Hierarchical sampling: n_samples_hierarchical : 128 # Number of samples per ray diff --git a/config/psi_train.yaml b/config/psi_train.yaml index 8815f18..dba59cf 100644 --- a/config/psi_train.yaml +++ b/config/psi_train.yaml @@ -6,5 +6,7 @@ Training: Lambda: regularization: 1 - + continuity: 1.e-4 + radial_regularization: 1.e-4 + velocity_regularization: 1.e-4 Debug: False \ No newline at end of file diff --git a/scripts/run_icarus.sh b/scripts/run_icarus.sh index 2aef71c..4775885 100644 --- a/scripts/run_icarus.sh +++ b/scripts/run_icarus.sh @@ -1,12 +1,6 @@ # Training from scratch # init workspace -# PSI -# convert data -python -m sunerf.prep.prep_psi_cor --psi_path "/mnt/ground-data/PSI/pb_raw/*.fits" --output_path "/mnt/prep-data/prep_PSI/pb_raw" -python -m sunerf.prep.prep_psi_cor --psi_path "/mnt/ground-data/PSI/b_raw/*.fits" --output_path "/mnt/prep-data/prep_PSI/b_raw" -# full training PSI -python -m sunerf.sunerf --wandb_name "psi" --data_path_pB "/mnt/ground-data/prep_PSI/pb_raw/*.fits" --data_path_tB "/mnt/ground-data/prep_PSI/b_raw/*.fits" --path_to_save "/mnt/training/PSI_v1" --train "config/train.yaml" --hyperparameters "config/hyperparams_icarus.yaml" # HAO # convert data @@ -102,7 +96,8 @@ gsutil -m cp gs://fdl23_europe_helio_onground/ground-data/data_fits/dcmer_340W_ gsutil -m cp gs://fdl23_europe_helio_onground/ground-data/data_fits/dcmer_360W_bang_0000_tB/stepnum_005.fits /mnt/ground-data/data_fits/dcmer_360W_bang_0000_tB/stepnum_005.fits gsutil -m cp gs://fdl23_europe_helio_onground/ground-data/data_fits/dcmer_360W_bang_0000_pB/stepnum_005.fits /mnt/ground-data/data_fits/dcmer_360W_bang_0000_pB/stepnum_005.fits - +# Download for all of the PSI Data +gsutil -m cp -R gs://fdl23_europe_helio_onground/ground-data/PSI /mnt/ground-data/PSI/ ################ # # # Prep Data # @@ -119,15 +114,27 @@ python -m sunerf.prep.prep_hao --resolution 512 --hao_path "/mnt/ground-data/dat # prep_HAO_allview python -m sunerf.prep.prep_hao --resolution 512 --hao_path "/mnt/ground-data/data_fits/**/*.fits" --output_path /mnt/prep-data/prep_HAO_allview --check_matching +# PSI +# convert data +python -m sunerf.prep.prep_psi_cor --psi_path "/mnt/ground-data/PSI/pb_raw/*.fits" --output_path "/mnt/prep-data/prep_PSI/pb_raw" +python -m sunerf.prep.prep_psi_cor --psi_path "/mnt/ground-data/PSI/b_raw/*.fits" --output_path "/mnt/prep-data/prep_PSI/b_raw" + + ###################### # # # Running ICARUS # # # ###################### +#Prep_HAO_1view +python -m sunerf.sunerf --wandb_name "hao_pinn_1view" --data_path_pB "/mnt/prep-data/prep_HAO_1view/*pB*.fits" --data_path_tB "/mnt/prep-data/prep_HAO_1view/*tB*.fits" --path_to_save "/mnt/training/HAO_pinn_1view" --train "config/train.yaml" --hyperparameters "config/hyperparams_hao.yaml" + # Prep_HAO_2view python -m sunerf.sunerf --wandb_name "hao_pinn_2view" --data_path_pB "/mnt/prep-data/prep_HAO_2view/*pB*.fits" --data_path_tB "/mnt/prep-data/prep_HAO_2view/*tB*.fits" --path_to_save "/mnt/training/HAO_pinn_2view" --train "config/train.yaml" --hyperparameters "config/hyperparams_hao.yaml" # prep_HAO_2view_backgrounds python -m sunerf.sunerf --wandb_name "hao_pinn_2view_background" --data_path_pB "/mnt/prep-data/prep_HAO_2view_background/*pB*.fits" --data_path_tB "/mnt/prep-data/prep_HAO_2view_background/*tB*.fits" --path_to_save "/mnt/training/HAO_pinn_2view_background" --train "config/train.yaml" --hyperparameters "config/hyperparams_hao.yaml" # prep_HAO_allview python -m sunerf.sunerf --wandb_name "hao_pinn_all" --data_path_pB "/mnt/prep-data/prep_HAO_allview/*pB*.fits" --data_path_tB "/mnt/prep-data/prep_HAO_allview/*tB*.fits" --path_to_save "/mnt/training/HAO_pinn_allview" --train "config/train.yaml" --hyperparameters "config/hyperparams_hao.yaml" + +# full training PSI +python -m sunerf.sunerf --wandb_name "psi" --data_path_pB "/mnt/prep-data/prep_PSI/pb_raw/*.fits" --data_path_tB "/mnt/prep-data/prep_PSI/b_raw/*.fits" --path_to_save "/mnt/training/PSI_v1" --train "config/train.yaml" --hyperparameters "config/hyperparams_icarus.yaml" diff --git a/sunerf/evaluation/data_cube_model.py b/sunerf/evaluation/data_cube_model.py new file mode 100644 index 0000000..a2670b2 --- /dev/null +++ b/sunerf/evaluation/data_cube_model.py @@ -0,0 +1,482 @@ +import os + +import numpy as np +import pandas as pd +import torch +from matplotlib import pyplot as plt +from tqdm import tqdm + +from sunerf.evaluation.loader import SuNeRFLoader +from sunerf.utilities.data_loader import normalize_datetime +import cv2 + +from tvtk.api import tvtk, write_data + +from mpl_toolkits.mplot3d import Axes3D + +import imageio #gifs +import logging +import sys + +#import vtk +#import tvtk + +#logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) # Output to console. +from scipy.stats import multivariate_normal + +base_path = '/mnt/training/HAO_pinn_cr_2view_a26978f_heliographic_reformat' + +chk_path = os.path.join(base_path, 'save_state.snf') +video_path_dens = os.path.join(base_path, 'video_cube') + + +#parameter for filtering points that belong to a CME +mask_mode = True + +# init loader +loader = SuNeRFLoader(chk_path, resolution=512) +os.makedirs(video_path_dens, exist_ok=True) +n_cubes = 70 # How many cubes to generate + +# Points in R_solar +num_points = 16 + +x = np.linspace(-250,250,num_points) +y = np.linspace(-250,250,num_points) +z = np.linspace(-250,250,num_points) + +xx,yy,zz = np.meshgrid(x,y,z,indexing = "ij") +solar_center = np.array([0,0,0]) +distance = np.sqrt((xx - solar_center[0])**2 + (yy - solar_center[1])**2 + (zz - solar_center[2])**2) +#Cut out inner solar radii as per rest of the program +distance_mask = 21 +maximum_distance = 216 # Solar Radii - 1AU = ~215 S_/odot, so this means we restrict to 1AU +print("Masking inner {} Solar Radii, as well as anything beyond {} Solar Radii.".format(distance_mask, maximum_distance)) +outside_sun_mask = distance > distance_mask +inside_earth_orbit_mask = distance < maximum_distance +x_filtered = xx[outside_sun_mask & inside_earth_orbit_mask] +y_filtered = yy[outside_sun_mask & inside_earth_orbit_mask] +z_filtered = zz[outside_sun_mask & inside_earth_orbit_mask] + +percentile = 95 # Use this percentile for masking +percentile = np.clip(percentile, 0, 100) +densities = [] #1d - Generates Density at each point in each cube +velocities = [] #3d - 3 Velocity at each point in each cube +speeds = [] #1d - Speed at each point in each cube + +for i, timei in tqdm(enumerate(pd.date_range(loader.start_time, loader.end_time, n_cubes)), total=n_cubes): + + time = normalize_datetime(timei) + t = np.ones_like(x_filtered) * time + query_points_npy = np.stack([x_filtered, y_filtered, z_filtered, t], -1).astype(np.float32) + # (256, 258, 4) + + query_points = torch.from_numpy(query_points_npy) + enc_query_points = loader.encoding_fn(query_points.view(-1, 4)) + raw = loader.fine_model(enc_query_points) + #electron_density = 10 ** (15 + x[..., 0]) + #velocity = torch.tanh(x[..., 1:]) / 3 * 250 + 50 + density = raw[...,0] # Function has been moved into the model, either remove it from the model or not. + velocity = raw[..., 1:] + + if torch.isnan(density).any() or torch.isnan(velocity).any() or torch.isinf(density).any() or torch.isinf(velocity).any(): + # remove nan values + density = torch.nan_to_num(density, nan=0.0, posinf=0.0, neginf=0.0) + velocity = torch.nan_to_num(velocity, nan=0.0, posinf=0.0, neginf=0.0) + + density = density.view(query_points_npy.shape[0]).cpu().detach().numpy() + velocity = velocity.view(query_points_npy.shape[:1] + velocity.shape[-1:]).cpu().detach().numpy() + # velocity = velocity / 10 + mag = np.sqrt(velocity[...,0]**2+velocity[...,1]**2+velocity[...,2]**2) + + densities.append(density) + velocities.append(velocity) + speeds.append(mag) + + +global_max_v = np.asarray(speeds).max() +global_min_v = np.asarray(speeds).min() +global_max_rho = np.asarray(densities).max() +global_min_rho = np.asarray(densities).min() + +mean_density = np.asarray(densities).mean() +mean_velocity = np.asarray(speeds).mean() + +perc_dens = np.percentile(np.asarray(densities),percentile) +perc_speed = np.percentile(np.asarray(speeds),percentile) +print("Mean Density: {} g per cm^3 \n {}% Percentile:{} g per cm^3 \n Max Density: {} g per cm^3 \n Min Density: {:.3f} g per cm^3.".format(mean_density,percentile, perc_dens,global_max_rho, global_min_rho)) +print("Mean Velocity: {:.3f} Solar Radii per 2 days \n {}% Percentile:{:.3f} Solar Radii per 2 days \n Max Velocity: {:.3f} Solar Radii per 2 days \n Min Velocity: {:.3f} Solar Radii per 2 days.".format(mean_velocity,percentile,perc_speed,global_max_v, global_min_v)) +# Choose thresholds on some percentile +density_threshold = perc_dens +velocity_threshold = perc_speed #Most CMEs appear to be moving with a velocity of 3 +print("Thresholding density at {} grams per cm^3, velocity at {:.3f} Solar Radii per 2 days.".format(density_threshold, velocity_threshold)) + +def compute_alpha(alpha_mode, cube_norm, alpha_expon,distance_from_sun): + if alpha_mode==0: + alpha = cube_norm**alpha_expon + elif alpha_mode==1: + sig_arg = cube_norm**alpha_expon + distance_from_sun + alpha = np.tanh(sig_arg) + +def plot_datacube_directly(cube, global_min,global_max, tag, idx, x_fil, y_fil,z_fil, alpha_expon, norm = "linear", fname_subtag = None, alpha_mode=1): + cube_norm = (cube - global_min)/(global_max - global_min) + plt.close("all") + fig = plt.figure() + ax = fig.add_subplot(111, projection = "3d") + distance_from_sun = np.sqrt(x_fil**2 + y_fil**2 + z_fil**2)/250 + alpha = compute_alpha(alpha_mode, cube_norm, alpha_expon,distance_from_sun) + if(len(x_fil) and len(y_fil) and len(z_fil)): + ax.scatter(x_fil, y_fil, z_fil, c=cube, marker='.',norm=norm , vmin=global_min, vmax=global_max, alpha = alpha) # + if norm == "log": + ticks = np.linspace(np.log(global_min), np.log(global_max), 10, endpoint = True) + else: + ticks= np.linspace(global_min,global_max,10, endpoint = True) + cbar = plt.colorbar(ax.collections[0], ax=ax, ticks = ticks) + cbar.set_label('{}'.format(tag)) + ax.set_xlim(-250,250) + ax.set_ylim(-250,250) + ax.set_zlim(-250,250) + ax.set_xlabel('X[ Solar Radii ]') + ax.set_ylabel('Y[ Solar Radii ]') + ax.set_zlabel('Z[ Solar Radii ]') + ax.set_title('3D Scatter Plot of points based on {} at timestep {}'.format(tag, idx)) + + # Add Sun + u = np.linspace(0, 2 * np.pi, 100) + v = np.linspace(0, np.pi, 100) + x_sun = distance_mask * np.outer(np.cos(u), np.sin(v)) + y_sun = distance_mask * np.outer(np.sin(u), np.sin(v)) + z_sun = distance_mask * np.outer(np.ones(np.size(u)), np.cos(v)) + + # Plot the Sun + + ax.plot_surface(x_sun, y_sun, z_sun, color='gold', alpha=0.5) + + # Add Earth + radius_earth = 0.009157683 # Solar radii = 6371 km - its a dot. + x_earth = np.outer(np.cos(u), np.sin(v))*7# * radius_earth + y_earth = np.outer(np.sin(u), np.sin(v))*7# * radius_earth + z_earth = np.outer(np.ones(np.size(u)), np.cos(v))*7# * radius_earth + + # Earth position defined on x axis, 1 AU = 215.032 Solar Radii + r_earth = 215.032 + center_earth = (-r_earth, 0, 0) + # If center_earth = (-r_earth, 0, 0) then on a circle, this is at pi radians. If L5 is trailing, then L5 is at (r_earth*cos(2/3*np.pi), r_earth*sin(2/3*np.pi), 0) + center_l5 = (r_earth*np.cos(2/3*np.pi), r_earth*np.sin(2/3*np.pi), 0) + + x_l5 = x_earth + center_l5[0] + y_l5 = x_earth + center_l5[1] + z_l5 = x_earth + center_l5[2] + + x_earth += center_earth[0] + y_earth += center_earth[1] + z_earth += center_earth[2] + + # Plot the Earth + ax.plot_surface(x_earth, y_earth, z_earth, color='cyan', alpha=1) + # Plot L5 + ax.plot_surface(x_l5, y_l5, z_l5, color = "red", alpha = 1) + + + # Plot earth orbit + x_orbit = r_earth * np.cos(u) + y_orbit = r_earth * np.sin(u) + z_orbit = np.zeros_like(u) + + ax.scatter(x_orbit, y_orbit, z_orbit, c='gray', marker='.', alpha = 0.1) + + filename = os.path.join(video_path_dens, f'{tag}_cube_{idx:03d}.jpg') + + if fname_subtag is not None: + filename = os.path.join(video_path_dens, f'{tag}_{fname_subtag}_cube_{idx:03d}.jpg') + fig.savefig(filename, dpi=100) + return filename + + +def plot_datacube(cube,global_min:float, global_max:float, tag:str, idx:int, plot_threshold:float, alpha_expon:float = 1.5, norm = "linear", alpha_mode=1): + ''' + Function plotting the datacube, the Sun, Earth and L5, alongside Earths orbit + ''' + cube_norm = (cube - global_min)/(global_max - global_min) + plt.close("all") + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + #ax = fig.add_subplot(121, projection='3d') # Main 3D scatter plot + #ax_hist_x = fig.add_subplot(322) # Histogram for x-axis + #ax_hist_y = fig.add_subplot(324) # Histogram for y-axis + #ax_hist_z = fig.add_subplot(326) # Histogram for z-axis + + # Only plot values above a threshold + mask = cube > plot_threshold + + x_filtered_again = x_filtered[mask] + y_filtered_again = y_filtered[mask] + z_filtered_again = z_filtered[mask] + + cube = cube[mask] + cube_norm = cube_norm[mask] + + distance_from_sun = np.sqrt(x_filtered_again**2 + y_filtered_again**2 + z_filtered_again**2)/250 + alpha = compute_alpha(alpha_mode, cube_norm, alpha_expon,distance_from_sun) + ax.scatter(x_filtered_again, y_filtered_again, z_filtered_again, c=cube, marker='.',norm=norm , vmin=global_min, vmax=global_max, alpha = alpha) # + ax.set_xlim(-250,250) + ax.set_ylim(-250,250) + ax.set_zlim(-250,250) + ax.set_xlabel('X[ Solar Radii ]') + ax.set_ylabel('Y[ Solar Radii ]') + ax.set_zlabel('Z[ Solar Radii ]') + ax.set_title('3D Scatter Plot of points based on {} at timestep {}'.format(tag, idx)) + if norm == "log": + ticks = np.linspace(np.log(global_min), np.log(global_max), 10, endpoint = True) + else: + ticks= np.linspace(global_min,global_max,10, endpoint = True) + cbar = plt.colorbar(ax.collections[0], ax=ax, ticks = ticks) + cbar.set_label('{}'.format(tag)) + # Add Sun + u = np.linspace(0, 2 * np.pi, 100) + v = np.linspace(0, np.pi, 100) + x_sun = distance_mask * np.outer(np.cos(u), np.sin(v)) + y_sun = distance_mask * np.outer(np.sin(u), np.sin(v)) + z_sun = distance_mask * np.outer(np.ones(np.size(u)), np.cos(v)) + + # Plot the Sun + + ax.plot_surface(x_sun, y_sun, z_sun, color='gold', alpha=0.5) + + # Add Earth + radius_earth = 0.009157683 # Solar radii = 6371 km - its a dot. + x_earth = np.outer(np.cos(u), np.sin(v))*7# * radius_earth + y_earth = np.outer(np.sin(u), np.sin(v))*7# * radius_earth + z_earth = np.outer(np.ones(np.size(u)), np.cos(v))*7# * radius_earth + + # Earth position defined on x axis, 1 AU = 215.032 Solar Radii + r_earth = 215.032 + center_earth = (-r_earth, 0, 0) + # If center_earth = (-r_earth, 0, 0) then on a circle, this is at pi radians. If L5 is trailing, then L5 is at (r_earth*cos(2/3*np.pi), r_earth*sin(2/3*np.pi), 0) + center_l5 = (r_earth*np.cos(2/3*np.pi), r_earth*np.sin(2/3*np.pi), 0) + x_l5 = x_earth + center_l5[0] + y_l5 = x_earth + center_l5[1] + z_l5 = x_earth + center_l5[2] + + x_earth += center_earth[0] + y_earth += center_earth[1] + z_earth += center_earth[2] + + # Plot the Earth + ax.plot_surface(x_earth, y_earth, z_earth, color='cyan', alpha=1) + # Plot L5 + + ax.plot_surface(x_l5, y_l5, z_l5, color = "red", alpha = 1) + + + # Plot earth orbit + x_orbit = r_earth * np.cos(u) + y_orbit = r_earth * np.sin(u) + z_orbit = np.zeros_like(u) + + ax.scatter(x_orbit, y_orbit, z_orbit, c='gray', marker='.', alpha = 0.1) + + # Histograms along each axis + #ax_hist_x.hist(x_filtered_again, bins=15, color='b', alpha=0.6) + #ax_hist_y.hist(y_filtered_again, bins=15, color='g', alpha=0.6) + #ax_hist_z.hist(z_filtered_again, bins=15, color='r', alpha=0.6) + + # Set titles for histograms + #ax_hist_x.set_title('Histogram along X axis') + #ax_hist_y.set_title('Histogram along Y axis') + #ax_hist_z.set_title('Histogram along Z axis') + + filename = os.path.join(video_path_dens, f'{tag}_cube_{idx:03d}.jpg') + fig.savefig(filename, dpi=100) + return filename + + +density_filenames = [] +velocity_filenames = [] +masked_density_filenames = [] +masked_velocity_filenames = [] + +masked_density = [] +masked_velocity = [] + +last_mask = None +mean_velocity = [] + +for i, (rho, v, abs_v) in enumerate(zip(densities, velocities, speeds)): + density_filename = plot_datacube(rho,global_min_rho, global_max_rho, tag = "density", idx = i,plot_threshold = density_threshold, alpha_expon = 3, norm = "log") + velocity_filename = plot_datacube(abs_v,global_min_v, global_max_v, tag = "velocity", idx = i,plot_threshold = velocity_threshold, alpha_expon = 3) + density_filenames.append(density_filename) + velocity_filenames.append(velocity_filename) + if mask_mode: + density_mask = rho > density_threshold + velocity_mask = abs_v > velocity_threshold + mask = density_mask & velocity_mask # + if last_mask is not None: + # Last mask needs to be blurred - we want to remove voxels around the currently active points, as well as the points themselves + # last_mask exists, therefore we take out the background + last_mask = ~last_mask + #Every spot that has been accepted last time is now disabled + # every new spot is still possible - achieving recent background subtraction + mask = mask & last_mask + + # Positions that belong only to outliers + x_filtered_again = x_filtered[mask] + y_filtered_again = y_filtered[mask] + z_filtered_again = z_filtered[mask] + # Values that belong only to outliers + masked_rho = rho[mask] + masked_v = abs_v[mask] + masked_density_fname = plot_datacube_directly(masked_rho,global_min_rho, global_max_rho, tag = "density",x_fil = x_filtered_again, y_fil = y_filtered_again, z_fil = z_filtered_again, idx = i, alpha_expon = 3, norm = "log", fname_subtag = "masked") + masked_velocity_fname = plot_datacube_directly(masked_v,global_min_v, global_max_v, tag = "velocity", x_fil = x_filtered_again, y_fil = y_filtered_again, z_fil = z_filtered_again, idx = i, alpha_expon = 3, fname_subtag = "masked") + masked_density_filenames.append(masked_density_fname) + masked_velocity_filenames.append(masked_velocity_fname) + + vmean = [0,0,0] + if len(v[mask]): + vmean = v[mask].mean(axis = 0) + mean_velocity.append(vmean) + masked_density.append(masked_rho) + masked_velocity.append(masked_v) + if last_mask is None: + last_mask = mask #set up last mask as the first possible mask - otherwise might flicker + + +frame_duration = 0.5 #2fps +if len(density_filenames): + with imageio.get_writer(os.path.join(video_path_dens,'density.gif'), mode='I', duration=frame_duration) as writer: + for filename in density_filenames: + image = imageio.v3.imread(filename) + writer.append_data(image) +if len(velocity_filenames): + with imageio.get_writer(os.path.join(video_path_dens,'velocity.gif'), mode='I', duration=frame_duration) as writer: + for filename in velocity_filenames: + image = imageio.v3.imread(filename) + writer.append_data(image) + +if mask_mode: + if len(masked_velocity_filenames): + with imageio.get_writer(os.path.join(video_path_dens,'masked_velocity.gif'), mode='I', duration=frame_duration) as writer: + for filename in masked_velocity_filenames: + image = imageio.v3.imread(filename) + writer.append_data(image) + if len(masked_density_filenames): + with imageio.get_writer(os.path.join(video_path_dens,'masked_density.gif'), mode='I', duration=frame_duration) as writer: + for filename in masked_density_filenames: + image = imageio.v3.imread(filename) + writer.append_data(image) + + +def calculate_mean_velocity_and_density(density_cube:np.ndarray, velocity_cube:np.ndarray, speed_cube:np.ndarray, percentile:float, speeds:np.array = None, densities:np.array = None, last_mask = None) -> dict: + ''' + Function: Calculate mean velocity: + Given the density and velocity cubes, calculates the mask needed to extract the CME. + Using the mask, calculates + the mean velocity (speed and direction) of the CME (in Solar Radii per 2 days), + alongside + the mean density (in N_e per cm^3), + + Input: + density_cube: np.ndarray: Density cube used extracted from fine model in the SuNeRF + velocity_cube: np.ndarray: Velocity cube from the same source + speed_cube: np.ndarray: Generated from velocity cube, vector norm of the velocities + percentile: Float: percentile used to set the threshold of speeds and densities + speeds: np.array : optional: default None: Array of speeds to use for calculating the thresholds. + If None, calculates this array from passed velocity_cube + densities: np.array: optional: default None: Array of densities used for calculating a threshold. + If None, calculates array from density_cube + last_mask: np.array: optional: default None: Mask used to deactivate pixels in background subtraction + Output: + Dictionary with 7 keys: + "Density": Density of the CME that has been detected in N_e/cm^3 + "Velocity": Mean Velocity Vector of the CME + "Speed": |v| - vector norm of velocity + "Direction": |\hat{v}| - Direction of the mean velocity vector + "Densities": Collection of densities that has been accepted + "Velocities": Collection of velocities that have been accepted + "Mask": mask used to accept voxels (in either cube) + ''' + + percentile = np.clip(percentile,0,99) + if speeds is None: + speeds = np.flatten(speed_cube) + if densities is None: + densities = np.flatten(densities) + density_threshold = np.percentile(np.asarray(densities),percentile) + velocity_threshold = np.percentile(np.asarray(speeds),percentile) + + + density_mask = densities > density_threshold + velocity_mask = speeds > velocity_threshold + mask = density_mask & velocity_mask # + if last_mask is not None: + # Last mask needs to be blurred - we want to remove voxels around the currently active points, as well as the points themselves + # last_mask exists, therefore we take out the background + last_mask = ~last_mask + #Every spot that has been accepted last time is now disabled + # every new spot is still possible - achieving recent background subtraction + mask = mask & last_mask + + rhomean = None + if len(density_cube[mask]): + rhomean = density_cube[mask].mean() + vmean = None + if len(velocity_cube[mask]): + vmean = velocity_cube[mask].mean(axis = 0) + speedmean = None + if len(speed_cube[mask]): + speedmean = speed_cube[mask].mean() + direction = vmean/(np.dot(vmean,vmean)) + + out_dict = {} + out_dict["Density"] = rhomean + out_dict["Velocity"] = vmean + out_dict["Speed"] = speedmean + out_dict["Direction"] = direction + out_dict["Densities"] = density_cube[mask] + out_dict["Velocities"] = velocity_cube[mask] + out_dict["Mask"] = mask + + return out_dict + +def estimate_probability_of_hit(earth_position:np.array, mean_velocity:np.array, densities:np.ndarray, velocities:np.ndarray, positions_x:np.array, positions_y:np.array,positions_z:np.array) -> np.array: + """ + This function estimates the probability of a detected CME hitting earth. + Currently, this is a simplified model, working with a kernel density estimate for the distribution of velocity. + + It estimates the probability of a CME hitting earth within the next two steps, with the following calculations: + 1. Calculate the direction of earth from the CME mean positions + 2. Calculate the KDE of direction vectors from the velocity array + 3. Use KDE to calculate the probability of the CME moving in the direction of earth - P0 + 4. Move CME to next position based on velocity (new position vectors) + + + Args: + earth_position (np.array): [x,y,z] - position vector of earth + mean_velocity (np.array): [vx,vy,vz] - mean velocity vector of the CME - output of calculate_mean_velocity_and_density + densities (np.ndarray): Density vector that is designated as CME - useful for classification + velocities (np.ndarray): Velocities of points designated as CME + positions_x (np.array): Positions that have been designated as CME + positions_y (np.array): " + positions_z (np.array): " + + Returns: + probability: float: probability of the CME hitting earth + """ + positions = np.stack([positions_x,positions_y,positions_z], axis = 1) + connection_vector = positions - earth_position + #direction vectors + direction_vectors = np.asarray([v/np.dot(v,v) for v in connection_vector]) + #direction vector for mean velocity + #mean_direction = mean_velocity/np.dot(mean_velocity,mean_velocity) + + + velocity_directions = np.asarray([v/np.dot(v,v) for v in velocities]) + #Distribution of direction vectors + std_dx,std_dy,std_dz = velocity_directions.std(axis = 0) + kernel_density_estimate = multivariate_normal(velocity_directions.mean(axis = 0), velocity_directions.std(axis = 0)) + #Use kernel_density_estimate.cdf about regions. + #Calculate probability of each direction based on the standard deviation of direction vectors + probabilities = [kernel_density_estimate.cdf(np.asarray([vx+std_dx, vy+std_dy,vz+std_dz])) - kernel_density_estimate.cdf(np.asarray([vx-std_dx, vy-std_dy,vz-std_dz])) for vx,vy,vz in direction_vectors] + probabilities = np.asarray(probabilities) + return probabilities \ No newline at end of file diff --git a/sunerf/evaluation/density_cube_eval.py b/sunerf/evaluation/density_cube_eval.py new file mode 100644 index 0000000..49bea52 --- /dev/null +++ b/sunerf/evaluation/density_cube_eval.py @@ -0,0 +1,113 @@ +import os + +import numpy as np +import torch +from tqdm import tqdm +import scipy +from sunpy.map import Map +from datetime import datetime +import pickle + +from sunerf.evaluation.loader import SuNeRFLoader +from sunerf.utilities.data_loader import normalize_datetime + +START_STEPNUM = 37 # 5 +END_STEPNUM = 37 # 74 +CHUNKS = 4 + +# R_SUN_CM = 6.957e+10 +# GRID_SIZE = 500 / 16 # solar radii + +def save_stepnum_to_datetime(): + stepnum_to_datetime = dict() + + for stepnum in range(5, 80, 1): + map_path = "/mnt/prep-data/prep_HAO_2view/dcmer_280W_bang_0000_pB_stepnum_%03d.fits" % stepnum + s_map = Map(map_path) + dt = s_map.date.datetime + stepnum_to_datetime[stepnum] = dt.strftime("%Y-%m-%d %H:%M:%S") + + print(stepnum_to_datetime) + with open('/mnt/ground-data/stepnum_to_datetime.pkl', 'wb') as f: + pickle.dump(stepnum_to_datetime, f) + +def load_stepnum_to_datetime(): + with open('/mnt/ground-data/stepnum_to_datetime.pkl', 'rb') as f: + stepnum_to_datetime = pickle.load(f) + return stepnum_to_datetime + +# load datetime for each stepnum +stepnum_to_datetime = load_stepnum_to_datetime() +# convert datetime from string to datetime.datetime +def dtstr_to_datetime(dtstr): + return datetime.strptime(dtstr, "%Y-%m-%d %H:%M:%S") +stepnum_to_datetime = dict(map(lambda kv: (kv[0], dtstr_to_datetime(kv[1])), stepnum_to_datetime.items())) + +mae_all_stepnums = [] + +for stepnum in range(START_STEPNUM, END_STEPNUM + 1, 1): + + # load ground truth + gt_fname = "/mnt/ground-data/density_cube/dens_stepnum_%03d.sav" % stepnum + o = scipy.io.readsav(gt_fname) + ph = o['ph1d'] # (258,) + th = o['th1d'] # (128,) + r = o['r1d'] # (256,) + density_gt = o['dens'] # (258, 128, 256) (phi, theta, r) + + # ignore half of r + # r_size = len(o['r1d']) + # r = o['r1d'][:int(r_size / 2)] # (256,) -> (128, 0) + # density_gt = o['dens'][:,:,:int(r_size / 2)] # (258, 128, 256) (phi, theta, r) + + # load model checkpoint + base_path = '/mnt/training/HAO_pinn_cr_2view_a26978f_heliographic_reformat' + chk_path = os.path.join(base_path, 'save_state.snf') + loader = SuNeRFLoader(chk_path, resolution=512) + + # put th into chunks to avoid CUDA out of memory error + th = th.reshape(CHUNKS, -1) + + time = normalize_datetime(stepnum_to_datetime[stepnum]) + observer_offset = np.deg2rad(90) + ph_copy = ph.copy() + observer_offset + + density_chunks = [] + for chunk in tqdm(range(CHUNKS)): + phph, thth, rr = np.meshgrid(ph_copy, th[chunk], r, indexing = "ij") + + x = rr * np.cos(phph) * np.sin(thth) + y = rr * np.sin(phph) * np.sin(thth) + z = rr * np.cos(thth) + t = np.ones_like(rr) * time + query_points_npy = np.stack([x, y, z, t], -1).astype(np.float32) # (258, 32, 256, 4) for one chunk + + query_points = torch.from_numpy(query_points_npy) + enc_query_points = loader.encoding_fn(query_points.view(-1, 4)) + + # model inference + with torch.no_grad(): # required for memory to be cleared properly + raw = loader.fine_model(enc_query_points) + density = raw[..., 0] + # density = 10 ** (15 + raw[..., 0]) + density = density.view(query_points_npy.shape[:3]).cpu().detach().numpy() + density_chunks.append(density) + + density = np.concatenate(density_chunks, 1) # in electrons / r_sun + # convert unit to that of ground truth: electrons / cm^3 + # density *= GRID_SIZE # electrons / grid cell + # density *= (GRID_SIZE * R_SUN_CM) ** (-3) # electrons / cm^3 + # density *= GRID_SIZE ** (-2) * R_SUN_CM ** (-3) + + # compare density to ground truth + rel_density = density / np.mean(density) + rel_density_gt = density_gt / np.mean(density_gt) + + print(rel_density[0]) + print(rel_density_gt[0]) + mae = (np.abs(rel_density - rel_density_gt)).mean(axis=None) + print(mae) + + mae_all_stepnums.append(mae) + +print(sum(mae_all_stepnums) / len(mae_all_stepnums)) diff --git a/sunerf/evaluation/density_cube_model.py b/sunerf/evaluation/density_cube_model.py index 978a1a0..d888253 100644 --- a/sunerf/evaluation/density_cube_model.py +++ b/sunerf/evaluation/density_cube_model.py @@ -8,15 +8,28 @@ from sunerf.evaluation.loader import SuNeRFLoader from sunerf.utilities.data_loader import normalize_datetime +import cv2 + +def visualise_velocity(velocity, file_path): + # Use Hue, Saturation, Value colour model + hsv = np.zeros(velocity.shape, dtype=np.uint8) + hsv[..., 1] = 255 + mag = np.sqrt(velocity[...,0]**2+velocity[...,1]**2+velocity[...,2]**2) + _, ang = cv2.cartToPolar(velocity[..., 0], velocity[..., 1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + cv2.imwrite(file_path, bgr) + return mag # base_path = '/mnt/training/HAO_pinn_cr_allview_a26978f_heliographic' # observer_offset = np.deg2rad(90) -base_path = '/mnt/training/HAO_pinn_2viewpoints_backgrounds' -observer_offset = np.deg2rad(180) +base_path = '/mnt/training/HAO_pinn_cr_2view_a26978f_heliographic_reformat' +observer_offset = np.deg2rad(90) chk_path = os.path.join(base_path, 'save_state.snf') -video_path_dens = os.path.join(base_path, 'video_density_cube') +video_path_dens = os.path.join(base_path, 'video_density_cube','video_cube') # init loader loader = SuNeRFLoader(chk_path, resolution=512) @@ -24,43 +37,64 @@ os.makedirs(video_path_dens, exist_ok=True) densities = [] +velocities = [] r = np.linspace(21, 200, 256) ph = np.linspace(-np.pi-np.pi/128, np.pi+np.pi/128, 258) -rr, phph = np.meshgrid(r, ph, indexing = "ij") -theta = (0.32395396 + 2.8176386) / 2 - - - -for i, timei in tqdm(enumerate(pd.date_range(loader.start_time, loader.end_time, n_points)), total=n_points): +th = np.linspace(0,np.pi,10) - # DENSITY CUBE SLICE - time = normalize_datetime(timei) - - x = rr * np.cos(phph) * np.sin(theta) - y = rr * np.sin(phph) * np.sin(theta) - z = rr * np.cos(theta) - t = np.ones_like(rr) * time - query_points_npy = np.stack([x, y, z, t], -1).astype(np.float32) - # (256, 258, 4) - - query_points = torch.from_numpy(query_points_npy) - - # Prepare points --> encoding. - enc_query_points = loader.encoding_fn(query_points.view(-1, 4)) +theta = (0.32395396 + 2.8176386) / 2 - raw = loader.fine_model(enc_query_points) - # density = raw[..., 0] - density = 10 ** (15 + raw[..., 0]) - # velocity = raw[..., 1:] - density = density.view(query_points_npy.shape[:2]).cpu().detach().numpy() - # velocity = velocity.view(query_points_npy.shape[:2] + velocity.shape[-1:]).cpu().detach().numpy() - # velocity = velocity / 10 +rr, phph = np.meshgrid(r, ph, indexing = "ij") - fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}) - im = ax.pcolormesh(phph - observer_offset, rr, density, edgecolors='face', cmap='viridis', norm='log', vmin=2e24, vmax=8e26) - plt.colorbar(im, label='$N_e$') - plt.axis('on') - fig.savefig(os.path.join(video_path_dens, f'dens_cube_slice_{i:03d}.jpg'), dpi=100) - plt.close(fig) +for i, timei in tqdm(enumerate(pd.date_range(loader.start_time, loader.end_time, n_points)), total=n_points): + densities_in_cube = [] + velocities_in_cube = [] + for j, theta in tqdm(enumerate(th)): + + # DENSITY CUBE SLICE + time = normalize_datetime(timei) + + x = rr * np.cos(phph) * np.sin(theta) + y = rr * np.sin(phph) * np.sin(theta) + z = rr * np.cos(theta) + t = np.ones_like(rr) * time + query_points_npy = np.stack([x, y, z, t], -1).astype(np.float32) + # (256, 258, 4) + + query_points = torch.from_numpy(query_points_npy) + + # Prepare points --> encoding. + enc_query_points = loader.encoding_fn(query_points.view(-1, 4)) + + raw = loader.fine_model(enc_query_points) # Force to CPU + # raw = loader.fine_model(enc_query_points) + # density = raw[..., 0] + density = 10 ** (15 + raw[..., 0]) + velocity = raw[..., 1:] + density = density.view(query_points_npy.shape[:2]).cpu().detach().numpy() + velocity = velocity.view(query_points_npy.shape[:2] + velocity.shape[-1:]).cpu().detach().numpy() + # velocity = velocity / 10 + densities_in_cube.append(density) + velocities_in_cube.append(velocity) + + fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}) + im = ax.pcolormesh(phph - observer_offset, rr, density, edgecolors='face', cmap='viridis', norm='log', vmin=2e24, vmax=8e28) + plt.colorbar(im, label='$N_e$') + plt.axis('on') + fig.savefig(os.path.join(video_path_dens, f'dens_cube_slice_{j:03d}_{i:03d}.jpg'), dpi=100) + plt.close(fig) + # Plot Magnitude of velocity + mag = np.sqrt(velocity[...,0]**2+velocity[...,1]**2+velocity[...,2]**2) + mag_n = (mag - mag.min())/(mag.max() - mag.min()) + + fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}) + im = ax.pcolormesh(phph - observer_offset, rr, mag_n, edgecolors='face', cmap='viridis') + plt.colorbar(im, label='$abs(V)$') + plt.axis('on') + fig.savefig(os.path.join(video_path_dens, f'velocity_cube_slice_{j:03d}_{i:03d}.jpg'), dpi=100) + plt.close(fig) + densities.append(densities_in_cube) + velocities.append(velocities_in_cube) + \ No newline at end of file diff --git a/sunerf/evaluation/load_density_cube.py b/sunerf/evaluation/load_density_cube.py new file mode 100644 index 0000000..6574724 --- /dev/null +++ b/sunerf/evaluation/load_density_cube.py @@ -0,0 +1,106 @@ +import os + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from tvtk.api import tvtk, write_data + +from sunerf.evaluation.loader import SuNeRFLoader +from sunerf.utilities.data_loader import normalize_datetime + +base_path = '/mnt/training/HAO_pinn_cr_allview_a26978f_heliographic' +chk_path = os.path.join(base_path, 'save_state.snf') +save_path = os.path.join(base_path, 'vtk_128') + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# init loader +resolution = 128 +loader = SuNeRFLoader(chk_path) + +n_time_points = 100 +batch_size = 4096 * 4 * torch.cuda.device_count() +os.makedirs(save_path, exist_ok=True) + + +def save_vtk(vec, path, name, scalar=None, scalar_name='scalar', sr_per_pix=1): + """Save numpy array as VTK file + + :param vec: numpy array of the vector field (x, y, z, c) + :param path: path to the target VTK file + :param name: label of the vector field (e.g., B) + :param Mm_per_pix: pixel size in Mm. 360e-3 for original HMI resolution. (default bin2 pixel scale) + """ + # Unpack + dim = vec.shape[:-1] + # Generate the grid + pts = np.stack(np.mgrid[0:dim[0], 0:dim[1], 0:dim[2]], -1).astype(np.int64) * sr_per_pix + # reorder the points and vectors in agreement with VTK + # requirement of x first, y next and z last. + pts = pts.transpose(2, 1, 0, 3) + pts = pts.reshape((-1, 3)) + vectors = vec.transpose(2, 1, 0, 3) + vectors = vectors.reshape((-1, 3)) + + sg = tvtk.StructuredGrid(dimensions=dim, points=pts) + sg.point_data.vectors = vectors + sg.point_data.vectors.name = name + if scalar is not None: + scalars = scalar.transpose(2, 1, 0) + scalars = scalars.reshape((-1)) + sg.point_data.add_array(scalars) + sg.point_data.get_array(1).name = scalar_name + sg.point_data.update() + + write_data(sg, path) + + +with torch.no_grad(): + for timei in tqdm(pd.date_range(loader.start_time, loader.end_time, n_time_points), total=n_time_points): + # DENSITY SLICE + time = normalize_datetime(timei) + + query_points_npy = np.stack(np.meshgrid( + np.linspace(-100, 100, resolution, dtype=np.float32), + np.linspace(-100, 100, resolution, dtype=np.float32), + np.linspace(-100, 100, resolution, dtype=np.float32), + np.ones((1,), dtype=np.float32) * time, indexing='ij'), -1) + + radius = np.sqrt(np.sum(query_points_npy[:, :, :, 0, :3] ** 2, axis=-1)) + mask = (radius < 21) | (radius > 100) + + r2_mask = radius ** 2 #((radius - 20) / (60 - 20)) ** 2 + # r2_mask = np.clip(r2_mask, 0, 1) + + query_points = torch.from_numpy(query_points_npy) + + # Prepare points --> encoding. + query_points = query_points.view(-1, 4) + + print('load cube') + density, velocity = [], [] + for i in range(np.ceil(query_points.shape[0] / batch_size).astype(int)): + batch = loader.encoding_fn(query_points[i * batch_size:(i + 1) * batch_size]) + raw = loader.fine_model(batch.to(device)) + density += [raw[..., 0].cpu().detach()] + velocity += [raw[..., 1:].cpu().detach()] + + # stack results + density = torch.cat(density, dim=0) + velocity = torch.cat(velocity, dim=0) + # reshape + density = density.view(query_points_npy.shape[:3]).cpu().detach().numpy() + velocity = velocity.view(query_points_npy.shape[:3] + velocity.shape[-1:]).cpu().detach().numpy() + # + # scale density + density *= r2_mask + # + velocity = velocity # * density[..., None] / 1e27 # scale to mass flux + # apply mask + density[mask] = 0 # np.nan + velocity[mask] = 0 # np.nan + + print('save vtk') + vtk_filename = os.path.join(save_path, f"data_cube_{timei.isoformat('T', timespec='minutes')}.vtk") + save_vtk(velocity, vtk_filename, "v", density, "density", sr_per_pix= 200 / resolution) diff --git a/sunerf/evaluation/loader.py b/sunerf/evaluation/loader.py index 4edda4f..3e4c871 100644 --- a/sunerf/evaluation/loader.py +++ b/sunerf/evaluation/loader.py @@ -30,9 +30,13 @@ def __init__(self, state_path, resolution=None, focal=None, device=None): encoder = PositionalEncoder(**state['encoder_kwargs']) self.encoding_fn = lambda x: encoder(x) - self.coarse_model = nn.DataParallel(state['coarse_model']).to(device) - self.fine_model = nn.DataParallel(state['fine_model']).to(device) - + print(device) + if device == torch.device("cuda"): + self.coarse_model = nn.DataParallel(state['coarse_model']).to(device) + self.fine_model = nn.DataParallel(state['fine_model']).to(device) + else: + self.coarse_model = state['coarse_model'] + self.fine_model = state['fine_model'] self.device = device def load_observer_image(self, lat: float, lon: float, time: datetime, diff --git a/sunerf/sunerf.py b/sunerf/sunerf.py index 072efbb..a77fdcd 100644 --- a/sunerf/sunerf.py +++ b/sunerf/sunerf.py @@ -147,7 +147,7 @@ def training_step(self, batch, batch_nb): loss = fine_loss + coarse_loss + self.lambda_continuity * continuity_loss + self.lambda_radial_regularization * radial_regularization_loss + self.lambda_velocity_regularization * velocity_regularization_loss formatted_loss_logstring = "="*25 + "\n Regularization and continuity" "\n \t Continuity Loss: {}".format(continuity_loss)+"\n \t Radial Regularization Loss: {}".format(radial_regularization_loss) +"\n \t Velocity Regularization Loss: {}".format(radial_regularization_loss)+"\n Model Losses" + "\n \t Fine Model Loss: {}".format(fine_loss) + "\n \t Coarse Model Loss: {} \n \n \t Complete Loss: {} \n".format(coarse_loss, loss) + "="*25 - print(formatted_loss_logstring) + # print(formatted_loss_logstring) with torch.no_grad(): psnr = -10. * torch.log10(fine_loss) diff --git a/sunerf/train/volume_render.py b/sunerf/train/volume_render.py index 2c654fc..7f2c5a2 100644 --- a/sunerf/train/volume_render.py +++ b/sunerf/train/volume_render.py @@ -58,7 +58,7 @@ def raw2outputs(raw: torch.Tensor, # (batch, sampling_points, density_e) s_q = query_points.pow(2).sum(-1).pow(0.5) s_t = 1 omega = torch.asin(s_t / s_q) - + #print("Max S_q: {} - Omega Minimum: {} - Omega = 0? {}".format(torch.max(s_q), torch.min(omega), (omega == 0).any())) # z = distance Q to observer z = z_vals * torch.norm(rays_d[..., None, :], dim=-1) # distance between observer and scattering point Q @@ -85,11 +85,14 @@ def raw2outputs(raw: torch.Tensor, # (batch, sampling_points, density_e) D = (1 / 8) * (5 + torch.sin(omega) ** 2 - cos2_sin * (5 - torch.sin(omega) ** 2) * ln) # equations 23, 24, 29 - intensity_T = I0 * torch.pi * sigma_e / (2 * z ** 2) * ((1 - u) * C + u * D) - intensity_pB = I0 * torch.pi * sigma_e / (2 * z ** 2) * torch.sin(chi) ** 2 * ((1 - u) * A + u * B) - - intensity_tB = 2 * intensity_T - intensity_pB - + intensity_T = I0 * torch.pi * sigma_e / (2 * z ** 2) * ((1 - u) * C + u * D) #I_T in paper - transverse + intensity_pB = I0 * torch.pi * sigma_e / (2 * z ** 2) * torch.sin(chi) ** 2 * ((1 - u) * A + u * B) #I_p in Paper + + intensity_tB = 2 * intensity_T - intensity_pB #I_tot in paper + # Intensities being negative is unphysical + intensity_T = torch.abs(intensity_T) + intensity_pB = torch.abs(intensity_pB) + intensity_tB = torch.abs(intensity_tB) if torch.isnan(intensity_tB).any() or torch.isnan(intensity_pB).any(): cond = torch.isnan(intensity_tB) | torch.isnan(intensity_pB) print(f'Invalid values in intensity_tB or intensity_pB: query points {query_points[cond]}') @@ -121,7 +124,8 @@ def raw2outputs(raw: torch.Tensor, # (batch, sampling_points, density_e) # sum all intensity contributions along LOS pixel_tB = emerging_tB.sum(1)[:, None] pixel_pB = emerging_pB.sum(1)[:, None] - + #print("pixel tB smaller than 0? - {} - Value: {}".format((pixel_tB < 0).any(),(pixel_tB < 0).nonzero())) + #print("Intensity tB smaller than 0? - {} - Value: {}".format((intensity_tB < 0).any(),(intensity_tB < 0).nonzero())) # height and density maps # electron_density: (batch, sampling_points, 1), s_q: (batch, sampling_points, 1) pixel_density = (electron_density * dists).sum(1) @@ -132,7 +136,6 @@ def raw2outputs(raw: torch.Tensor, # (batch, sampling_points, density_e) pixel_tB = (torch.log(pixel_tB) - v_min) / (v_max - v_min) # normalization pixel_pB = (torch.log(pixel_pB) - v_min) / (v_max - v_min) # normalization pixel_B = torch.cat([pixel_tB, pixel_pB], dim=-1) - # set the weigths to the intensity contributions (sample primary contributing regions) # need weights for sampling for fine model weights = electron_density / (electron_density.sum(1)[:, None] + 1e-10)