Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
update seq classification examples
Browse files Browse the repository at this point in the history
  • Loading branch information
saidbleik committed Jan 13, 2020
1 parent 74f6ba6 commit 5611740
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 247 deletions.
205 changes: 146 additions & 59 deletions examples/text_classification/tc_mnli_transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"from sklearn.preprocessing import LabelEncoder\n",
"from tqdm import tqdm\n",
"from utils_nlp.common.timer import Timer\n",
"from utils_nlp.common.pytorch_utils import dataloader_from_dataset\n",
"from utils_nlp.dataset.multinli import load_pandas_df\n",
"from utils_nlp.models.transformers.sequence_classification import (\n",
" Processor, SequenceClassifier)"
Expand Down Expand Up @@ -93,7 +94,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 222k/222k [01:25<00:00, 2.60kKB/s] \n"
"100%|██████████| 222k/222k [01:20<00:00, 2.74kKB/s] \n"
]
}
],
Expand Down Expand Up @@ -196,7 +197,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/media/bleik2/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
"/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
" FutureWarning)\n"
]
}
Expand Down Expand Up @@ -232,11 +233,11 @@
{
"data": {
"text/plain": [
"telephone 1055\n",
"slate 1003\n",
"travel 961\n",
"fiction 952\n",
"government 938\n",
"telephone 1043\n",
"slate 989\n",
"fiction 968\n",
"travel 964\n",
"government 945\n",
"Name: genre, dtype: int64"
]
},
Expand Down Expand Up @@ -385,32 +386,108 @@
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>roberta-base</td>\n",
" <td>bert-base-japanese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>roberta-large</td>\n",
" <td>bert-base-japanese-whole-word-masking</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>roberta-large-mnli</td>\n",
" <td>bert-base-japanese-char</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>xlnet-base-cased</td>\n",
" <td>bert-base-japanese-char-whole-word-masking</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>xlnet-large-cased</td>\n",
" <td>bert-base-finnish-cased-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>distilbert-base-uncased</td>\n",
" <td>bert-base-finnish-uncased-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>roberta-base</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>roberta-large</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>roberta-large-mnli</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>distilroberta-base</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>roberta-base-openai-detector</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>roberta-large-openai-detector</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>xlnet-base-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>xlnet-large-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>distilbert-base-uncased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>distilbert-base-uncased-distilled-squad</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>distilbert-base-german-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>distilbert-base-multilingual-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33</th>\n",
" <td>albert-base-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34</th>\n",
" <td>albert-large-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35</th>\n",
" <td>albert-xlarge-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36</th>\n",
" <td>albert-xxlarge-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37</th>\n",
" <td>albert-base-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38</th>\n",
" <td>albert-large-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39</th>\n",
" <td>albert-xlarge-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40</th>\n",
" <td>albert-xxlarge-v2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
Expand All @@ -432,13 +509,32 @@
"12 bert-base-cased-finetuned-mrpc\n",
"13 bert-base-german-dbmdz-cased\n",
"14 bert-base-german-dbmdz-uncased\n",
"15 roberta-base\n",
"16 roberta-large\n",
"17 roberta-large-mnli\n",
"18 xlnet-base-cased\n",
"19 xlnet-large-cased\n",
"20 distilbert-base-uncased\n",
"21 distilbert-base-uncased-distilled-squad"
"15 bert-base-japanese\n",
"16 bert-base-japanese-whole-word-masking\n",
"17 bert-base-japanese-char\n",
"18 bert-base-japanese-char-whole-word-masking\n",
"19 bert-base-finnish-cased-v1\n",
"20 bert-base-finnish-uncased-v1\n",
"21 roberta-base\n",
"22 roberta-large\n",
"23 roberta-large-mnli\n",
"24 distilroberta-base\n",
"25 roberta-base-openai-detector\n",
"26 roberta-large-openai-detector\n",
"27 xlnet-base-cased\n",
"28 xlnet-large-cased\n",
"29 distilbert-base-uncased\n",
"30 distilbert-base-uncased-distilled-squad\n",
"31 distilbert-base-german-cased\n",
"32 distilbert-base-multilingual-cased\n",
"33 albert-base-v1\n",
"34 albert-large-v1\n",
"35 albert-xlarge-v1\n",
"36 albert-xxlarge-v1\n",
"37 albert-base-v2\n",
"38 albert-large-v2\n",
"39 albert-xlarge-v2\n",
"40 albert-xxlarge-v2"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -492,18 +588,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 231508/231508 [00:00<00:00, 15545441.79B/s]\n",
"100%|██████████| 492/492 [00:00<00:00, 560455.61B/s]\n",
"100%|██████████| 267967963/267967963 [00:04<00:00, 61255588.46B/s]\n",
"/media/bleik2/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"100%|██████████| 898823/898823 [00:00<00:00, 23932308.55B/s]\n",
"100%|██████████| 456318/456318 [00:00<00:00, 23321916.66B/s]\n",
"100%|██████████| 473/473 [00:00<00:00, 477015.10B/s]\n",
"100%|██████████| 501200538/501200538 [00:07<00:00, 64332558.45B/s]\n",
"100%|██████████| 798011/798011 [00:00<00:00, 25002433.16B/s]\n",
"100%|██████████| 641/641 [00:00<00:00, 695974.34B/s]\n",
"100%|██████████| 467042463/467042463 [00:08<00:00, 55154509.21B/s]\n"
"/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
}
],
Expand All @@ -518,11 +604,17 @@
" to_lower=model_name.endswith(\"uncased\"),\n",
" cache_dir=CACHE_DIR,\n",
" )\n",
" train_dataloader = processor.create_dataloader_from_df(\n",
" df_train, TEXT_COL, LABEL_COL, max_len=MAX_LEN, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True\n",
" train_dataset = processor.dataset_from_dataframe(\n",
" df_train, TEXT_COL, LABEL_COL, max_len=MAX_LEN\n",
" )\n",
" test_dataloader = processor.create_dataloader_from_df(\n",
" df_test, TEXT_COL, LABEL_COL, max_len=MAX_LEN, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False\n",
" train_dataloader = dataloader_from_dataset(\n",
" train_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True\n",
" )\n",
" test_dataset = processor.dataset_from_dataframe(\n",
" df_test, TEXT_COL, LABEL_COL, max_len=MAX_LEN\n",
" )\n",
" test_dataloader = dataloader_from_dataset(\n",
" test_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False\n",
" )\n",
"\n",
" # fine-tune\n",
Expand All @@ -531,17 +623,12 @@
" )\n",
" with Timer() as t:\n",
" classifier.fit(\n",
" train_dataloader,\n",
" num_epochs=NUM_EPOCHS,\n",
" num_gpus=NUM_GPUS,\n",
" verbose=False,\n",
" train_dataloader, num_epochs=NUM_EPOCHS, num_gpus=NUM_GPUS, verbose=False,\n",
" )\n",
" train_time = t.interval / 3600\n",
"\n",
" # predict\n",
" preds = classifier.predict(\n",
" test_dataloader, num_gpus=NUM_GPUS, verbose=False\n",
" )\n",
" preds = classifier.predict(test_dataloader, num_gpus=NUM_GPUS, verbose=False)\n",
"\n",
" # eval\n",
" accuracy = accuracy_score(df_test[LABEL_COL], preds)\n",
Expand Down Expand Up @@ -600,31 +687,31 @@
" <tbody>\n",
" <tr>\n",
" <th>accuracy</th>\n",
" <td>0.895477</td>\n",
" <td>0.879584</td>\n",
" <td>0.894866</td>\n",
" <td>0.889364</td>\n",
" <td>0.885697</td>\n",
" <td>0.886308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>f1-score</th>\n",
" <td>0.896656</td>\n",
" <td>0.881218</td>\n",
" <td>0.896108</td>\n",
" <td>0.885225</td>\n",
" <td>0.880926</td>\n",
" <td>0.881819</td>\n",
" </tr>\n",
" <tr>\n",
" <th>time(hrs)</th>\n",
" <td>0.021865</td>\n",
" <td>0.035351</td>\n",
" <td>0.046295</td>\n",
" <td>0.023326</td>\n",
" <td>0.044209</td>\n",
" <td>0.052801</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" distilbert-base-uncased roberta-base xlnet-base-cased\n",
"accuracy 0.895477 0.879584 0.894866\n",
"f1-score 0.896656 0.881218 0.896108\n",
"time(hrs) 0.021865 0.035351 0.046295"
"accuracy 0.889364 0.885697 0.886308\n",
"f1-score 0.885225 0.880926 0.881819\n",
"time(hrs) 0.023326 0.044209 0.052801"
]
},
"execution_count": 13,
Expand All @@ -645,7 +732,7 @@
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.8899755501222494,
"data": 0.887123064384678,
"encoder": "json",
"name": "accuracy",
"version": 1
Expand All @@ -663,7 +750,7 @@
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.8913273009038569,
"data": 0.8826569624491233,
"encoder": "json",
"name": "f1",
"version": 1
Expand All @@ -688,9 +775,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "nlp_gpu",
"display_name": "Python 3.6.8 64-bit ('nlp_gpu': conda)",
"language": "python",
"name": "nlp_gpu"
"name": "python36864bitnlpgpucondaa579511bcea84c65877ff3dca4205921"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading

0 comments on commit 5611740

Please sign in to comment.