[Feat][Graph]Support FULL_DECEDE_ONLY mode for MLA models (#3125)
### What this PR does / why we need it?
Adds support for capturing the Multi-Layer Attention (MLA) decode
operation into an ACL graph. This improves performance by compiling the
attention kernel for single-token decoding.
Key changes include:
- Implementing the graph capture logic for the MLA kernel, including
workspace management and parameter updates.
- Modifying the rotary embedding (RoPE) handling to use pre-allocated
tensors, which is a requirement for graph capture.
- Adding a `build_for_graph_capture` method to the MLA metadata builder
to create dummy metadata during the graph compilation phase.
Known issues:
- Currently, MTP is not supported in FULL_DECEDE_ONLY mode -- we're
working on a fix
- We are preparing to remove update_mla_attn_params with
auto_dispatch_capture
### Does this PR introduce _any_ user-facing change?
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: panchao-hub <315134829@qq.com>
Signed-off-by: p00465316 <panchao13@huawei.com>
Co-authored-by: p00465316 <panchao13@huawei.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -20,6 +20,8 @@ Compare the outputs of vLLM with and without aclgraph.
|
||||
Run `pytest tests/compile/test_aclgraph.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
@@ -73,3 +75,76 @@ def test_models_with_aclgraph(
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_aclgraph_outputs",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
def test_models_with_aclgraph_full_decode_only(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$'
|
||||
'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,'
|
||||
'$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.'
|
||||
'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,'
|
||||
'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.'
|
||||
),
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen'
|
||||
'independently and uniformly at random on the perimeter of $ABCD$.'
|
||||
'If the expected value of the area of triangle $\\triangle AXY$'
|
||||
'can be expressed as $\\frac{m}{n}$, for relatively prime positive'
|
||||
'integers $m$ and $n$, compute $m+n$.'),
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$'
|
||||
'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$'
|
||||
'and $x^2 + cx + b = 0$ also have a common real root.'
|
||||
'Compute the sum $a + b + c$.')
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
top_k=1)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=False,
|
||||
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"},
|
||||
) as runner:
|
||||
vllm_aclgraph_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
vllm_aclgraph_outputs_list = []
|
||||
for output in vllm_aclgraph_outputs:
|
||||
vllm_aclgraph_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_aclgraph_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_aclgraph_outputs",
|
||||
)
|
||||
|
||||
@@ -461,11 +461,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode_without_graph(self,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_up_proj):
|
||||
mock_up_proj,
|
||||
mock_get_forward_context):
|
||||
num_tokens = 100
|
||||
block_size = 4
|
||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
@@ -487,6 +489,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
|
||||
block_size, metadata)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
@@ -614,12 +617,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch("torch.npu.stream")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||
mock_get_multistream_comm_context,
|
||||
mock_npu_stream):
|
||||
mock_get_multistream_comm_context, mock_npu_stream,
|
||||
mock_get_forward_context):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
BS = 100
|
||||
@@ -644,6 +648,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
]
|
||||
mock_get_multistream_comm_context.return_value = None
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
|
||||
@@ -237,6 +237,7 @@ class AscendAttentionMetadataBuilder:
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
attn_metadata = self.build(
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
@@ -21,6 +22,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
@@ -169,7 +171,7 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
class AscendMLAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -389,6 +391,8 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
cos = common_attn_metadata.cos
|
||||
sin = common_attn_metadata.sin
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
@@ -397,21 +401,45 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
assert self.sin_cache is not None
|
||||
if cos is None and sin is None:
|
||||
cos = self.cos_cache[
|
||||
input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
else:
|
||||
cos[:num_decodes,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(
|
||||
1).unsqueeze(2)
|
||||
sin[:num_decodes,
|
||||
...] = self.sin_cache[input_positions].unsqueeze(
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:num_decodes, ...],
|
||||
cos=cos[:num_decodes, ...])
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -431,6 +459,26 @@ class AscendMLAMetadataBuilder:
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
)
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class DecodeMLAPreprocessResult(NamedTuple):
|
||||
ql_nope: Optional[torch.Tensor] = None
|
||||
@@ -834,24 +882,63 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=decode_meta.block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths)
|
||||
common_kwargs = {
|
||||
'query_rope': q_pe,
|
||||
'key_rope': k_pe,
|
||||
'num_heads': self.num_heads,
|
||||
'num_key_value_heads': self.num_kv_heads,
|
||||
'input_layout': input_layout,
|
||||
'atten_mask': spec_attn_mask,
|
||||
'sparse_mode': sparse_mode,
|
||||
'scale': self.scale,
|
||||
'antiquant_mode': 0,
|
||||
'antiquant_scale': None,
|
||||
'block_table': decode_meta.block_table,
|
||||
'block_size': block_size,
|
||||
"actual_seq_lengths": actual_seq_lengths,
|
||||
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
|
||||
input_layout, spec_attn_mask, sparse_mode, self.scale,
|
||||
decode_meta.block_table, block_size,
|
||||
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
|
||||
attn_output, softmax_lse))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
**common_kwargs,
|
||||
workspace=workspace,
|
||||
out=[attn_output, softmax_lse])
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
|
||||
@@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
# NOTE: This is a temporary solution for rotary embedding in MLA
|
||||
cos: torch.Tensor = None
|
||||
sin: torch.Tensor = None
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
|
||||
@@ -229,6 +229,52 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
spec_attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
seq_lens_list, actual_seq_lengths, workspace, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
workspace=workspace,
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
|
||||
@@ -214,12 +214,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
|
||||
# after MLA being supported
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
|
||||
compilation_config.cudagraph_mode
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
|
||||
and model_config.use_mla):
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
|
||||
@@ -104,7 +104,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_graph_params,
|
||||
update_attn_params)
|
||||
update_attn_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
||||
D2DExpertWeightLoader
|
||||
@@ -358,6 +359,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
if self.vllm_config.model_config.use_mla and \
|
||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos = torch.ones(self.max_num_reqs,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
self.sin = torch.zeros(self.max_num_reqs,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
else:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
@@ -1427,6 +1447,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
graph_pad_size=self.graph_pad_size,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
@@ -1453,7 +1475,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=self.model,
|
||||
model=self.get_model(),
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
@@ -1488,8 +1510,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
|
||||
if get_forward_context().sp_enabled:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
@@ -2195,14 +2222,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_table_tensor=block_table_tensor[:num_reqs],
|
||||
slot_mapping=self.slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
max_query_len=max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = builder.build_for_graph_capture(
|
||||
common_attn_metadata)
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly,
|
||||
self.get_model())
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
@@ -2218,9 +2252,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds)
|
||||
forward_context = get_forward_context()
|
||||
assert forward_context is not None
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
|
||||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||||
hidden_states, _ = hidden_states
|
||||
|
||||
Reference in New Issue
Block a user