diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py index 02f6d52..ea32495 100644 --- a/vllm_ascend/distributed/moe_comm_method.py +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -94,43 +94,6 @@ class MoECommMethod(ABC): pass -class DummyCommImpl(MoECommMethod): - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Dummy prepare method that does nothing.""" - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """Dummy finalize method that does nothing.""" - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: - """Dummy implementation, make sure the output shapes are correct.""" - top_k_num = topk_ids.shape[1] - permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, - dim=0) - expert_tokens = torch.zeros((num_experts, ), - dtype=torch.int64, - device=hidden_states.device) - group_list_type = 0 - return permuted_hidden_states, expert_tokens, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - """Dummy implementation that does nothing.""" - pass - - class AllGatherCommImpl(MoECommMethod): """This implementation is the same as NativeAllGatherCommImpl, but uses NPU-specific ops for better performance. diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index ffc1dea..72ee91b 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - DummyCommImpl, MC2CommImpl, MoECommMethod) from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -230,7 +229,7 @@ class AscendFusedMoE(FusedMoE): self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}: + for method in {AllGatherCommImpl, MC2CommImpl}: setattr( self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] @@ -241,8 +240,11 @@ class AscendFusedMoE(FusedMoE): forward_context = get_forward_context() moe_comm_method_name = forward_context.moe_comm_method_name - if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl": + + # TODO: Can we refactor this logic to model_runner? + if not self.moe_config.use_ep: moe_comm_method_name = "allgathercommimpl" + forward_context.moe_comm_method = getattr(self, moe_comm_method_name) hidden_states, router_logits = forward_context.moe_comm_method.prepare( diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index fb4f583..24fd33a 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -70,7 +70,7 @@ class NPUTorchairModelRunner(NPUModelRunner): register_torchair_model() torchair_quant_method_register() - def _get_forward_metadata_across_dp_and_pad( + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: """Override from NPUModelRunner to pad num_tokens""" @@ -81,8 +81,17 @@ class NPUTorchairModelRunner(NPUModelRunner): return maybe_padded_num_tokens, None, with_prefill, enable_dbo return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( - num_tokens, with_prefill, enable_dbo) + num_tokens_across_dp = torch.zeros(self.dp_size + 2, + dtype=torch.int32, + device="npu") + num_tokens_across_dp[self.dp_rank] = num_tokens + num_tokens_across_dp[-2] = int(with_prefill) + num_tokens_across_dp[-1] = int(not enable_dbo) + dist.all_reduce(num_tokens_across_dp, + group=get_dp_group().device_group) + with_prefill = bool(num_tokens_across_dp[-2]) + enable_dbo = not bool(num_tokens_across_dp[-1]) + num_tokens_across_dp = num_tokens_across_dp[:-2] if not with_prefill: max_num_token = num_tokens_across_dp.max().item() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 468f59e..6930a1c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -43,8 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group, is_global_first_rank) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - get_forward_context) +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -373,10 +372,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): device=self.device, ) - self.moe_comm_method = "mc2" - self.fallback_moe_comm_method = "allgather" - self.dummy_moe_comm_method = "dummy" - def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager @@ -594,32 +589,43 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _get_forward_metadata_across_dp( - self, num_tokens: int, with_prefill: bool, - enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) - num_tokens_across_dp = torch.zeros(self.dp_size + 2, - dtype=torch.int32, - device="cpu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] - return num_tokens_across_dp, with_prefill, enable_dbo - - def _get_forward_metadata_across_dp_and_pad( + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: - if self.dp_size == 1: + if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager: return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( - num_tokens, with_prefill, enable_dbo) - return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + # Sync num_tokens, with_prefill, enable_dbo across dp ranks + num_tokens_tensor = torch.tensor([ + num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) + ], + dtype=torch.int32, + device="npu") + + flags_tensor = torch.tensor( + [int(with_prefill), int(not enable_dbo)], + dtype=torch.int32, + device="npu") + + packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) + + dist.all_reduce(packed_tensor, group=get_dp_group().device_group) + + # Unpack the results + num_tokens_across_dp = packed_tensor[:-2] + synced_flags = packed_tensor[-2:] + + max_tokens_across_dp = torch.max(num_tokens_across_dp).item() + global_with_prefill = bool(synced_flags[0]) + global_enable_dbo = not bool(synced_flags[1]) + + # Create a tensor for num_tokens_after_padding + num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * + self.dp_size, + device="npu", + dtype=torch.int32) + + return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo def _check_dbo_is_valid(self, query_lens: torch.Tensor, attn_state: AscendAttentionState, @@ -1025,32 +1031,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): mm_embeds.append(mm_embeds_item) return mm_embeds - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - """This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`. - Please note that vLLM may refactor or modify this function over time, - at present, we are using the version introduced in PR #18935. - """ - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use ACL graphs (enabled by this padding) on the decoder. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1060,24 +1040,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - if (self.use_aclgraph and total_num_scheduled_tokens - <= self.aclgraph_batch_sizes[-1]): - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - total_num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = total_num_scheduled_tokens - - # Padding for DP - num_pad, num_tokens_across_dp_native = self.get_dp_padding( - num_input_tokens) - num_input_tokens += num_pad self.attn_metadata_builder.reorder_batch(self.input_batch, scheduler_output) @@ -1098,6 +1064,41 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + if (self.use_aclgraph and total_num_scheduled_tokens + <= self.aclgraph_batch_sizes[-1]): + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._sync_metadata_across_dp(num_input_tokens, + with_prefill, enable_dbo) + + if self.use_aclgraph: + # When using TorchAir with DP, we have other plans for padding + num_input_tokens = maybe_padded_num_tokens + # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1166,20 +1167,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) - - (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, - enable_dbo) = self._get_forward_metadata_across_dp_and_pad( - total_num_scheduled_tokens, with_prefill, enable_dbo) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp - self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp) + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], @@ -1247,7 +1237,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.positions[:num_input_tokens] input_ids, positions = self._update_input_ids_and_positions( input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp) + maybe_padded_num_tokens) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1262,14 +1252,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - # NOTE: Currently this padding logic is really messy, - # MC2 may not be available in eager mode - # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP - if self.use_aclgraph: - num_tokens_across_dp = num_tokens_across_dp_native - else: - num_input_tokens = padded_num_tokens_across_dp - use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1297,12 +1279,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, - padded_num_tokens_across_dp, logits_indices, - spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors) + maybe_padded_num_tokens, logits_indices, spec_decode_metadata, + input_ids, inputs_embeds, intermediate_tensors) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, - padded_num_tokens_across_dp, + maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds): @@ -1345,7 +1326,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp): + maybe_padded_num_tokens): if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions @@ -1632,6 +1613,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_connector_output=kv_connector_output, ) + def _select_moe_comm_method(self, num_tokens: int) -> str: + return ("mc2" + if num_tokens <= self.mc2_tokens_capacity else "allgather") + @torch.inference_mode() def execute_model( self, @@ -1649,15 +1634,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) (attn_metadata, positions, num_scheduled_tokens_np, - num_input_tokens, num_tokens_across_dp, - padded_num_tokens_across_dp, logits_indices, spec_decode_metadata, - input_ids, inputs_embeds, + num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, + logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) - moe_comm_method = (self.moe_comm_method - if num_input_tokens <= self.mc2_tokens_capacity else - self.fallback_moe_comm_method) + moe_comm_method = self._select_moe_comm_method(num_input_tokens) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ @@ -1680,9 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( - attn_metadata, self.with_prefill, - padded_num_tokens_across_dp, input_ids, positions, - intermediate_tensors, inputs_embeds) + attn_metadata, self.with_prefill, maybe_padded_num_tokens, + input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -1988,7 +1970,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, - moe_comm_method: str = "dummy", aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, @@ -2003,13 +1984,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) # Padding for DP - num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens) - # num_tokens += num_pad ## Uncomment this after TorchAir is removed - - # Padding for DP (for TorchAir) (num_tokens, num_tokens_across_dp, with_prefill, - _) = self._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) + _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) + + moe_comm_method = self._select_moe_comm_method(num_tokens) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using @@ -2518,12 +2496,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self._dummy_run(num_tokens, aclgraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, - uniform_decode=uniform_decode, - moe_comm_method=self.moe_comm_method) + uniform_decode=uniform_decode) self._dummy_run(num_tokens, aclgraph_runtime_mode=aclgraph_runtime_mode, - uniform_decode=uniform_decode, - moe_comm_method=self.moe_comm_method) + uniform_decode=uniform_decode) def _capture_model(self): if not self.use_aclgraph: diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 1ec1436..120b17a 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -194,7 +194,7 @@ class MtpProposer: # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later (num_input_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( + _) = self.runner._sync_metadata_across_dp( num_tokens, self.runner.with_prefill, False) attn_metadata.slot_mapping = target_slot_mapping else: @@ -281,8 +281,8 @@ class MtpProposer: if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) + _) = self.runner._sync_metadata_across_dp(num_tokens, + with_prefill, False) is_running_torchair = self.torchair_graph_enabled and \ not with_prefill