[Model] Support pooling models (#3122)

### What this PR does / why we need it?

Support pooling models (like `bge-reranker-v2-m3`) in vllm-ascend, this
pr covered the three model types of embed (cls_token, mean_token,
lasttoken).

After this
[commit](17373dcd93),
vllm has provided support for adapting pooling models on the v1 engine.
This PR includes corresponding adaptations on the vllm-ascend side.

Fixes #1960

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: lianyibo <lianyibo1@kunlunit.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
lianyibo
2025-12-10 11:37:57 +08:00
committed by GitHub
parent 1a7a34c5ec
commit e32014ac1d
17 changed files with 577 additions and 338 deletions

View File

@@ -221,6 +221,10 @@ class AscendMetadata:
# dcp
decode_meta: Optional[AscendMetadataForDecode] = None
# Whether is the pooling model with causal attention,
# used to guide the attention computation for pooling models.
is_causal_pooling: Optional[bool] = None
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
@@ -319,6 +323,10 @@ class AscendAttentionMetadataBuilder:
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
is_causal_pooling = None
if self.model_config.runner_type == "pooling":
is_causal_pooling = common_attn_metadata.causal if hasattr(
common_attn_metadata, 'causal') else True
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
@@ -336,7 +344,8 @@ class AscendAttentionMetadataBuilder:
attn_mask=attn_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes)
num_decodes=num_decodes,
is_causal_pooling=is_causal_pooling)
return attn_metadata
def build_for_graph_capture(
@@ -597,30 +606,39 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
return output
def _forward_encode(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
return output
def _forward_encoder_attention(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.is_causal_pooling is not None
if attn_metadata.is_causal_pooling:
# use sparse_mode 3 in causal scenario
return torch_npu.npu_fusion_attention(
query=query,
key=key,
value=value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=3,
atten_mask=attn_metadata.attn_mask,
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]
else:
# use default sparse_mode 0 in normal scenario, which means no mask works on it
return torch_npu.npu_fusion_attention(
query=query,
key=key,
value=value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]
def reshape_and_cache(
self,
@@ -697,18 +715,22 @@ class AscendAttentionBackendImpl(AttentionImpl):
" for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
attn_type = self.attn_type
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/Decoder cross-attention "
"is not implemented for "
"PallasAttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
if self.attn_type == AttentionType.ENCODER_ONLY:
attn_output = self._forward_encode(query, key, value,
attn_metadata, output)
# pooling model branch
if isinstance(attn_metadata.is_causal_pooling, bool):
attn_output = self._forward_encoder_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
output = self.forward_impl(query, key, value, kv_cache, attn_metadata,