Skip to content

Commit

Permalink
init flamingo embeds new weights
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Sep 17, 2023
1 parent 82d1c69 commit 4875822
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(
_weight=self.lang_model.get_input_embeddings().weight,
pad_token_id=self.pad_token_id,
)
input_embeds.additional_fc.weight.data.normal_(
mean=0.0, std=self.lang_model.config.initializer_range
)

self.lang_model.set_input_embeddings(input_embeds)

out_embeds = DecoupledLinear(
Expand All @@ -72,6 +76,9 @@ def __init__(
_weight=self.lang_model.get_output_embeddings().weight,
_bias=self.lang_model.get_output_embeddings().bias,
)
out_embeds.additional_fc.weight.data.normal_(
mean=0.0, std=self.lang_model.config.initializer_range
)
self.lang_model.set_output_embeddings(out_embeds)

# gradient checkpointing
Expand Down

0 comments on commit 4875822

Please sign in to comment.