[feature] vllm-ascend support msprobe (eager mode dump) (#4241)

### What this PR does / why we need it?
vllm-ascend need to dump data during model execution to debug some
precision problems, here msprobe provide the corresponding abilities, so
msprobe will join vllm-ascend to make debug easier

### Does this PR introduce _any_ user-facing change?
```
'dump_config': '/path/to/config.json'
```



- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: Tjh-UKN <2559659915@qq.com>
This commit is contained in:
Tjh-UKN
2025-11-24 21:58:31 +08:00
committed by GitHub
parent 5b1a7514eb
commit 00ea61ec88
17 changed files with 1385 additions and 159 deletions

View File

@@ -311,6 +311,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.runner_only_attn_layers: set[str] = set()
# Ascend-specific configurations
self.ascend_config = get_ascend_config()
if self.ascend_config.ascend_scheduler_config.enabled:
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
@@ -318,6 +319,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.chunked_prefill_enabled = True
self.weight_prefetch_method = WeightPrefetchMethod(
self.ascend_config.weight_prefetch_config)
# Dump / PrecisionDebugger configuration now comes from AscendConfig
dump_cfg = self.ascend_config.dump_config
self.dump_enable = dump_cfg.enable_dump
self.debugger = None
if self.dump_enable:
if self.model_config.enforce_eager:
from msprobe.pytorch import PrecisionDebugger
self.debugger = PrecisionDebugger(dump_cfg.config_path)
else:
raise RuntimeError(
"Dumping/debugging only works in eager mode.")
if self.cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
@@ -2284,6 +2296,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.eplb_updator.take_update_info_from_eplb_process()
moe_comm_type = self._select_moe_comm_method(num_input_tokens)
# prevent debugger is None
need_dump = self.dump_enable and self.debugger is not None
if need_dump:
assert self.debugger is not None
dbg_cfg = getattr(self.debugger, "config", None)
dump_level = str(
getattr(dbg_cfg, "level",
"L1")).upper() if dbg_cfg is not None else "L1"
if dump_level in ("L0", "MIX"):
self.debugger.start(model=self.model)
else:
self.debugger.start()
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
@@ -2341,6 +2365,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
if need_dump:
assert self.debugger is not None
self.debugger.stop()
self.debugger.step()
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(
@@ -2348,11 +2376,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(
pool_output = self._pool(
hidden_states,
scheduler_output.total_num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving, kv_connector_output)
if need_dump:
assert self.debugger is not None
self.debugger.stop()
self.debugger.step()
return pool_output
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if broadcast_pp_output:
@@ -2558,8 +2591,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.dynamic_eplb:
self.eplb_updator.forward_end()
if not self.use_async_scheduling:
if need_dump:
assert self.debugger is not None
self.debugger.stop()
self.debugger.step()
return model_runner_output
if need_dump:
assert self.debugger is not None
self.debugger.stop()
self.debugger.step()
return AsyncNPUModelRunnerOutput(
model_runner_output=model_runner_output,
sampled_token_ids=sampled_token_ids,