support deepseek quant & mix-parallel with graphmode (#585)

### What this PR does / why we need it?
1. support deepseek with w8a8 quant;
2. support deepseek with mix-parallel(multi-DP, EP+TP);
3. support deepseek with graphmode.
---------

Signed-off-by: wen-jie666 <wenjie39@huawei.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
zzzzwwjj
2025-04-23 16:23:25 +08:00
committed by GitHub
parent e74331a1ed
commit 5c6d05a59e
13 changed files with 520 additions and 221 deletions

View File

@@ -11,8 +11,6 @@
import gc
import os
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
def main():
dp_rank = int(os.environ['RANK'])
@@ -20,8 +18,8 @@ def main():
dp_size = int(os.environ['WORLD_SIZE'])
master_addr = os.environ['MASTER_ADDR']
master_port = os.environ['MASTER_PORT']
tp_size = 4
etp_size = 2
tp_size = 1
etp_size = 1
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -58,15 +56,15 @@ def main():
max_tokens=4,
min_tokens=4)
# Create an LLM.
llm = LLM(
model="deepseek-ai/DeepSeek-V2-Lite-Chat",
tensor_parallel_size=tp_size,
trust_remote_code=True,
expert_tensor_parallel_size=etp_size,
max_model_len=4096,
max_num_seqs=num_seqs,
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0,
)
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat",
tensor_parallel_size=tp_size,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=num_seqs,
additional_config={
'expert_tensor_parallel_size': etp_size,
'enable_graph_mode': False,
})
outputs = llm.generate(prompts, sampling_params)
for output in outputs:

View File

