diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7d192cb2..13b69ffd 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -134,8 +134,9 @@ class BLIP_Decoder(nn.Module): def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): image_embeds = self.visual_encoder(image) - if not sample: - image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + # recent version of transformers seems to do repeat_interleave automatically + # if not sample: + # image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}