support deepseek quant & mix-parallel with graphmode (#585)
### What this PR does / why we need it? 1. support deepseek with w8a8 quant; 2. support deepseek with mix-parallel(multi-DP, EP+TP); 3. support deepseek with graphmode. --------- Signed-off-by: wen-jie666 <wenjie39@huawei.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
@@ -11,8 +11,6 @@
|
||||
import gc
|
||||
import os
|
||||
|
||||
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
|
||||
|
||||
|
||||
def main():
|
||||
dp_rank = int(os.environ['RANK'])
|
||||
@@ -20,8 +18,8 @@ def main():
|
||||
dp_size = int(os.environ['WORLD_SIZE'])
|
||||
master_addr = os.environ['MASTER_ADDR']
|
||||
master_port = os.environ['MASTER_PORT']
|
||||
tp_size = 4
|
||||
etp_size = 2
|
||||
tp_size = 1
|
||||
etp_size = 1
|
||||
|
||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
@@ -58,15 +56,15 @@ def main():
|
||||
max_tokens=4,
|
||||
min_tokens=4)
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
expert_tensor_parallel_size=etp_size,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=num_seqs,
|
||||
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0,
|
||||
)
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=num_seqs,
|
||||
additional_config={
|
||||
'expert_tensor_parallel_size': etp_size,
|
||||
'enable_graph_mode': False,
|
||||
})
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
|
||||
@@ -6,15 +6,13 @@ export HCCL_SOCKET_IFNAME=${ifname}
|
||||
# dp_size = node_size * dp_per_node
|
||||
node_size=1
|
||||
node_rank=0
|
||||
dp_per_node=2
|
||||
dp_per_node=4
|
||||
master_addr=127.0.0.1
|
||||
master_port=12345
|
||||
|
||||
rm -rf ./.torchair_cache/
|
||||
rm -rf ./dynamo_*
|
||||
rm -rf /root/ascend/log/debug/plog/*
|
||||
export VLLM_ENABLE_GRAPH_MODE=0
|
||||
export VLLM_ENABLE_MC2=0
|
||||
|
||||
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
|
||||
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \
|
||||
|
||||
@@ -27,6 +27,7 @@ try:
|
||||
except ImportError:
|
||||
print("Failed to import torch_npu.")
|
||||
|
||||
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
@@ -36,9 +37,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
@@ -913,6 +914,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1084,7 +1091,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.num_heads, -1)
|
||||
|
||||
# TODO: Replace the env with more flexible expressions
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
if self.enable_graph_mode:
|
||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.num_prefills > 0:
|
||||
slots = attn_metadata.slot_mapping
|
||||
@@ -1141,7 +1148,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
if self.enable_graph_mode:
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
|
||||
@@ -26,13 +26,13 @@
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_dp_group, get_pp_group,
|
||||
@@ -64,7 +64,6 @@ from vllm.model_executor.models.utils import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
|
||||
class CustomDeepseekV2MoE(nn.Module):
|
||||
@@ -133,7 +132,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
@@ -309,38 +308,36 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
if VLLM_ENABLE_GRAPH_MODE == "1":
|
||||
self.forward = self.forward_torchair
|
||||
else:
|
||||
self.forward = self.forward_eager # type: ignore
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
def forward_torchair(self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor = None,
|
||||
attn_metadata=None):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
||||
kv_cache, attn_metadata)
|
||||
|
||||
def forward_eager(self, positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor):
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
if self.enable_graph_mode:
|
||||
return self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -408,6 +405,54 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# We scale both hidden_states and residual before
|
||||
# rmsnorm, and rmsnorm result would not affect by scale.
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
if self.layer_idx == 0:
|
||||
# The residual is shared by all layers, we only scale it on
|
||||
# first layer.
|
||||
residual *= 1. / self.routed_scaling_factor
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(self.mlp,
|
||||
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# Scaling the DeepseekV2MLP output, it is the input of
|
||||
# input_layernorm of next decoder layer.
|
||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||
# of DeepseekV2MOE
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class CustomDeepseekV2Model(nn.Module):
|
||||
|
||||
@@ -459,7 +504,9 @@ class CustomDeepseekV2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
@@ -473,8 +520,13 @@ class CustomDeepseekV2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@@ -514,6 +566,20 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -330,17 +330,16 @@ def native_grouped_topk(
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: Optional[bool] = True
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
@@ -364,7 +363,6 @@ def select_experts(
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
"Custom routing function is not supported now")
|
||||
@@ -466,21 +464,36 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
is_prefill=False,
|
||||
**kwargs,
|
||||
):
|
||||
# set prefill as false always, should fix this
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; 第三个输出是否输出
|
||||
# y2_flag=False, # old api; 第三个输出是否输出
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -611,10 +624,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
real_top_k = self.top_k
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
|
||||
elif int(os.environ.get("USING_LCCL_COM",
|
||||
'0')) == 1: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
@@ -631,7 +645,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
top_k=real_top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.num_experts,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
@@ -641,7 +655,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
else:
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch_npu # noqa: F401
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
from vllm.utils import supports_dynamo
|
||||
|
||||
CUSTOM_OP_ENABLED = False
|
||||
try:
|
||||
@@ -119,6 +120,15 @@ class NPUPlatform(Platform):
|
||||
compilation_config.level)
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if vllm_config.additional_config is not None:
|
||||
enable_graph_mode = vllm_config.additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
if enable_graph_mode and not supports_dynamo():
|
||||
logger.warning(
|
||||
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
|
||||
)
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
if envs.VLLM_USE_V1:
|
||||
|
||||
@@ -310,21 +310,22 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, x, router_logits, top_k,
|
||||
renormalize, use_grouped_topk,
|
||||
topk_group, num_expert_group,
|
||||
global_num_experts, expert_map,
|
||||
topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func,
|
||||
e_score_correction_bias)
|
||||
e_score_correction_bias, is_prefill)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
|
||||
@@ -23,10 +23,8 @@ import torch_npu
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor):
|
||||
out = torch.empty_like(in_tensor, dtype=torch.int8)
|
||||
torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset,
|
||||
out)
|
||||
return out
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
||||
torch.qint8, -1, True)
|
||||
|
||||
|
||||
class AscendW8A8LinearMethod:
|
||||
@@ -88,7 +86,11 @@ class AscendW8A8LinearMethod:
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
return torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
@@ -99,6 +101,13 @@ class AscendW8A8LinearMethod:
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
|
||||
@@ -15,14 +15,183 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
|
||||
|
||||
def apply_mlp(x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
Args:
|
||||
x: input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
||||
w2: expert weights2 with shape
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
||||
group_list: number of tokens for each expert, follow cumsum mode, and
|
||||
with shape (num_experts).
|
||||
transpose_weight:
|
||||
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: (num_experts, hidden_size, intermediate_size) ->
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
Returns:
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
if dynamic_scale is None:
|
||||
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
|
||||
else:
|
||||
h = x
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else \
|
||||
torch.float16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[h],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype)
|
||||
gate_up_out = gate_up_out_list[0]
|
||||
|
||||
# swiglu
|
||||
swiglu_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
swiglu_out, swiglu_out_scale = torch_npu.npu_dynamic_quant(swiglu_out)
|
||||
|
||||
# down_proj
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[swiglu_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype)
|
||||
return down_out_list[0]
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
# hidden_states = hidden_states.bfloat16()
|
||||
kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": global_bs,
|
||||
}
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
quant_mode = 2
|
||||
ep_group = get_ep_group().device_group
|
||||
local_rank = torch.distributed.get_rank(group=ep_group)
|
||||
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
||||
|
||||
world_szie = torch.distributed.get_world_size()
|
||||
tp_size = world_szie // all_to_all_group_size
|
||||
tp_rank = rank % tp_size
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if quant_mode == 0:
|
||||
dynamic_scale = None
|
||||
|
||||
down_out_list = apply_mlp(expand_x,
|
||||
w1,
|
||||
w1_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_token_nums,
|
||||
dynamic_scale=dynamic_scale)
|
||||
|
||||
# moeCombine
|
||||
kwargs = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
@@ -75,11 +244,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
token_counts = token_counts[:num_experts]
|
||||
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
|
||||
|
||||
expert_tokens = token_counts[:num_experts]
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[sorted_token_indices]
|
||||
group_list_type = 1
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
@@ -97,46 +265,15 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
|
||||
quant_x, x_dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
sorted_hidden_states)
|
||||
del sorted_hidden_states
|
||||
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else torch.float16
|
||||
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_x],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
per_token_scale=[x_dynamic_scale],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=output_dtype)
|
||||
del quant_x
|
||||
|
||||
gate_up_out_list = gate_up_out_list[0] if len(
|
||||
gate_up_out_list) == 1 else torch.cat(gate_up_out_list, dim=0)
|
||||
gate_up_out_list = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
|
||||
quant_gate_up_out_list, gate_up_out_dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
gate_up_out_list)
|
||||
del gate_up_out_list
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_gate_up_out_list],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[gate_up_out_dynamic_scale],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=output_dtype)
|
||||
del quant_gate_up_out_list
|
||||
|
||||
down_out_list = down_out_list[0] if len(down_out_list) == 1 else torch.cat(
|
||||
down_out_list, dim=0)
|
||||
down_out_list = apply_mlp(sorted_hidden_states,
|
||||
w1,
|
||||
w1_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
if expert_map is not None:
|
||||
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
||||
@@ -144,12 +281,18 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
final_hidden_states = torch.zeros(*original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices,
|
||||
weighted_down_out)
|
||||
# TODO: This should not happen! Look into it!
|
||||
# fill nan with 0.0
|
||||
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
|
||||
|
||||
num_valid_tokens = mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out_list,
|
||||
skip1=None,
|
||||
@@ -157,7 +300,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids)
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
del down_out_list
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
@@ -230,6 +374,18 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
ep_group = get_ep_group()
|
||||
|
||||
try:
|
||||
device_group = ep_group.device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
@@ -272,48 +428,78 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
dtype=params_dtype)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; 第三个输出是否输出
|
||||
# y2_flag=False, # old api; 第三个输出是否输出
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
else:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
from packaging.version import Version
|
||||
@@ -25,8 +23,6 @@ from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
|
||||
@@ -17,53 +17,66 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
|
||||
def allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[Tuple]:
|
||||
) -> List[Any]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[Tuple] = []
|
||||
kv_cache: List[Any] = []
|
||||
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
alloc_shape = kv_cache_shape
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config and additional_config.get("enable_graph_mode", False):
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
alloc_shape = kv_cache_shape
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
|
||||
else:
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append(layer_kv_cache)
|
||||
return kv_cache
|
||||
|
||||
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
CacheEngine._allocate_kv_cache = allocate_kv_cache
|
||||
CacheEngine._allocate_kv_cache = allocate_kv_cache
|
||||
|
||||
@@ -32,7 +32,7 @@ import torch_npu
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
@@ -56,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
|
||||
is_pin_memory_available, supports_dynamo)
|
||||
is_pin_memory_available)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@@ -546,8 +546,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
}
|
||||
|
||||
# Add graph_pad_size here
|
||||
if self.runner.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
if self.runner.enable_graph_mode:
|
||||
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
|
||||
seq_lens)
|
||||
else:
|
||||
@@ -609,8 +608,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
if self.runner.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
if self.runner.enable_graph_mode:
|
||||
torch._dynamo.mark_static(input_tokens_tensor)
|
||||
torch._dynamo.mark_static(input_positions_tensor)
|
||||
torch._dynamo.mark_static(attn_metadata.block_tables)
|
||||
@@ -871,6 +869,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
self.max_batchsize_to_capture = \
|
||||
self.vllm_config.compilation_config.max_capture_size
|
||||
|
||||
self.enable_graph_mode = False
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
self.has_inner_state = model_config.has_inner_state
|
||||
|
||||
self.in_profile_run = False
|
||||
@@ -971,8 +975,7 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
if self.enable_graph_mode:
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
@@ -1279,15 +1282,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
assert model_input.attn_metadata is not None
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
# TODO(zzzzwwjj): Do we need to do it every time?
|
||||
if self.enable_graph_mode:
|
||||
torch._dynamo.mark_static(model_input.input_tokens)
|
||||
torch._dynamo.mark_static(model_input.input_positions)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(
|
||||
model_input.attn_metadata.query_start_loc)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
|
||||
for kv in kv_caches:
|
||||
if isinstance(kv, tuple):
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
@@ -1298,7 +1298,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
virtual_engine = model_input.virtual_engine
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
if prefill_meta is None and self.vllm_config.compilation_config.level > 0:
|
||||
if prefill_meta is None and self.enable_graph_mode:
|
||||
model_executable = self.compile_model
|
||||
# Note: graph_batch_size value not same as GPU
|
||||
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
|
||||
@@ -1341,9 +1341,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
model_kwargs = {"inputs_embeds": None}
|
||||
if self.enable_graph_mode:
|
||||
model_kwargs: Dict[str, Any] = {"inputs_embeds": None}
|
||||
else:
|
||||
model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
@@ -1360,6 +1359,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
self.vllm_config, virtual_engine):
|
||||
if model_input.attn_metadata is not None:
|
||||
model_input.attn_metadata.input_positions = model_input.input_positions
|
||||
if self.enable_graph_mode:
|
||||
model_kwargs["kv_caches"] = kv_caches
|
||||
model_kwargs["attn_metadata"] = model_input.attn_metadata
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
@@ -1430,8 +1432,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||
elif self.vllm_config.compilation_config.level == \
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
elif self.enable_graph_mode:
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
@@ -300,7 +300,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext() # type: ignore
|
||||
with context:
|
||||
self._init_cache_engine()
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self._init_cache_engine()
|
||||
self._warm_up_model()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
@@ -511,10 +512,9 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
expert_tensor_parallel_size = 1
|
||||
if additional_config is not None and hasattr(
|
||||
additional_config, "expert_tensor_parallel_size"):
|
||||
expert_tensor_parallel_size = getattr(
|
||||
additional_config, "expert_tensor_parallel_size")
|
||||
if additional_config:
|
||||
expert_tensor_parallel_size = additional_config.get(
|
||||
"expert_tensor_parallel_size", 1)
|
||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
expert_tensor_parallel_size)
|
||||
|
||||
Reference in New Issue
Block a user