[Feat]: Add custom lmhead tensor model parallel (#2309)
### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.
performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |
example
`--additional_config={"lmhead_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
This commit is contained in:
@@ -90,7 +90,7 @@ from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
||||
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration, is_310p,
|
||||
vllm_version_is)
|
||||
lmhead_tp_enable, vllm_version_is)
|
||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
@@ -1277,6 +1277,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
if lmhead_tp_enable():
|
||||
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
|
||||
logits_indices = nn.functional.pad(
|
||||
logits_indices,
|
||||
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
|
||||
|
||||
return (attn_metadata, positions, num_scheduled_tokens,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
||||
@@ -1734,11 +1740,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:self.input_batch.num_reqs]
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
@@ -2081,6 +2091,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
f"Aclgraph runtime mode mismatch at dummy_run. "
|
||||
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
|
||||
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
and lmhead_tp_enable())
|
||||
|
||||
if need_dummy_logits:
|
||||
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
||||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||||
dtype=torch.int32)
|
||||
|
||||
def dummy_compute_logits(hidden_states):
|
||||
return self.model.compute_logits(
|
||||
hidden_states[dummy_indices], None)
|
||||
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -2097,6 +2119,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
if need_dummy_logits:
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
self.drafter.dummy_run(
|
||||
@@ -2105,7 +2130,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
skip_attn=True,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
|
||||
if need_dummy_logits:
|
||||
dummy_compute_logits(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -19,7 +19,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import ProfileExecuteDuration
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
|
||||
class MtpProposer:
|
||||
@@ -235,8 +235,20 @@ class MtpProposer:
|
||||
previous_hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
kv_caches=self.runner.kv_caches[-1:])
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
max_num_reqs_across_dp = num_input_tokens
|
||||
else:
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# [batch_size, 1]
|
||||
|
||||
Reference in New Issue
Block a user