fix fullgraph in ds. (#4016)
### What this PR does / why we need it?
DS don't have 'AscendAttentionMetadataBuilder' class so will fail in
fullgraph.
We resolved the issue by modifying the code to only check for
'GDNAttentionMetadataBuilder ', while all other attention cases follow
the default branch.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -108,8 +108,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState)
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
# yapf: disable
|
||||
@@ -2887,12 +2886,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if isinstance(builder, AscendAttentionMetadataBuilder):
|
||||
attn_metadata_full_attention = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
elif isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
|
||||
common_metadata)
|
||||
else:
|
||||
attn_metadata_full_attention = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
if "linear_attn" in layer_name:
|
||||
attn_metadata[
|
||||
|
||||
Reference in New Issue
Block a user