vlm: enable radix cache for qwen-vl models (#5349)
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -190,25 +190,18 @@ class HFRunner:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
outputs = self.model.model(
|
||||
input_ids=None,
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
inputs_embeds=inputs_embeds,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
|
||||
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
|
||||
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
|
||||
if left_padding:
|
||||
embeddings = outputs.last_hidden_state[:, -1]
|
||||
else:
|
||||
sequence_lengths = pooling_mask.sum(dim=1) - 1
|
||||
batch_size = outputs.last_hidden_state.shape[0]
|
||||
embeddings = outputs.last_hidden_state[
|
||||
torch.arange(batch_size, device=outputs.last_hidden_state.device),
|
||||
sequence_lengths,
|
||||
]
|
||||
embeddings = outputs.hidden_states[-1][:, -1]
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
return embeddings.contiguous()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user