Skip to content

Commit 484b9b2

Browse files
authored
Bugfix in GPT2 model loading (#227)
1 parent c6935c2 commit 484b9b2

File tree

1 file changed

+4
-6
lines changed
  • texar/torch/modules/pretrained

1 file changed

+4
-6
lines changed

texar/torch/modules/pretrained/gpt2.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,12 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
226226
for name, array in zip(names, arrays):
227227
if name in global_tensor_map:
228228
v_name = global_tensor_map[name]
229-
230229
if name == "model/wte":
231-
if load_output_layer:
232-
pointer = self._name_to_variable(
233-
"word_embedder.embedding")
234-
assert pointer.shape == array.shape
235-
pointer.data = torch.from_numpy(array)
230+
pointer = self._name_to_variable("word_embedder.embedding")
231+
assert pointer.shape == array.shape
232+
pointer.data = torch.from_numpy(array)
236233

234+
if load_output_layer:
237235
output_pointer = self._name_to_variable(
238236
"_output_layer.weight")
239237
assert output_pointer.shape == array.shape

0 commit comments

Comments
 (0)