diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 745be6e5..cbfda43a 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -165,7 +165,6 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.speculative_config = MagicMock() self.device = torch.device("cpu") self.runner = MagicMock() - self.runner._select_moe_comm_method.return_value = "alltoall" self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 @@ -192,8 +191,6 @@ class TestEagleProposerDummyRun(TestBase): def test_dummy_run_with_prefill(self, mock_context): mock_context.return_value.__enter__.return_value = None self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) - - self.runner._select_moe_comm_method.assert_called_with(64) self.proposer.model.assert_called_once() diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index daf23cf3..08ef8a68 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -158,7 +158,6 @@ class TestMtpProposer: proposer.model = MagicMock() proposer.enable_shared_expert_dp = False runner._sync_metadata_across_dp.return_value = (8, 8, False) - runner._select_moe_comm_method.return_value = "alltoall" mock_get_forward_context = MagicMock() mock_get_forward_context.cudagraph_runtime_mode = None @@ -168,7 +167,6 @@ class TestMtpProposer: # Verify runner._sync_metadata_across_dp.assert_called_once() - runner._select_moe_comm_method.assert_called_once() mock_set_context.assert_called() # Check that model was called correct number of times @@ -187,7 +185,6 @@ class TestMtpProposer: proposer.enable_shared_expert_dp = False proposer.model = MagicMock() runner._sync_metadata_across_dp.return_value = (8, 8, False) - runner._select_moe_comm_method.return_value = "alltoall" runner.attn_groups = [] mock_get_forward_context = MagicMock() @@ -200,7 +197,6 @@ class TestMtpProposer: # Verify runner._sync_metadata_across_dp.assert_called_once() - runner._select_moe_comm_method.assert_called_once() mock_set_context.assert_called() # Check that model was called correct number of times diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 22233bc9..b4343e76 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -5,12 +5,15 @@ from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import CUDAGraphMode, VllmConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size) from vllm.forward_context import (BatchDescriptor, get_forward_context, set_forward_context) import vllm_ascend.envs as envs_ascend -from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx, +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, + get_ascend_device_type, has_layer_idx, is_moe_model) if TYPE_CHECKING: @@ -31,11 +34,10 @@ def set_ascend_forward_context( attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: Optional[int] = None, + num_tokens: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None, with_prefill: bool = True, in_profile_run: bool = False, - moe_comm_type: Optional[MoECommType] = None, num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None, @@ -60,6 +62,11 @@ def set_ascend_forward_context( from vllm_ascend.ops.fused_moe.moe_comm_method import \ get_moe_comm_method + moe_comm_type = select_moe_comm_method(num_tokens, vllm_config) + # TODO: remove this after moe_comm_type selection logic is finalized + if in_profile_run and is_mtp_model: + moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type + == MoECommType.FUSED_ALLTOALL else moe_comm_type) forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) @@ -231,3 +238,69 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, def get_cos_and_sin(): return _cos, _sin + + +def select_moe_comm_method(num_tokens: int, + vllm_config: VllmConfig) -> Optional[MoECommType]: + """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all + are designed for expert parallelism. + 2. If expert parallel is enabled, we need to consider the soc version and the + number of tokens. This is based on the observation that all-gather is more + efficient than all-to-all when running on A2. + + a. For A2, we choose from MC2 and all-gather. + + b. For A3, we choose from MC2 and all-to-all. + + In both cases, we use MC2 when the number of tokens is smaller than + a its capacity threshold. + + Args: + num_tokens (int): The number of tokens in the current batch. + + Raises: + ValueError: If the soc version is unsupported. + + Returns: + MoECommType: The selected MoE communication method. + """ + if not is_moe_model(vllm_config): + return None + mc2_tokens_capacity = get_mc2_tokens_capacity() + soc_version = get_ascend_device_type() + quant_type = getattr( + vllm_config.model_config.hf_config, 'moe_quantize', + getattr(vllm_config.model_config.hf_config, 'quantize', None)) + model_type = vllm_config.model_config.hf_config.model_type + + if not vllm_config.parallel_config.enable_expert_parallel: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType._910B}: + if (num_tokens <= mc2_tokens_capacity + and vllm_config.parallel_config.world_size_across_dp / + vllm_config.parallel_config.pipeline_parallel_size >= 16): + moe_comm_type = MoECommType.MC2 + else: + # Currently, w4a8_dynamic does not support allgatherep + if quant_type == "w4a8_dynamic": + moe_comm_type = MoECommType.ALLTOALL + else: + moe_comm_type = MoECommType.ALLGATHER + + elif soc_version in {AscendDeviceType._910_93}: + ascend_config = get_ascend_config() + dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes + fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group( + ).world_size <= 16 and (not dynamic_eplb) + moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity + else MoECommType.FUSED_ALLTOALL + if fused_all2all_enable else MoECommType.ALLTOALL) + else: + raise ValueError(f"Unsupported soc_version: {soc_version}") + moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type + == MoECommType.FUSED_ALLTOALL else moe_comm_type) + # PanguProMoE only supports allgather + if model_type == "PanguProMoE": + moe_comm_type = MoECommType.ALLGATHER + return moe_comm_type diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 266eadca..a71c3537 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -123,10 +123,9 @@ class EagleProposer(Proposer): aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None): - moe_comm_type = self.runner._select_moe_comm_method(num_tokens) with set_ascend_forward_context(None, self.vllm_config, - moe_comm_type=moe_comm_type, + in_profile_run=True, num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], @@ -458,15 +457,12 @@ class EagleProposer(Proposer): else: num_input_tokens = num_tokens - moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) - # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions.to(device) self.hidden_states[:num_tokens] = target_hidden_states attn_metadata.block_tables = block_table.to(device) with set_ascend_forward_context(attn_metadata, self.vllm_config, - moe_comm_type=moe_comm_type, num_tokens=num_input_tokens): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -498,8 +494,6 @@ class EagleProposer(Proposer): else: input_batch_size = batch_size - moe_comm_type = self.runner._select_moe_comm_method(input_batch_size) - attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -575,7 +569,6 @@ class EagleProposer(Proposer): # Run the model. with set_ascend_forward_context(attn_metadata, self.vllm_config, - moe_comm_type=moe_comm_type, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 253f0ef4..348956d6 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -28,8 +28,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.ascend_forward_context import (MoECommType, - set_ascend_forward_context) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, @@ -242,11 +241,6 @@ class MtpProposer(Proposer): # NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run # and _propose. aclgraph_runtime_mode = CUDAGraphMode.NONE - moe_comm_type = self.runner._select_moe_comm_method(num_tokens) - # TODO: remove this after moe_comm_type selection logic is finalized - moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type - == MoECommType.FUSED_ALLTOALL else moe_comm_type) - if aclgraph_runtime_mode == CUDAGraphMode.FULL: if len(self.runner.attn_groups) > 0: num_computed_tokens_cpu = ( @@ -299,9 +293,8 @@ class MtpProposer(Proposer): self.vllm_config, num_tokens=num_tokens, with_prefill=with_prefill, + in_profile_run=True, num_tokens_across_dp=num_tokens_across_dp, - moe_comm_type=moe_comm_type, - in_profile_run=self.runner.in_profile_run, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, @@ -720,11 +713,6 @@ class MtpProposer(Proposer): with_prefill) = self.runner._sync_metadata_across_dp( num_input_tokens, self.runner.with_prefill) - moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) - # TODO: remove this after moe_comm_type selection logic is finalized - moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type - == MoECommType.FUSED_ALLTOALL else moe_comm_type) - # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. if scheduler_output and not self.enable_shared_expert_dp: max_query_len = common_attn_metadata.max_query_len @@ -771,7 +759,6 @@ class MtpProposer(Proposer): num_tokens=num_input_tokens, with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, - moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, in_profile_run=self.runner.in_profile_run, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ad0f50b9..f8de7729 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -17,7 +17,6 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # -import gc import math import time from collections import defaultdict @@ -46,8 +45,8 @@ from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group, - get_ep_group, get_pcp_group, - get_pp_group, get_tp_group, + get_pcp_group, get_pp_group, + get_tp_group, is_global_first_rank) from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -87,6 +86,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import (MoECommType, get_mc2_tokens_capacity, + select_moe_comm_method, set_ascend_forward_context, set_cos_and_sin, set_mc2_mask, set_mc2_tokens_capacity) @@ -113,7 +113,6 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort -from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method @@ -457,8 +456,8 @@ class NPUModelRunner(GPUModelRunner): # To ensure skipping all_reduce across dp group is valid, we need to ensure that # moe_comm_method of each rank is MC2 and recomputation would never happen in D # nodes. So here we check whether recompute_scheduler_enable is True. - return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method( - potential_max_num_tokens) == MoECommType.MC2 + return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method( + potential_max_num_tokens, self.vllm_config) == MoECommType.MC2 def _sync_metadata_across_dp( self, num_tokens: int, @@ -1152,51 +1151,17 @@ class NPUModelRunner(GPUModelRunner): input_ids, inputs_embeds, intermediate_tensors, max_num_scheduled_tokens) - def _init_model_kwargs(self): - model_kwargs = dict[str, Any]() - num_reqs = self.input_batch.num_reqs - - num_pooling_reqs = len(self.input_batch.pooling_params) - - if num_pooling_reqs == 0: - return model_kwargs - - pooling_params = self.input_batch.get_pooling_params() - - assert num_pooling_reqs == num_reqs - - token_type_id_requests = dict[int, Any]() - for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: - token_type_id_requests[i] = token_types - - if len(token_type_id_requests) == 0: - return model_kwargs - - seq_lens = self.seq_lens.gpu[:num_reqs] - token_type_ids = [] - - for i in range(num_reqs): - pos = token_type_id_requests.get(i, seq_lens[i]) - ids = (torch.arange(seq_lens[i]) >= pos).int() - token_type_ids.append(ids) - - model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) - return model_kwargs - def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds): assert self.model is not None - hidden_states = self.model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **self._init_model_kwargs()) + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **self._init_model_kwargs(maybe_padded_num_tokens)) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ @@ -1386,73 +1351,6 @@ class NPUModelRunner(GPUModelRunner): hidden_states, aux_hidden_states) return draft_token_ids - def _select_moe_comm_method(self, - num_tokens: int) -> Optional[MoECommType]: - """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all - are designed for expert parallelism. - 2. If expert parallel is enabled, we need to consider the soc version and the - number of tokens. This is based on the observation that all-gather is more - efficient than all-to-all when running on A2. - - a. For A2, we choose from MC2 and all-gather. - - b. For A3, we choose from MC2 and all-to-all. - - In both cases, we use MC2 when the number of tokens is smaller than - a its capacity threshold. - - Args: - num_tokens (int): The number of tokens in the current batch. - - Raises: - ValueError: If the soc version is unsupported. - - Returns: - MoECommType: The selected MoE communication method. - """ - if not is_moe_model(self.vllm_config): - return None - mc2_tokens_capacity = get_mc2_tokens_capacity() - soc_version = get_ascend_device_type() - quant_type = getattr( - self.vllm_config.model_config.hf_config, 'moe_quantize', - getattr(self.vllm_config.model_config.hf_config, 'quantize', None)) - model_type = self.vllm_config.model_config.hf_config.model_type - - if not self.parallel_config.enable_expert_parallel: - moe_comm_type = MoECommType.ALLGATHER - elif soc_version in {AscendDeviceType._910B}: - if (num_tokens <= mc2_tokens_capacity - and self.parallel_config.world_size_across_dp / - self.parallel_config.pipeline_parallel_size >= 16): - moe_comm_type = MoECommType.MC2 - else: - # Currently, w4a8_dynamic does not support allgatherep - if quant_type == "w4a8_dynamic": - moe_comm_type = MoECommType.ALLTOALL - else: - moe_comm_type = MoECommType.ALLGATHER - - elif soc_version in {AscendDeviceType._910_93}: - # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes - fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group( - ).world_size <= 16 and (not self.dynamic_eplb) - moe_comm_type = (MoECommType.MC2 - if num_tokens <= mc2_tokens_capacity else - MoECommType.FUSED_ALLTOALL - if fused_all2all_enable else MoECommType.ALLTOALL) - else: - raise ValueError(f"Unsupported soc_version: {soc_version}") - - # PanguProMoE only supports allgather - if model_type == "PanguProMoE": - moe_comm_type = MoECommType.ALLGATHER - - if is_global_first_rank(): - logger.debug(f"num_tokens: {num_tokens}, " - f"moe_comm_type: {moe_comm_type}") - return moe_comm_type - @staticmethod def get_finished_kv_transfer( scheduler_output: "SchedulerOutput", @@ -1506,7 +1404,6 @@ class NPUModelRunner(GPUModelRunner): if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() - moe_comm_type = self._select_moe_comm_method(num_input_tokens) # prevent debugger is None need_dump = self.dump_enable and self.debugger is not None if need_dump: @@ -1535,7 +1432,6 @@ class NPUModelRunner(GPUModelRunner): num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, with_prefill=self.with_prefill, - moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. @@ -2084,6 +1980,7 @@ class NPUModelRunner(GPUModelRunner): aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, + is_profile: bool = False, ) -> torch.Tensor: # only support eager mode and piecewise graph now assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { @@ -2161,8 +2058,6 @@ class NPUModelRunner(GPUModelRunner): num_tokens_across_dp[:] = num_tokens_padded num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) - moe_comm_type = self._select_moe_comm_method(num_tokens_padded) - # filter out the valid batch descriptor if aclgraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support @@ -2252,9 +2147,7 @@ class NPUModelRunner(GPUModelRunner): num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill, - in_profile_run=self.in_profile_run, - # reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_type=moe_comm_type, + in_profile_run=is_profile, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, @@ -2281,60 +2174,43 @@ class NPUModelRunner(GPUModelRunner): if not self.in_profile_run and self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() - return hidden_states + return hidden_states, hidden_states - @contextmanager - def set_in_profile_run(self): - self.in_profile_run = True - try: - yield - finally: - self.in_profile_run = False + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + output = None + + # For profile, have maximum num_reqs and that collectively have + # maximum num_tokens. + min_tokens_per_req = self.max_num_tokens // self.max_num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs + num_scheduled_tokens_list[ + -1] += self.max_num_tokens % self.max_num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + # TODO: need to rum a dummy sampler for generate task + # Sometimes, after the model is compiled through the AOT backend, + # the model output may become a list containing only one Tensor object. + if isinstance(hidden_states, list) and \ + len(hidden_states) == 1 and \ + isinstance(hidden_states[0], torch.Tensor): + hidden_states = hidden_states[0] + hidden_states = hidden_states[logit_indices] + output = self.model.compute_logits(hidden_states) + return output def profile_run(self) -> None: - # Trigger compilation for general shape. - with self.set_in_profile_run(): - hidden_states = self._dummy_run( - self.max_num_tokens // - self.pcp_size if self.pcp_size > 1 else self.max_num_tokens, - with_prefill=True) - # MC2 will consume additional NPU memory. - # Therefore, we need to run the MC2 path once here to complete its initialization, - # allowing vLLM to correctly estimate the maximum memory required. - mc2_tokens_capacity = get_mc2_tokens_capacity() - if self.max_num_tokens > mc2_tokens_capacity and \ - self._select_moe_comm_method(mc2_tokens_capacity) == MoECommType.MC2: - self._dummy_run(mc2_tokens_capacity, with_prefill=True) - - output = None - if get_pp_group().is_last_rank: - if self.is_pooling_model: - output = self._dummy_pooler_run(hidden_states) - else: - # For profile, have maximum num_reqs and that collectively have - # maximum num_tokens. - min_tokens_per_req = self.max_num_tokens // self.max_num_reqs - num_scheduled_tokens_list = [min_tokens_per_req - ] * self.max_num_reqs - num_scheduled_tokens_list[ - -1] += self.max_num_tokens % self.max_num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) - logit_indices = np.cumsum(num_scheduled_tokens) - 1 - # TODO: need to rum a dummy sampler for generate task - # Sometimes, after the model is compiled through the AOT backend, - # the model output may become a list containing only one Tensor object. - if isinstance(hidden_states, list) and \ - len(hidden_states) == 1 and \ - isinstance(hidden_states[0], torch.Tensor): - hidden_states = hidden_states[0] - hidden_states = hidden_states[logit_indices] - output = self.model.compute_logits(hidden_states) - - NPUPlatform.synchronize() - del hidden_states, output - self.encoder_cache.clear() - gc.collect() + mc2_tokens_capacity = get_mc2_tokens_capacity() + if self.max_num_tokens > mc2_tokens_capacity and \ + select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2: + self._dummy_run(mc2_tokens_capacity, + with_prefill=True, + is_profile=True) + super().profile_run() def eplb_warmup(self): if self.dynamic_eplb and not self.is_eplb_warmuped: