[Feat] Supports Aclgraph for bge-m3 (#3171)

### What this PR does / why we need it?
[Feat] Supports Aclgraph for bge-m3

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
```
pytest -s tests/e2e/singlecard/test_embedding.py
pytest -s tests/e2e/singlecard/test_embedding_aclgraph.py
```
to start an online server with bs 10, each batch's seq length=8192, we
set --max-num-batched-tokens=8192*10 to ensure encoder is not chunked:
```
vllm serve /home/data/bge-m3 --max_model_len 1024 --served-model-name "bge-m3" --task embed --host 0.0.0.0 --port 9095 --max-num-batched-tokens 81920 --compilation-config '{"cudagraph_capture_sizes":[8192, 10240, 20480, 40960, 81920]}'
```
For bs10, each batch's seq length=8192, QPS is improved from 85 to 104,
which is a 22% improvement, lots of host bound is reduced.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: wangyongjun <1104133197@qq.com>
This commit is contained in:
xuyexiong
2025-10-14 23:07:45 +08:00
committed by GitHub
parent 434059e417
commit 02c26dcfc7
11 changed files with 307 additions and 21 deletions

View File

@@ -77,10 +77,11 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, MambaSpec,
MLAAttentionSpec,
from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, MLAAttentionSpec,
UniformTypeKVCacheSpecs)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
@@ -867,8 +868,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
@@ -1426,14 +1430,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens]
self.slot_mapping[:total_num_scheduled_tokens].copy_(
slot_mapping[:total_num_scheduled_tokens],
non_blocking=True,
)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_reqs, 1),
dtype=torch.int32,
device=self.device,
)
slot_mapping = torch.zeros(
(total_num_scheduled_tokens, ),
dtype=torch.int64,
device=self.device,
)
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens]
self.slot_mapping[:total_num_scheduled_tokens].copy_(
slot_mapping[:total_num_scheduled_tokens],
non_blocking=True,
)
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1469,7 +1488,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
common_prefix_len = 0
extra_attn_metadata_args = {}
builder = attn_group.get_metadata_builder()
if isinstance(builder, GDNAttentionMetadataBuilder):
if isinstance(builder, GDNAttentionMetadataBuilder
) or self.model_config.runner_type == "pooling":
if use_spec_decode:
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
@@ -2625,7 +2645,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.initialize_attn_backend(kv_cache_config)
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
self.need_accepted_tokens = any([
@@ -2634,6 +2653,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
])
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.initialize_attn_backend(kv_cache_config)
if self.ascend_config.is_deepseek_sfa:
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
@@ -3089,6 +3110,31 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kernel_block_sizes=kernel_block_sizes,
)
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype)
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
assert len(
encoder_only_attn_specs
) == 1, "Only support one encoder-only attention spec now"
spec, layer_names = encoder_only_attn_specs.popitem()
self.kv_cache_config.kv_cache_groups.append(
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.