diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7dfb8a2..223f97e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -117,6 +117,8 @@ class AscendMLAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. + with_prefill_across_dp: bool = False + # The dimension of the attention heads head_dim: Optional[int] = None attn_mask: torch.Tensor = None @@ -260,6 +262,10 @@ class AscendMLAMetadataBuilder: PAD_SLOT_ID, dtype=torch.int32, device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -278,15 +284,21 @@ class AscendMLAMetadataBuilder: attn_state=AscendAttentionState.DecodeOnly, prefill=None, decode=decode_metadata, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_tables=block_table, ) - def build(self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, - common_attn_metadata: CommonAttentionMetadata, - common_prefix_len: Optional[int] = None, - graph_pad_size: int = -1) -> AscendMLAMetadata: + def build( + self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, + common_attn_metadata: CommonAttentionMetadata, + common_prefix_len: Optional[int] = None, + graph_pad_size: int = -1, + with_prefill_across_dp: bool = False, + ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -388,6 +400,7 @@ class AscendMLAMetadataBuilder: query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, + with_prefill_across_dp=with_prefill_across_dp, ) @@ -621,7 +634,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv = self.kv_a_proj_with_mqa(hidden_states)[0] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache( + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight, cos, @@ -643,7 +656,7 @@ class AscendMLAImpl(MLAAttentionImpl): B, N, D = x.shape S = 1 x = x.view(B, N, S, D) - x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin) + x = torch_npu.npu_interleave_rope(x, cos, sin) return x.view(B, N, D) def _forward_decode( @@ -766,6 +779,7 @@ class AscendMLAImpl(MLAAttentionImpl): sin = sin[attn_metadata.decode.input_positions] cos = cos[:, None, None, :] sin = sin[:, None, None, :] + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_k_pe, decode_k_nope = self.exec_kv( hidden_states_or_kv_c_normed, cos, sin, kv_cache, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 38a8053..515ebe1 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -212,6 +212,14 @@ class CustomDeepseekV2MoE(nn.Module): self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group + self.params_dtype = torch.get_default_dtype() + + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + def forward( self, hidden_states: torch.Tensor, @@ -228,33 +236,35 @@ class CustomDeepseekV2MoE(nn.Module): else: is_prefill = attn_metadata.num_prefills > 0 enable_force_load_balance = False - num_tokens, hidden_dim = hidden_states.shape + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + num_tokens, hidden_size = hidden_states.shape if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) if self.tp_size > 1: - # pass - num_tokens, hidden_size = hidden_states.shape - if num_tokens < self.tp_size: - target_size = self.tp_size - new_hidden_states = torch.empty([target_size, hidden_size], - dtype=hidden_states.dtype, - device=hidden_states.device) - new_hidden_states[:num_tokens] = hidden_states - hidden_states = new_hidden_states - chunk_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - local_hidden_states = chunk_hidden_states[self.tp_rank] - else: - local_hidden_states = hidden_states + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank] + elif not self.enable_graph_mode: + num_padding_tokens = (self.tp_size - + num_tokens % self.tp_size) % self.tp_size + # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C + if num_padding_tokens > 0: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, num_padding_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + hidden_states = chunk_hidden_states[self.tp_rank] # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(local_hidden_states) + router_logits, _ = self.gate(hidden_states) - router_hidden_states = self.experts( - hidden_states=local_hidden_states, + hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits, is_prefill=is_prefill, top_k=CustomDeepseekV2MoE.top_k, @@ -262,18 +272,29 @@ class CustomDeepseekV2MoE(nn.Module): ) * self.routed_scaling_factor if self.tp_size > 1: - dist.all_gather(list(chunk_hidden_states), router_hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens < self.tp_size: - final_hidden_states = final_hidden_states[:num_tokens] - else: - final_hidden_states = router_hidden_states + if self.enable_graph_mode: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + final_hidden_states = torch.zeros( + [num_tokens, hidden_size], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, + hidden_states, self.tp_group) + hidden_states = final_hidden_states + else: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + else: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_padding_tokens > 0: + hidden_states = hidden_states[:-num_padding_tokens] if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + hidden_states = hidden_states + shared_output - return final_hidden_states.view(num_tokens, hidden_dim) + return hidden_states.view(num_tokens, hidden_size) class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 2f5cea0..688ea3a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -587,6 +587,12 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + try: device_group = ep_group.device_group # TODO: Try local_rank = ep_group.rank_in_group @@ -664,7 +670,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - elif get_ep_group().world_size == 1: + elif self.enable_graph_mode or get_ep_group().world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -750,26 +756,20 @@ class AscendFusedMoE(FusedMoE): self.expert_map = None self.activation = activation - if self.ep_size > 1: - # Create a tensor of size num_experts filled with -1 - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) - self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group - self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group + self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group + self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - else: - # Adjust TP size for DP attention - # haven't test its functionality yet, may remove in the future + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) - self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank - self.moe_parallel_config.ep_rank = 0 - self.moe_parallel_config.tp_size = self.tp_size * self.dp_size - self.moe_parallel_config.ep_size = 1 - - self.local_num_experts, self.expert_map = (self.global_num_experts, - None) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") @@ -807,8 +807,15 @@ class AscendFusedMoE(FusedMoE): in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size + self.ep_group = get_ep_group() self.quant_method.create_weights(layer=self, **moe_quant_params) + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -822,11 +829,28 @@ class AscendFusedMoE(FusedMoE): else: real_top_k = self.top_k - if VLLM_ENABLE_MC2 and not is_prefill: - ... + # MC2 ag/rs broadcast/all_reduce + # prefill_req x x √ + # decode_req √ x √ + # graph_mode √ √ x + if self.dp_size > 1: + if VLLM_ENABLE_MC2 and not is_prefill: + ... + elif self.enable_graph_mode: + if USING_LCCL_COM: # type: ignore + hidden_states = get_dp_group().all_gather( + hidden_states, 0, False) + router_logits = get_dp_group().all_gather( + router_logits, 0, False) + elif self.enable_graph_mode and not is_prefill: + hidden_states = get_dp_group().all_gather(hidden_states, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) + else: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) # Matrix multiply. - final_hidden_states = self.quant_method.apply( + hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, @@ -843,11 +867,26 @@ class AscendFusedMoE(FusedMoE): is_prefill=is_prefill, enable_force_load_balance=enable_force_load_balance) - if VLLM_ENABLE_MC2 and not is_prefill: - ... + if self.dp_size > 1: + if VLLM_ENABLE_MC2 and not is_prefill: + ... + elif self.enable_graph_mode: + if USING_LCCL_COM: # type: ignore + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + elif self.enable_graph_mode and not is_prefill: + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + else: + hidden_states = get_ep_group().combine(hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + hidden_states = tensor_model_parallel_all_reduce(hidden_states) - return final_hidden_states + return hidden_states diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 8606eb4..413ba6f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -138,7 +138,7 @@ class NPUPlatform(Platform): # Calculate expert parallel size based on world size parallel_config.expert_parallel_size = ( - parallel_config.world_size // + parallel_config.world_size_across_dp // parallel_config.expert_tensor_parallel_size) if model_config is None: @@ -167,6 +167,8 @@ class NPUPlatform(Platform): raise NotImplementedError( "enable_graph_mode only works with deepseek model." ) + # Set compilation level to NO_COMPILATION to disable ACL Graph + compilation_config.level = CompilationLevel.NO_COMPILATION elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager: model_type = model_config.hf_config.model_type diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a847364..64ab5a3 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Optional import torch import torch.distributed as dist import torch_npu +from vllm.config import get_current_vllm_config from vllm.distributed import GroupCoordinator import vllm_ascend.envs as envs_ascend @@ -508,6 +509,12 @@ class AscendW8A8DynamicFusedMoEMethod: self.ep_group = get_ep_group() + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + try: device_group = self.ep_group.device_group # TODO: Try local_rank = ep_group.rank_in_group @@ -629,7 +636,7 @@ class AscendW8A8DynamicFusedMoEMethod: top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - elif self.ep_group.world_size == 1: + elif self.enable_graph_mode or self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e90c114..f64336d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -29,12 +29,14 @@ import numpy as np import numpy.typing as npt import torch import torch._dynamo.cache_size +import torch.distributed as dist import torch.nn as nn +from torch.distributed import ReduceOp from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger @@ -361,6 +363,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch._logging.set_logs( recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -512,6 +517,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): if batch_changed: self.input_batch.refresh_sampling_metadata() + def _get_forward_metadata_across_dp( + self, batch_size: int, with_prefill: bool) -> tuple[int, bool]: + forward_metadata = torch.tensor([batch_size, with_prefill], + device="cpu", + dtype=torch.int32) + dist.all_reduce(forward_metadata, + op=ReduceOp.MAX, + group=get_dp_group().cpu_group) + return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + def get_model(self) -> nn.Module: return self.model @@ -648,12 +663,24 @@ class NPUModelRunner(LoRAModelRunnerMixin): seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) + with_prefill = attn_state != AscendAttentionState.DecodeOnly + + if self.dp_size > 1: + max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( + total_num_scheduled_tokens, with_prefill) + extra_builder_kwargs['with_prefill_across_dp'] = with_prefill + # Add graph_pad_size here - if self.enable_torchair_graph_mode: - batchsize = len(seq_lens) - padded_batch_size = self.select_torchair_padded_batchsize( - batchsize) - graph_pad_size = padded_batch_size - batchsize + if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode + and not with_prefill): + batch_size = len(seq_lens) + if self.dp_size > 1: + padded_batch_size = self.select_torchair_padded_batch_size( + max_num_tokens) + else: + padded_batch_size = self.select_torchair_padded_batch_size( + batch_size) + graph_pad_size = padded_batch_size - batch_size extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: @@ -687,7 +714,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) input_ids = self.input_ids[:num_input_tokens] - if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + if (envs_ascend.VLLM_ENABLE_MC2 + or self.enable_torchair_graph_mode) and not with_prefill: input_ids = self.input_ids[:padded_batch_size] positions = self.positions[:padded_batch_size] @@ -699,7 +727,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.enable_torchair_graph_mode: model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata - if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + if self.enable_torchair_graph_mode and not with_prefill: hidden_states = self.compile_model( input_ids=input_ids, positions=positions, @@ -1095,7 +1123,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, num_tokens: int, is_compile: bool = False, - attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill, + with_prefill: bool = True, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1139,8 +1167,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, self.vllm_config): - if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly: + with set_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + if self.enable_torchair_graph_mode and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) # Only mark static while compiling @@ -1393,7 +1423,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): logger.info( "Capturing torchair graph, this usually takes %.1f~%.1f mins.", 0.5 * graph_num, 1.5 * graph_num) - attn_state = AscendAttentionState.DecodeOnly # Trigger torchair graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. @@ -1403,10 +1432,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): cudagraph_num_of_warmups): self._dummy_run(num_tokens, is_compile=True, - attn_state=attn_state) + with_prefill=False) self._dummy_run(num_tokens, is_compile=True, - attn_state=attn_state) + with_prefill=False) logger.info("Batchsize %d is compiled successfully: %d/%d.", num_tokens, idx + 1, graph_num) elif self.use_aclgraph: @@ -1551,9 +1580,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.torchair_graph_batch_sizes.append(largest_batch_size) largest_batch_size += batch_size_step - def select_torchair_padded_batchsize(self, batchsize: int): - selected_batchsize = self.max_num_reqs - for padded_batchsize in self.torchair_graph_batch_sizes: - if batchsize <= padded_batchsize < selected_batchsize: - selected_batchsize = padded_batchsize - return selected_batchsize + def select_torchair_padded_batch_size(self, batch_size: int): + selected_batch_size = self.max_num_reqs + for padded_batch_size in self.torchair_graph_batch_sizes: + if batch_size <= padded_batch_size < selected_batch_size: + selected_batch_size = padded_batch_size + return selected_batch_size diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index d98b0fe..84abe03 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -544,7 +544,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): init_ascend_model_parallel( parallel_config.expert_parallel_size, parallel_config.expert_tensor_parallel_size, - parallel_config.world_size, + parallel_config.world_size_across_dp, ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 21c9955..ad7440d 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -41,6 +41,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.worker_base import WorkerBase +import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib @@ -230,7 +231,18 @@ class NPUWorker(WorkerBase): return self.model_runner.pin_lora(lora_id) def execute_dummy_batch(self) -> None: - self.model_runner._dummy_run(1) + runner = self.model_runner + num_tokens = 1 + if runner.dp_size > 1: + max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( + 1, False) + if envs_ascend.VLLM_ENABLE_MC2 or runner.enable_torchair_graph_mode: + if not with_prefill: + num_tokens = max_num_tokens + num_tokens = runner.select_torchair_padded_batch_size(num_tokens) + runner._dummy_run(num_tokens, + is_compile=False, + with_prefill=with_prefill) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" @@ -246,7 +258,7 @@ class NPUWorker(WorkerBase): init_ascend_model_parallel( parallel_config.expert_parallel_size, parallel_config.expert_tensor_parallel_size, - parallel_config.world_size, + parallel_config.world_size_across_dp, ) ensure_kv_transfer_initialized(self.vllm_config)