[Model] Support pooling models (#3122)
### What this PR does / why we need it? Support pooling models (like `bge-reranker-v2-m3`) in vllm-ascend, this pr covered the three model types of embed (cls_token, mean_token, lasttoken). After this [commit](17373dcd93), vllm has provided support for adapting pooling models on the v1 engine. This PR includes corresponding adaptations on the vllm-ascend side. Fixes #1960 - vLLM version: v0.12.0 - vLLM main:ad32e3e19c--------- Signed-off-by: lianyibo <lianyibo1@kunlunit.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -377,6 +377,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
use_sparse=self.use_sparse)
|
||||
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
self._set_up_drafter()
|
||||
@@ -1029,8 +1030,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
# Pooling situation.
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask()
|
||||
if self.model_config.runner_type == "pooling":
|
||||
return self.attn_mask_builder.get_attn_mask(2048, torch.bool)
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size > 1:
|
||||
@@ -1933,8 +1934,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
common_prefix_len = 0
|
||||
extra_attn_metadata_args = {}
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder
|
||||
) or self.model_config.runner_type == "pooling":
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
if use_spec_decode:
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
@@ -1946,6 +1946,11 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
elif self.model_config.runner_type == "pooling":
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
@@ -1968,18 +1973,52 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
input_ids, inputs_embeds, intermediate_tensors,
|
||||
max_num_scheduled_tokens)
|
||||
|
||||
def _init_model_kwargs(self):
|
||||
model_kwargs = dict[str, Any]()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
num_pooling_reqs = len(self.input_batch.pooling_params)
|
||||
|
||||
if num_pooling_reqs == 0:
|
||||
return model_kwargs
|
||||
|
||||
pooling_params = self.input_batch.get_pooling_params()
|
||||
|
||||
assert num_pooling_reqs == num_reqs
|
||||
|
||||
token_type_id_requests = dict[int, Any]()
|
||||
for i, param in enumerate(pooling_params):
|
||||
if param.extra_kwargs is not None and \
|
||||
(token_types := param.extra_kwargs.get(
|
||||
"compressed_token_type_ids")) is not None:
|
||||
token_type_id_requests[i] = token_types
|
||||
|
||||
if len(token_type_id_requests) == 0:
|
||||
return model_kwargs
|
||||
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
token_type_ids = []
|
||||
|
||||
for i in range(num_reqs):
|
||||
pos = token_type_id_requests.get(i, seq_lens[i])
|
||||
ids = (torch.arange(seq_lens[i]) >= pos).int()
|
||||
token_type_ids.append(ids)
|
||||
|
||||
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
|
||||
device=self.device)
|
||||
return model_kwargs
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
maybe_padded_num_tokens,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**self._init_model_kwargs())
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||
@@ -2022,7 +2061,14 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens):
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
if self.model_config.runner_type == "pooling":
|
||||
if isinstance(
|
||||
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
EncoderOnlyAttentionSpec):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
@@ -2251,7 +2297,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
" a batch must be pooling request"
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||
device=hidden_states.device)
|
||||
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
||||
@@ -4049,6 +4095,15 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
desc="Capturing ACL graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
aclgraph_runtime_mode.name))
|
||||
|
||||
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
# When the kv cache spec is empty, PiecewiseBackend is not initialized, and
|
||||
# compilation_case=1 will cause the dynamic shape position to be incorrectly derived.
|
||||
if not self.get_kv_cache_spec():
|
||||
self._dummy_run(2,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode)
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
@@ -4057,7 +4112,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
|
||||
@@ -793,17 +793,12 @@ class InputBatch:
|
||||
logitsprocs=self.logitsprocs,
|
||||
)
|
||||
|
||||
@property
|
||||
def pooling_metadata(self) -> PoolingMetadata:
|
||||
if len(self.pooling_params) == 0:
|
||||
pooling_params = []
|
||||
else:
|
||||
# Note, for now this assumes that all request in the batch
|
||||
# are either sampling or pooling requests
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
pooling_params = [
|
||||
self.pooling_params[req_id] for req_id in self.req_ids
|
||||
]
|
||||
def get_pooling_params(self) -> list[PoolingParams]:
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
||||
|
||||
def get_pooling_metadata(self) -> PoolingMetadata:
|
||||
pooling_params = self.get_pooling_params()
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
|
||||
Reference in New Issue
Block a user