From eefcdd8337ac853d60c3d978f6400a99098f17a3 Mon Sep 17 00:00:00 2001 From: dead-water Date: Tue, 6 Aug 2024 20:30:20 +0100 Subject: [PATCH] Appeased black --- notebooks/embedding_visualizer.ipynb | 124 +++++++++++++++++---------- 1 file changed, 79 insertions(+), 45 deletions(-) diff --git a/notebooks/embedding_visualizer.ipynb b/notebooks/embedding_visualizer.ipynb index 498bd94..50e259e 100644 --- a/notebooks/embedding_visualizer.ipynb +++ b/notebooks/embedding_visualizer.ipynb @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "sunspots_df = pd.read_csv('SN_d_tot_V2.0.csv', delimiter=';', header=None)" + "sunspots_df = pd.read_csv(\"SN_d_tot_V2.0.csv\", delimiter=\";\", header=None)" ] }, { @@ -45,7 +45,16 @@ "metadata": {}, "outputs": [], "source": [ - "columns = ['year','month','day','date_fraction','sunspot_count','std_dev','observations','indicator']" + "columns = [\n", + " \"year\",\n", + " \"month\",\n", + " \"day\",\n", + " \"date_fraction\",\n", + " \"sunspot_count\",\n", + " \"std_dev\",\n", + " \"observations\",\n", + " \"indicator\",\n", + "]" ] }, { @@ -54,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "sunspots_df.columns = columns " + "sunspots_df.columns = columns" ] }, { @@ -138,15 +147,15 @@ } ], "source": [ - "#model = MAE(\n", + "# model = MAE(\n", "# **cfg.model.mae,\n", - " # **cfg.model.samae,\n", - " # hmi_mask=data_module.hmi_mask,\n", + "# **cfg.model.samae,\n", + "# hmi_mask=data_module.hmi_mask,\n", "# optimiser=cfg.model.opt.optimiser,\n", "# lr=cfg.model.opt.learning_rate,\n", "# weight_decay=cfg.model.opt.weight_decay,\n", - " \n", - "#)\n", + "\n", + "# )\n", "\n", "logger = WandbLogger(\n", " # WandbLogger params\n", @@ -160,9 +169,8 @@ " group=cfg.experiment.wandb.group,\n", " save_code=True,\n", " job_type=cfg.experiment.wandb.job_type,\n", - "\n", ")\n", - "model = Pretrainer(cfg, logger=logger, is_backbone=True)\n" + "model = Pretrainer(cfg, logger=logger, is_backbone=True)" ] }, { @@ -181,7 +189,7 @@ "metadata": {}, "outputs": [], "source": [ - "#model.model.to(\"cuda\");\n", + "# model.model.to(\"cuda\");\n", "model.model.eval();" ] }, @@ -212,6 +220,7 @@ "outputs": [], "source": [ "import pandas as pd\n", + "\n", "dates = []\n", "for i in range(train_dataset.__len__()):\n", "\n", @@ -228,11 +237,11 @@ "metadata": {}, "outputs": [], "source": [ - "dates_df['year'] = pd.to_datetime(dates_df['date']).dt.year\n", - "dates_df['month'] = pd.to_datetime(dates_df['date']).dt.month\n", - "dates_df['day'] = pd.to_datetime(dates_df['date']).dt.day\n", - "dates_df['hour'] = pd.to_datetime(dates_df['date']).dt.hour\n", - "dates_df['time'] = pd.to_datetime(dates_df['date']).dt.time\n" + "dates_df[\"year\"] = pd.to_datetime(dates_df[\"date\"]).dt.year\n", + "dates_df[\"month\"] = pd.to_datetime(dates_df[\"date\"]).dt.month\n", + "dates_df[\"day\"] = pd.to_datetime(dates_df[\"date\"]).dt.day\n", + "dates_df[\"hour\"] = pd.to_datetime(dates_df[\"date\"]).dt.hour\n", + "dates_df[\"time\"] = pd.to_datetime(dates_df[\"date\"]).dt.time" ] }, { @@ -438,7 +447,11 @@ "metadata": {}, "outputs": [], "source": [ - "dates_with_spots_df = dates_df.merge(sunspots_df[['year','month','day','sunspot_count']], on=['year', 'month','day'], how='left')" + "dates_with_spots_df = dates_df.merge(\n", + " sunspots_df[[\"year\", \"month\", \"day\", \"sunspot_count\"]],\n", + " on=[\"year\", \"month\", \"day\"],\n", + " how=\"left\",\n", + ")" ] }, { @@ -448,8 +461,13 @@ "outputs": [], "source": [ "import datetime\n", - "dates_with_spots_noon_df = dates_with_spots_df[dates_with_spots_df['time'] == datetime.time(12, 0)]\n", - "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n" + "\n", + "dates_with_spots_noon_df = dates_with_spots_df[\n", + " dates_with_spots_df[\"time\"] == datetime.time(12, 0)\n", + "]\n", + "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n", + " dates_with_spots_noon_df[\"year\"] == 2011\n", + "]" ] }, { @@ -481,8 +499,11 @@ "source": [ "# plot the sunspot count vs date\n", "import matplotlib.pyplot as plt\n", - "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[dates_with_spots_noon_df['year'] == 2011]\n", - "plt.plot(dates_with_spots_noon_df['date'], dates_with_spots_noon_df['sunspot_count'])\n" + "\n", + "dates_with_spots_noon_2011_df = dates_with_spots_noon_df[\n", + " dates_with_spots_noon_df[\"year\"] == 2011\n", + "]\n", + "plt.plot(dates_with_spots_noon_df[\"date\"], dates_with_spots_noon_df[\"sunspot_count\"])" ] }, { @@ -500,11 +521,17 @@ } ], "source": [ - "df_2011 = dates_df[dates_df['year'] == 2011]\n", + "df_2011 = dates_df[dates_df[\"year\"] == 2011]\n", "# groupby month, select 100 random samples\n", - "df_2011_subset = df_2011.groupby('month').apply(lambda x: x.sample(100, random_state=1)).reset_index(drop=True)\n", + "df_2011_subset = (\n", + " df_2011.groupby(\"month\")\n", + " .apply(lambda x: x.sample(100, random_state=1))\n", + " .reset_index(drop=True)\n", + ")\n", "quiet_months = [1, 2, 5, 6, 7, 8]\n", - "df_2011_subset['is_active'] = df_2011_subset['month'].apply(lambda x: 0 if x in quiet_months else 1)" + "df_2011_subset[\"is_active\"] = df_2011_subset[\"month\"].apply(\n", + " lambda x: 0 if x in quiet_months else 1\n", + ")" ] }, { @@ -525,15 +552,15 @@ "mean_embeddings = []\n", "names = []\n", "\n", - "for idx in tqdm(dates_with_spots_noon_2011_df['index'].values):\n", + "for idx in tqdm(dates_with_spots_noon_2011_df[\"index\"].values):\n", " batch = train_dataset[idx]\n", " name = train_dataset.aligndata.iloc[idx].name\n", - " batch = torch.tensor(batch).unsqueeze(0) \n", - " #batch = batch.to(\"cuda\") \n", - " x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)\n", - " # cls_token \n", - " cls_embedding = x[:,0,:].detach().cpu()\n", - " mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()\n", + " batch = torch.tensor(batch).unsqueeze(0)\n", + " # batch = batch.to(\"cuda\")\n", + " x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio=0)\n", + " # cls_token\n", + " cls_embedding = x[:, 0, :].detach().cpu()\n", + " mean_embedding = x[:, 1:, :].mean(dim=1).detach().cpu()\n", " cls_embeddings.append(cls_embedding)\n", " mean_embeddings.append(mean_embedding)\n", " names.append(name)\n", @@ -581,16 +608,14 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "tsne = TSNE(n_components=2, random_state=0)\n", "\n", "cls_embeddings_np = cls_embeddings.numpy()\n", "cls_embeddings_tsne = tsne.fit_transform(cls_embeddings_np)\n", "\n", "\n", - "\n", - "dates_with_spots_noon_2011_df['cls_tsne_x'] = cls_embeddings_tsne[:,0]\n", - "dates_with_spots_noon_2011_df['cls_tsne_y'] = cls_embeddings_tsne[:,1]\n", + "dates_with_spots_noon_2011_df[\"cls_tsne_x\"] = cls_embeddings_tsne[:, 0]\n", + "dates_with_spots_noon_2011_df[\"cls_tsne_y\"] = cls_embeddings_tsne[:, 1]\n", "\n", "\n", "tsne = TSNE(n_components=2, random_state=0)\n", @@ -599,10 +624,8 @@ "mean_embeddings_tsne = tsne.fit_transform(mean_embeddings_np)\n", "\n", "\n", - "\n", - "dates_with_spots_noon_2011_df['avg_tsne_x'] = mean_embeddings_tsne[:,0]\n", - "dates_with_spots_noon_2011_df['avg_tsne_y'] = mean_embeddings_tsne[:,1]\n", - "\n" + "dates_with_spots_noon_2011_df[\"avg_tsne_x\"] = mean_embeddings_tsne[:, 0]\n", + "dates_with_spots_noon_2011_df[\"avg_tsne_y\"] = mean_embeddings_tsne[:, 1]" ] }, { @@ -611,10 +634,15 @@ "metadata": {}, "outputs": [], "source": [ + "fig = px.scatter(\n", + " dates_with_spots_noon_2011_df,\n", + " x=\"cls_tsne_x\",\n", + " y=\"cls_tsne_y\",\n", + " color=\"is_active\",\n", + " hover_data=[\"month\", \"day\", \"hour\"],\n", + ")\n", "\n", - "fig = px.scatter(dates_with_spots_noon_2011_df, x=\"cls_tsne_x\", y=\"cls_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n", - "\n", - "fig.show()\n" + "fig.show()" ] }, { @@ -623,9 +651,15 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.scatter(dates_with_spots_noon_2011_df, x=\"avg_tsne_x\", y=\"avg_tsne_y\", color=\"is_active\", hover_data=[\"month\", \"day\", \"hour\"])\n", + "fig = px.scatter(\n", + " dates_with_spots_noon_2011_df,\n", + " x=\"avg_tsne_x\",\n", + " y=\"avg_tsne_y\",\n", + " color=\"is_active\",\n", + " hover_data=[\"month\", \"day\", \"hour\"],\n", + ")\n", "\n", - "fig.show()\n" + "fig.show()" ] }, { @@ -635,7 +669,7 @@ "outputs": [], "source": [ "# save html\n", - "fig.write_html(\"mean_pooling_tsne.html\")\n" + "fig.write_html(\"mean_pooling_tsne.html\")" ] }, {