Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the initialization of customed Q-Former text module #2

Open
wdr-RA02 opened this issue Mar 27, 2024 · 0 comments
Open

Comments

@wdr-RA02
Copy link

Hi, first of all thank you for your great work! After viewing the implementation of the model, I felt a bit of strange about the way you initialize the customized Q-Former module.

From the parts of code related to its initializaion below, I suspect that the Qformer_txt module here is only randomly initialized.

EVCap/models/evcap.py

Lines 84 to 96 in 68bd158

self.Qformer_txt, self.query_tokens_txt = self.init_Qformer_txt(
num_query_token_txt, self.Qformer.config.hidden_size
)
self.Qformer_txt.resize_token_embeddings(len(self.bert_tokenizer))
self.Qformer_txt.cls = None
self.load_from_pretrained(url_or_filename=q_former_model)
if freeze_qformer:
for name, param in self.Qformer_txt.named_parameters():
param.requires_grad = False
self.Qformer_txt = self.Qformer_txt.eval()
self.Qformer_txt.train = disabled_train
self.query_tokens_txt.requires_grad = True
logging.info("freeze Qformer")

EVCap/models/blip2.py

Lines 61 to 74 in 68bd158

@classmethod
def init_Qformer_txt(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel_txt(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens

In init_Qformer_txt function above, only a BERT model with random weight is created, with no other functions to load any sort of pretrained checkpoints into it. Obviously self.load_from_pretrained function in Line 89 of evcap.py will only load the state_dict into model.Qformer from the original BLIP-2 checkpoint, leaving model.Qformer_txt in an uninitialized state.

Moreover, in the default setting the custom QFormer module is frozen except for the query tokens, which, in my view, means that the randomly initialized weights will persist during the entire training progress.

I think that in order for the module to carry out the retrieval task properly, some sort of pretrained BERT weight is required to be loaded into that module. (either it is bert-base-uncased or some sort of BLIP-2/InstructBLIP QFormer checkpoints) In your paper I found no description of any checkpoints used for this module.

I wonder if the random initialization setting here is caused by mistake or it is meant to work in this way, or if there are any sorts of weight copying mechanism that I overlooked in other parts of the code (in this case, I will give you my sincere apologies)

Hope to receive your reply soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant