Skip to content

Commit 062e35e

Browse files
committed
coreglib: clean up at/ct correction code
1 parent c5a8a57 commit 062e35e

File tree

1 file changed

+76
-70
lines changed

1 file changed

+76
-70
lines changed

demcoreg/coreglib.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import matplotlib.pyplot as plt
1111

1212
from pygeotools.lib import malib, iolib, warplib
13-
from imview import pltlib
1413

1514
def apply_xy_shift(ds, dx, dy, createcopy=True):
1615
"""
@@ -486,21 +485,23 @@ def find_subpixel_peak_position(corr, subpixel_method='gaussian'):
486485
return subp_peak_position[0], subp_peak_position[1]
487486

488487
## functions for along-track cross-track correction and plotting
489-
def successive_med(a, first_axis=1,first_axis_only=False,sav_filter=False,sg_window=101,sg_poly=2,min_axes_count=350):
488+
489+
#TO DO: add number of iterations as argument
490+
def successive_med(a, first_axis=1, first_axis_only=False, sav_filter=False, sg_window=101, sg_poly=2, min_axes_count=350):
490491
"""
491492
Subtract median values from each axis of the input difference map array
492493
Parameters
493494
-----------
494495
a: np.ma.array
495496
input array
496497
first_axis: int
497-
1 or 2 (which axis) to operate on first
498+
0 or 1 (which axis) to operate on first
498499
first_axis_only: bool
499500
wether to apply the correction over the first axis only (True) or both the axes (False)
500501
sav_filter: bool
501-
whether to smooth the axis-medians or not using savgol filter
502+
whether to smooth the median values along each axis using savgol filter
502503
sg_window: int
503-
odd window sizes to be used in savgol filtering
504+
odd window size to be used in savgol filtering
504505
sg_poly: int
505506
polynomial order to be used in savgol filtering
506507
min_axes_count: int
@@ -510,39 +511,43 @@ def successive_med(a, first_axis=1,first_axis_only=False,sav_filter=False,sg_win
510511
b: np.ma.array
511512
Corrected array
512513
med_first: np.array
513-
first_axis corrections
514+
first_axis 1D corrections
514515
med_second: np.array
515-
second_axis corrections
516+
second_axis 1D corrections
516517
first_correction_surface: np.array
517-
correction surface along input first axis
518-
second_correction_surface:np.array
519-
correction surface along input second axis
518+
2D correction surface using 1D corrections along first axis
519+
second_correction_surface: np.array
520+
2D correction surface using 1D corrections along second axis
520521
med_first_smooth: np.array
521-
smoothend first_axis corrections (optional, if sav_filter=True)
522+
smoothed first_axis corrections (optional, if sav_filter=True)
522523
med_second_smooth: np.array
523-
smoothened second_axis corrections (optional, if sav_filter=True)
524+
smoothed second_axis corrections (optional, if sav_filter=True)
524525
525526
"""
526527
import scipy.signal
527528
### Parts of the function was first written by David for the Arctic DEM snow drift correction
529+
# Determine specified axis order
528530
second_axis = 0
529531
if first_axis == 0:
530532
second_axis = 1
531-
# compute axis-wide median and count metrics
533+
534+
# compute median and count metrics along first axis
532535
med_first = np.ma.median(a, axis=first_axis)
533-
count_first = np.ma.count(a,axis=first_axis)
534-
535-
# if an axis has lower than the min_axes_count the number of pixels, the median offset for that axis is set to 0, assuming it to be unreliable
536-
idx_first = count_first<min_axes_count
536+
count_first = np.ma.count(a, axis=first_axis)
537+
538+
# Each row/col must have a minimum number of samples for reliable statistics
539+
# DES note TODO: this should be set to masked, not 0
540+
idx_first = count_first < min_axes_count
537541
med_first[idx_first] = 0
538542

539-
# apply savgol filter to smoothen the signal
543+
# apply savgol filter to smooth the 1D correction
540544
if sav_filter:
541-
med_first_smooth = scipy.signal.savgol_filter(med_first,window_length=sg_window, polyorder=sg_poly, mode='nearest')
542-
first_correction_surface = np.expand_dims(med_first_smooth,axis=first_axis)
545+
med_first_smooth = scipy.signal.savgol_filter(med_first, window_length=sg_window, polyorder=sg_poly, mode='nearest')
546+
first_correction_surface = np.expand_dims(med_first_smooth, axis=first_axis)
543547
else:
544-
first_correction_surface = np.expand_dims(med_first,axis=first_axis)
545-
# correct the first axis with the smoothed curve
548+
first_correction_surface = np.expand_dims(med_first, axis=first_axis)
549+
550+
# correct the array using the corrections along the first axis
546551
b = a - first_correction_surface
547552

548553
# if correction is only to be performed over first axis, return a zero magnitude signal
@@ -553,28 +558,32 @@ def successive_med(a, first_axis=1,first_axis_only=False,sav_filter=False,sg_win
553558
shp_axis = 1
554559
med_second = np.zeros(a.shape[shp_axis])
555560

556-
# if correction is to be performed both axes
561+
# Compute a correction along the second axis
557562
else:
558-
# compute axis-wide median and count metrics
563+
#DES note why b and a here? Is this because b is filled with 0 above?
559564
med_second = np.ma.median(b, axis=second_axis)
560-
count_second = np.ma.count(a,axis=second_axis)
561-
# if an axis has lower than the min_axes_count the number of pixels, the median offset for that axis is set to 0, assuming it to be unreliable
562-
idx_second = count_second<min_axes_count
565+
count_second = np.ma.count(a, axis=second_axis)
566+
567+
# Each row/col must have a minimum number of samples for reliable statistics
568+
idx_second = count_second < min_axes_count
563569
med_second[idx_second] = 0
564570

565-
# apply savgol filter to smoothen the signal
571+
# apply savgol filter to smooth the signal
566572
if sav_filter:
567-
med_second_smooth = scipy.signal.savgol_filter(med_second,window_length=sg_window, polyorder=sg_poly, mode='nearest')
568-
second_correction_surface = np.expand_dims(med_second_smooth,axis=second_axis)
573+
med_second_smooth = scipy.signal.savgol_filter(med_second, window_length=sg_window, polyorder=sg_poly, mode='nearest')
574+
second_correction_surface = np.expand_dims(med_second_smooth, axis=second_axis)
569575
else:
570-
second_correction_surface = np.expand_dims(med_second_smooth,axis=second_axis)
576+
second_correction_surface = np.expand_dims(med_second_smooth, axis=second_axis)
577+
578+
# correct the array along the second axis
571579
b = b - second_correction_surface
572-
out = [b,med_first,med_second,first_correction_surface,second_correction_surface]
580+
581+
out = [b, med_first, med_second, first_correction_surface, second_correction_surface]
573582
if sav_filter:
574-
out.extend([med_first_smooth,med_second_smooth])
583+
out.extend([med_first_smooth, med_second_smooth])
575584
return out
576585

577-
def plot_ct_at_dh_map(ax,dh_init,clim_dh,ct_correction_surface,at_correction_surface,dh_final):
586+
def plot_ct_at_dh_map(ax, dh_init, clim_dh, ct_correction_surface, at_correction_surface, dh_final):
578587
"""
579588
Plot initial elevation difference map, Across-track (Row-wise) and Along-track (Column-wise) correction surface and corrected elevation difference map
580589
Parameters
@@ -592,16 +601,17 @@ def plot_ct_at_dh_map(ax,dh_init,clim_dh,ct_correction_surface,at_correction_sur
592601
dh_final: np.ma.array
593602
final dh map (after correction applied)
594603
"""
595-
pltlib.iv(dh_init,cmap='RdBu',clim=clim_dh,label='Elevation difference (m)',title='dh before',ax=ax[0])
596-
pltlib.add_scalebar(ax=ax[0],res=1)
604+
from imview import pltlib
605+
pltlib.iv(dh_init, cmap='RdBu', clim=clim_dh, label='Elevation difference (m)', title='dh before', ax=ax[0])
606+
pltlib.add_scalebar(ax=ax[0], res=1)
597607
#across_track_clim = malib.calcperc_sym(ct_correction_surface,(2,98))
598-
pltlib.iv(np.zeros(dh_init.shape)+ct_correction_surface,cmap='RdBu',clim=clim_dh,label='Elevation difference (m)',title='Row-wise correction surface',ax=ax[1])
608+
pltlib.iv(np.zeros(dh_init.shape)+ct_correction_surface, cmap='RdBu', clim=clim_dh, label='Elevation difference (m)', title='Row-wise correction surface', ax=ax[1])
599609
#along_track_clim = malib.calcperc_sym(at_correction_surface,(2,98))
600-
pltlib.iv(np.zeros(dh_init.shape)+at_correction_surface,cmap='RdBu',clim=clim_dh,label='Elevation difference (m)',title='Column-wise correction surface',ax=ax[2])
601-
pltlib.iv(dh_final,cmap='RdBu',clim=clim_dh,label='Elevation difference (m)',title='dh after',ax=ax[3])
610+
pltlib.iv(np.zeros(dh_init.shape)+at_correction_surface, cmap='RdBu', clim=clim_dh, label='Elevation difference (m)', title='Column-wise correction surface', ax=ax[2])
611+
pltlib.iv(dh_final, cmap='RdBu', clim=clim_dh, label='Elevation difference (m)', title='dh after', ax=ax[3])
602612
plt.tight_layout()
603613

604-
def plot_ct_at_dh_fits(f,ct_med,ct_smooth,at_med,at_smooth,clim_dh=None):
614+
def plot_ct_at_dh_fits(f, ct_med, ct_smooth, at_med, at_smooth, clim_dh=None):
605615
"""
606616
Plot Across-track (Row-wise) and Along-track (Column-wise) correction fits
607617
Parameters
@@ -613,26 +623,26 @@ def plot_ct_at_dh_fits(f,ct_med,ct_smooth,at_med,at_smooth,clim_dh=None):
613623
ct_smooth: np.array
614624
Smooth fit to median error per-row computed using Sav-Golay fit
615625
at_med: np.array
616-
median error per column (1,dh_map.shape[0])
626+
median error per column (1, dh_map.shape[0])
617627
at_smooth:np.array
618628
Smooth fit to median error per-column computed using Sav-Golay fit
619629
clim_dh: tuple
620630
symmetrical min/max values to limit correction fits
621631
622632
"""
623633
ax1 = plt.subplot(1,2,1)
624-
ax1.plot(ct_med,np.arange(len(ct_med)),c='k',label='median correction')
625-
ax1.plot(ct_smooth,np.arange(len(ct_med)),c='r',label='smooth Sav-Golay fit')
634+
ax1.plot(ct_med, np.arange(len(ct_med)), c='k', label='median correction')
635+
ax1.plot(ct_smooth, np.arange(len(ct_med)), c='r', label='smooth Sav-Golay fit')
626636
ax1.set_title('Row-wise correction')
627637
ax1.set_xlabel('Elevation difference (m)')
628-
ax1.axvline(x=0,ls='--',alpha=0.6,c='teal')
638+
ax1.axvline(x=0, ls='--', alpha=0.6, c='teal')
629639
ax1.set_ylabel('Row number')
630640
ax1.legend()
631641

632642
ax2 = plt.subplot(1,2,2)
633-
ax2.plot(np.arange(len(at_med)),at_med,c='k',label='median correction')
634-
ax2.plot(np.arange(len(at_med)),at_smooth,c='r',label='smooth Sav-Golay fit')
635-
ax2.axhline(y=0,ls='--',alpha=0.6,c='teal')
643+
ax2.plot(np.arange(len(at_med)), at_med, c='k', label='median correction')
644+
ax2.plot(np.arange(len(at_med)), at_smooth, c='r', label='smooth Sav-Golay fit')
645+
ax2.axhline(y=0, ls='--', alpha=0.6, c='teal')
636646
ax2.set_title('Column-wise correction')
637647
ax2.set_ylabel('Elevation difference (m)')
638648
ax2.set_xlabel('Col number')
@@ -647,7 +657,8 @@ def plot_ct_at_dh_fits(f,ct_med,ct_smooth,at_med,at_smooth,clim_dh=None):
647657
ax2.set_ylim(clim_dh)
648658
plt.tight_layout()
649659

650-
def ct_at_correction_wrapper(src_dem_fn,dh_fn,dh_filt_fn,ct_only=False,sg_window=101,sg_poly=2,min_axes_count=350,outdir=None):
660+
#DES TO DO: add number of iterations as argument
661+
def ct_at_correction_wrapper(src_dem_fn, dh_fn, dh_filt_fn, ct_only=False, sg_window=101, sg_poly=2, min_axes_count=350, outdir=None):
651662
"""
652663
Wrapper function to apply across-track (row-wise) and along-track correction to difference maps and src DEM
653664
Parameters
@@ -672,54 +683,49 @@ def ct_at_correction_wrapper(src_dem_fn,dh_fn,dh_filt_fn,ct_only=False,sg_window
672683
Write corrected source DEM, difference maps and plots to disc (not used currently)
673684
"""
674685
# warp the difference maps to extent and resolution of source DEMs
675-
print("Step1: Warping difference maps to extent of source DEM")
676-
ds_list = warplib.memwarp_multi_fn([src_dem_fn,dh_fn,dh_filt_fn],res='first',extent='first',r='cubic')
686+
print("Warping difference maps to extent of source DEM")
687+
ds_list = warplib.memwarp_multi_fn([src_dem_fn, dh_fn, dh_filt_fn], res='first', extent='first', r='cubic')
677688

678689
# read into memory
679-
src_dem,dh,dh_filt = [iolib.ds_getma(ds) for ds in ds_list]
690+
src_dem, dh, dh_filt = [iolib.ds_getma(ds) for ds in ds_list]
680691

681692
# calculate clim
682-
clim_dh = malib.calcperc_sym(dh_filt,(5,95))
693+
clim_dh = malib.calcperc_sym(dh_filt, (5,95))
683694
# perform the correction
684695
print("Computing Across-track (Row-wise) and Along-track (Column-wise) correction")
685-
dh_filt_corr,ct_med,at_med,ct_correction_surface,at_correction_surface, ct_med_smooth, at_med_smooth = successive_med(dh_filt,
686-
first_axis_only=ct_only,sav_filter=True,
687-
sg_window=sg_window,sg_poly=sg_poly,
688-
min_axes_count=min_axes_count)
696+
dh_filt_corr, ct_med, at_med, ct_correction_surface, at_correction_surface, ct_med_smooth, at_med_smooth = \
697+
successive_med(dh_filt, first_axis_only=ct_only, sav_filter=True, sg_window=sg_window, sg_poly=sg_poly, min_axes_count=min_axes_count)
698+
689699
# prepare the plots
690700
out_dh_fig = os.path.splitext(dh_fn)[0] + '_ct_at_dh_map.png'
691701
print(f"Creating Across-track (Row-wise) and Along-track (Column-wise) difference map figure at {out_dh_fig}")
692-
f,ax = plt.subplots(1,4,figsize=(12,5))
702+
f, ax = plt.subplots(1, 4, figsize=(12,5))
693703

694-
plot_ct_at_dh_map(ax,dh_filt,clim_dh,ct_correction_surface,at_correction_surface,dh_filt_corr)
695-
f.savefig(out_dh_fig, dpi=300,bbox_inches='tight', pad_inches=0.1)
704+
plot_ct_at_dh_map(ax, dh_filt, clim_dh, ct_correction_surface, at_correction_surface, dh_filt_corr)
705+
f.savefig(out_dh_fig, dpi=300, bbox_inches='tight', pad_inches=0.1)
696706

697707
out_lineplot_fig = os.path.splitext(dh_fn)[0] + '_ct_at_correction_fit.png'
698708
print(f"Creating Across-track (Row-wise) and Along-track (Column-wise) correction fits figure at {out_lineplot_fig}")
699709
fig = plt.figure(figsize=(8,4))
700-
plot_ct_at_dh_fits(fig,ct_med,ct_med_smooth,at_med,at_med_smooth,clim_dh=clim_dh)
701-
fig.savefig(out_lineplot_fig,dpi=300,bbox_inches='tight', pad_inches=0.1)
710+
plot_ct_at_dh_fits(fig, ct_med, ct_med_smooth, at_med, at_med_smooth, clim_dh=clim_dh)
711+
fig.savefig(out_lineplot_fig, dpi=300, bbox_inches='tight', pad_inches=0.1)
702712

703713
# Correct source DEM
704-
src_dem_corrected = src_dem - ct_correction_surface
705-
src_dem_corrected = src_dem_corrected - at_correction_surface
714+
src_dem_corrected = src_dem - ct_correction_surface - at_correction_surface
706715
src_dem_corrected_fn = os.path.splitext(src_dem_fn)[0]+'_ct_at_corrected.tif'
707716
print(f"Writing out corrected source DEM at {src_dem_corrected_fn}")
708-
iolib.writeGTiff(src_dem_corrected,src_dem_corrected_fn,src_ds=ds_list[0])
717+
iolib.writeGTiff(src_dem_corrected, src_dem_corrected_fn, src_ds=ds_list[0])
709718

710719
# correct difference map
711720
# this is the entire difference map from which the filtered map is derived (containing glacier etc)
712-
dh_corrected = dh - ct_correction_surface
713-
dh_corrected = dh_corrected - at_correction_surface
721+
dh_corrected = dh - ct_correction_surface - at_correction_surface
714722
dh_corrected_fn = os.path.splitext(dh_fn)[0]+'_ct_at_corrected.tif'
715723
print(f"Writing out corrected elevation difference map at {dh_corrected_fn}")
716-
iolib.writeGTiff(dh_corrected,dh_corrected_fn,src_ds=ds_list[1])
724+
iolib.writeGTiff(dh_corrected, dh_corrected_fn, src_ds=ds_list[1])
717725

718726
# write out filtered difference map
719727
dh_filt_corrected_fn = os.path.splitext(dh_filt_fn)[0]+'_ct_at_corrected.tif'
720728
print(f"Writing out corrected elevation difference map at {dh_filt_fn}")
721-
iolib.writeGTiff(dh_filt_corr,dh_filt_corrected_fn,src_ds=ds_list[2])
729+
iolib.writeGTiff(dh_filt_corr, dh_filt_corrected_fn, src_ds=ds_list[2])
722730

723731
print("Across-track (row-wise), Along-track (col-wise) correction complete")
724-
725-

0 commit comments

Comments
 (0)