Skip to content

Commit ae39b05

Browse files
Refactor non-parameter kernel function and enhance VAE training script
- Renamed _non_para_kernel to _non_para_kernel_t4 for clarity - Added device selection for GPU/CPU in train_vae_nogcn function - Improved folder creation logic with warning for existing directories
1 parent 6638ac2 commit ae39b05

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

dynamo/external/celldancer/utilities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ def __call__(self, *args, **kwargs):
2727
joblib.parallel.BatchCompletionCallBack = old_batch_callback
2828
tqdm_object.close()
2929

30-
def _non_para_kernel(X,Y,down_sample_idx):
30+
def _non_para_kernel_t4(X,Y,down_sample_idx):
3131
# (no first cls),pseudotime r square calculation
3232
# this version has downsampling section
3333
# TO DO WHEN ONLY USING ONE GENE, WILL CAUSL PROBLEM WHEN COMBINING
3434
# Usage: Gene pseudotime fitting and r square (moved to utilities)
3535
# input: X,Y
3636
# return: estimator, r_square
37-
# example:
37+
# example:
3838
# X = pd.DataFrame(np.arange(100)*np.pi/100)
3939
# Y = pd.DataFrame(np.sin(X)+np.random.normal(loc = 0, scale = 0.5, size = (100,1)))
4040
# estimator,r_square=non_para_kernel(X,Y)
41-
41+
4242
# X2=pd.DataFrame(np.random.randint(0,100,size=[200,1]))
4343
# Y2=pd.DataFrame(np.random.normal(9,5,size=[200]))
4444
# X = pd.DataFrame(np.arange(100)*np.pi/100)

dynamo/external/latentvelo/trainer_nogcn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
def train_vae_nogcn(model, adata, epochs = 50, learning_rate = 1e-2, batch_size = 200, grad_clip = 1, shuffle=True, test=0.1, name = '', optimizer='adam', random_seed=42):
1313

14+
# Set device (GPU if available, otherwise CPU)
15+
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
16+
1417
results_folder = './' + name + '/'
1518
if not os.path.exists(results_folder):
1619
os.mkdir(results_folder)
1720
else:
1821
print('Warning, folder already exists. This may overwrite a previous fit.')
19-
22+
2023
if optimizer == 'adam':
2124
optimizer = th.optim.Adam(model.parameters(), lr = learning_rate)
2225
elif optimizer == 'adamW':

0 commit comments

Comments
 (0)