[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)
### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
And the rest of them are just bugfix.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test cases needed.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -124,6 +124,9 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
common_attn_metadata.attn_mask = None
|
common_attn_metadata.attn_mask = None
|
||||||
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
||||||
|
common_attn_metadata.cos = None
|
||||||
|
common_attn_metadata.sin = None
|
||||||
|
common_attn_metadata.num_input_tokens = 100
|
||||||
|
|
||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
model.model.layers = [MagicMock() for _ in range(10)]
|
model.model.layers = [MagicMock() for _ in range(10)]
|
||||||
@@ -166,6 +169,9 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
common_attn_metadata.attn_mask = None
|
common_attn_metadata.attn_mask = None
|
||||||
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
||||||
|
common_attn_metadata.cos = None
|
||||||
|
common_attn_metadata.sin = None
|
||||||
|
common_attn_metadata.num_input_tokens = 100
|
||||||
|
|
||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
model.model.layers = [MagicMock() for _ in range(10)]
|
model.model.layers = [MagicMock() for _ in range(10)]
|
||||||
|
|||||||
@@ -297,30 +297,11 @@ class AscendAttentionMetadataBuilder:
|
|||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
|
||||||
attn_mask = common_attn_metadata.attn_mask
|
attn_mask = common_attn_metadata.attn_mask
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
|
||||||
num_reqs
|
|
||||||
+ 1]
|
|
||||||
if common_attn_metadata.num_input_tokens > num_actual_tokens:
|
|
||||||
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
|
||||||
seq_lens = torch.cat([
|
|
||||||
seq_lens,
|
|
||||||
torch.tensor([padded_num_tokens
|
|
||||||
]).to(seq_lens.device).to(seq_lens.dtype)
|
|
||||||
])
|
|
||||||
block_table_padding = torch.zeros(
|
|
||||||
(padded_num_tokens, ) + block_table.shape[1:],
|
|
||||||
dtype=block_table.dtype,
|
|
||||||
device=block_table.device)
|
|
||||||
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
|
||||||
query_start_loc_cpu = torch.cat([
|
|
||||||
query_start_loc_cpu,
|
|
||||||
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
|
|
||||||
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
|
|
||||||
])
|
|
||||||
|
|
||||||
|
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
is_causal_pooling = None
|
is_causal_pooling = None
|
||||||
|
|||||||
@@ -163,16 +163,12 @@ class AscendSFAMetadataBuilder:
|
|||||||
) -> AscendSFAMetadata:
|
) -> AscendSFAMetadata:
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
num_input_tokens = common_attn_metadata.num_input_tokens
|
||||||
device = self.device
|
|
||||||
|
|
||||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
|
||||||
num_actual_tokens].to(
|
|
||||||
device,
|
|
||||||
non_blocking=True)
|
|
||||||
input_positions = common_attn_metadata.positions[:
|
input_positions = common_attn_metadata.positions[:
|
||||||
num_actual_tokens].long(
|
num_input_tokens].long(
|
||||||
)
|
)
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||||
@@ -189,15 +185,23 @@ class AscendSFAMetadataBuilder:
|
|||||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||||
self.model_config.dtype) # type: ignore
|
self.model_config.dtype) # type: ignore
|
||||||
|
|
||||||
cum_query_lens = query_start_loc_cpu[1:num_reqs + 1].to(
|
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||||
torch.int32).to(device, non_blocking=True)
|
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to(
|
|
||||||
torch.int32).to(device, non_blocking=True)
|
|
||||||
|
|
||||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
cos = common_attn_metadata.cos
|
||||||
1).unsqueeze(2)
|
sin = common_attn_metadata.sin
|
||||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
|
||||||
1).unsqueeze(2)
|
assert self.cos_cache is not None and self.sin_cache is not None
|
||||||
|
new_cos = self.cos_cache[input_positions][:, None, None]
|
||||||
|
new_sin = self.sin_cache[input_positions][:, None, None]
|
||||||
|
|
||||||
|
if (cos is not None and sin is not None
|
||||||
|
and num_input_tokens <= cos.shape[0]
|
||||||
|
and num_input_tokens <= sin.shape[0]):
|
||||||
|
cos[:num_input_tokens] = new_cos
|
||||||
|
sin[:num_input_tokens] = new_sin
|
||||||
|
else:
|
||||||
|
cos, sin = new_cos, new_sin
|
||||||
|
|
||||||
sfa_cp_context = None
|
sfa_cp_context = None
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
@@ -268,8 +272,8 @@ class AscendSFAMetadataBuilder:
|
|||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
attn_state=common_attn_metadata.attn_state,
|
attn_state=common_attn_metadata.attn_state,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
sin=sin,
|
sin=sin[:num_input_tokens],
|
||||||
cos=cos,
|
cos=cos[:num_input_tokens],
|
||||||
sfa_cp_context=sfa_cp_context)
|
sfa_cp_context=sfa_cp_context)
|
||||||
|
|
||||||
def build_for_graph_capture(
|
def build_for_graph_capture(
|
||||||
@@ -278,7 +282,10 @@ class AscendSFAMetadataBuilder:
|
|||||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||||
model: Optional[nn.Module] = None,
|
model: Optional[nn.Module] = None,
|
||||||
):
|
):
|
||||||
if attn_state == AscendAttentionState.DecodeOnly:
|
if attn_state in {
|
||||||
|
AscendAttentionState.DecodeOnly,
|
||||||
|
AscendAttentionState.SpecDecoding
|
||||||
|
}:
|
||||||
attn_metadata = self.build(
|
attn_metadata = self.build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
@@ -681,29 +688,29 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.q_proj.quant_bias = None
|
self.q_proj.quant_bias = None
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
def _sfa_preprocessc_decode(
|
def _sfa_preprocess_decode(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
need_gather_q_kv: bool,
|
need_gather_q_kv: bool,
|
||||||
num_actual_tokens: int,
|
num_input_tokens: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
hidden_states.contiguous(), need_gather_q_kv)
|
hidden_states.contiguous(), need_gather_q_kv)
|
||||||
k_nope, k_pe = kv_cache[0], kv_cache[1]
|
k_nope, k_pe = kv_cache[0], kv_cache[1]
|
||||||
ql_nope = torch.empty(
|
ql_nope = torch.empty(
|
||||||
(num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
q_pe = torch.empty(
|
q_pe = torch.empty(
|
||||||
(num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
|
(num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
q_c = torch.empty(
|
q_c = torch.empty(
|
||||||
(num_actual_tokens, self.q_lora_rank),
|
(num_input_tokens, self.q_lora_rank),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
@@ -721,7 +728,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.W_UK_T,
|
self.W_UK_T,
|
||||||
k_nope,
|
k_nope,
|
||||||
k_pe,
|
k_pe,
|
||||||
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
|
attn_metadata.slot_mapping,
|
||||||
quant_scale0=self.quant_scale0,
|
quant_scale0=self.quant_scale0,
|
||||||
quant_offset0=self.quant_offset0,
|
quant_offset0=self.quant_offset0,
|
||||||
bias0=self.quant_bias_qkv,
|
bias0=self.quant_bias_qkv,
|
||||||
@@ -761,25 +768,22 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
reach_layer_for_shared_weight_series(self.o_proj)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
has_prefill = attn_metadata.has_prefill
|
has_prefill = attn_metadata.has_prefill
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
||||||
cos = attn_metadata.cos
|
cos = attn_metadata.cos
|
||||||
sin = attn_metadata.sin
|
sin = attn_metadata.sin
|
||||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||||
actual_seq_lengths_key = attn_metadata.seq_lens
|
actual_seq_lengths_key = attn_metadata.seq_lens
|
||||||
hidden_states = hidden_states[:num_actual_tokens]
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
need_gather_q_kv = False
|
need_gather_q_kv = False
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
output = output[:num_actual_tokens]
|
|
||||||
|
|
||||||
if self.enable_mlapo and not forward_context.with_prefill:
|
if self.enable_mlapo and not forward_context.with_prefill:
|
||||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode(
|
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
need_gather_q_kv=need_gather_q_kv,
|
need_gather_q_kv=need_gather_q_kv,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_input_tokens=attn_metadata.num_input_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||||
@@ -802,7 +806,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
slot_mapping_cp = None
|
slot_mapping_cp = None
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
assert attn_metadata.sfa_cp_context is not None
|
assert attn_metadata.sfa_cp_context is not None
|
||||||
|
|||||||
@@ -273,6 +273,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
|||||||
key].decode.actual_seq_lengths_q
|
key].decode.actual_seq_lengths_q
|
||||||
block_table = forward_context.attn_metadata[
|
block_table = forward_context.attn_metadata[
|
||||||
key].decode.block_table
|
key].decode.block_table
|
||||||
|
# TODO: This is a hack and should be fixed in the future.
|
||||||
|
if speculative_config.disable_padded_drafter_batch:
|
||||||
|
block_table = block_table[:len(actual_seq_lengths)]
|
||||||
seq_lens_list = seq_lens_list + [0] * (
|
seq_lens_list = seq_lens_list + [0] * (
|
||||||
len(actual_seq_lengths) - len(seq_lens_list))
|
len(actual_seq_lengths) - len(seq_lens_list))
|
||||||
else:
|
else:
|
||||||
@@ -427,7 +430,7 @@ class GraphParams:
|
|||||||
_graph_params: Optional[GraphParams] = None
|
_graph_params: Optional[GraphParams] = None
|
||||||
|
|
||||||
|
|
||||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||||
global _graph_params
|
global _graph_params
|
||||||
if _graph_params is not None:
|
if _graph_params is not None:
|
||||||
raise ValueError("Graph parameters have already been set!")
|
raise ValueError("Graph parameters have already been set!")
|
||||||
@@ -456,7 +459,7 @@ def get_graph_params():
|
|||||||
_mtp_graph_params: Optional[GraphParams] = None
|
_mtp_graph_params: Optional[GraphParams] = None
|
||||||
|
|
||||||
|
|
||||||
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
|
def set_mtp_graph_params(aclgraph_capture_sizes: list[int]):
|
||||||
global _mtp_graph_params
|
global _mtp_graph_params
|
||||||
if _mtp_graph_params is not None:
|
if _mtp_graph_params is not None:
|
||||||
raise ValueError("MTPGraph parameters have already been set!")
|
raise ValueError("MTPGraph parameters have already been set!")
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
set_mtp_graph_params,
|
|
||||||
update_mla_attn_params)
|
update_mla_attn_params)
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||||
@@ -214,8 +213,6 @@ class MtpProposer(Proposer):
|
|||||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||||
):
|
):
|
||||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||||
set_mtp_graph_params(
|
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes)
|
|
||||||
self.model = ACLGraphWrapper(self.model,
|
self.model = ACLGraphWrapper(self.model,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
runtime_mode=CUDAGraphMode.FULL)
|
runtime_mode=CUDAGraphMode.FULL)
|
||||||
@@ -254,9 +251,10 @@ class MtpProposer(Proposer):
|
|||||||
query_start_loc_cpu=self.runner.
|
query_start_loc_cpu=self.runner.
|
||||||
query_start_loc_cpu[:num_reqs + 1],
|
query_start_loc_cpu[:num_reqs + 1],
|
||||||
seq_lens_cpu=self.runner.seq_lens_cpu,
|
seq_lens_cpu=self.runner.seq_lens_cpu,
|
||||||
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
|
seq_lens=self.runner.seq_lens[:num_reqs],
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
|
num_input_tokens=num_tokens,
|
||||||
max_query_len=self.num_speculative_tokens + 1,
|
max_query_len=self.num_speculative_tokens + 1,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
@@ -289,7 +287,7 @@ class MtpProposer(Proposer):
|
|||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||||
for i in range(self.num_speculative_tokens):
|
for i in range(self.num_speculative_tokens):
|
||||||
if i > 0:
|
if i > 0 and not skip_attn and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
@@ -316,7 +314,7 @@ class MtpProposer(Proposer):
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||||
not forward_context.capturing:
|
not forward_context.capturing:
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||||
update_mla_attn_params(
|
update_mla_attn_params(
|
||||||
self.update_stream, forward_context, num_tokens,
|
self.update_stream, forward_context, num_tokens,
|
||||||
self.vllm_config.speculative_config)
|
self.vllm_config.speculative_config)
|
||||||
@@ -514,6 +512,7 @@ class MtpProposer(Proposer):
|
|||||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||||
|
|
||||||
|
num_actual_reqs = len(num_draft_tokens)
|
||||||
num_rejected_tokens = [
|
num_rejected_tokens = [
|
||||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||||
for i, n in enumerate(num_draft_tokens)
|
for i, n in enumerate(num_draft_tokens)
|
||||||
@@ -522,8 +521,11 @@ class MtpProposer(Proposer):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
device = common_attn_metadata.query_start_loc.device
|
device = common_attn_metadata.query_start_loc.device
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
|
num_actual_reqs
|
||||||
|
+ 1]
|
||||||
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
|
||||||
|
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
|
||||||
|
|
||||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||||
new_query_len_per_req = query_start_loc_cpu[
|
new_query_len_per_req = query_start_loc_cpu[
|
||||||
@@ -587,6 +589,7 @@ class MtpProposer(Proposer):
|
|||||||
num_computed_tokens_cpu,
|
num_computed_tokens_cpu,
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
num_actual_tokens=total_num_tokens,
|
num_actual_tokens=total_num_tokens,
|
||||||
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
@@ -704,8 +707,8 @@ class MtpProposer(Proposer):
|
|||||||
|
|
||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
if self.runner.use_aclgraph and num_scheduled_tokens <= self.cudagraph_batch_sizes[
|
||||||
) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
|
-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
@@ -797,7 +800,7 @@ class MtpProposer(Proposer):
|
|||||||
hidden_states=hidden_states)
|
hidden_states=hidden_states)
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||||
update_mla_attn_params(
|
update_mla_attn_params(
|
||||||
self.update_stream, forward_context,
|
self.update_stream, forward_context,
|
||||||
num_input_tokens,
|
num_input_tokens,
|
||||||
@@ -1109,9 +1112,10 @@ class MtpProposer(Proposer):
|
|||||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=common_attn_metadata.query_start_loc,
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens_cpu=common_attn_metadata.seq_lens,
|
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
num_actual_tokens=total_num_tokens,
|
num_actual_tokens=total_num_tokens,
|
||||||
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
set_graph_params,
|
set_graph_params,
|
||||||
|
set_mtp_graph_params,
|
||||||
update_attn_dcp_pcp_params,
|
update_attn_dcp_pcp_params,
|
||||||
update_attn_params,
|
update_attn_params,
|
||||||
update_mla_attn_dcp_pcp_params,
|
update_mla_attn_dcp_pcp_params,
|
||||||
@@ -406,8 +407,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla and \
|
# NOTE: This will have some extra memory allocated, is it OK?
|
||||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
if self.vllm_config.model_config.use_mla:
|
||||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.cos = torch.ones(self.max_num_reqs *
|
self.cos = torch.ones(self.max_num_reqs *
|
||||||
self.decode_token_per_req,
|
self.decode_token_per_req,
|
||||||
@@ -1843,6 +1844,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
|
# NOTE: This is strange, why did we use total_num_scheduled_tokens before?
|
||||||
slot_mapping_size = (total_num_scheduled_tokens
|
slot_mapping_size = (total_num_scheduled_tokens
|
||||||
if self.pcp_size == 1 else
|
if self.pcp_size == 1 else
|
||||||
total_num_scheduled_tokens * self.pcp_size -
|
total_num_scheduled_tokens * self.pcp_size -
|
||||||
@@ -1864,7 +1866,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||||
blk_table_tensor = blk_table.get_device_tensor()
|
blk_table_tensor = blk_table.get_device_tensor()
|
||||||
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
|
|
||||||
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
slot_mapping_for_pcp = blk_table.slot_mapping[:
|
slot_mapping_for_pcp = blk_table.slot_mapping[:
|
||||||
@@ -1884,14 +1885,48 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
slot_mapping_size]
|
slot_mapping_size]
|
||||||
slot_mapping_for_pcp[:long_seq_metadata.
|
slot_mapping_for_pcp[:long_seq_metadata.
|
||||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||||
slot_mapping = slot_mapping_for_pcp
|
blk_table.slot_mapping[:long_seq_metadata.num_actual_tokens_pcp_padded] = \
|
||||||
|
slot_mapping_for_pcp
|
||||||
|
slot_mapping = blk_table.slot_mapping
|
||||||
|
|
||||||
|
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
|
||||||
|
# has been split to multiple parts, and there are 3 parts that is related to this
|
||||||
|
# `num_reqs`, we'll take `query_start_loc` as an example:
|
||||||
|
# 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
||||||
|
# 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the
|
||||||
|
# following simplified `dispatch` logic here, we try to minimize the impact
|
||||||
|
# 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
|
||||||
|
uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \
|
||||||
|
and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs)
|
||||||
|
|
||||||
|
# TODO: We should make this official ASAP. Also note that if we pad here,
|
||||||
|
# the builders won’t need to add any extra padding.
|
||||||
|
if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||||
|
uniform_decode:
|
||||||
|
num_reqs_padded = num_input_tokens // self.uniform_decode_query_len
|
||||||
|
pad_size = num_reqs_padded - num_reqs
|
||||||
|
if pad_size > 0:
|
||||||
|
last_query_loc = self.query_start_loc[num_reqs]
|
||||||
|
|
||||||
|
steps = torch.arange(1,
|
||||||
|
pad_size + 1,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.query_start_loc.dtype)
|
||||||
|
fill_values = last_query_loc + (
|
||||||
|
steps * self.uniform_decode_query_len)
|
||||||
|
|
||||||
|
self.query_start_loc[num_reqs + 1:num_reqs_padded +
|
||||||
|
1] = fill_values
|
||||||
|
# So we are trying to simulate the behavior of GPUModelRunner's
|
||||||
|
# prepare_inputs for uniform decode mode by padding query_start_loc
|
||||||
|
num_reqs = num_reqs_padded
|
||||||
|
|
||||||
# Make AscendCommonAttentionMetadata
|
# Make AscendCommonAttentionMetadata
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
seq_lens=self.seq_lens[:num_reqs],
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=slot_mapping_size,
|
num_actual_tokens=slot_mapping_size,
|
||||||
num_input_tokens=num_input_tokens,
|
num_input_tokens=num_input_tokens,
|
||||||
@@ -2876,6 +2911,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
seq_lens = max_query_len
|
seq_lens = max_query_len
|
||||||
self.seq_lens_np[:num_reqs] = seq_lens
|
self.seq_lens_np[:num_reqs] = seq_lens
|
||||||
self.seq_lens_np[num_reqs:] = 0
|
self.seq_lens_np[num_reqs:] = 0
|
||||||
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
@@ -2906,21 +2943,22 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
[0] * dcp_world_size for _ in range(pcp_world_size)
|
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||||||
] for _ in range(num_tokens)]
|
] for _ in range(num_tokens)]
|
||||||
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||||
|
# QUESTION: Why do we separately set query_start_loc for spec in the first place?
|
||||||
|
# While in _prepare_inputs we don't?
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
query_start_loc = torch.tensor(
|
self.query_start_loc[:num_reqs + 1] = torch.tensor(
|
||||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
else:
|
|
||||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||||
1],
|
1],
|
||||||
seq_lens_cpu=self.seq_lens_cpu,
|
seq_lens_cpu=self.seq_lens_cpu,
|
||||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
seq_lens=self.seq_lens[:num_reqs],
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
|
num_input_tokens=num_tokens,
|
||||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||||
block_table_tensor=block_table_tensor[:num_reqs],
|
block_table_tensor=block_table_tensor[:num_reqs],
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
@@ -3210,7 +3248,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
dummy_compute_logits=dummy_drafter_compute_logits)
|
dummy_compute_logits=dummy_drafter_compute_logits,
|
||||||
|
skip_attn=not force_attention)
|
||||||
if self.in_profile_run and self.dynamic_eplb:
|
if self.in_profile_run and self.dynamic_eplb:
|
||||||
self.model.clear_all_moe_loads()
|
self.model.clear_all_moe_loads()
|
||||||
if not self.in_profile_run and self.dynamic_eplb:
|
if not self.in_profile_run and self.dynamic_eplb:
|
||||||
@@ -3373,7 +3412,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
# wrap the model with full graph wrapper if needed.
|
# wrap the model with full graph wrapper if needed.
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||||
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
|
||||||
self.model = ACLGraphWrapper(self.model,
|
self.model = ACLGraphWrapper(self.model,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
runtime_mode=CUDAGraphMode.FULL)
|
runtime_mode=CUDAGraphMode.FULL)
|
||||||
@@ -4092,6 +4130,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
self.aclgraph_batch_sizes = (capture_sizes
|
self.aclgraph_batch_sizes = (capture_sizes
|
||||||
if capture_sizes is not None else [])
|
if capture_sizes is not None else [])
|
||||||
|
|
||||||
|
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
|
||||||
|
# we set the graph params right before initializing the keys.
|
||||||
|
set_graph_params(self.aclgraph_batch_sizes)
|
||||||
|
if self.speculative_config:
|
||||||
|
set_mtp_graph_params(self.aclgraph_batch_sizes)
|
||||||
|
|
||||||
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
||||||
self.compilation_config.cudagraph_mode,
|
self.compilation_config.cudagraph_mode,
|
||||||
self.uniform_decode_query_len)
|
self.uniform_decode_query_len)
|
||||||
|
|||||||
Reference in New Issue
Block a user