Skip to content

Commit eefcdd8

Browse files
committed
Appeased black
1 parent f9e1b10 commit eefcdd8

File tree

1 file changed

+79
-45
lines changed

1 file changed

+79
-45
lines changed

notebooks/embedding_visualizer.ipynb

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"metadata": {},
3737
"outputs": [],
3838
"source": [
39-
"sunspots_df = pd.read_csv('SN_d_tot_V2.0.csv', delimiter=';', header=None)"
39+
"sunspots_df = pd.read_csv(\"SN_d_tot_V2.0.csv\", delimiter=\";\", header=None)"
4040
]
4141
},
4242
{
@@ -45,7 +45,16 @@
4545
"metadata": {},
4646
"outputs": [],
4747
"source": [
48-
"columns = ['year','month','day','date_fraction','sunspot_count','std_dev','observations','indicator']"
48+
"columns = [\n",
49+
" \"year\",\n",
50+
" \"month\",\n",
51+
" \"day\",\n",
52+
" \"date_fraction\",\n",
53+
" \"sunspot_count\",\n",
54+
" \"std_dev\",\n",
55+
" \"observations\",\n",
56+
" \"indicator\",\n",
57+
"]"
4958
]
5059
},
5160
{
@@ -54,7 +63,7 @@
5463
"metadata": {},
5564
"outputs": [],
5665
"source": [
57-
"sunspots_df.columns = columns "
66+
"sunspots_df.columns = columns"
5867
]
5968
},
6069
{
@@ -138,15 +147,15 @@
138147
}
139148
],
140149
"source": [
141-
"#model = MAE(\n",
150+
"# model = MAE(\n",
142151
"# **cfg.model.mae,\n",
143-
" # **cfg.model.samae,\n",
144-
" # hmi_mask=data_module.hmi_mask,\n",
152+
"# **cfg.model.samae,\n",
153+
"# hmi_mask=data_module.hmi_mask,\n",
145154
"# optimiser=cfg.model.opt.optimiser,\n",
146155
"# lr=cfg.model.opt.learning_rate,\n",
147156
"# weight_decay=cfg.model.opt.weight_decay,\n",
148-
" \n",
149-
"#)\n",
157+
"\n",
158+
"# )\n",
150159
"\n",
151160
"logger = WandbLogger(\n",
152161
" # WandbLogger params\n",
@@ -160,9 +169,8 @@
160169
" group=cfg.experiment.wandb.group,\n",
161170
" save_code=True,\n",
162171
" job_type=cfg.experiment.wandb.job_type,\n",
163-
"\n",
164172
")\n",
165-
"model = Pretrainer(cfg, logger=logger, is_backbone=True)\n"
173+
"model = Pretrainer(cfg, logger=logger, is_backbone=True)"
166174
]
167175
},
168176
{
@@ -181,7 +189,7 @@
181189
"metadata": {},
182190
"outputs": [],
183191
"source": [
184-
"#model.model.to(\"cuda\");\n",
192+
"# model.model.to(\"cuda\");\n",
185193
"model.model.eval();"
186194
]
187195
},
@@ -212,6 +220,7 @@
212220
"outputs": [],
213221
"source": [
214222
"import pandas as pd\n",
223+
"\n",
215224
"dates = []\n",
216225
"for i in range(train_dataset.__len__()):\n",
217226
"\n",
@@ -228,11 +237,11 @@
228237
"metadata": {},
229238
"outputs": [],
230239
"source": [
231-
"dates_df['year'] = pd.to_datetime(dates_df['date']).dt.year\n",
232-
"dates_df['month'] = pd.to_datetime(dates_df['date']).dt.month\n",
233-
"dates_df['day'] = pd.to_datetime(dates_df['date']).dt.day\n",
234-
"dates_df['hour'] = pd.to_datetime(dates_df['date']).dt.hour\n",
235-
"dates_df['time'] = pd.to_datetime(dates_df['date']).dt.time\n"
240+
"dates_df[\"year\"] = pd.to_datetime(dates_df[\"date\"]).dt.year\n",
241+
"dates_df[\"month\"] = pd.to_datetime(dates_df[\"date\"]).dt.month\n",
242+
"dates_df[\"day\"] = pd.to_datetime(dates_df[\"date\"]).dt.day\n",
243+
"dates_df[\"hour\"] = pd.to_datetime(dates_df[\"date\"]).dt.hour\n",
244+
"dates_df[\"time\"] = pd.to_datetime(dates_df[\"date\"]).dt.time"
236245
]
237246
},
238247
{
@@ -438,7 +447,11 @@
438447
"metadata": {},
439448
"outputs": [],
440449
"source": [
441-
"dates_with_spots_df = dates_df.merge(sunspots_df[['year','month','day','sunspot_count']], on=['year', 'month','day'], how='left')"
450+
"dates_with_spots_df = dates_df.merge(\n",
451+
" sunspots_df[[\"year\", \"month\", \"day\", \"sunspot_count\"]],\n",
452+
" on=[\"year\", \"month\", \"day\"],\n",
453+
" how=\"left\",\n",
454+
")"
442455
]
443456
},
444457
{
@@ -448,8 +461,13 @@
448461
"outputs": [],
449462
"source": [
450463
"import datetime\n",
451-
"dates_with_spots_noon_df = dates_with_spots_df[dates_with_spots_df['time'] == datetime.time(12, 0)]\n",
452-
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n"
464+
"\n",
465+
"dates_with_spots_noon_df = dates_with_spots_df[\n",
466+
" dates_with_spots_df[\"time\"] == datetime.time(12, 0)\n",
467+
"]\n",
468+
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n",
469+
" dates_with_spots_noon_df[\"year\"] == 2011\n",
470+
"]"
453471
]
454472
},
455473
{
@@ -481,8 +499,11 @@
481499
"source": [
482500
"# plot the sunspot count vs date\n",
483501
"import matplotlib.pyplot as plt\n",
484-
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n",
485-
"plt.plot(dates_with_spots_noon_df['date'], dates_with_spots_noon_df['sunspot_count'])\n"
502+
"\n",
503+
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n",
504+
" dates_with_spots_noon_df[\"year\"] == 2011\n",
505+
"]\n",
506+
"plt.plot(dates_with_spots_noon_df[\"date\"], dates_with_spots_noon_df[\"sunspot_count\"])"
486507
]
487508
},
488509
{
@@ -500,11 +521,17 @@
500521
}
501522
],
502523
"source": [
503-
"df_2011 = dates_df[dates_df['year'] == 2011]\n",
524+
"df_2011 = dates_df[dates_df[\"year\"] == 2011]\n",
504525
"# groupby month, select 100 random samples\n",
505-
"df_2011_subset = df_2011.groupby('month').apply(lambda x: x.sample(100, random_state=1)).reset_index(drop=True)\n",
526+
"df_2011_subset = (\n",
527+
" df_2011.groupby(\"month\")\n",
528+
" .apply(lambda x: x.sample(100, random_state=1))\n",
529+
" .reset_index(drop=True)\n",
530+
")\n",
506531
"quiet_months = [1, 2, 5, 6, 7, 8]\n",
507-
"df_2011_subset['is_active'] = df_2011_subset['month'].apply(lambda x: 0 if x in quiet_months else 1)"
532+
"df_2011_subset[\"is_active\"] = df_2011_subset[\"month\"].apply(\n",
533+
" lambda x: 0 if x in quiet_months else 1\n",
534+
")"
508535
]
509536
},
510537
{
@@ -525,15 +552,15 @@
525552
"mean_embeddings = []\n",
526553
"names = []\n",
527554
"\n",
528-
"for idx in tqdm(dates_with_spots_noon_2011_df['index'].values):\n",
555+
"for idx in tqdm(dates_with_spots_noon_2011_df[\"index\"].values):\n",
529556
" batch = train_dataset[idx]\n",
530557
" name = train_dataset.aligndata.iloc[idx].name\n",
531-
" batch = torch.tensor(batch).unsqueeze(0) \n",
532-
" #batch = batch.to(\"cuda\") \n",
533-
" x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)\n",
534-
" # cls_token \n",
535-
" cls_embedding = x[:,0,:].detach().cpu()\n",
536-
" mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()\n",
558+
" batch = torch.tensor(batch).unsqueeze(0)\n",
559+
" # batch = batch.to(\"cuda\")\n",
560+
" x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio=0)\n",
561+
" # cls_token\n",
562+
" cls_embedding = x[:, 0, :].detach().cpu()\n",
563+
" mean_embedding = x[:, 1:, :].mean(dim=1).detach().cpu()\n",
537564
" cls_embeddings.append(cls_embedding)\n",
538565
" mean_embeddings.append(mean_embedding)\n",
539566
" names.append(name)\n",
@@ -581,16 +608,14 @@
581608
"metadata": {},
582609
"outputs": [],
583610
"source": [
584-
"\n",
585611
"tsne = TSNE(n_components=2, random_state=0)\n",
586612
"\n",
587613
"cls_embeddings_np = cls_embeddings.numpy()\n",
588614
"cls_embeddings_tsne = tsne.fit_transform(cls_embeddings_np)\n",
589615
"\n",
590616
"\n",
591-
"\n",
592-
"dates_with_spots_noon_2011_df['cls_tsne_x'] = cls_embeddings_tsne[:,0]\n",
593-
"dates_with_spots_noon_2011_df['cls_tsne_y'] = cls_embeddings_tsne[:,1]\n",
617+
"dates_with_spots_noon_2011_df[\"cls_tsne_x\"] = cls_embeddings_tsne[:, 0]\n",
618+
"dates_with_spots_noon_2011_df[\"cls_tsne_y\"] = cls_embeddings_tsne[:, 1]\n",
594619
"\n",
595620
"\n",
596621
"tsne = TSNE(n_components=2, random_state=0)\n",
@@ -599,10 +624,8 @@
599624
"mean_embeddings_tsne = tsne.fit_transform(mean_embeddings_np)\n",
600625
"\n",
601626
"\n",
602-
"\n",
603-
"dates_with_spots_noon_2011_df['avg_tsne_x'] = mean_embeddings_tsne[:,0]\n",
604-
"dates_with_spots_noon_2011_df['avg_tsne_y'] = mean_embeddings_tsne[:,1]\n",
605-
"\n"
627+
"dates_with_spots_noon_2011_df[\"avg_tsne_x\"] = mean_embeddings_tsne[:, 0]\n",
628+
"dates_with_spots_noon_2011_df[\"avg_tsne_y\"] = mean_embeddings_tsne[:, 1]"
606629
]
607630
},
608631
{
@@ -611,10 +634,15 @@
611634
"metadata": {},
612635
"outputs": [],
613636
"source": [
637+
"fig = px.scatter(\n",
638+
" dates_with_spots_noon_2011_df,\n",
639+
" x=\"cls_tsne_x\",\n",
640+
" y=\"cls_tsne_y\",\n",
641+
" color=\"is_active\",\n",
642+
" hover_data=[\"month\", \"day\", \"hour\"],\n",
643+
")\n",
614644
"\n",
615-
"fig = px.scatter(dates_with_spots_noon_2011_df, x=\"cls_tsne_x\", y=\"cls_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n",
616-
"\n",
617-
"fig.show()\n"
645+
"fig.show()"
618646
]
619647
},
620648
{
@@ -623,9 +651,15 @@
623651
"metadata": {},
624652
"outputs": [],
625653
"source": [
626-
"fig = px.scatter(dates_with_spots_noon_2011_df, x=\"avg_tsne_x\", y=\"avg_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n",
654+
"fig = px.scatter(\n",
655+
" dates_with_spots_noon_2011_df,\n",
656+
" x=\"avg_tsne_x\",\n",
657+
" y=\"avg_tsne_y\",\n",
658+
" color=\"is_active\",\n",
659+
" hover_data=[\"month\", \"day\", \"hour\"],\n",
660+
")\n",
627661
"\n",
628-
"fig.show()\n"
662+
"fig.show()"
629663
]
630664
},
631665
{
@@ -635,7 +669,7 @@
635669
"outputs": [],
636670
"source": [
637671
"# save html\n",
638-
"fig.write_html(\"mean_pooling_tsne.html\")\n"
672+
"fig.write_html(\"mean_pooling_tsne.html\")"
639673
]
640674
},
641675
{

0 commit comments

Comments
 (0)