[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)
### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
And the rest of them are just bugfix.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test cases needed.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -163,16 +163,12 @@ class AscendSFAMetadataBuilder:
|
||||
) -> AscendSFAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
device = self.device
|
||||
num_input_tokens = common_attn_metadata.num_input_tokens
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=True)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
num_input_tokens].long(
|
||||
)
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
@@ -189,15 +185,23 @@ class AscendSFAMetadataBuilder:
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
cum_query_lens = query_start_loc_cpu[1:num_reqs + 1].to(
|
||||
torch.int32).to(device, non_blocking=True)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to(
|
||||
torch.int32).to(device, non_blocking=True)
|
||||
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
cos = common_attn_metadata.cos
|
||||
sin = common_attn_metadata.sin
|
||||
|
||||
assert self.cos_cache is not None and self.sin_cache is not None
|
||||
new_cos = self.cos_cache[input_positions][:, None, None]
|
||||
new_sin = self.sin_cache[input_positions][:, None, None]
|
||||
|
||||
if (cos is not None and sin is not None
|
||||
and num_input_tokens <= cos.shape[0]
|
||||
and num_input_tokens <= sin.shape[0]):
|
||||
cos[:num_input_tokens] = new_cos
|
||||
sin[:num_input_tokens] = new_sin
|
||||
else:
|
||||
cos, sin = new_cos, new_sin
|
||||
|
||||
sfa_cp_context = None
|
||||
if self.enable_sfa_cp:
|
||||
@@ -268,8 +272,8 @@ class AscendSFAMetadataBuilder:
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
block_tables=block_table,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
sin=sin[:num_input_tokens],
|
||||
cos=cos[:num_input_tokens],
|
||||
sfa_cp_context=sfa_cp_context)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -278,7 +282,10 @@ class AscendSFAMetadataBuilder:
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
}:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@@ -681,29 +688,29 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.q_proj.quant_bias = None
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def _sfa_preprocessc_decode(
|
||||
def _sfa_preprocess_decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool,
|
||||
num_actual_tokens: int,
|
||||
num_input_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
k_nope, k_pe = kv_cache[0], kv_cache[1]
|
||||
ql_nope = torch.empty(
|
||||
(num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
||||
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_pe = torch.empty(
|
||||
(num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
|
||||
(num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_c = torch.empty(
|
||||
(num_actual_tokens, self.q_lora_rank),
|
||||
(num_input_tokens, self.q_lora_rank),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
@@ -721,7 +728,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.W_UK_T,
|
||||
k_nope,
|
||||
k_pe,
|
||||
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
|
||||
attn_metadata.slot_mapping,
|
||||
quant_scale0=self.quant_scale0,
|
||||
quant_offset0=self.quant_offset0,
|
||||
bias0=self.quant_bias_qkv,
|
||||
@@ -761,25 +768,22 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||
actual_seq_lengths_key = attn_metadata.seq_lens
|
||||
hidden_states = hidden_states[:num_actual_tokens]
|
||||
if self.enable_sfa_cp:
|
||||
need_gather_q_kv = False
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens]
|
||||
|
||||
if self.enable_mlapo and not forward_context.with_prefill:
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode(
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_input_tokens=attn_metadata.num_input_tokens,
|
||||
)
|
||||
else:
|
||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||
@@ -802,7 +806,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping_cp = None
|
||||
if self.enable_sfa_cp:
|
||||
assert attn_metadata.sfa_cp_context is not None
|
||||
|
||||
Reference in New Issue
Block a user