Skip to content

Commit

Permalink
Appeased black
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Aug 6, 2024
1 parent f9e1b10 commit eefcdd8
Showing 1 changed file with 79 additions and 45 deletions.
124 changes: 79 additions & 45 deletions notebooks/embedding_visualizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"sunspots_df = pd.read_csv('SN_d_tot_V2.0.csv', delimiter=';', header=None)"
"sunspots_df = pd.read_csv(\"SN_d_tot_V2.0.csv\", delimiter=\";\", header=None)"
]
},
{
Expand All @@ -45,7 +45,16 @@
"metadata": {},
"outputs": [],
"source": [
"columns = ['year','month','day','date_fraction','sunspot_count','std_dev','observations','indicator']"
"columns = [\n",
" \"year\",\n",
" \"month\",\n",
" \"day\",\n",
" \"date_fraction\",\n",
" \"sunspot_count\",\n",
" \"std_dev\",\n",
" \"observations\",\n",
" \"indicator\",\n",
"]"
]
},
{
Expand All @@ -54,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"sunspots_df.columns = columns "
"sunspots_df.columns = columns"
]
},
{
Expand Down Expand Up @@ -138,15 +147,15 @@
}
],
"source": [
"#model = MAE(\n",
"# model = MAE(\n",
"# **cfg.model.mae,\n",
" # **cfg.model.samae,\n",
" # hmi_mask=data_module.hmi_mask,\n",
"# **cfg.model.samae,\n",
"# hmi_mask=data_module.hmi_mask,\n",
"# optimiser=cfg.model.opt.optimiser,\n",
"# lr=cfg.model.opt.learning_rate,\n",
"# weight_decay=cfg.model.opt.weight_decay,\n",
" \n",
"#)\n",
"\n",
"# )\n",
"\n",
"logger = WandbLogger(\n",
" # WandbLogger params\n",
Expand All @@ -160,9 +169,8 @@
" group=cfg.experiment.wandb.group,\n",
" save_code=True,\n",
" job_type=cfg.experiment.wandb.job_type,\n",
"\n",
")\n",
"model = Pretrainer(cfg, logger=logger, is_backbone=True)\n"
"model = Pretrainer(cfg, logger=logger, is_backbone=True)"
]
},
{
Expand All @@ -181,7 +189,7 @@
"metadata": {},
"outputs": [],
"source": [
"#model.model.to(\"cuda\");\n",
"# model.model.to(\"cuda\");\n",
"model.model.eval();"
]
},
Expand Down Expand Up @@ -212,6 +220,7 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"dates = []\n",
"for i in range(train_dataset.__len__()):\n",
"\n",
Expand All @@ -228,11 +237,11 @@
"metadata": {},
"outputs": [],
"source": [
"dates_df['year'] = pd.to_datetime(dates_df['date']).dt.year\n",
"dates_df['month'] = pd.to_datetime(dates_df['date']).dt.month\n",
"dates_df['day'] = pd.to_datetime(dates_df['date']).dt.day\n",
"dates_df['hour'] = pd.to_datetime(dates_df['date']).dt.hour\n",
"dates_df['time'] = pd.to_datetime(dates_df['date']).dt.time\n"
"dates_df[\"year\"] = pd.to_datetime(dates_df[\"date\"]).dt.year\n",
"dates_df[\"month\"] = pd.to_datetime(dates_df[\"date\"]).dt.month\n",
"dates_df[\"day\"] = pd.to_datetime(dates_df[\"date\"]).dt.day\n",
"dates_df[\"hour\"] = pd.to_datetime(dates_df[\"date\"]).dt.hour\n",
"dates_df[\"time\"] = pd.to_datetime(dates_df[\"date\"]).dt.time"
]
},
{
Expand Down Expand Up @@ -438,7 +447,11 @@
"metadata": {},
"outputs": [],
"source": [
"dates_with_spots_df = dates_df.merge(sunspots_df[['year','month','day','sunspot_count']], on=['year', 'month','day'], how='left')"
"dates_with_spots_df = dates_df.merge(\n",
" sunspots_df[[\"year\", \"month\", \"day\", \"sunspot_count\"]],\n",
" on=[\"year\", \"month\", \"day\"],\n",
" how=\"left\",\n",
")"
]
},
{
Expand All @@ -448,8 +461,13 @@
"outputs": [],
"source": [
"import datetime\n",
"dates_with_spots_noon_df = dates_with_spots_df[dates_with_spots_df['time'] == datetime.time(12, 0)]\n",
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n"
"\n",
"dates_with_spots_noon_df = dates_with_spots_df[\n",
" dates_with_spots_df[\"time\"] == datetime.time(12, 0)\n",
"]\n",
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n",
" dates_with_spots_noon_df[\"year\"] == 2011\n",
"]"
]
},
{
Expand Down Expand Up @@ -481,8 +499,11 @@
"source": [
"# plot the sunspot count vs date\n",
"import matplotlib.pyplot as plt\n",
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n",
"plt.plot(dates_with_spots_noon_df['date'], dates_with_spots_noon_df['sunspot_count'])\n"
"\n",
"dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n",
" dates_with_spots_noon_df[\"year\"] == 2011\n",
"]\n",
"plt.plot(dates_with_spots_noon_df[\"date\"], dates_with_spots_noon_df[\"sunspot_count\"])"
]
},
{
Expand All @@ -500,11 +521,17 @@
}
],
"source": [
"df_2011 = dates_df[dates_df['year'] == 2011]\n",
"df_2011 = dates_df[dates_df[\"year\"] == 2011]\n",
"# groupby month, select 100 random samples\n",
"df_2011_subset = df_2011.groupby('month').apply(lambda x: x.sample(100, random_state=1)).reset_index(drop=True)\n",
"df_2011_subset = (\n",
" df_2011.groupby(\"month\")\n",
" .apply(lambda x: x.sample(100, random_state=1))\n",
" .reset_index(drop=True)\n",
")\n",
"quiet_months = [1, 2, 5, 6, 7, 8]\n",
"df_2011_subset['is_active'] = df_2011_subset['month'].apply(lambda x: 0 if x in quiet_months else 1)"
"df_2011_subset[\"is_active\"] = df_2011_subset[\"month\"].apply(\n",
" lambda x: 0 if x in quiet_months else 1\n",
")"
]
},
{
Expand All @@ -525,15 +552,15 @@
"mean_embeddings = []\n",
"names = []\n",
"\n",
"for idx in tqdm(dates_with_spots_noon_2011_df['index'].values):\n",
"for idx in tqdm(dates_with_spots_noon_2011_df[\"index\"].values):\n",
" batch = train_dataset[idx]\n",
" name = train_dataset.aligndata.iloc[idx].name\n",
" batch = torch.tensor(batch).unsqueeze(0) \n",
" #batch = batch.to(\"cuda\") \n",
" x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)\n",
" # cls_token \n",
" cls_embedding = x[:,0,:].detach().cpu()\n",
" mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()\n",
" batch = torch.tensor(batch).unsqueeze(0)\n",
" # batch = batch.to(\"cuda\")\n",
" x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio=0)\n",
" # cls_token\n",
" cls_embedding = x[:, 0, :].detach().cpu()\n",
" mean_embedding = x[:, 1:, :].mean(dim=1).detach().cpu()\n",
" cls_embeddings.append(cls_embedding)\n",
" mean_embeddings.append(mean_embedding)\n",
" names.append(name)\n",
Expand Down Expand Up @@ -581,16 +608,14 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"tsne = TSNE(n_components=2, random_state=0)\n",
"\n",
"cls_embeddings_np = cls_embeddings.numpy()\n",
"cls_embeddings_tsne = tsne.fit_transform(cls_embeddings_np)\n",
"\n",
"\n",
"\n",
"dates_with_spots_noon_2011_df['cls_tsne_x'] = cls_embeddings_tsne[:,0]\n",
"dates_with_spots_noon_2011_df['cls_tsne_y'] = cls_embeddings_tsne[:,1]\n",
"dates_with_spots_noon_2011_df[\"cls_tsne_x\"] = cls_embeddings_tsne[:, 0]\n",
"dates_with_spots_noon_2011_df[\"cls_tsne_y\"] = cls_embeddings_tsne[:, 1]\n",
"\n",
"\n",
"tsne = TSNE(n_components=2, random_state=0)\n",
Expand All @@ -599,10 +624,8 @@
"mean_embeddings_tsne = tsne.fit_transform(mean_embeddings_np)\n",
"\n",
"\n",
"\n",
"dates_with_spots_noon_2011_df['avg_tsne_x'] = mean_embeddings_tsne[:,0]\n",
"dates_with_spots_noon_2011_df['avg_tsne_y'] = mean_embeddings_tsne[:,1]\n",
"\n"
"dates_with_spots_noon_2011_df[\"avg_tsne_x\"] = mean_embeddings_tsne[:, 0]\n",
"dates_with_spots_noon_2011_df[\"avg_tsne_y\"] = mean_embeddings_tsne[:, 1]"
]
},
{
Expand All @@ -611,10 +634,15 @@
"metadata": {},
"outputs": [],
"source": [
"fig = px.scatter(\n",
" dates_with_spots_noon_2011_df,\n",
" x=\"cls_tsne_x\",\n",
" y=\"cls_tsne_y\",\n",
" color=\"is_active\",\n",
" hover_data=[\"month\", \"day\", \"hour\"],\n",
")\n",
"\n",
"fig = px.scatter(dates_with_spots_noon_2011_df, x=\"cls_tsne_x\", y=\"cls_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n",
"\n",
"fig.show()\n"
"fig.show()"
]
},
{
Expand All @@ -623,9 +651,15 @@
"metadata": {},
"outputs": [],
"source": [
"fig = px.scatter(dates_with_spots_noon_2011_df, x=\"avg_tsne_x\", y=\"avg_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n",
"fig = px.scatter(\n",
" dates_with_spots_noon_2011_df,\n",
" x=\"avg_tsne_x\",\n",
" y=\"avg_tsne_y\",\n",
" color=\"is_active\",\n",
" hover_data=[\"month\", \"day\", \"hour\"],\n",
")\n",
"\n",
"fig.show()\n"
"fig.show()"
]
},
{
Expand All @@ -635,7 +669,7 @@
"outputs": [],
"source": [
"# save html\n",
"fig.write_html(\"mean_pooling_tsne.html\")\n"
"fig.write_html(\"mean_pooling_tsne.html\")"
]
},
{
Expand Down

0 comments on commit eefcdd8

Please sign in to comment.