[Aclgraph][DP] Fix dp dummy run not in aclgraph error (#3208)
### What this PR does / why we need it? When running DP in a non-equilibrium scenario, which means there is some dp groups executing `dummy_run`, we need to make sure it running the same mode as other dp, thus improving then performance in dp scenario ### How was this patch tested? Tested by adding log in `_dummy_run` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -444,6 +444,8 @@ class TestNPUWorker(TestBase):
|
||||
# Create worker mock
|
||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||
worker = NPUWorker()
|
||||
worker.compilation_config = MagicMock()
|
||||
worker.compilation_config.cudagraph_mode = MagicMock()
|
||||
mock_model_runner = MagicMock()
|
||||
worker.model_runner = mock_model_runner
|
||||
|
||||
@@ -451,7 +453,8 @@ class TestNPUWorker(TestBase):
|
||||
worker.execute_dummy_batch()
|
||||
|
||||
# Verify call
|
||||
mock_model_runner._dummy_run.assert_called_once_with(1)
|
||||
mock_model_runner._dummy_run.assert_called_once_with(
|
||||
num_tokens=1, uniform_decode=True, force_attention=False)
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
||||
@patch("vllm_ascend.worker.worker_v1.logger")
|
||||
|
||||
@@ -2321,6 +2321,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
forward_context = get_forward_context()
|
||||
assert forward_context is not None
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
|
||||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||||
hidden_states, _ = hidden_states
|
||||
else:
|
||||
@@ -2333,12 +2339,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# only support eager mode and piecewise graph now
|
||||
assert aclgraph_runtime_mode in {
|
||||
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
@@ -2371,8 +2377,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
if uniform_decode:
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||
if num_tokens % max_query_len != 0:
|
||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||||
@@ -2395,12 +2399,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_kv_producer and not self.is_kv_consumer:
|
||||
with_prefill = True
|
||||
|
||||
# TODO(cmq): check if with_prefill is reasonable
|
||||
attn_metadata = self._build_attention_metadata(
|
||||
with_prefill,
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
max_query_len,
|
||||
force_attention,
|
||||
False,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
force_attention=force_attention,
|
||||
)
|
||||
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
@@ -2433,18 +2438,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE:
|
||||
batch_descriptor = None
|
||||
else:
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = \
|
||||
self.aclgraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
# sanity check
|
||||
assert aclgraph_runtime_mode == _cg_mode, (
|
||||
|
||||
# filter out the valid batch descriptor
|
||||
_ag_mode, batch_descriptor = \
|
||||
self.aclgraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
if aclgraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
# warm ups for aclgraph capture
|
||||
assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
aclgraph_runtime_mode == _ag_mode, (
|
||||
f"Aclgraph runtime mode mismatch at dummy_run. "
|
||||
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
|
||||
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
|
||||
else:
|
||||
aclgraph_runtime_mode = _ag_mode
|
||||
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
and lmhead_tp_enable())
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
||||
from torch_npu.profiler import dynamic_profile as dp
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
@@ -356,7 +356,10 @@ class NPUWorker(WorkerBase):
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||
self.model_runner._dummy_run(num_tokens=1,
|
||||
uniform_decode=True,
|
||||
force_attention=force_attention)
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
Reference in New Issue
Block a user