From 48758227c664a6ab96152b132728be8323beb7a1 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Sat, 16 Sep 2023 23:22:15 -0700 Subject: [PATCH] init flamingo embeds new weights --- open_flamingo/src/vlm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/open_flamingo/src/vlm.py b/open_flamingo/src/vlm.py index 09aa165b..f8cda167 100644 --- a/open_flamingo/src/vlm.py +++ b/open_flamingo/src/vlm.py @@ -64,6 +64,10 @@ def __init__( _weight=self.lang_model.get_input_embeddings().weight, pad_token_id=self.pad_token_id, ) + input_embeds.additional_fc.weight.data.normal_( + mean=0.0, std=self.lang_model.config.initializer_range + ) + self.lang_model.set_input_embeddings(input_embeds) out_embeds = DecoupledLinear( @@ -72,6 +76,9 @@ def __init__( _weight=self.lang_model.get_output_embeddings().weight, _bias=self.lang_model.get_output_embeddings().bias, ) + out_embeds.additional_fc.weight.data.normal_( + mean=0.0, std=self.lang_model.config.initializer_range + ) self.lang_model.set_output_embeddings(out_embeds) # gradient checkpointing