1
+ from torch .utils .data import DataLoader
2
+ import argparse
3
+ from torch import nn , optim
4
+ import numpy as np
5
+ import torch
6
+ from collections import Counter
7
+ from nltk .tokenize import sent_tokenize , word_tokenize
8
+ import pandas as pd
9
+ import math
10
+
11
+ sequence_length = 1
12
+ batch_size = 450
13
+ max_epochs = 1
14
+ write_perplexities = True
15
+
16
+
17
+ class Model (nn .Module ):
18
+ def __init__ (self , dataset ):
19
+ super (Model , self ).__init__ ()
20
+ self .sequence_length = sequence_length
21
+ self .batch_size = batch_size
22
+ self .embedding_dim = 128
23
+ self .num_layers = 3
24
+ self .max_epochs = max_epochs
25
+ n_vocab = len (dataset .uniq_words )
26
+ self .embedding = nn .Embedding (num_embeddings = n_vocab ,embedding_dim = self .embedding_dim )
27
+ self .lstm_size = 128
28
+ self .fc = nn .Linear (self .lstm_size , n_vocab )
29
+ self .lstm = nn .LSTM (input_size = self .lstm_size ,hidden_size = self .lstm_size ,num_layers = self .num_layers ,dropout = 0.2 )
30
+
31
+ def forward (self , x , prev_state ):
32
+ embed = self .embedding (x ) #vector of word
33
+ output , state = self .lstm (embed , prev_state ) #pre_state is output of previous state
34
+ # print("OUTTTTTTTTTTTT", output)
35
+ # print(output.size())
36
+ logits = self .fc (output )
37
+ # print(logits.size())
38
+ return logits , state
39
+
40
+ def init_state (self , sequence_length ):
41
+ return (torch .zeros (self .num_layers , sequence_length , self .lstm_size ),
42
+ torch .zeros (self .num_layers , sequence_length , self .lstm_size ))
43
+
44
+ class Corpus (torch .utils .data .Dataset ):
45
+ def __init__ (self ):
46
+ self .sequence_length = sequence_length
47
+ self .batch_size = batch_size
48
+ self .max_epochs = max_epochs
49
+ self .words , self .validation_words , self .testing_words = self .load_words ()
50
+ word_counts = Counter (self .words )
51
+ self .uniq_words = sorted (word_counts , key = word_counts .get , reverse = True )
52
+
53
+ self .word_to_index = {word : index for index , word in enumerate (self .uniq_words )}
54
+ self .words_indexes = [self .word_to_index [w ] for w in self .words ]
55
+ self .index_to_word = {index : word for index , word in enumerate (self .uniq_words )}
56
+
57
+
58
+ def load_words (self ):
59
+ with open ("./train.txt" , "r" ) as f1 :
60
+ dat = f1 .read ()
61
+ dat = dat .replace ("\n " , " " )
62
+ dat = dat .replace ("@" , "" )
63
+ dat = dat .replace ("#" , "" )
64
+ dat = dat .replace ("*" , "" )
65
+ dat = dat .replace ("+" , "" )
66
+ dat = dat .replace ("^" , "" )
67
+ dat = dat .replace ("&" , "" )
68
+ dat = dat .replace ("~" , "" )
69
+ dat = dat .replace (" " , " " )
70
+ dat = dat .replace ("{" , "" )
71
+ dat = dat .replace ("}" , "" )
72
+ dat = dat .replace ("[" , "" )
73
+ dat = dat .replace ("]" , "" )
74
+ dat = dat .replace ("(" , "" )
75
+ dat = dat .replace (")" , "" )
76
+ dat = dat .replace (":" , "" )
77
+ dat = dat .replace ("\\ " , "" )
78
+ dat = dat .replace ("`" , "" )
79
+ # dat=dat.replace('"', "")
80
+
81
+
82
+ sentences = sent_tokenize (dat )
83
+ training_arr = []
84
+ for line in sentences :
85
+ a = line .strip ()
86
+ a = a .replace ('"' ,"" )
87
+ training_arr += ["<sent>" ] + word_tokenize (a )
88
+ with open ("./valid.txt" , "r" ) as f1 :
89
+ dat = f1 .read ()
90
+ dat = dat .replace ("\n " , " " )
91
+ dat = dat .replace ("@" , "" )
92
+ dat = dat .replace ("#" , "" )
93
+ dat = dat .replace ("*" , "" )
94
+ dat = dat .replace ("+" , "" )
95
+ dat = dat .replace ("^" , "" )
96
+ dat = dat .replace ("&" , "" )
97
+ dat = dat .replace ("~" , "" )
98
+ dat = dat .replace (" " , " " )
99
+ dat = dat .replace ("{" , "" )
100
+ dat = dat .replace ("}" , "" )
101
+ dat = dat .replace ("[" , "" )
102
+ dat = dat .replace ("]" , "" )
103
+ dat = dat .replace ("(" , "" )
104
+ dat = dat .replace (")" , "" )
105
+ dat = dat .replace (":" , "" )
106
+ dat = dat .replace ("\\ " , "" )
107
+ dat = dat .replace ("`" , "" )
108
+ sentences = sent_tokenize (dat )
109
+ validation_arr = []
110
+ for line in sentences :
111
+ a = line .strip ()
112
+ a = a .replace ('"' ,"" )
113
+ validation_arr += ["<sent>" ] + word_tokenize (a )
114
+ with open ("./test.txt" , "r" ) as f1 :
115
+ dat = f1 .read ()
116
+ dat = dat .replace ("\n " , " " )
117
+ dat = dat .replace ("@" , "" )
118
+ dat = dat .replace ("#" , "" )
119
+ dat = dat .replace ("*" , "" )
120
+ dat = dat .replace ("+" , "" )
121
+ dat = dat .replace ("^" , "" )
122
+ dat = dat .replace ("&" , "" )
123
+ dat = dat .replace ("~" , "" )
124
+ dat = dat .replace (" " , " " )
125
+ dat = dat .replace ("{" , "" )
126
+ dat = dat .replace ("}" , "" )
127
+ dat = dat .replace ("[" , "" )
128
+ dat = dat .replace ("]" , "" )
129
+ dat = dat .replace ("(" , "" )
130
+ dat = dat .replace (")" , "" )
131
+ dat = dat .replace (":" , "" )
132
+ dat = dat .replace ("\\ " , "" )
133
+ # dat=dat.replace('"', "")
134
+ dat = dat .replace ("`" , "" )
135
+
136
+ sentences = sent_tokenize (dat )
137
+ testing_arr = []
138
+ for line in sentences :
139
+ a = line .strip ()
140
+ a = a .replace ('"' ,"" )
141
+ # print("BBBBBBBBBBBBBBBB", (a))
142
+ testing_arr += ["<sent>" ] + word_tokenize (a )
143
+ # return data
144
+ # print(return_arr[:10])
145
+ training_arr = self .changeunknownwords (training_arr )
146
+ return training_arr , validation_arr , testing_arr
147
+ # train_df = pd.read_csv('./reddit-cleanjokes.csv')
148
+ # text = train_df['Joke'].str.cat(sep=' ')
149
+ # return text.split(' ')
150
+ def changeunknownwords (self , arr ):
151
+ count_dic = {}
152
+ for element in arr :
153
+ if element in count_dic :
154
+ count_dic [element ]+= 1
155
+ else :
156
+ count_dic [element ]= 1
157
+ for index ,element in enumerate (arr ):
158
+ if count_dic [element ]<= 3 :
159
+ arr [index ]= "<unk>"
160
+ return arr
161
+
162
+ def __len__ (self ):
163
+ return len (self .words_indexes ) - self .sequence_length
164
+
165
+ def __getitem__ (self , index ):
166
+ return (
167
+ torch .tensor (self .words_indexes [index :index + self .sequence_length ]),
168
+ torch .tensor (self .words_indexes [index + 1 :index + self .sequence_length + 1 ]),
169
+ )
170
+
171
+ def calculate_perplexity (dataset , model , current_data , savetofile = False , filename = "" ):
172
+ model .eval ()
173
+ sent_count = 0
174
+ perplexity = - 1
175
+ # print("AAAAAAA",current_data[:10])
176
+ # print(current_data[:10])
177
+ probab_product = 0
178
+ word_count = 0
179
+ final_answer = 0
180
+ sentence_count = 0
181
+ state_h , state_c = model .init_state (sequence_length )
182
+ sentence = ""
183
+
184
+ for i in range (0 , len (current_data )- sequence_length ):
185
+ # print(current_data[i], word_count)
186
+ if current_data [i ]== "<sent>" :
187
+ if i != 0 :
188
+ # print(word_count)
189
+ perplexity = math .exp ((- 1 / word_count )* probab_product )
190
+ sent_count += 1
191
+ if savetofile :
192
+ print (perplexity )
193
+ filename .write (sentence + "\t " + str (perplexity )+ "\n " )
194
+ final_answer += perplexity
195
+ sentence = ""
196
+ word_count = 0
197
+ probab_product = 0
198
+ else :
199
+ word_count += 1
200
+ sentence = sentence + " " + current_data [i ]
201
+
202
+ x_vector = []
203
+ for w in current_data [i :i + sequence_length ]:
204
+ if w in dataset .word_to_index :
205
+ x_vector .append (dataset .word_to_index [w ])
206
+ else :
207
+ x_vector .append (dataset .word_to_index ["<unk>" ])
208
+ # print(x_vector)
209
+ x = torch .tensor ([x_vector ])
210
+ # print(x)
211
+ y_pred , (state_h , state_c ) = model (x , (state_h , state_c ))
212
+
213
+ last_word_logits = y_pred [0 ][- 1 ]
214
+ p = torch .nn .functional .softmax (last_word_logits , dim = 0 ).detach ().numpy ()
215
+ if current_data [i + 1 ] in dataset .word_to_index :
216
+ word_to_predict = dataset .word_to_index [current_data [i + 1 ]]
217
+ else :
218
+ word_to_predict = dataset .word_to_index ["<unk>" ]
219
+ # print(p[:20])
220
+ # print(len(p))
221
+ # if word_to_predict==-1:
222
+ probab_product += math .log (p [word_to_predict ])
223
+ # else:
224
+ # probab_product+=math.log(0.001)
225
+ # # print(words)
226
+ perplexity = math .exp ((- 1 / word_count )* probab_product )
227
+ if savetofile :
228
+ print (perplexity )
229
+ filename .write (sentence + "\t " + str (perplexity )+ "\n " )
230
+ final_answer += perplexity
231
+ sentence_count += 1
232
+ avg_perplexity = final_answer / (sentence_count )
233
+ if savetofile :
234
+ filename .write (str (avg_perplexity ))
235
+ return avg_perplexity
236
+ # perplexity=1
237
+ # print(dataset.word_to_index)
238
+ # print(predict(dataset, model, text='<sent> <sent> <sent> he saw Benny.\n yet'))
239
+
240
+ def train (dataset , model ):
241
+ model .train ()
242
+ min_perplexity = 1e15
243
+ best_model = - 1
244
+
245
+ dataloader = DataLoader (dataset , batch_size = batch_size , drop_last = True )
246
+ criterion = nn .CrossEntropyLoss ()
247
+ optimizer = optim .Adam (model .parameters (), lr = 0.001 )
248
+ # training_error=[]
249
+ # I am going.
250
+ for epoch in range (max_epochs ):
251
+ print ("EPOCH: " ,epoch )
252
+ state_h , state_c = model .init_state (sequence_length )
253
+ # curr_loss=0
254
+
255
+ for batch , (x , y ) in enumerate (dataloader ):
256
+ # print((x))
257
+ optimizer .zero_grad ()
258
+
259
+ y_pred , (state_h , state_c ) = model (x , (state_h , state_c ))
260
+ loss = criterion (y_pred .transpose (1 , 2 ), y )
261
+
262
+ state_h = state_h .detach ()
263
+ state_c = state_c .detach ()
264
+
265
+ loss .backward ()
266
+ optimizer .step ()
267
+ # curr_loss+=loss.item()
268
+ print ({ 'epoch' : epoch , 'batch' : batch , 'loss' : loss .item () })
269
+ # training_error.append(curr_loss)
270
+ torch .save (model , f"./model{ epoch } " )
271
+ perplexity = calculate_perplexity (dataset = dataset , model = model , current_data = dataset .validation_words , savetofile = False )
272
+ if write_perplexities :
273
+ with open (f"./2019114006-LM{ epoch } -train-perplexity" , "w" ) as f1 :
274
+ calculate_perplexity (dataset = dataset , model = model , current_data = dataset .words [:30000 ], savetofile = True ,filename = f1 )
275
+ with open (f"./2019114006-LM{ epoch } -validate-perplexity" , "w" ) as f1 :
276
+ perplexity = calculate_perplexity (dataset = dataset , model = model , current_data = dataset .validation_words [:30000 ], savetofile = True ,filename = f1 )
277
+ with open (f"./2019114006-LM{ epoch } -test-perplexity" , "w" ) as f1 :
278
+ calculate_perplexity (dataset = dataset , model = model , current_data = dataset .testing_words [:30000 ], savetofile = True ,filename = f1 )
279
+ output_pipe = open ("./output.txt" , "a" )
280
+ output_pipe .write ("epoch " + str (epoch )+ ": " + str (perplexity )+ "\n " )
281
+ if perplexity < min_perplexity :
282
+ min_perplexity = perplexity
283
+ best_model = epoch
284
+
285
+ model = torch .load (f"./model{ best_model } " )
286
+
287
+ dataset = Corpus ()
288
+ model = Model (dataset )
289
+
290
+ train (dataset , model )
291
+ # print(perplexity(dataset, model, dataset.testing_words))
292
+
293
+ if write_perplexities :
294
+ with open ("./2019114006-LMBEST-train-perplexity" , "w" ) as f1 :
295
+ calculate_perplexity (dataset = dataset , model = model , current_data = dataset .words [:30000 ], savetofile = True ,filename = f1 )
296
+ with open ("./2019114006-LMBEST-validate-perplexity" , "w" ) as f1 :
297
+ calculate_perplexity (dataset = dataset , model = model , current_data = dataset .validation_words [:30000 ], savetofile = True ,filename = f1 )
298
+ with open ("./2019114006-LMBEST-test-perplexity" , "w" ) as f1 :
299
+ calculate_perplexity (dataset = dataset , model = model , current_data = dataset .testing_words [:30000 ], savetofile = True ,filename = f1 )
300
+ else :
301
+ dat = input ("ENTER SENTENCE" )
302
+ dat = dat .replace ("\n " , " " )
303
+ dat = dat .replace ("@" , "" )
304
+ dat = dat .replace ("#" , "" )
305
+ dat = dat .replace ("*" , "" )
306
+ dat = dat .replace ("+" , "" )
307
+ dat = dat .replace ("^" , "" )
308
+ dat = dat .replace ("&" , "" )
309
+ dat = dat .replace ("~" , "" )
310
+ dat = dat .replace (" " , " " )
311
+ dat = dat .replace ("{" , "" )
312
+ dat = dat .replace ("}" , "" )
313
+ dat = dat .replace ("[" , "" )
314
+ dat = dat .replace ("]" , "" )
315
+ dat = dat .replace ("(" , "" )
316
+ dat = dat .replace (")" , "" )
317
+ dat = dat .replace (":" , "" )
318
+ dat = dat .replace ("\\ " , "" )
319
+ # dat=dat.replace('"', "")
320
+ dat = dat .replace ("`" , "" )
321
+ sent = ["<sent>" ]+ word_tokenize (dat )+ ["<sent>" ]
322
+ print (calculate_perplexity (dataset = dataset , model = model , current_data = sent , savetofile = False ))
323
+
324
+
325
+
326
+
327
+
328
+ # perplexity=1
329
+ # print(dataset.word_to_index)
330
+ # print(predict(dataset, model, text='<sent> <sent> <sent> he saw Benny.\n yet'))
0 commit comments