[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)

### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.

The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.

And the rest of them are just bugfix.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test cases needed.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-12-10 20:11:09 +08:00
committed by GitHub
parent 0d8c0f1a24
commit 5b179c53f1
6 changed files with 120 additions and 78 deletions

View File

@@ -163,16 +163,12 @@ class AscendSFAMetadataBuilder:
) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
device = self.device
num_input_tokens = common_attn_metadata.num_input_tokens
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens].to(
device,
non_blocking=True)
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
num_input_tokens].long(
)
query_start_loc = common_attn_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
@@ -189,15 +185,23 @@ class AscendSFAMetadataBuilder:
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
cum_query_lens = query_start_loc_cpu[1:num_reqs + 1].to(
torch.int32).to(device, non_blocking=True)
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to(
torch.int32).to(device, non_blocking=True)
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]
new_sin = self.sin_cache[input_positions][:, None, None]
if (cos is not None and sin is not None
and num_input_tokens <= cos.shape[0]
and num_input_tokens <= sin.shape[0]):
cos[:num_input_tokens] = new_cos
sin[:num_input_tokens] = new_sin
else:
cos, sin = new_cos, new_sin
sfa_cp_context = None
if self.enable_sfa_cp:
@@ -268,8 +272,8 @@ class AscendSFAMetadataBuilder:
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
sin=sin,
cos=cos,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
sfa_cp_context=sfa_cp_context)
def build_for_graph_capture(
@@ -278,7 +282,10 @@ class AscendSFAMetadataBuilder:
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
if attn_state in {
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
@@ -681,29 +688,29 @@ class AscendSFAImpl(MLAAttentionImpl):
self.q_proj.quant_bias = None
torch.npu.empty_cache()
def _sfa_preprocessc_decode(
def _sfa_preprocess_decode(
self,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool,
num_actual_tokens: int,
num_input_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
k_nope, k_pe = kv_cache[0], kv_cache[1]
ql_nope = torch.empty(
(num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_pe = torch.empty(
(num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
(num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_c = torch.empty(
(num_actual_tokens, self.q_lora_rank),
(num_input_tokens, self.q_lora_rank),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
@@ -721,7 +728,7 @@ class AscendSFAImpl(MLAAttentionImpl):
self.W_UK_T,
k_nope,
k_pe,
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
attn_metadata.slot_mapping,
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
@@ -761,25 +768,22 @@ class AscendSFAImpl(MLAAttentionImpl):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
num_actual_tokens = attn_metadata.num_actual_tokens
cos = attn_metadata.cos
sin = attn_metadata.sin
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
hidden_states = hidden_states[:num_actual_tokens]
if self.enable_sfa_cp:
need_gather_q_kv = False
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_tokens]
if self.enable_mlapo and not forward_context.with_prefill:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode(
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
need_gather_q_kv=need_gather_q_kv,
num_actual_tokens=num_actual_tokens,
num_input_tokens=attn_metadata.num_input_tokens,
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
@@ -802,7 +806,7 @@ class AscendSFAImpl(MLAAttentionImpl):
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
slot_mapping = attn_metadata.slot_mapping
slot_mapping_cp = None
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None