Skip to content

Commit 5d46987

Browse files
committed
Add save and load
1 parent 13a6c7c commit 5d46987

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

notebooks/test_gzip_classify.ipynb

+56-2
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,60 @@
6060
"source": [
6161
"model.predict(\"ฉันดีใจ\", k=1)"
6262
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 5,
67+
"id": "5a97f0d3",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"model.save(\"d.model\")"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": 6,
77+
"id": "6e183243",
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"model2 = pythainlp.classify.param_free.GzipModel(model_path=\"d.model\")"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 7,
87+
"id": "b30af6f0",
88+
"metadata": {},
89+
"outputs": [
90+
{
91+
"data": {
92+
"text/plain": [
93+
"'Positive'"
94+
]
95+
},
96+
"execution_count": 7,
97+
"metadata": {},
98+
"output_type": "execute_result"
99+
}
100+
],
101+
"source": [
102+
"model2.predict(x1=\"ฉันดีใจ\", k=1)"
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": null,
108+
"id": "3e72a33b",
109+
"metadata": {},
110+
"outputs": [],
111+
"source": []
63112
}
64113
],
65114
"metadata": {
66115
"kernelspec": {
67-
"display_name": "Python 3 (ipykernel)",
116+
"display_name": "Python 3.8.13 ('base')",
68117
"language": "python",
69118
"name": "python3"
70119
},
@@ -78,7 +127,12 @@
78127
"name": "python",
79128
"nbconvert_exporter": "python",
80129
"pygments_lexer": "ipython3",
81-
"version": "3.10.9"
130+
"version": "3.8.13"
131+
},
132+
"vscode": {
133+
"interpreter": {
134+
"hash": "a1d6ff38954a1cdba4cf61ffa51e42f4658fc35985cd256cd89123cae8466a39"
135+
}
82136
}
83137
},
84138
"nbformat": 4,

pythainlp/classify/param_free.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import gzip
66
from typing import List, Tuple
77
import numpy as np
8+
import json
89

910

1011
class GzipModel:
@@ -16,9 +17,12 @@ class GzipModel:
1617
:param list training_data: list [(text_sample,label)]
1718
"""
1819

19-
def __init__(self, training_data: List[Tuple[str, str]]):
20-
self.training_data = np.array(training_data)
21-
self.Cx2_list = self.train()
20+
def __init__(self, training_data: List[Tuple[str, str]]=None, model_path=None):
21+
if model_path!=None:
22+
self.load(model_path)
23+
else:
24+
self.training_data = np.array(training_data)
25+
self.Cx2_list = self.train()
2226

2327
def train(self):
2428
Cx2_list = []
@@ -72,3 +76,15 @@ def predict(self, x1: str, k: int = 1) -> str:
7276
predict_class = top_k_class[counts.argmax()]
7377

7478
return predict_class
79+
80+
def save(self, path: str):
81+
with open(path, "w") as f:
82+
json.dump({
83+
"training_data": self.training_data.tolist(), "Cx2_list":self.Cx2_list
84+
}, f, ensure_ascii=False)
85+
86+
def load(self, path: str):
87+
with open(path, "r") as f:
88+
data = json.load(f)
89+
self.Cx2_list = data["Cx2_list"]
90+
self.training_data = np.array(data["training_data"])

0 commit comments

Comments
 (0)