【Feature】refactor npu_modelrunner for profile_run (#4993)
### What this PR does / why we need it?
(1)refactor npu_model_runner for profile_run
(2) move _select_moe_comm_method to ascend_forward_context
(3) delete _init_model_kwargs in npu_model_runner
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Na
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from collections import defaultdict
|
||||
@@ -46,8 +45,8 @@ from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
||||
get_ep_group, get_pcp_group,
|
||||
get_pp_group, get_tp_group,
|
||||
get_pcp_group, get_pp_group,
|
||||
get_tp_group,
|
||||
is_global_first_rank)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
@@ -87,6 +86,7 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
get_mc2_tokens_capacity,
|
||||
select_moe_comm_method,
|
||||
set_ascend_forward_context,
|
||||
set_cos_and_sin, set_mc2_mask,
|
||||
set_mc2_tokens_capacity)
|
||||
@@ -113,7 +113,6 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
@@ -457,8 +456,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# To ensure skipping all_reduce across dp group is valid, we need to ensure that
|
||||
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
|
||||
# nodes. So here we check whether recompute_scheduler_enable is True.
|
||||
return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method(
|
||||
potential_max_num_tokens) == MoECommType.MC2
|
||||
return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
|
||||
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
@@ -1152,51 +1151,17 @@ class NPUModelRunner(GPUModelRunner):
|
||||
input_ids, inputs_embeds, intermediate_tensors,
|
||||
max_num_scheduled_tokens)
|
||||
|
||||
def _init_model_kwargs(self):
|
||||
model_kwargs = dict[str, Any]()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
num_pooling_reqs = len(self.input_batch.pooling_params)
|
||||
|
||||
if num_pooling_reqs == 0:
|
||||
return model_kwargs
|
||||
|
||||
pooling_params = self.input_batch.get_pooling_params()
|
||||
|
||||
assert num_pooling_reqs == num_reqs
|
||||
|
||||
token_type_id_requests = dict[int, Any]()
|
||||
for i, param in enumerate(pooling_params):
|
||||
if param.extra_kwargs is not None and \
|
||||
(token_types := param.extra_kwargs.get(
|
||||
"compressed_token_type_ids")) is not None:
|
||||
token_type_id_requests[i] = token_types
|
||||
|
||||
if len(token_type_id_requests) == 0:
|
||||
return model_kwargs
|
||||
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
token_type_ids = []
|
||||
|
||||
for i in range(num_reqs):
|
||||
pos = token_type_id_requests.get(i, seq_lens[i])
|
||||
ids = (torch.arange(seq_lens[i]) >= pos).int()
|
||||
token_type_ids.append(ids)
|
||||
|
||||
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
|
||||
device=self.device)
|
||||
return model_kwargs
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**self._init_model_kwargs())
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**self._init_model_kwargs(maybe_padded_num_tokens))
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||
@@ -1386,73 +1351,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
hidden_states, aux_hidden_states)
|
||||
return draft_token_ids
|
||||
|
||||
def _select_moe_comm_method(self,
|
||||
num_tokens: int) -> Optional[MoECommType]:
|
||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||
are designed for expert parallelism.
|
||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||
number of tokens. This is based on the observation that all-gather is more
|
||||
efficient than all-to-all when running on A2.
|
||||
|
||||
a. For A2, we choose from MC2 and all-gather.
|
||||
|
||||
b. For A3, we choose from MC2 and all-to-all.
|
||||
|
||||
In both cases, we use MC2 when the number of tokens is smaller than
|
||||
a its capacity threshold.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Returns:
|
||||
MoECommType: The selected MoE communication method.
|
||||
"""
|
||||
if not is_moe_model(self.vllm_config):
|
||||
return None
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(
|
||||
self.vllm_config.model_config.hf_config, 'moe_quantize',
|
||||
getattr(self.vllm_config.model_config.hf_config, 'quantize', None))
|
||||
model_type = self.vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not self.parallel_config.enable_expert_parallel:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
elif soc_version in {AscendDeviceType._910B}:
|
||||
if (num_tokens <= mc2_tokens_capacity
|
||||
and self.parallel_config.world_size_across_dp /
|
||||
self.parallel_config.pipeline_parallel_size >= 16):
|
||||
moe_comm_type = MoECommType.MC2
|
||||
else:
|
||||
# Currently, w4a8_dynamic does not support allgatherep
|
||||
if quant_type == "w4a8_dynamic":
|
||||
moe_comm_type = MoECommType.ALLTOALL
|
||||
else:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
elif soc_version in {AscendDeviceType._910_93}:
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
|
||||
).world_size <= 16 and (not self.dynamic_eplb)
|
||||
moe_comm_type = (MoECommType.MC2
|
||||
if num_tokens <= mc2_tokens_capacity else
|
||||
MoECommType.FUSED_ALLTOALL
|
||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
|
||||
# PanguProMoE only supports allgather
|
||||
if model_type == "PanguProMoE":
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
if is_global_first_rank():
|
||||
logger.debug(f"num_tokens: {num_tokens}, "
|
||||
f"moe_comm_type: {moe_comm_type}")
|
||||
return moe_comm_type
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfer(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1506,7 +1404,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.dynamic_eplb:
|
||||
self.eplb_updator.take_update_info_from_eplb_process()
|
||||
|
||||
moe_comm_type = self._select_moe_comm_method(num_input_tokens)
|
||||
# prevent debugger is None
|
||||
need_dump = self.dump_enable and self.debugger is not None
|
||||
if need_dump:
|
||||
@@ -1535,7 +1432,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=self.with_prefill,
|
||||
moe_comm_type=moe_comm_type,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
num_actual_tokens=scheduler_output.
|
||||
@@ -2084,6 +1980,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
is_profile: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# only support eager mode and piecewise graph now
|
||||
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
||||
@@ -2161,8 +2058,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_tokens_across_dp[:] = num_tokens_padded
|
||||
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)
|
||||
|
||||
moe_comm_type = self._select_moe_comm_method(num_tokens_padded)
|
||||
|
||||
# filter out the valid batch descriptor
|
||||
if aclgraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
@@ -2252,9 +2147,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_tokens=num_tokens_padded,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
# reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
in_profile_run=is_profile,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
@@ -2281,60 +2174,43 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
self.eplb_updator.take_update_info_from_eplb_process()
|
||||
self.eplb_updator.forward_end()
|
||||
return hidden_states
|
||||
return hidden_states, hidden_states
|
||||
|
||||
@contextmanager
|
||||
def set_in_profile_run(self):
|
||||
self.in_profile_run = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.in_profile_run = False
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = None
|
||||
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
|
||||
num_scheduled_tokens_list[
|
||||
-1] += self.max_num_tokens % self.max_num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
# TODO: need to rum a dummy sampler for generate task
|
||||
# Sometimes, after the model is compiled through the AOT backend,
|
||||
# the model output may become a list containing only one Tensor object.
|
||||
if isinstance(hidden_states, list) and \
|
||||
len(hidden_states) == 1 and \
|
||||
isinstance(hidden_states[0], torch.Tensor):
|
||||
hidden_states = hidden_states[0]
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
output = self.model.compute_logits(hidden_states)
|
||||
return output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Trigger compilation for general shape.
|
||||
with self.set_in_profile_run():
|
||||
hidden_states = self._dummy_run(
|
||||
self.max_num_tokens //
|
||||
self.pcp_size if self.pcp_size > 1 else self.max_num_tokens,
|
||||
with_prefill=True)
|
||||
# MC2 will consume additional NPU memory.
|
||||
# Therefore, we need to run the MC2 path once here to complete its initialization,
|
||||
# allowing vLLM to correctly estimate the maximum memory required.
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
if self.max_num_tokens > mc2_tokens_capacity and \
|
||||
self._select_moe_comm_method(mc2_tokens_capacity) == MoECommType.MC2:
|
||||
self._dummy_run(mc2_tokens_capacity, with_prefill=True)
|
||||
|
||||
output = None
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.is_pooling_model:
|
||||
output = self._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req
|
||||
] * self.max_num_reqs
|
||||
num_scheduled_tokens_list[
|
||||
-1] += self.max_num_tokens % self.max_num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
# TODO: need to rum a dummy sampler for generate task
|
||||
# Sometimes, after the model is compiled through the AOT backend,
|
||||
# the model output may become a list containing only one Tensor object.
|
||||
if isinstance(hidden_states, list) and \
|
||||
len(hidden_states) == 1 and \
|
||||
isinstance(hidden_states[0], torch.Tensor):
|
||||
hidden_states = hidden_states[0]
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
output = self.model.compute_logits(hidden_states)
|
||||
|
||||
NPUPlatform.synchronize()
|
||||
del hidden_states, output
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
if self.max_num_tokens > mc2_tokens_capacity and \
|
||||
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2:
|
||||
self._dummy_run(mc2_tokens_capacity,
|
||||
with_prefill=True,
|
||||
is_profile=True)
|
||||
super().profile_run()
|
||||
|
||||
def eplb_warmup(self):
|
||||
if self.dynamic_eplb and not self.is_eplb_warmuped:
|
||||
|
||||
Reference in New Issue
Block a user