feat: support data parallel for deepseek (#1012)
### What this PR does / why we need it?
feat: support data parallel for deepseek
### Does this PR introduce _any_ user-facing change?
Yes, support dp for deepseek
### How was this patch tested?
```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
nohup python -m vllm.entrypoints.openai.api_server
--model=/path/to/DeepSeek-R1-W8A8 \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=8 \
-dp=2 \
--max-num-seqs 24 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--block-size 128 \
-O 0 \
--no-enable-prefix-caching \
--additional-config
'{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}'
\
--gpu-memory-utilization 0.95 &> run.log &
disown
```
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -117,6 +117,8 @@ class AscendMLAMetadata:
|
|||||||
# For logging.
|
# For logging.
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
|
with_prefill_across_dp: bool = False
|
||||||
|
|
||||||
# The dimension of the attention heads
|
# The dimension of the attention heads
|
||||||
head_dim: Optional[int] = None
|
head_dim: Optional[int] = None
|
||||||
attn_mask: torch.Tensor = None
|
attn_mask: torch.Tensor = None
|
||||||
@@ -260,6 +262,10 @@ class AscendMLAMetadataBuilder:
|
|||||||
PAD_SLOT_ID,
|
PAD_SLOT_ID,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
|
query_start_loc = torch.full((num_reqs, ),
|
||||||
|
-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
decode_metadata = AscendMLADecodeMetadata(
|
decode_metadata = AscendMLADecodeMetadata(
|
||||||
input_positions=input_positions,
|
input_positions=input_positions,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
@@ -278,15 +284,21 @@ class AscendMLAMetadataBuilder:
|
|||||||
attn_state=AscendAttentionState.DecodeOnly,
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
prefill=None,
|
prefill=None,
|
||||||
decode=decode_metadata,
|
decode=decode_metadata,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
block_tables=block_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self,
|
def build(
|
||||||
num_reqs: int,
|
self,
|
||||||
num_actual_tokens: int,
|
num_reqs: int,
|
||||||
max_query_len: int,
|
num_actual_tokens: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
max_query_len: int,
|
||||||
common_prefix_len: Optional[int] = None,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
graph_pad_size: int = -1) -> AscendMLAMetadata:
|
common_prefix_len: Optional[int] = None,
|
||||||
|
graph_pad_size: int = -1,
|
||||||
|
with_prefill_across_dp: bool = False,
|
||||||
|
) -> AscendMLAMetadata:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||||
@@ -388,6 +400,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
with_prefill_across_dp=with_prefill_across_dp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -621,7 +634,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
kv,
|
kv,
|
||||||
self.kv_a_layernorm.weight,
|
self.kv_a_layernorm.weight,
|
||||||
cos,
|
cos,
|
||||||
@@ -643,7 +656,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
B, N, D = x.shape
|
B, N, D = x.shape
|
||||||
S = 1
|
S = 1
|
||||||
x = x.view(B, N, S, D)
|
x = x.view(B, N, S, D)
|
||||||
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
||||||
return x.view(B, N, D)
|
return x.view(B, N, D)
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
@@ -766,6 +779,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
sin = sin[attn_metadata.decode.input_positions]
|
sin = sin[attn_metadata.decode.input_positions]
|
||||||
cos = cos[:, None, None, :]
|
cos = cos[:, None, None, :]
|
||||||
sin = sin[:, None, None, :]
|
sin = sin[:, None, None, :]
|
||||||
|
|
||||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||||
decode_k_pe, decode_k_nope = self.exec_kv(
|
decode_k_pe, decode_k_nope = self.exec_kv(
|
||||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||||
|
|||||||
@@ -212,6 +212,14 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
self.tp_group = get_tp_group().device_group
|
self.tp_group = get_tp_group().device_group
|
||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
|
|
||||||
|
self.params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.enable_graph_mode = False
|
||||||
|
additional_config = get_current_vllm_config().additional_config
|
||||||
|
if additional_config:
|
||||||
|
self.enable_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode", False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -228,33 +236,35 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
else:
|
else:
|
||||||
is_prefill = attn_metadata.num_prefills > 0
|
is_prefill = attn_metadata.num_prefills > 0
|
||||||
enable_force_load_balance = False
|
enable_force_load_balance = False
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||||
|
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||||
|
|
||||||
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
|
|
||||||
if self.n_shared_experts is not None:
|
if self.n_shared_experts is not None:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# pass
|
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
|
||||||
if num_tokens < self.tp_size:
|
hidden_states = chunks[self.tp_rank]
|
||||||
target_size = self.tp_size
|
elif not self.enable_graph_mode:
|
||||||
new_hidden_states = torch.empty([target_size, hidden_size],
|
num_padding_tokens = (self.tp_size -
|
||||||
dtype=hidden_states.dtype,
|
num_tokens % self.tp_size) % self.tp_size
|
||||||
device=hidden_states.device)
|
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
|
||||||
new_hidden_states[:num_tokens] = hidden_states
|
if num_padding_tokens > 0:
|
||||||
hidden_states = new_hidden_states
|
hidden_states = nn.functional.pad(
|
||||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
hidden_states, (0, 0, 0, num_padding_tokens))
|
||||||
self.tp_size,
|
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||||
dim=0)
|
self.tp_size,
|
||||||
local_hidden_states = chunk_hidden_states[self.tp_rank]
|
dim=0)
|
||||||
else:
|
hidden_states = chunk_hidden_states[self.tp_rank]
|
||||||
local_hidden_states = hidden_states
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(local_hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
router_hidden_states = self.experts(
|
hidden_states = self.experts(
|
||||||
hidden_states=local_hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
is_prefill=is_prefill,
|
is_prefill=is_prefill,
|
||||||
top_k=CustomDeepseekV2MoE.top_k,
|
top_k=CustomDeepseekV2MoE.top_k,
|
||||||
@@ -262,18 +272,29 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
) * self.routed_scaling_factor
|
) * self.routed_scaling_factor
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
|
if self.enable_graph_mode:
|
||||||
self.tp_group)
|
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
final_hidden_states = torch.zeros(
|
||||||
if num_tokens < self.tp_size:
|
[num_tokens, hidden_size],
|
||||||
final_hidden_states = final_hidden_states[:num_tokens]
|
dtype=self.params_dtype,
|
||||||
else:
|
device="npu")
|
||||||
final_hidden_states = router_hidden_states
|
dist.all_gather_into_tensor(final_hidden_states,
|
||||||
|
hidden_states, self.tp_group)
|
||||||
|
hidden_states = final_hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
hidden_states)
|
||||||
|
else:
|
||||||
|
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
||||||
|
self.tp_group)
|
||||||
|
hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||||
|
if num_padding_tokens > 0:
|
||||||
|
hidden_states = hidden_states[:-num_padding_tokens]
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
hidden_states = hidden_states + shared_output
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return hidden_states.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||||
|
|||||||
@@ -587,6 +587,12 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||||
|
|
||||||
|
self.enable_graph_mode = False
|
||||||
|
additional_config = get_current_vllm_config().additional_config
|
||||||
|
if additional_config:
|
||||||
|
self.enable_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_group = ep_group.device_group
|
device_group = ep_group.device_group
|
||||||
# TODO: Try local_rank = ep_group.rank_in_group
|
# TODO: Try local_rank = ep_group.rank_in_group
|
||||||
@@ -664,7 +670,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||||
elif get_ep_group().world_size == 1:
|
elif self.enable_graph_mode or get_ep_group().world_size == 1:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@@ -750,26 +756,20 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.expert_map = None
|
self.expert_map = None
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
if self.ep_size > 1:
|
# Create a tensor of size num_experts filled with -1
|
||||||
# Create a tensor of size num_experts filled with -1
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
self.ep_size,
|
||||||
self.ep_size,
|
get_ep_group().rank_in_group, self.global_num_experts)
|
||||||
get_ep_group().rank_in_group, self.global_num_experts)
|
|
||||||
|
|
||||||
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
|
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
|
||||||
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
||||||
|
|
||||||
else:
|
self.enable_graph_mode = False
|
||||||
# Adjust TP size for DP attention
|
additional_config = get_current_vllm_config().additional_config
|
||||||
# haven't test its functionality yet, may remove in the future
|
if additional_config:
|
||||||
|
self.enable_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode", False)
|
||||||
|
|
||||||
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
|
|
||||||
self.moe_parallel_config.ep_rank = 0
|
|
||||||
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
|
|
||||||
self.moe_parallel_config.ep_size = 1
|
|
||||||
|
|
||||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
|
||||||
None)
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
"non-grouped topk.")
|
"non-grouped topk.")
|
||||||
@@ -807,8 +807,15 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
|
||||||
|
self.ep_group = get_ep_group()
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|
||||||
|
self.enable_graph_mode = False
|
||||||
|
additional_config = get_current_vllm_config().additional_config
|
||||||
|
if additional_config:
|
||||||
|
self.enable_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode", False)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@@ -822,11 +829,28 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
else:
|
else:
|
||||||
real_top_k = self.top_k
|
real_top_k = self.top_k
|
||||||
|
|
||||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
# MC2 ag/rs broadcast/all_reduce
|
||||||
...
|
# prefill_req x x √
|
||||||
|
# decode_req √ x √
|
||||||
|
# graph_mode √ √ x
|
||||||
|
if self.dp_size > 1:
|
||||||
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
|
...
|
||||||
|
elif self.enable_graph_mode:
|
||||||
|
if USING_LCCL_COM: # type: ignore
|
||||||
|
hidden_states = get_dp_group().all_gather(
|
||||||
|
hidden_states, 0, False)
|
||||||
|
router_logits = get_dp_group().all_gather(
|
||||||
|
router_logits, 0, False)
|
||||||
|
elif self.enable_graph_mode and not is_prefill:
|
||||||
|
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||||
|
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||||
|
else:
|
||||||
|
hidden_states, router_logits = get_ep_group().dispatch(
|
||||||
|
hidden_states, router_logits)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -843,11 +867,26 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
is_prefill=is_prefill,
|
is_prefill=is_prefill,
|
||||||
enable_force_load_balance=enable_force_load_balance)
|
enable_force_load_balance=enable_force_load_balance)
|
||||||
|
|
||||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
if self.dp_size > 1:
|
||||||
...
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
|
...
|
||||||
|
elif self.enable_graph_mode:
|
||||||
|
if USING_LCCL_COM: # type: ignore
|
||||||
|
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||||
|
hidden_states,
|
||||||
|
"sum",
|
||||||
|
scatter_dim=0,
|
||||||
|
group=get_dp_group().device_group)
|
||||||
|
elif self.enable_graph_mode and not is_prefill:
|
||||||
|
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||||
|
hidden_states,
|
||||||
|
"sum",
|
||||||
|
scatter_dim=0,
|
||||||
|
group=get_dp_group().device_group)
|
||||||
|
else:
|
||||||
|
hidden_states = get_ep_group().combine(hidden_states)
|
||||||
|
|
||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class NPUPlatform(Platform):
|
|||||||
|
|
||||||
# Calculate expert parallel size based on world size
|
# Calculate expert parallel size based on world size
|
||||||
parallel_config.expert_parallel_size = (
|
parallel_config.expert_parallel_size = (
|
||||||
parallel_config.world_size //
|
parallel_config.world_size_across_dp //
|
||||||
parallel_config.expert_tensor_parallel_size)
|
parallel_config.expert_tensor_parallel_size)
|
||||||
|
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
@@ -167,6 +167,8 @@ class NPUPlatform(Platform):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"enable_graph_mode only works with deepseek model."
|
"enable_graph_mode only works with deepseek model."
|
||||||
)
|
)
|
||||||
|
# Set compilation level to NO_COMPILATION to disable ACL Graph
|
||||||
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
|
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
|
||||||
model_type = model_config.hf_config.model_type
|
model_type = model_config.hf_config.model_type
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import GroupCoordinator
|
from vllm.distributed import GroupCoordinator
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
@@ -508,6 +509,12 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
|
|
||||||
self.ep_group = get_ep_group()
|
self.ep_group = get_ep_group()
|
||||||
|
|
||||||
|
self.enable_graph_mode = False
|
||||||
|
additional_config = get_current_vllm_config().additional_config
|
||||||
|
if additional_config:
|
||||||
|
self.enable_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_group = self.ep_group.device_group
|
device_group = self.ep_group.device_group
|
||||||
# TODO: Try local_rank = ep_group.rank_in_group
|
# TODO: Try local_rank = ep_group.rank_in_group
|
||||||
@@ -629,7 +636,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||||
elif self.ep_group.world_size == 1:
|
elif self.enable_graph_mode or self.ep_group.world_size == 1:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
|||||||
@@ -29,12 +29,14 @@ import numpy as np
|
|||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.cache_size
|
import torch._dynamo.cache_size
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.distributed import ReduceOp
|
||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -361,6 +363,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
torch._logging.set_logs(
|
torch._logging.set_logs(
|
||||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||||
|
|
||||||
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
output.
|
output.
|
||||||
@@ -512,6 +517,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if batch_changed:
|
if batch_changed:
|
||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
|
def _get_forward_metadata_across_dp(
|
||||||
|
self, batch_size: int, with_prefill: bool) -> tuple[int, bool]:
|
||||||
|
forward_metadata = torch.tensor([batch_size, with_prefill],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
dist.all_reduce(forward_metadata,
|
||||||
|
op=ReduceOp.MAX,
|
||||||
|
group=get_dp_group().cpu_group)
|
||||||
|
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@@ -648,12 +663,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
seq_lens = self.seq_lens[:num_reqs]
|
seq_lens = self.seq_lens[:num_reqs]
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||||
|
with_prefill = attn_state != AscendAttentionState.DecodeOnly
|
||||||
|
|
||||||
|
if self.dp_size > 1:
|
||||||
|
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||||
|
total_num_scheduled_tokens, with_prefill)
|
||||||
|
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
||||||
|
|
||||||
# Add graph_pad_size here
|
# Add graph_pad_size here
|
||||||
if self.enable_torchair_graph_mode:
|
if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode
|
||||||
batchsize = len(seq_lens)
|
and not with_prefill):
|
||||||
padded_batch_size = self.select_torchair_padded_batchsize(
|
batch_size = len(seq_lens)
|
||||||
batchsize)
|
if self.dp_size > 1:
|
||||||
graph_pad_size = padded_batch_size - batchsize
|
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||||
|
max_num_tokens)
|
||||||
|
else:
|
||||||
|
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||||
|
batch_size)
|
||||||
|
graph_pad_size = padded_batch_size - batch_size
|
||||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
@@ -687,7 +714,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
if (envs_ascend.VLLM_ENABLE_MC2
|
||||||
|
or self.enable_torchair_graph_mode) and not with_prefill:
|
||||||
input_ids = self.input_ids[:padded_batch_size]
|
input_ids = self.input_ids[:padded_batch_size]
|
||||||
positions = self.positions[:padded_batch_size]
|
positions = self.positions[:padded_batch_size]
|
||||||
|
|
||||||
@@ -699,7 +727,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.enable_torchair_graph_mode:
|
if self.enable_torchair_graph_mode:
|
||||||
model_kwargs["kv_caches"] = self.kv_caches
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
if self.enable_torchair_graph_mode and not with_prefill:
|
||||||
hidden_states = self.compile_model(
|
hidden_states = self.compile_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -1095,7 +1123,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
is_compile: bool = False,
|
is_compile: bool = False,
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill,
|
with_prefill: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
@@ -1139,8 +1167,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
with set_forward_context(None, self.vllm_config):
|
with set_forward_context(None,
|
||||||
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens):
|
||||||
|
if self.enable_torchair_graph_mode and not with_prefill:
|
||||||
attn_metadata = self.attn_metadata_builder.build_dummy(
|
attn_metadata = self.attn_metadata_builder.build_dummy(
|
||||||
num_reqs=num_tokens, num_actual_tokens=1)
|
num_reqs=num_tokens, num_actual_tokens=1)
|
||||||
# Only mark static while compiling
|
# Only mark static while compiling
|
||||||
@@ -1393,7 +1423,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||||
0.5 * graph_num, 1.5 * graph_num)
|
0.5 * graph_num, 1.5 * graph_num)
|
||||||
attn_state = AscendAttentionState.DecodeOnly
|
|
||||||
# Trigger torchair graph capture for specific shapes.
|
# Trigger torchair graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
@@ -1403,10 +1432,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
self._dummy_run(num_tokens,
|
self._dummy_run(num_tokens,
|
||||||
is_compile=True,
|
is_compile=True,
|
||||||
attn_state=attn_state)
|
with_prefill=False)
|
||||||
self._dummy_run(num_tokens,
|
self._dummy_run(num_tokens,
|
||||||
is_compile=True,
|
is_compile=True,
|
||||||
attn_state=attn_state)
|
with_prefill=False)
|
||||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||||
num_tokens, idx + 1, graph_num)
|
num_tokens, idx + 1, graph_num)
|
||||||
elif self.use_aclgraph:
|
elif self.use_aclgraph:
|
||||||
@@ -1551,9 +1580,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
||||||
largest_batch_size += batch_size_step
|
largest_batch_size += batch_size_step
|
||||||
|
|
||||||
def select_torchair_padded_batchsize(self, batchsize: int):
|
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||||
selected_batchsize = self.max_num_reqs
|
selected_batch_size = self.max_num_reqs
|
||||||
for padded_batchsize in self.torchair_graph_batch_sizes:
|
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||||
if batchsize <= padded_batchsize < selected_batchsize:
|
if batch_size <= padded_batch_size < selected_batch_size:
|
||||||
selected_batchsize = padded_batchsize
|
selected_batch_size = padded_batch_size
|
||||||
return selected_batchsize
|
return selected_batch_size
|
||||||
|
|||||||
@@ -544,7 +544,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
init_ascend_model_parallel(
|
init_ascend_model_parallel(
|
||||||
parallel_config.expert_parallel_size,
|
parallel_config.expert_parallel_size,
|
||||||
parallel_config.expert_tensor_parallel_size,
|
parallel_config.expert_tensor_parallel_size,
|
||||||
parallel_config.world_size,
|
parallel_config.world_size_across_dp,
|
||||||
)
|
)
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
|||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import try_register_lib
|
from vllm_ascend.utils import try_register_lib
|
||||||
@@ -230,7 +231,18 @@ class NPUWorker(WorkerBase):
|
|||||||
return self.model_runner.pin_lora(lora_id)
|
return self.model_runner.pin_lora(lora_id)
|
||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
self.model_runner._dummy_run(1)
|
runner = self.model_runner
|
||||||
|
num_tokens = 1
|
||||||
|
if runner.dp_size > 1:
|
||||||
|
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
|
||||||
|
1, False)
|
||||||
|
if envs_ascend.VLLM_ENABLE_MC2 or runner.enable_torchair_graph_mode:
|
||||||
|
if not with_prefill:
|
||||||
|
num_tokens = max_num_tokens
|
||||||
|
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
|
||||||
|
runner._dummy_run(num_tokens,
|
||||||
|
is_compile=False,
|
||||||
|
with_prefill=with_prefill)
|
||||||
|
|
||||||
def _init_worker_distributed_environment(self) -> None:
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
@@ -246,7 +258,7 @@ class NPUWorker(WorkerBase):
|
|||||||
init_ascend_model_parallel(
|
init_ascend_model_parallel(
|
||||||
parallel_config.expert_parallel_size,
|
parallel_config.expert_parallel_size,
|
||||||
parallel_config.expert_tensor_parallel_size,
|
parallel_config.expert_tensor_parallel_size,
|
||||||
parallel_config.world_size,
|
parallel_config.world_size_across_dp,
|
||||||
)
|
)
|
||||||
ensure_kv_transfer_initialized(self.vllm_config)
|
ensure_kv_transfer_initialized(self.vllm_config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user