@@ -6,15 +6,13 @@ export HCCL_SOCKET_IFNAME=${ifname}
# dp_size = node_size * dp_per_node
node_size=1
node_rank=0
dp_per_node=2
dp_per_node=4
master_addr=127.0.0.1
master_port=12345
rm -rf ./.torchair_cache/
rm -rf ./dynamo_*
rm -rf /root/ascend/log/debug/plog/*
export VLLM_ENABLE_GRAPH_MODE=0
export VLLM_ENABLE_MC2=0
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \

View File

@@ -27,6 +27,7 @@ try:
except ImportError:
print("Failed to import torch_npu.")
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
@@ -36,9 +37,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.config import get_current_vllm_config
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
@@ -913,6 +914,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.w_kc = None
self.w_vc = None
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def exec_kv(
self,
hidden_states: torch.Tensor,
@@ -1084,7 +1091,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.num_heads, -1)
# TODO: Replace the env with more flexible expressions
if VLLM_ENABLE_GRAPH_MODE == '1':
if self.enable_graph_mode:
if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.num_prefills > 0:
slots = attn_metadata.slot_mapping
@@ -1141,7 +1148,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
if VLLM_ENABLE_GRAPH_MODE == '1':
if self.enable_graph_mode:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)

View File

@@ -26,13 +26,13 @@
# """Inference-only DeepseekV2/DeepseekV3 model."""
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group,
@@ -64,7 +64,6 @@ from vllm.model_executor.models.utils import (
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
class CustomDeepseekV2MoE(nn.Module):
@@ -133,7 +132,7 @@ class CustomDeepseekV2MoE(nn.Module):
vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
@@ -309,38 +308,36 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
if VLLM_ENABLE_GRAPH_MODE == "1":
self.forward = self.forward_torchair
else:
self.forward = self.forward_eager # type: ignore
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def forward_torchair(self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor = None,
attn_metadata=None):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
kv_cache, attn_metadata)
def forward_eager(self, positions: torch.Tensor,
hidden_states: torch.Tensor):
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
if self.enable_graph_mode:
return self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
hidden_states, None, kv_cache,
attn_metadata)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -408,6 +405,54 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
class CustomDeepseekV2Model(nn.Module):
@@ -459,7 +504,9 @@ class CustomDeepseekV2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
@@ -473,8 +520,13 @@ class CustomDeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
@@ -514,6 +566,20 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass

View File

@@ -330,17 +330,16 @@ def native_grouped_topk(
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = True
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
@@ -364,7 +363,6 @@ def select_experts(
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if custom_routing_function is not None:
raise NotImplementedError(
"Custom routing function is not supported now")
@@ -466,21 +464,36 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
is_prefill=False,
**kwargs,
):
# set prefill as false always, should fix this
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
is_prefill=is_prefill)
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
@@ -611,10 +624,11 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
) == 1 and not is_prefill:
...
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
elif int(os.environ.get("USING_LCCL_COM",
'0')) == 1: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
@@ -631,7 +645,7 @@ class AscendFusedMoE(FusedMoE):
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.num_experts,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
@@ -641,7 +655,7 @@ class AscendFusedMoE(FusedMoE):
is_prefill=is_prefill)
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
) == 1 and not is_prefill:
...
else:

View File

@@ -24,6 +24,7 @@ import torch_npu # noqa: F401
import vllm.envs as envs
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
from vllm.utils import supports_dynamo
CUSTOM_OP_ENABLED = False
try:
@@ -119,6 +120,15 @@ class NPUPlatform(Platform):
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(
"enable_graph_mode", False)
if enable_graph_mode and not supports_dynamo():
logger.warning(
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
)
vllm_config.additional_config["enable_graph_mode"] = False
parallel_config = vllm_config.parallel_config
if parallel_config and parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:

View File

@@ -310,21 +310,22 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(layer, x, router_logits, top_k,
renormalize, use_grouped_topk,
topk_group, num_expert_group,
global_num_experts, expert_map,
topk_group, num_expert_group,
custom_routing_function, scoring_func,
e_score_correction_bias)
e_score_correction_bias, is_prefill)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):

View File

@@ -23,10 +23,8 @@ import torch_npu
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
input_offset: torch.Tensor):
out = torch.empty_like(in_tensor, dtype=torch.int8)
torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset,
out)
return out
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
torch.qint8, -1, True)
class AscendW8A8LinearMethod:
@@ -88,7 +86,11 @@ class AscendW8A8LinearMethod:
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
x = quant_per_tensor(
x,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
return torch_npu.npu_quant_matmul(
x,
@@ -99,6 +101,13 @@ class AscendW8A8LinearMethod:
)
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)

View File

@@ -15,14 +15,183 @@
# limitations under the License.
#
import os
from typing import Any, Callable, Dict, Optional
import torch
import torch_npu
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
def apply_mlp(x: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj
Args:
x: input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
w2_scale: weights2 scale with shape (num_experts, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)
Returns:
hidden_states: output hidden states after MLP.
"""
if dynamic_scale is None:
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
else:
h = x
pertoken_scale = dynamic_scale
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else \
torch.float16
# gmm1: gate_up_proj
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[h],
weight=[w1],
scale=[w1_scale],
per_token_scale=[pertoken_scale],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype)
gate_up_out = gate_up_out_list[0]
# swiglu
swiglu_out = torch_npu.npu_swiglu(gate_up_out)
swiglu_out, swiglu_out_scale = torch_npu.npu_dynamic_quant(swiglu_out)
# down_proj
down_out_list = torch_npu.npu_grouped_matmul(
x=[swiglu_out],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype)
return down_out_list[0]
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
# hidden_states = hidden_states.bfloat16()
kwargs = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
rank = torch.distributed.get_rank()
quant_mode = 2
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_szie = torch.distributed.get_world_size()
tp_size = world_szie // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
if quant_mode == 0:
dynamic_scale = None
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
# moeCombine
kwargs = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
return hidden_states
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
@@ -75,11 +244,10 @@ def fused_experts(hidden_states: torch.Tensor,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
token_counts = token_counts[:num_experts]
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
@@ -97,46 +265,15 @@ def fused_experts(hidden_states: torch.Tensor,
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
quant_x, x_dynamic_scale = torch_npu.npu_dynamic_quant(
sorted_hidden_states)
del sorted_hidden_states
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else torch.float16
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[quant_x],
weight=[w1],
scale=[w1_scale],
per_token_scale=[x_dynamic_scale],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=output_dtype)
del quant_x
gate_up_out_list = gate_up_out_list[0] if len(
gate_up_out_list) == 1 else torch.cat(gate_up_out_list, dim=0)
gate_up_out_list = torch_npu.npu_swiglu(gate_up_out_list)
quant_gate_up_out_list, gate_up_out_dynamic_scale = torch_npu.npu_dynamic_quant(
gate_up_out_list)
del gate_up_out_list
down_out_list = torch_npu.npu_grouped_matmul(
x=[quant_gate_up_out_list],
weight=[w2],
scale=[w2_scale],
per_token_scale=[gate_up_out_dynamic_scale],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=output_dtype)
del quant_gate_up_out_list
down_out_list = down_out_list[0] if len(down_out_list) == 1 else torch.cat(
down_out_list, dim=0)
down_out_list = apply_mlp(sorted_hidden_states,
w1,
w1_scale,
w2,
w2_scale,
expert_tokens,
group_list_type=group_list_type)
if expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
@@ -144,12 +281,18 @@ def fused_experts(hidden_states: torch.Tensor,
final_hidden_states = torch.zeros(*original_shape,
device=hidden_states.device,
dtype=dtype)
final_hidden_states.index_add_(0, sorted_token_indices,
weighted_down_out)
# TODO: This should not happen! Look into it!
# fill nan with 0.0
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
num_valid_tokens = mask.sum()
valid_token_mask = torch.arange(
0, sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None,
@@ -157,7 +300,8 @@ def fused_experts(hidden_states: torch.Tensor,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids)
export_for_source_row=topk_ids,
)
del down_out_list
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
@@ -230,6 +374,18 @@ class AscendW8A8DynamicFusedMoEMethod:
def __init__(self):
self.transpose_weight = True
ep_group = get_ep_group()
try:
device_group = ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
@staticmethod
def get_weight(num_experts: int, intermediate_size_per_partition: int,
hidden_sizes: int,
@@ -272,48 +428,78 @@ class AscendW8A8DynamicFusedMoEMethod:
dtype=params_dtype)
return param_dict
@staticmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
else:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
def process_weights_after_loading(self, layer):
if self.transpose_weight:

View File

@@ -16,8 +16,6 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#
import os
import torch
import torch_npu # noqa: F401
from packaging.version import Version
@@ -25,8 +23,6 @@ from vllm.logger import logger
import vllm_ascend.envs as envs
VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')
def try_register_lib(lib_name: str, lib_info: str = ""):
import importlib

View File

@@ -17,53 +17,66 @@
# limitations under the License.
#
from typing import List, Tuple
from typing import Any, List
import torch
from vllm.config import get_current_vllm_config
from vllm.utils import is_pin_memory_available
from vllm.worker.cache_engine import CacheEngine
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
def allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[Tuple]:
) -> List[Any]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[Tuple] = []
kv_cache: List[Any] = []
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
alloc_shape = kv_cache_shape
additional_config = get_current_vllm_config().additional_config
if additional_config and additional_config.get("enable_graph_mode", False):
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
alloc_shape = kv_cache_shape
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache_nope = torch.zeros(
alloc_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
layer_kv_cache_pe = torch.zeros(
alloc_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim, ),
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache_nope = torch.zeros(
alloc_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
layer_kv_cache_pe = torch.zeros(
alloc_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim, ),
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
else:
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache)
return kv_cache
if VLLM_ENABLE_GRAPH_MODE == '1':
CacheEngine._allocate_kv_cache = allocate_kv_cache
CacheEngine._allocate_kv_cache = allocate_kv_cache

View File

@@ -32,7 +32,7 @@ import torch_npu
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
@@ -56,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
is_pin_memory_available, supports_dynamo)
is_pin_memory_available)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@@ -546,8 +546,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
}
# Add graph_pad_size here
if self.runner.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
if self.runner.enable_graph_mode:
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
seq_lens)
else:
@@ -609,8 +608,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
if self.runner.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
if self.runner.enable_graph_mode:
torch._dynamo.mark_static(input_tokens_tensor)
torch._dynamo.mark_static(input_positions_tensor)
torch._dynamo.mark_static(attn_metadata.block_tables)
@@ -871,6 +869,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size
self.enable_graph_mode = False
additional_config = vllm_config.additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
self.has_inner_state = model_config.has_inner_state
self.in_profile_run = False
@@ -971,8 +975,7 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.model = self.lora_manager.create_lora_manager(self.model)
# adapter torch compile with npu_backend
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
if self.enable_graph_mode:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
@@ -1279,15 +1282,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
# TODO(zzzzwwjj): Do we need to do it every time?
if self.enable_graph_mode:
torch._dynamo.mark_static(model_input.input_tokens)
torch._dynamo.mark_static(model_input.input_positions)
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
torch._dynamo.mark_static(
model_input.attn_metadata.query_start_loc)
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
for kv in kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
@@ -1298,7 +1298,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
virtual_engine = model_input.virtual_engine
prefill_meta = model_input.attn_metadata.prefill_metadata
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and self.vllm_config.compilation_config.level > 0:
if prefill_meta is None and self.enable_graph_mode:
model_executable = self.compile_model
# Note: graph_batch_size value not same as GPU
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
@@ -1341,9 +1341,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
model_kwargs = {"inputs_embeds": None}
if self.enable_graph_mode:
model_kwargs: Dict[str, Any] = {"inputs_embeds": None}
else:
model_kwargs = {}
if previous_hidden_states is not None:
@@ -1360,6 +1359,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.vllm_config, virtual_engine):
if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions
if self.enable_graph_mode:
model_kwargs["kv_caches"] = kv_caches
model_kwargs["attn_metadata"] = model_input.attn_metadata
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
@@ -1430,8 +1432,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif self.vllm_config.compilation_config.level == \
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
elif self.enable_graph_mode:
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states

View File

@@ -24,7 +24,7 @@ import torch
import torch.distributed
from torch import nn
from vllm import envs
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
@@ -300,7 +300,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self._init_cache_engine()
with set_current_vllm_config(self.vllm_config):
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):
@@ -511,10 +512,9 @@ class NPUWorker(LocalOrDistributedWorkerBase):
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
expert_tensor_parallel_size = 1
if additional_config is not None and hasattr(
additional_config, "expert_tensor_parallel_size"):
expert_tensor_parallel_size = getattr(
additional_config, "expert_tensor_parallel_size")
if additional_config:
expert_tensor_parallel_size = additional_config.get(
"expert_tensor_parallel_size", 1)
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
expert_tensor_parallel_size)