diff --git a/examples/dp_offline/data_parallel.py b/examples/dp_offline/data_parallel.py index ae5b104..1e94940 100644 --- a/examples/dp_offline/data_parallel.py +++ b/examples/dp_offline/data_parallel.py @@ -11,8 +11,6 @@ import gc import os -VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" - def main(): dp_rank = int(os.environ['RANK']) @@ -20,8 +18,8 @@ def main(): dp_size = int(os.environ['WORLD_SIZE']) master_addr = os.environ['MASTER_ADDR'] master_port = os.environ['MASTER_PORT'] - tp_size = 4 - etp_size = 2 + tp_size = 1 + etp_size = 1 os.environ["VLLM_DP_RANK"] = str(dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -58,15 +56,15 @@ def main(): max_tokens=4, min_tokens=4) # Create an LLM. - llm = LLM( - model="deepseek-ai/DeepSeek-V2-Lite-Chat", - tensor_parallel_size=tp_size, - trust_remote_code=True, - expert_tensor_parallel_size=etp_size, - max_model_len=4096, - max_num_seqs=num_seqs, - compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0, - ) + llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", + tensor_parallel_size=tp_size, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=num_seqs, + additional_config={ + 'expert_tensor_parallel_size': etp_size, + 'enable_graph_mode': False, + }) outputs = llm.generate(prompts, sampling_params) for output in outputs: diff --git a/examples/dp_offline/run_dp.sh b/examples/dp_offline/run_dp.sh index 0e525f4..405df60 100644 --- a/examples/dp_offline/run_dp.sh +++ b/examples/dp_offline/run_dp.sh @@ -6,15 +6,13 @@ export HCCL_SOCKET_IFNAME=${ifname} # dp_size = node_size * dp_per_node node_size=1 node_rank=0 -dp_per_node=2 +dp_per_node=4 master_addr=127.0.0.1 master_port=12345 rm -rf ./.torchair_cache/ rm -rf ./dynamo_* rm -rf /root/ascend/log/debug/plog/* -export VLLM_ENABLE_GRAPH_MODE=0 -export VLLM_ENABLE_MC2=0 torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \ --node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \ diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 6943fe8..2e0262c 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -27,6 +27,7 @@ try: except ImportError: print("Failed to import torch_npu.") +import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionType, @@ -36,9 +37,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.config import get_current_vllm_config from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE from vllm_ascend.worker.model_runner import ( ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) @@ -913,6 +914,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): self.w_kc = None self.w_vc = None + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + def exec_kv( self, hidden_states: torch.Tensor, @@ -1084,7 +1091,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): self.num_heads, -1) # TODO: Replace the env with more flexible expressions - if VLLM_ENABLE_GRAPH_MODE == '1': + if self.enable_graph_mode: if len(kv_cache) > 0 and kv_cache[0].numel( ) > 0 and attn_metadata.num_prefills > 0: slots = attn_metadata.slot_mapping @@ -1141,7 +1148,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): ) elif attn_metadata.decode_metadata: assert kv_cache is not None - if VLLM_ENABLE_GRAPH_MODE == '1': + if self.enable_graph_mode: # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index d0ef762..76f9468 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -26,13 +26,13 @@ # """Inference-only DeepseekV2/DeepseekV3 model.""" import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionMetadata from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) from vllm.distributed import (get_dp_group, get_pp_group, @@ -64,7 +64,6 @@ from vllm.model_executor.models.utils import ( from vllm.sequence import IntermediateTensors from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE class CustomDeepseekV2MoE(nn.Module): @@ -133,7 +132,7 @@ class CustomDeepseekV2MoE(nn.Module): vllm_config = get_current_vllm_config() self.dp_size = get_dp_group().world_size batch_size = vllm_config.scheduler_config.max_num_seqs - self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1 + self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1 params_dtype = torch.get_default_dtype() self.final_hidden_states = torch.zeros( @@ -309,38 +308,36 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): self.prefix = prefix self.debug_layer_idx = int(self.prefix.split(".")[-2]) - if VLLM_ENABLE_GRAPH_MODE == "1": - self.forward = self.forward_torchair - else: - self.forward = self.forward_eager # type: ignore + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) - def forward_torchair(self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor = None, - attn_metadata=None): + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states - return self.mla_attn(hidden_states_or_q_c, hidden_states, None, - kv_cache, attn_metadata) - - def forward_eager(self, positions: torch.Tensor, - hidden_states: torch.Tensor): - if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) + if self.enable_graph_mode: + return self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata) else: - hidden_states_or_q_c = hidden_states - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): @@ -408,6 +405,54 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + class CustomDeepseekV2Model(nn.Module): @@ -459,7 +504,9 @@ class CustomDeepseekV2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: @@ -473,8 +520,13 @@ class CustomDeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -514,6 +566,20 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): pass diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 50781b4..7eebc7d 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -330,17 +330,16 @@ def native_grouped_topk( def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: Optional[bool] = True + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. @@ -364,7 +363,6 @@ def select_experts( Raises: ValueError: If an unsupported scoring function is provided. """ - if custom_routing_function is not None: raise NotImplementedError( "Custom routing function is not supported now") @@ -466,21 +464,36 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): is_prefill=False, **kwargs, ): - # set prefill as false always, should fix this - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - is_prefill=is_prefill) + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk当前写8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) - if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill: + if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -611,10 +624,11 @@ class AscendFusedMoE(FusedMoE): real_top_k = self.top_k if self.dp_size > 1: - if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore + if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore ) == 1 and not is_prefill: ... - elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore + elif int(os.environ.get("USING_LCCL_COM", + '0')) == 1: # type: ignore hidden_states = get_dp_group().all_gather( hidden_states, 0, False) router_logits = get_dp_group().all_gather( @@ -631,7 +645,7 @@ class AscendFusedMoE(FusedMoE): top_k=real_top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.num_experts, + global_num_experts=self.global_num_experts, expert_map=self.expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, @@ -641,7 +655,7 @@ class AscendFusedMoE(FusedMoE): is_prefill=is_prefill) if self.dp_size > 1: - if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore + if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore ) == 1 and not is_prefill: ... else: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 174f84b..ccaee9d 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -24,6 +24,7 @@ import torch_npu # noqa: F401 import vllm.envs as envs from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum +from vllm.utils import supports_dynamo CUSTOM_OP_ENABLED = False try: @@ -119,6 +120,15 @@ class NPUPlatform(Platform): compilation_config.level) compilation_config.level = CompilationLevel.NO_COMPILATION + if vllm_config.additional_config is not None: + enable_graph_mode = vllm_config.additional_config.get( + "enable_graph_mode", False) + if enable_graph_mode and not supports_dynamo(): + logger.warning( + "enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode" + ) + vllm_config.additional_config["enable_graph_mode"] = False + parallel_config = vllm_config.parallel_config if parallel_config and parallel_config.worker_cls == "auto": if envs.VLLM_USE_V1: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3f3646b..da8e96b 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -310,21 +310,22 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, **kwargs, ) -> torch.Tensor: return self.quant_method.apply(layer, x, router_logits, top_k, renormalize, use_grouped_topk, - topk_group, num_expert_group, global_num_experts, expert_map, + topk_group, num_expert_group, custom_routing_function, scoring_func, - e_score_correction_bias) + e_score_correction_bias, is_prefill) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index b1e081d..ae9dd46 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -23,10 +23,8 @@ import torch_npu def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, input_offset: torch.Tensor): - out = torch.empty_like(in_tensor, dtype=torch.int8) - torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset, - out) - return out + return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, + torch.qint8, -1, True) class AscendW8A8LinearMethod: @@ -88,7 +86,11 @@ class AscendW8A8LinearMethod: ) -> torch.Tensor: original_dtype = x.dtype if original_dtype != torch.int8: - x = quant_per_tensor(x, layer.input_scale, layer.input_offset) + x = quant_per_tensor( + x, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + ) quant_bias = layer.quant_bias if tp_rank == 0 else None return torch_npu.npu_quant_matmul( x, @@ -99,6 +101,13 @@ class AscendW8A8LinearMethod: ) def process_weights_after_loading(self, layer): + expanding_factor = layer.weight.data.shape[1] + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + layer.aclnn_input_offset = torch.nn.Parameter( + layer.input_offset.data.repeat(expanding_factor), + requires_grad=False) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight_scale.data = torch.flatten(layer.weight_scale.data) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 52796a8..71c71cd 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,14 +15,183 @@ # limitations under the License. # +import os from typing import Any, Callable, Dict, Optional import torch import torch_npu +from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts +def apply_mlp(x: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + x: input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + if dynamic_scale is None: + h, pertoken_scale = torch_npu.npu_dynamic_quant(x) + else: + h = x + pertoken_scale = dynamic_scale + + output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else \ + torch.float16 + + # gmm1: gate_up_proj + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[h], + weight=[w1], + scale=[w1_scale], + per_token_scale=[pertoken_scale], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype) + gate_up_out = gate_up_out_list[0] + + # swiglu + swiglu_out = torch_npu.npu_swiglu(gate_up_out) + swiglu_out, swiglu_out_scale = torch_npu.npu_dynamic_quant(swiglu_out) + + # down_proj + down_out_list = torch_npu.npu_grouped_matmul( + x=[swiglu_out], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype) + return down_out_list[0] + + +def fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: str = "", +) -> torch.Tensor: + global_bs = 0 + moe_expert_num = len(expert_map) + # hidden_states = hidden_states.bfloat16() + kwargs = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": global_bs, + } + + rank = torch.distributed.get_rank() + + quant_mode = 2 + ep_group = get_ep_group().device_group + local_rank = torch.distributed.get_rank(group=ep_group) + all_to_all_group_size = torch.distributed.get_world_size(ep_group) + + world_szie = torch.distributed.get_world_size() + tp_size = world_szie // all_to_all_group_size + tp_rank = rank % tp_size + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": all_to_all_group_size, + "ep_rank_id": local_rank, + # "group_tp": self.moe_rs_group_name, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": tp_size, + "tp_rank_id": tp_rank, + } + kwargs.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch(**kwargs) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] + + if quant_mode == 0: + dynamic_scale = None + + down_out_list = apply_mlp(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) + + # moeCombine + kwargs = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expand_idx": expand_idx, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": all_to_all_group_size, + "ep_rank_id": local_rank, + "tp_send_counts": tp_recv_counts, + # "group_tp": self.moe_rs_group_name, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": tp_size, + "tp_rank_id": tp_rank, + } + kwargs.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs) + + return hidden_states + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -75,11 +244,10 @@ def fused_experts(hidden_states: torch.Tensor, dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - token_counts = token_counts[:num_experts] - expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) - + expert_tokens = token_counts[:num_experts] # Rearrange hidden_states sorted_hidden_states = hidden_states[sorted_token_indices] + group_list_type = 1 else: row_idx_len = num_tokens * top_k row_idx = torch.arange(0, @@ -97,46 +265,15 @@ def fused_experts(hidden_states: torch.Tensor, expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 - quant_x, x_dynamic_scale = torch_npu.npu_dynamic_quant( - sorted_hidden_states) - del sorted_hidden_states - output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else torch.float16 - - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[quant_x], - weight=[w1], - scale=[w1_scale], - per_token_scale=[x_dynamic_scale], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - output_dtype=output_dtype) - del quant_x - - gate_up_out_list = gate_up_out_list[0] if len( - gate_up_out_list) == 1 else torch.cat(gate_up_out_list, dim=0) - gate_up_out_list = torch_npu.npu_swiglu(gate_up_out_list) - - quant_gate_up_out_list, gate_up_out_dynamic_scale = torch_npu.npu_dynamic_quant( - gate_up_out_list) - del gate_up_out_list - - down_out_list = torch_npu.npu_grouped_matmul( - x=[quant_gate_up_out_list], - weight=[w2], - scale=[w2_scale], - per_token_scale=[gate_up_out_dynamic_scale], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - output_dtype=output_dtype) - del quant_gate_up_out_list - - down_out_list = down_out_list[0] if len(down_out_list) == 1 else torch.cat( - down_out_list, dim=0) + down_out_list = apply_mlp(sorted_hidden_states, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) @@ -144,12 +281,18 @@ def fused_experts(hidden_states: torch.Tensor, final_hidden_states = torch.zeros(*original_shape, device=hidden_states.device, dtype=dtype) - final_hidden_states.index_add_(0, sorted_token_indices, - weighted_down_out) - # TODO: This should not happen! Look into it! - # fill nan with 0.0 - final_hidden_states[torch.isnan(final_hidden_states)] = 0.0 + + num_valid_tokens = mask.sum() + valid_token_mask = torch.arange( + 0, sorted_token_indices.shape[0], + device=device).unsqueeze(1) < num_valid_tokens + valid_output = torch.where( + valid_token_mask, weighted_down_out, + torch.zeros_like(weighted_down_out)).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( down_out_list, skip1=None, @@ -157,7 +300,8 @@ def fused_experts(hidden_states: torch.Tensor, bias=None, scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids) + export_for_source_row=topk_ids, + ) del down_out_list if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) @@ -230,6 +374,18 @@ class AscendW8A8DynamicFusedMoEMethod: def __init__(self): self.transpose_weight = True + ep_group = get_ep_group() + + try: + device_group = ep_group.device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + @staticmethod def get_weight(num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, @@ -272,48 +428,78 @@ class AscendW8A8DynamicFusedMoEMethod: dtype=params_dtype) return param_dict - @staticmethod def apply( + self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk当前写8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) + if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill: + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name) + else: + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) def process_weights_after_loading(self, layer): if self.transpose_weight: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 54279d0..dfd0f68 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -16,8 +16,6 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/worker.py # -import os - import torch import torch_npu # noqa: F401 from packaging.version import Version @@ -25,8 +23,6 @@ from vllm.logger import logger import vllm_ascend.envs as envs -VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0') - def try_register_lib(lib_name: str, lib_info: str = ""): import importlib diff --git a/vllm_ascend/worker/cache_engine.py b/vllm_ascend/worker/cache_engine.py index 018a66b..72de201 100644 --- a/vllm_ascend/worker/cache_engine.py +++ b/vllm_ascend/worker/cache_engine.py @@ -17,53 +17,66 @@ # limitations under the License. # -from typing import List, Tuple +from typing import Any, List import torch +from vllm.config import get_current_vllm_config from vllm.utils import is_pin_memory_available from vllm.worker.cache_engine import CacheEngine -from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE - def allocate_kv_cache( self, num_blocks: int, device: str, -) -> List[Tuple]: +) -> List[Any]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False - kv_cache: List[Tuple] = [] + kv_cache: List[Any] = [] - # Align entries so they are 256 byte aligned for better performance - # Primarily targets MLA as this typically only ends up having entries - # be 128 byte aligned. - alloc_shape = kv_cache_shape + additional_config = get_current_vllm_config().additional_config + if additional_config and additional_config.get("enable_graph_mode", False): + # Align entries so they are 256 byte aligned for better performance + # Primarily targets MLA as this typically only ends up having entries + # be 128 byte aligned. + alloc_shape = kv_cache_shape - for _ in range(self.num_attention_layers): - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - layer_kv_cache_nope = torch.zeros( - alloc_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=pin_memory, - device=device) - layer_kv_cache_pe = torch.zeros( - alloc_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, ), - dtype=self.dtype, - pin_memory=pin_memory, - device=device) + for _ in range(self.num_attention_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + layer_kv_cache_nope = torch.zeros( + alloc_shape[:-1] + + (self.model_config.hf_text_config.kv_lora_rank, ), + dtype=self.dtype, + pin_memory=pin_memory, + device=device) + layer_kv_cache_pe = torch.zeros( + alloc_shape[:-1] + + (self.model_config.hf_text_config.qk_rope_head_dim, ), + dtype=self.dtype, + pin_memory=pin_memory, + device=device) - # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases - # when entry_shape is higher than 1D - kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe)) + # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases + # when entry_shape is higher than 1D + kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe)) + else: + for _ in range(self.num_attention_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + layer_kv_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device) + + # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases + # when entry_shape is higher than 1D + kv_cache.append(layer_kv_cache) return kv_cache -if VLLM_ENABLE_GRAPH_MODE == '1': - CacheEngine._allocate_kv_cache = allocate_kv_cache \ No newline at end of file +CacheEngine._allocate_kv_cache = allocate_kv_cache diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index d30070c..f2a62d9 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -32,7 +32,7 @@ import torch_npu import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.utils import CommonAttentionState -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context @@ -56,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists, - is_pin_memory_available, supports_dynamo) + is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -546,8 +546,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): } # Add graph_pad_size here - if self.runner.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + if self.runner.enable_graph_mode: graph_pad_size = self.runner.scheduler_config.max_num_seqs - len( seq_lens) else: @@ -609,8 +608,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): ] multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - if self.runner.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + if self.runner.enable_graph_mode: torch._dynamo.mark_static(input_tokens_tensor) torch._dynamo.mark_static(input_positions_tensor) torch._dynamo.mark_static(attn_metadata.block_tables) @@ -871,6 +869,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): self.max_batchsize_to_capture = \ self.vllm_config.compilation_config.max_capture_size + self.enable_graph_mode = False + additional_config = vllm_config.additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + self.has_inner_state = model_config.has_inner_state self.in_profile_run = False @@ -971,8 +975,7 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): self.model = self.lora_manager.create_lora_manager(self.model) # adapter torch compile with npu_backend - if self.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + if self.enable_graph_mode: import torchair # type: ignore from torchair import patch_for_hcom # type: ignore @@ -1279,15 +1282,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): self.attn_state.begin_forward(model_input) assert model_input.attn_metadata is not None - if self.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + # TODO(zzzzwwjj): Do we need to do it every time? + if self.enable_graph_mode: torch._dynamo.mark_static(model_input.input_tokens) torch._dynamo.mark_static(model_input.input_positions) torch._dynamo.mark_static(model_input.attn_metadata.block_tables) torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping) - torch._dynamo.mark_static( - model_input.attn_metadata.query_start_loc) - torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc) for kv in kv_caches: if isinstance(kv, tuple): torch._dynamo.mark_static(kv[0]) @@ -1298,7 +1298,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): virtual_engine = model_input.virtual_engine prefill_meta = model_input.attn_metadata.prefill_metadata previous_hidden_states = kwargs.get("previous_hidden_states") - if prefill_meta is None and self.vllm_config.compilation_config.level > 0: + if prefill_meta is None and self.enable_graph_mode: model_executable = self.compile_model # Note: graph_batch_size value not same as GPU graph_batch_size = model_input.input_tokens.shape[ # type: ignore @@ -1341,9 +1341,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} - if self.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - model_kwargs = {"inputs_embeds": None} + if self.enable_graph_mode: + model_kwargs: Dict[str, Any] = {"inputs_embeds": None} else: model_kwargs = {} if previous_hidden_states is not None: @@ -1360,6 +1359,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): self.vllm_config, virtual_engine): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions + if self.enable_graph_mode: + model_kwargs["kv_caches"] = kv_caches + model_kwargs["attn_metadata"] = model_input.attn_metadata hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1430,8 +1432,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): hidden_states = hidden_or_intermediate_states.index_select( 0, indices) output.prefill_hidden_states = hidden_or_intermediate_states - elif self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + elif self.enable_graph_mode: hidden_states = hidden_or_intermediate_states[:len(indices)] else: hidden_states = hidden_or_intermediate_states diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 1479bdc..6cfa9ca 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -24,7 +24,7 @@ import torch import torch.distributed from torch import nn from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -300,7 +300,8 @@ class NPUWorker(LocalOrDistributedWorkerBase): from contextlib import nullcontext context = nullcontext() # type: ignore with context: - self._init_cache_engine() + with set_current_vllm_config(self.vllm_config): + self._init_cache_engine() self._warm_up_model() def _init_cache_engine(self): @@ -511,10 +512,9 @@ class NPUWorker(LocalOrDistributedWorkerBase): parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) expert_tensor_parallel_size = 1 - if additional_config is not None and hasattr( - additional_config, "expert_tensor_parallel_size"): - expert_tensor_parallel_size = getattr( - additional_config, "expert_tensor_parallel_size") + if additional_config: + expert_tensor_parallel_size = additional_config.get( + "expert_tensor_parallel_size", 1) init_ascend_model_parallel(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, expert_tensor_parallel_size)