[Bugfix] Resolve MTP > 1 issue when lm head tp > 1 (#4254)
### What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless
of num_speculative_tokens. This caused execute_model to hang on
compute_logits when lm head tensor parallelism exceeded 1. The fix
ensures compute_logits executes correctly during dummy run, matching
num_speculative_tokens.
I set the `non_blocking` argument to False when moving
`exceeds_max_model_len` to the CPU. From what I understand, using
`non_blocking=True` and immediately accessing the tensor on the CPU can
cause accuracy problems. However, this issue doesn't happen when
transferring data to a device. ref:
https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/18
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -123,7 +123,8 @@ class EagleProposer(Proposer):
|
|||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor=None):
|
batch_descriptor=None,
|
||||||
|
dummy_compute_logits=lambda hidden_states: None):
|
||||||
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
||||||
with set_ascend_forward_context(None,
|
with set_ascend_forward_context(None,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
@@ -134,6 +135,7 @@ class EagleProposer(Proposer):
|
|||||||
positions=self.positions[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
hidden_states=self.hidden_states[:num_tokens],
|
hidden_states=self.hidden_states[:num_tokens],
|
||||||
)
|
)
|
||||||
|
dummy_compute_logits(self.hidden_states)
|
||||||
|
|
||||||
def generate_token_ids(self,
|
def generate_token_ids(self,
|
||||||
valid_sampled_token_ids: list[np.ndarray],
|
valid_sampled_token_ids: list[np.ndarray],
|
||||||
|
|||||||
@@ -213,7 +213,8 @@ class MtpProposer(Proposer):
|
|||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp=None,
|
num_tokens_across_dp=None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor=None) -> None:
|
batch_descriptor=None,
|
||||||
|
dummy_compute_logits=lambda hidden_states: None) -> None:
|
||||||
|
|
||||||
(
|
(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -296,6 +297,7 @@ class MtpProposer(Proposer):
|
|||||||
self.update_stream, forward_context,
|
self.update_stream, forward_context,
|
||||||
positions.shape[0],
|
positions.shape[0],
|
||||||
self.vllm_config.speculative_config)
|
self.vllm_config.speculative_config)
|
||||||
|
dummy_compute_logits(previous_hidden_states)
|
||||||
if with_prefill:
|
if with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -756,6 +758,7 @@ class MtpProposer(Proposer):
|
|||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||||
logits = logits[:num_indices]
|
logits = logits[:num_indices]
|
||||||
|
last_token_indices = last_token_indices[:num_indices]
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
|
|
||||||
if self.num_speculative_tokens == 1:
|
if self.num_speculative_tokens == 1:
|
||||||
@@ -821,7 +824,7 @@ class MtpProposer(Proposer):
|
|||||||
# For the requests that exceed the max model length, we set the
|
# For the requests that exceed the max model length, we set the
|
||||||
# sequence length to 1 to minimize their overheads in attention.
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||||
attn_metadata_i.seq_lens.device, non_blocking=True)
|
attn_metadata_i.seq_lens.device, non_blocking=False)
|
||||||
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||||
exceeds_max_model_len_cpu, 1)
|
exceeds_max_model_len_cpu, 1)
|
||||||
# Mask out the slot mappings that exceed the max model length.
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
|||||||
num_reqs=None,
|
num_reqs=None,
|
||||||
num_tokens_across_dp=None,
|
num_tokens_across_dp=None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor=None):
|
batch_descriptor=None,
|
||||||
|
dummy_compute_logits=lambda hidden_states: None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def generate_token_ids(self,
|
def generate_token_ids(self,
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ class TorchairMtpProposer(MtpProposer):
|
|||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp=None,
|
num_tokens_across_dp=None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor=None) -> None:
|
batch_descriptor=None,
|
||||||
|
dummy_compute_logits=lambda hidden_states: None) -> None:
|
||||||
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
||||||
|
|
||||||
if not with_prefill:
|
if not with_prefill:
|
||||||
@@ -143,6 +144,7 @@ class TorchairMtpProposer(MtpProposer):
|
|||||||
self.model(input_ids=input_ids,
|
self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=previous_hidden_states)
|
hidden_states=previous_hidden_states)
|
||||||
|
dummy_compute_logits(previous_hidden_states)
|
||||||
if with_prefill:
|
if with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -3003,14 +3003,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
need_dummy_logits = (not self.in_profile_run
|
need_dummy_logits = (not self.in_profile_run
|
||||||
and lmhead_tp_enable())
|
and lmhead_tp_enable())
|
||||||
|
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
||||||
|
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
if need_dummy_logits:
|
def dummy_compute_logits(hidden_states):
|
||||||
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
if not need_dummy_logits:
|
||||||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
return None
|
||||||
dtype=torch.int32)
|
return self.model.compute_logits(hidden_states[dummy_indices])
|
||||||
|
|
||||||
def dummy_compute_logits(hidden_states):
|
def dummy_drafter_compute_logits(hidden_states):
|
||||||
return self.model.compute_logits(
|
if not need_dummy_logits or self.drafter is None:
|
||||||
|
return
|
||||||
|
if hasattr(self.drafter, "model") and hasattr(
|
||||||
|
self.drafter.model, "compute_logits"):
|
||||||
|
return self.drafter.model.compute_logits(
|
||||||
hidden_states[dummy_indices])
|
hidden_states[dummy_indices])
|
||||||
|
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
@@ -3032,8 +3039,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with_prefill, is_torchair_compile, input_ids, positions,
|
with_prefill, is_torchair_compile, input_ids, positions,
|
||||||
attn_metadata, num_tokens, intermediate_tensors,
|
attn_metadata, num_tokens, intermediate_tensors,
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
if need_dummy_logits:
|
dummy_compute_logits(hidden_states)
|
||||||
dummy_compute_logits(hidden_states)
|
|
||||||
|
|
||||||
if self.drafter:
|
if self.drafter:
|
||||||
self.drafter.dummy_run(
|
self.drafter.dummy_run(
|
||||||
@@ -3042,10 +3048,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor)
|
batch_descriptor=batch_descriptor,
|
||||||
if need_dummy_logits:
|
dummy_compute_logits=dummy_drafter_compute_logits)
|
||||||
self.drafter.model.compute_logits(
|
|
||||||
hidden_states[dummy_indices])
|
|
||||||
if self.in_profile_run and self.dynamic_eplb:
|
if self.in_profile_run and self.dynamic_eplb:
|
||||||
self.model.clear_all_moe_loads()
|
self.model.clear_all_moe_loads()
|
||||||
if not self.in_profile_run and self.dynamic_eplb:
|
if not self.in_profile_run and self.dynamic_eplb:
|
||||||
|
|||||||
Reference in New Issue
Block a user