|
36 | 36 | "metadata": {},
|
37 | 37 | "outputs": [],
|
38 | 38 | "source": [
|
39 |
| - "sunspots_df = pd.read_csv('SN_d_tot_V2.0.csv', delimiter=';', header=None)" |
| 39 | + "sunspots_df = pd.read_csv(\"SN_d_tot_V2.0.csv\", delimiter=\";\", header=None)" |
40 | 40 | ]
|
41 | 41 | },
|
42 | 42 | {
|
|
45 | 45 | "metadata": {},
|
46 | 46 | "outputs": [],
|
47 | 47 | "source": [
|
48 |
| - "columns = ['year','month','day','date_fraction','sunspot_count','std_dev','observations','indicator']" |
| 48 | + "columns = [\n", |
| 49 | + " \"year\",\n", |
| 50 | + " \"month\",\n", |
| 51 | + " \"day\",\n", |
| 52 | + " \"date_fraction\",\n", |
| 53 | + " \"sunspot_count\",\n", |
| 54 | + " \"std_dev\",\n", |
| 55 | + " \"observations\",\n", |
| 56 | + " \"indicator\",\n", |
| 57 | + "]" |
49 | 58 | ]
|
50 | 59 | },
|
51 | 60 | {
|
|
54 | 63 | "metadata": {},
|
55 | 64 | "outputs": [],
|
56 | 65 | "source": [
|
57 |
| - "sunspots_df.columns = columns " |
| 66 | + "sunspots_df.columns = columns" |
58 | 67 | ]
|
59 | 68 | },
|
60 | 69 | {
|
|
138 | 147 | }
|
139 | 148 | ],
|
140 | 149 | "source": [
|
141 |
| - "#model = MAE(\n", |
| 150 | + "# model = MAE(\n", |
142 | 151 | "# **cfg.model.mae,\n",
|
143 |
| - " # **cfg.model.samae,\n", |
144 |
| - " # hmi_mask=data_module.hmi_mask,\n", |
| 152 | + "# **cfg.model.samae,\n", |
| 153 | + "# hmi_mask=data_module.hmi_mask,\n", |
145 | 154 | "# optimiser=cfg.model.opt.optimiser,\n",
|
146 | 155 | "# lr=cfg.model.opt.learning_rate,\n",
|
147 | 156 | "# weight_decay=cfg.model.opt.weight_decay,\n",
|
148 |
| - " \n", |
149 |
| - "#)\n", |
| 157 | + "\n", |
| 158 | + "# )\n", |
150 | 159 | "\n",
|
151 | 160 | "logger = WandbLogger(\n",
|
152 | 161 | " # WandbLogger params\n",
|
|
160 | 169 | " group=cfg.experiment.wandb.group,\n",
|
161 | 170 | " save_code=True,\n",
|
162 | 171 | " job_type=cfg.experiment.wandb.job_type,\n",
|
163 |
| - "\n", |
164 | 172 | ")\n",
|
165 |
| - "model = Pretrainer(cfg, logger=logger, is_backbone=True)\n" |
| 173 | + "model = Pretrainer(cfg, logger=logger, is_backbone=True)" |
166 | 174 | ]
|
167 | 175 | },
|
168 | 176 | {
|
|
181 | 189 | "metadata": {},
|
182 | 190 | "outputs": [],
|
183 | 191 | "source": [
|
184 |
| - "#model.model.to(\"cuda\");\n", |
| 192 | + "# model.model.to(\"cuda\");\n", |
185 | 193 | "model.model.eval();"
|
186 | 194 | ]
|
187 | 195 | },
|
|
212 | 220 | "outputs": [],
|
213 | 221 | "source": [
|
214 | 222 | "import pandas as pd\n",
|
| 223 | + "\n", |
215 | 224 | "dates = []\n",
|
216 | 225 | "for i in range(train_dataset.__len__()):\n",
|
217 | 226 | "\n",
|
|
228 | 237 | "metadata": {},
|
229 | 238 | "outputs": [],
|
230 | 239 | "source": [
|
231 |
| - "dates_df['year'] = pd.to_datetime(dates_df['date']).dt.year\n", |
232 |
| - "dates_df['month'] = pd.to_datetime(dates_df['date']).dt.month\n", |
233 |
| - "dates_df['day'] = pd.to_datetime(dates_df['date']).dt.day\n", |
234 |
| - "dates_df['hour'] = pd.to_datetime(dates_df['date']).dt.hour\n", |
235 |
| - "dates_df['time'] = pd.to_datetime(dates_df['date']).dt.time\n" |
| 240 | + "dates_df[\"year\"] = pd.to_datetime(dates_df[\"date\"]).dt.year\n", |
| 241 | + "dates_df[\"month\"] = pd.to_datetime(dates_df[\"date\"]).dt.month\n", |
| 242 | + "dates_df[\"day\"] = pd.to_datetime(dates_df[\"date\"]).dt.day\n", |
| 243 | + "dates_df[\"hour\"] = pd.to_datetime(dates_df[\"date\"]).dt.hour\n", |
| 244 | + "dates_df[\"time\"] = pd.to_datetime(dates_df[\"date\"]).dt.time" |
236 | 245 | ]
|
237 | 246 | },
|
238 | 247 | {
|
|
438 | 447 | "metadata": {},
|
439 | 448 | "outputs": [],
|
440 | 449 | "source": [
|
441 |
| - "dates_with_spots_df = dates_df.merge(sunspots_df[['year','month','day','sunspot_count']], on=['year', 'month','day'], how='left')" |
| 450 | + "dates_with_spots_df = dates_df.merge(\n", |
| 451 | + " sunspots_df[[\"year\", \"month\", \"day\", \"sunspot_count\"]],\n", |
| 452 | + " on=[\"year\", \"month\", \"day\"],\n", |
| 453 | + " how=\"left\",\n", |
| 454 | + ")" |
442 | 455 | ]
|
443 | 456 | },
|
444 | 457 | {
|
|
448 | 461 | "outputs": [],
|
449 | 462 | "source": [
|
450 | 463 | "import datetime\n",
|
451 |
| - "dates_with_spots_noon_df = dates_with_spots_df[dates_with_spots_df['time'] == datetime.time(12, 0)]\n", |
452 |
| - "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n" |
| 464 | + "\n", |
| 465 | + "dates_with_spots_noon_df = dates_with_spots_df[\n", |
| 466 | + " dates_with_spots_df[\"time\"] == datetime.time(12, 0)\n", |
| 467 | + "]\n", |
| 468 | + "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n", |
| 469 | + " dates_with_spots_noon_df[\"year\"] == 2011\n", |
| 470 | + "]" |
453 | 471 | ]
|
454 | 472 | },
|
455 | 473 | {
|
|
481 | 499 | "source": [
|
482 | 500 | "# plot the sunspot count vs date\n",
|
483 | 501 | "import matplotlib.pyplot as plt\n",
|
484 |
| - "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n", |
485 |
| - "plt.plot(dates_with_spots_noon_df['date'], dates_with_spots_noon_df['sunspot_count'])\n" |
| 502 | + "\n", |
| 503 | + "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n", |
| 504 | + " dates_with_spots_noon_df[\"year\"] == 2011\n", |
| 505 | + "]\n", |
| 506 | + "plt.plot(dates_with_spots_noon_df[\"date\"], dates_with_spots_noon_df[\"sunspot_count\"])" |
486 | 507 | ]
|
487 | 508 | },
|
488 | 509 | {
|
|
500 | 521 | }
|
501 | 522 | ],
|
502 | 523 | "source": [
|
503 |
| - "df_2011 = dates_df[dates_df['year'] == 2011]\n", |
| 524 | + "df_2011 = dates_df[dates_df[\"year\"] == 2011]\n", |
504 | 525 | "# groupby month, select 100 random samples\n",
|
505 |
| - "df_2011_subset = df_2011.groupby('month').apply(lambda x: x.sample(100, random_state=1)).reset_index(drop=True)\n", |
| 526 | + "df_2011_subset = (\n", |
| 527 | + " df_2011.groupby(\"month\")\n", |
| 528 | + " .apply(lambda x: x.sample(100, random_state=1))\n", |
| 529 | + " .reset_index(drop=True)\n", |
| 530 | + ")\n", |
506 | 531 | "quiet_months = [1, 2, 5, 6, 7, 8]\n",
|
507 |
| - "df_2011_subset['is_active'] = df_2011_subset['month'].apply(lambda x: 0 if x in quiet_months else 1)" |
| 532 | + "df_2011_subset[\"is_active\"] = df_2011_subset[\"month\"].apply(\n", |
| 533 | + " lambda x: 0 if x in quiet_months else 1\n", |
| 534 | + ")" |
508 | 535 | ]
|
509 | 536 | },
|
510 | 537 | {
|
|
525 | 552 | "mean_embeddings = []\n",
|
526 | 553 | "names = []\n",
|
527 | 554 | "\n",
|
528 |
| - "for idx in tqdm(dates_with_spots_noon_2011_df['index'].values):\n", |
| 555 | + "for idx in tqdm(dates_with_spots_noon_2011_df[\"index\"].values):\n", |
529 | 556 | " batch = train_dataset[idx]\n",
|
530 | 557 | " name = train_dataset.aligndata.iloc[idx].name\n",
|
531 |
| - " batch = torch.tensor(batch).unsqueeze(0) \n", |
532 |
| - " #batch = batch.to(\"cuda\") \n", |
533 |
| - " x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)\n", |
534 |
| - " # cls_token \n", |
535 |
| - " cls_embedding = x[:,0,:].detach().cpu()\n", |
536 |
| - " mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()\n", |
| 558 | + " batch = torch.tensor(batch).unsqueeze(0)\n", |
| 559 | + " # batch = batch.to(\"cuda\")\n", |
| 560 | + " x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio=0)\n", |
| 561 | + " # cls_token\n", |
| 562 | + " cls_embedding = x[:, 0, :].detach().cpu()\n", |
| 563 | + " mean_embedding = x[:, 1:, :].mean(dim=1).detach().cpu()\n", |
537 | 564 | " cls_embeddings.append(cls_embedding)\n",
|
538 | 565 | " mean_embeddings.append(mean_embedding)\n",
|
539 | 566 | " names.append(name)\n",
|
|
581 | 608 | "metadata": {},
|
582 | 609 | "outputs": [],
|
583 | 610 | "source": [
|
584 |
| - "\n", |
585 | 611 | "tsne = TSNE(n_components=2, random_state=0)\n",
|
586 | 612 | "\n",
|
587 | 613 | "cls_embeddings_np = cls_embeddings.numpy()\n",
|
588 | 614 | "cls_embeddings_tsne = tsne.fit_transform(cls_embeddings_np)\n",
|
589 | 615 | "\n",
|
590 | 616 | "\n",
|
591 |
| - "\n", |
592 |
| - "dates_with_spots_noon_2011_df['cls_tsne_x'] = cls_embeddings_tsne[:,0]\n", |
593 |
| - "dates_with_spots_noon_2011_df['cls_tsne_y'] = cls_embeddings_tsne[:,1]\n", |
| 617 | + "dates_with_spots_noon_2011_df[\"cls_tsne_x\"] = cls_embeddings_tsne[:, 0]\n", |
| 618 | + "dates_with_spots_noon_2011_df[\"cls_tsne_y\"] = cls_embeddings_tsne[:, 1]\n", |
594 | 619 | "\n",
|
595 | 620 | "\n",
|
596 | 621 | "tsne = TSNE(n_components=2, random_state=0)\n",
|
|
599 | 624 | "mean_embeddings_tsne = tsne.fit_transform(mean_embeddings_np)\n",
|
600 | 625 | "\n",
|
601 | 626 | "\n",
|
602 |
| - "\n", |
603 |
| - "dates_with_spots_noon_2011_df['avg_tsne_x'] = mean_embeddings_tsne[:,0]\n", |
604 |
| - "dates_with_spots_noon_2011_df['avg_tsne_y'] = mean_embeddings_tsne[:,1]\n", |
605 |
| - "\n" |
| 627 | + "dates_with_spots_noon_2011_df[\"avg_tsne_x\"] = mean_embeddings_tsne[:, 0]\n", |
| 628 | + "dates_with_spots_noon_2011_df[\"avg_tsne_y\"] = mean_embeddings_tsne[:, 1]" |
606 | 629 | ]
|
607 | 630 | },
|
608 | 631 | {
|
|
611 | 634 | "metadata": {},
|
612 | 635 | "outputs": [],
|
613 | 636 | "source": [
|
| 637 | + "fig = px.scatter(\n", |
| 638 | + " dates_with_spots_noon_2011_df,\n", |
| 639 | + " x=\"cls_tsne_x\",\n", |
| 640 | + " y=\"cls_tsne_y\",\n", |
| 641 | + " color=\"is_active\",\n", |
| 642 | + " hover_data=[\"month\", \"day\", \"hour\"],\n", |
| 643 | + ")\n", |
614 | 644 | "\n",
|
615 |
| - "fig = px.scatter(dates_with_spots_noon_2011_df, x=\"cls_tsne_x\", y=\"cls_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n", |
616 |
| - "\n", |
617 |
| - "fig.show()\n" |
| 645 | + "fig.show()" |
618 | 646 | ]
|
619 | 647 | },
|
620 | 648 | {
|
|
623 | 651 | "metadata": {},
|
624 | 652 | "outputs": [],
|
625 | 653 | "source": [
|
626 |
| - "fig = px.scatter(dates_with_spots_noon_2011_df, x=\"avg_tsne_x\", y=\"avg_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n", |
| 654 | + "fig = px.scatter(\n", |
| 655 | + " dates_with_spots_noon_2011_df,\n", |
| 656 | + " x=\"avg_tsne_x\",\n", |
| 657 | + " y=\"avg_tsne_y\",\n", |
| 658 | + " color=\"is_active\",\n", |
| 659 | + " hover_data=[\"month\", \"day\", \"hour\"],\n", |
| 660 | + ")\n", |
627 | 661 | "\n",
|
628 |
| - "fig.show()\n" |
| 662 | + "fig.show()" |
629 | 663 | ]
|
630 | 664 | },
|
631 | 665 | {
|
|
635 | 669 | "outputs": [],
|
636 | 670 | "source": [
|
637 | 671 | "# save html\n",
|
638 |
| - "fig.write_html(\"mean_pooling_tsne.html\")\n" |
| 672 | + "fig.write_html(\"mean_pooling_tsne.html\")" |
639 | 673 | ]
|
640 | 674 | },
|
641 | 675 | {
|
|
0 commit comments