## What this PR does / why we need it?
Implements [RFC
#6954](https://github.com/vllm-project/vllm-ascend/issues/6954):
NPUWorker Profiler profile_prefix full adaptation for API parity with
upstream vLLM.
### Changes
- **Lazy profiler init**: Defer profiler creation until first
`profile(is_start=True)` call
- **profile_prefix param**: Add `profile_prefix` to `profile()`; compute
`trace_name` from prefix + `get_worker_rank_suffix()`
- **Refactor `_init_profiler` → `_create_profiler(trace_name)`**: Pass
`worker_name` to `tensorboard_trace_handler` for unique trace files per
worker
- Unique trace files per worker; no collision in multi-worker setups
### Testing
- Unit tests updated/added in `tests/ut/worker/test_worker_v1.py`
- `pytest tests/ut/worker/test_worker_v1.py::TestNPUWorker` passed
## Does this PR introduce _any_ user-facing change?
Yes. Trace file naming may differ (more descriptive with worker rank
suffix). `profile(is_start=True, profile_prefix="warmup")` now
supported.
## How was this patch tested?
- Unit tests:`pytest tests/ut/worker/test_worker_v1.py::TestNPUWorker`
- Manual: vLLM serve with profiler config, start/stop profile, verified
trace files
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -125,7 +125,9 @@ class NPUWorker(WorkerBase):
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
|
||||
|
||||
self.profiler = self._init_profiler()
|
||||
# Profiler is lazily initialized on first profile(is_start=True) call (RFC #6954)
|
||||
self.profiler_config = vllm_config.profiler_config
|
||||
self.profiler = None
|
||||
if vllm_config.model_config and vllm_config.model_config.enable_sleep_mode:
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
@@ -511,12 +513,34 @@ class NPUWorker(WorkerBase):
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
|
||||
# Check if profiling is enabled (RFC #6954 - align with upstream vLLM)
|
||||
if self.profiler_config is None or self.profiler_config.profiler is None:
|
||||
raise RuntimeError(
|
||||
"Profiling is not enabled. Please set --profiler-config to enable "
|
||||
"profiling. Example: "
|
||||
"'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
|
||||
"=YOUR_DIR_PATH_TO_DUMP_TRACE'"
|
||||
)
|
||||
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
from vllm.distributed.utils import get_worker_rank_suffix
|
||||
|
||||
rank_suffix = get_worker_rank_suffix(global_rank=self.rank)
|
||||
trace_name = f"{profile_prefix}_{rank_suffix}" if profile_prefix else rank_suffix
|
||||
|
||||
if self.profiler is None:
|
||||
self.profiler = self._create_profiler(trace_name)
|
||||
logger.debug("Starting torch profiler with trace name: %s", trace_name)
|
||||
self.profiler.start() # type: ignore[attr-defined]
|
||||
else:
|
||||
# Profiler already initialized. Restart profiling but keep
|
||||
# the original trace name from the first initialization.
|
||||
self.profiler.start()
|
||||
else:
|
||||
if self.profiler is None:
|
||||
logger.warning("Profiler was not started, nothing to stop.")
|
||||
return
|
||||
self.profiler.stop()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
@@ -553,43 +577,45 @@ class NPUWorker(WorkerBase):
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
ensure_ec_transfer_initialized(self.vllm_config)
|
||||
|
||||
def _init_profiler(self):
|
||||
# Torch profiler. Enabled through profiler_config:
|
||||
# --profiler-config.profiler=torch --profiler-config.torch_profiler_dir=/path/to/save/trace
|
||||
profiler_config = self.vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch" and profiler_config.torch_profiler_dir:
|
||||
if envs_ascend.MSMONITOR_USE_DAEMON:
|
||||
raise RuntimeError("MSMONITOR_USE_DAEMON and torch profiler cannot be both enabled at the same time.")
|
||||
torch_profiler_trace_dir = profiler_config.torch_profiler_dir
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir)
|
||||
def _create_profiler(self, trace_name: str):
|
||||
"""Create torch_npu profiler with trace naming for unique files per worker (RFC #6954)."""
|
||||
profiler_config = self.profiler_config
|
||||
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
export_type=torch_npu.profiler.ExportType.Text,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
|
||||
msprof_tx=False,
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
||||
l2_cache=False,
|
||||
op_attr=False,
|
||||
data_simplification=True,
|
||||
record_op_args=False,
|
||||
gc_detect_threshold=None,
|
||||
)
|
||||
if profiler_config.profiler != "torch":
|
||||
raise RuntimeError(f"Unrecognized profiler: {profiler_config.profiler}")
|
||||
if not profiler_config.torch_profiler_dir:
|
||||
raise RuntimeError("torch_profiler_dir cannot be empty.")
|
||||
if envs_ascend.MSMONITOR_USE_DAEMON:
|
||||
raise RuntimeError("MSMONITOR_USE_DAEMON and torch profiler cannot be both enabled at the same time.")
|
||||
|
||||
return torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU,
|
||||
],
|
||||
with_stack=False,
|
||||
profile_memory=profiler_config.torch_profiler_with_memory,
|
||||
# NOTE: torch_npu.profiler.with_modules is equivalent to torch.profiler.with_stack.
|
||||
# The with_stack option in torch_npu.profiler introduces significant time overhead.
|
||||
with_modules=profiler_config.torch_profiler_with_stack,
|
||||
experimental_config=experimental_config,
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_profiler_trace_dir),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
export_type=torch_npu.profiler.ExportType.Text,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
|
||||
msprof_tx=False,
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
||||
l2_cache=False,
|
||||
op_attr=False,
|
||||
data_simplification=True,
|
||||
record_op_args=False,
|
||||
gc_detect_threshold=None,
|
||||
)
|
||||
|
||||
return torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU,
|
||||
],
|
||||
with_stack=False,
|
||||
profile_memory=profiler_config.torch_profiler_with_memory,
|
||||
# NOTE: torch_npu.profiler.with_modules is equivalent to torch.profiler.with_stack.
|
||||
# The with_stack option in torch_npu.profiler introduces significant time overhead.
|
||||
with_modules=profiler_config.torch_profiler_with_stack,
|
||||
experimental_config=experimental_config,
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
||||
profiler_config.torch_profiler_dir,
|
||||
worker_name=trace_name,
|
||||
),
|
||||
)
|
||||
|
||||
def get_supported_pooling_tasks(self):
|
||||
return self.model_runner.get_supported_pooling_tasks()
|
||||
|
||||
Reference in New Issue
Block a user