[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? 1. This PR introduces native `all_to_all` communication operator to fix `allgather` bugs when dp_size > 1. Besides, it adds a naive implementation of force-load-balance when doing profile runs. 2. The operator `npu_dequant_swiglu_quant` only supports input hidden_states with dtype `torch.int32`. This tensor occupies space of `global_bs * seq_len * topk * hidden_size`, which might be very large as `ep_size` grows. Therefore we need to disable this operator and use original `swiglu` && `quantize`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By performing offline inference:  --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -205,50 +205,66 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
)
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
# TODO: need a better flag to indicate whether in profile run or not.
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
is_prefill = True
|
||||
enable_force_load_balance = True
|
||||
else:
|
||||
is_prefill = attn_metadata.num_prefills > 0
|
||||
enable_force_load_balance = False
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
dim=0)
|
||||
hidden_states = chunks[get_tp_group().rank_in_group]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
dist.all_gather_into_tensor(self.final_hidden_states,
|
||||
final_hidden_states, self.tp_group)
|
||||
final_hidden_states = self.final_hidden_states
|
||||
else:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if self.tp_size > 1:
|
||||
# pass
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
if num_tokens < self.tp_size:
|
||||
target_size = self.tp_size
|
||||
new_hidden_states = torch.empty([target_size, hidden_size],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
new_hidden_states[:num_tokens] = hidden_states
|
||||
hidden_states = new_hidden_states
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
local_hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
else:
|
||||
local_hidden_states = hidden_states
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(local_hidden_states)
|
||||
|
||||
router_hidden_states = self.experts(
|
||||
hidden_states=local_hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < self.tp_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
else:
|
||||
final_hidden_states = router_hidden_states
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
Reference in New Issue
Block a user