[Aclgraph][DP] Fix dp dummy run not in aclgraph error (#3208)
### What this PR does / why we need it? When running DP in a non-equilibrium scenario, which means there is some dp groups executing `dummy_run`, we need to make sure it running the same mode as other dp, thus improving then performance in dp scenario ### How was this patch tested? Tested by adding log in `_dummy_run` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -444,6 +444,8 @@ class TestNPUWorker(TestBase):
|
|||||||
# Create worker mock
|
# Create worker mock
|
||||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||||
worker = NPUWorker()
|
worker = NPUWorker()
|
||||||
|
worker.compilation_config = MagicMock()
|
||||||
|
worker.compilation_config.cudagraph_mode = MagicMock()
|
||||||
mock_model_runner = MagicMock()
|
mock_model_runner = MagicMock()
|
||||||
worker.model_runner = mock_model_runner
|
worker.model_runner = mock_model_runner
|
||||||
|
|
||||||
@@ -451,7 +453,8 @@ class TestNPUWorker(TestBase):
|
|||||||
worker.execute_dummy_batch()
|
worker.execute_dummy_batch()
|
||||||
|
|
||||||
# Verify call
|
# Verify call
|
||||||
mock_model_runner._dummy_run.assert_called_once_with(1)
|
mock_model_runner._dummy_run.assert_called_once_with(
|
||||||
|
num_tokens=1, uniform_decode=True, force_attention=False)
|
||||||
|
|
||||||
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
||||||
@patch("vllm_ascend.worker.worker_v1.logger")
|
@patch("vllm_ascend.worker.worker_v1.logger")
|
||||||
|
|||||||
@@ -2321,6 +2321,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
assert forward_context is not None
|
||||||
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
|
update_attn_params(self.update_stream, forward_context,
|
||||||
|
positions.shape[0])
|
||||||
|
|
||||||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||||||
hidden_states, _ = hidden_states
|
hidden_states, _ = hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -2333,12 +2339,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
with_prefill: bool = False,
|
with_prefill: bool = False,
|
||||||
is_torchair_compile: bool = False,
|
is_torchair_compile: bool = False,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||||
force_attention: bool = False,
|
force_attention: bool = False,
|
||||||
uniform_decode: bool = False,
|
uniform_decode: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# only support eager mode and piecewise graph now
|
# only support eager mode and piecewise graph now
|
||||||
assert aclgraph_runtime_mode in {
|
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2371,8 +2377,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
if uniform_decode:
|
if uniform_decode:
|
||||||
num_reqs = cdiv(num_tokens, max_query_len)
|
num_reqs = cdiv(num_tokens, max_query_len)
|
||||||
assert num_reqs <= max_num_reqs, \
|
|
||||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
|
||||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||||
if num_tokens % max_query_len != 0:
|
if num_tokens % max_query_len != 0:
|
||||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||||||
@@ -2395,12 +2399,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.is_kv_producer and not self.is_kv_consumer:
|
if self.is_kv_producer and not self.is_kv_consumer:
|
||||||
with_prefill = True
|
with_prefill = True
|
||||||
|
|
||||||
|
# TODO(cmq): check if with_prefill is reasonable
|
||||||
attn_metadata = self._build_attention_metadata(
|
attn_metadata = self._build_attention_metadata(
|
||||||
with_prefill,
|
False,
|
||||||
num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_tokens,
|
num_tokens=num_tokens,
|
||||||
max_query_len,
|
max_query_len=max_query_len,
|
||||||
force_attention,
|
force_attention=force_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.in_profile_run and self.dynamic_eplb:
|
if not self.in_profile_run and self.dynamic_eplb:
|
||||||
@@ -2433,18 +2438,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
k: v[:num_tokens]
|
k: v[:num_tokens]
|
||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE:
|
|
||||||
batch_descriptor = None
|
# filter out the valid batch descriptor
|
||||||
else:
|
_ag_mode, batch_descriptor = \
|
||||||
# filter out the valid batch descriptor
|
self.aclgraph_dispatcher.dispatch(
|
||||||
_cg_mode, batch_descriptor = \
|
BatchDescriptor(num_tokens=num_tokens,
|
||||||
self.aclgraph_dispatcher.dispatch(
|
uniform_decode=uniform_decode))
|
||||||
BatchDescriptor(num_tokens=num_tokens,
|
if aclgraph_runtime_mode is not None:
|
||||||
uniform_decode=uniform_decode))
|
# we allow forcing NONE when the dispatcher disagrees to support
|
||||||
# sanity check
|
# warm ups for aclgraph capture
|
||||||
assert aclgraph_runtime_mode == _cg_mode, (
|
assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||||
|
aclgraph_runtime_mode == _ag_mode, (
|
||||||
f"Aclgraph runtime mode mismatch at dummy_run. "
|
f"Aclgraph runtime mode mismatch at dummy_run. "
|
||||||
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
|
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
|
||||||
|
else:
|
||||||
|
aclgraph_runtime_mode = _ag_mode
|
||||||
|
|
||||||
need_dummy_logits = (not self.in_profile_run
|
need_dummy_logits = (not self.in_profile_run
|
||||||
and lmhead_tp_enable())
|
and lmhead_tp_enable())
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import torch_npu
|
|||||||
import vllm.envs as envs_vllm
|
import vllm.envs as envs_vllm
|
||||||
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
||||||
from torch_npu.profiler import dynamic_profile as dp
|
from torch_npu.profiler import dynamic_profile as dp
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
@@ -356,7 +356,10 @@ class NPUWorker(WorkerBase):
|
|||||||
return self.model_runner.pin_lora(lora_id)
|
return self.model_runner.pin_lora(lora_id)
|
||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
self.model_runner._dummy_run(1)
|
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
|
self.model_runner._dummy_run(num_tokens=1,
|
||||||
|
uniform_decode=True,
|
||||||
|
force_attention=force_attention)
|
||||||
|
|
||||||
def _init_worker_distributed_environment(self) -> None:
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
|
|||||||
Reference in New Issue
Block a user