vlm: enable radix cache for qwen-vl models (#5349)

Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Mick
2025-04-24 12:35:05 +09:00
committed by GitHub
parent 7d0edf3cae
commit c998d04b46
26 changed files with 429 additions and 331 deletions

View File

@@ -190,25 +190,18 @@ class HFRunner:
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model.model(
input_ids=None,
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
return_dict=True,
inputs_embeds=inputs_embeds,
image_grid_thw=image_grid_thw,
)
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
if left_padding:
embeddings = outputs.last_hidden_state[:, -1]
else:
sequence_lengths = pooling_mask.sum(dim=1) - 1
batch_size = outputs.last_hidden_state.shape[0]
embeddings = outputs.last_hidden_state[
torch.arange(batch_size, device=outputs.last_hidden_state.device),
sequence_lengths,
]
embeddings = outputs.hidden_states[-1][:, -1]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.contiguous()