[Feat]Make full graph mode compalible with MTP (#3276)
### What this PR does / why we need it? Make the Full Graph mode can run with MTP. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -345,6 +345,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculative_config.method, self.vllm_config,
|
||||
self.device, self)
|
||||
self.rejection_sampler = AscendRejectionSampler()
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
|
||||
# Persistent batch.
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
@@ -366,13 +369,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.model_config.use_mla and \
|
||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos = torch.ones(self.max_num_reqs,
|
||||
self.cos = torch.ones(self.max_num_reqs *
|
||||
self.decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
self.sin = torch.zeros(self.max_num_reqs,
|
||||
self.sin = torch.zeros(self.max_num_reqs *
|
||||
self.decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
@@ -1554,7 +1559,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens)
|
||||
maybe_padded_num_tokens,
|
||||
self.speculative_config)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens)
|
||||
@@ -2255,7 +2261,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_table_tensor = self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc=torch.tensor(
|
||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||
device=self.device,
|
||||
dtype=torch.int32),
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
@@ -2275,12 +2284,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
)
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and \
|
||||
self.speculative_config.method == "deepseek_mtp":
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly,
|
||||
self.get_model())
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
@@ -2301,7 +2313,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
positions.shape[0],
|
||||
self.speculative_config)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
@@ -3388,23 +3401,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
CUDAGraphMode.FULL_DECODE_ONLY)
|
||||
logger.warning(msg)
|
||||
|
||||
# check that if spec-decode + decode full-graphs is supported
|
||||
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.uniform_decode_query_len > 1 and min_ag_support.value
|
||||
< AttentionCGSupport.UNIFORM_BATCH.value):
|
||||
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
|
||||
f" with spec-decode for attention backend "
|
||||
f"{min_ag_builder_name} (support: {min_ag_support})")
|
||||
if splitting_ops_contain_attention:
|
||||
msg += "; setting cudagraph_mode=PIECEWISE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=NONE"
|
||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.NONE
|
||||
logger.warning(msg)
|
||||
|
||||
# double check that we can support full graph if they are requested
|
||||
# even after automatic downgrades
|
||||
if aclgraph_mode.has_full_cudagraphs() \
|
||||
|
||||
Reference in New Issue
Block a user