[Performance] Add async exponential while model executing (#4501)
### What this PR does / why we need it?
Add a control to enable the exponential distribution operator
overlapping with model executing (default is OFF due to this feature
might not perform well on MOE models, i.e. For Qwen3-30B).
Enable async exponential overlapping will provides performance
improvement.
Also, overlapping the exponential operator with module execution can
cover the performance drop introduced by AICPU-version's exponential
operator.
**UPDATE**: (12/12)
Now our overlap will use the same stream that introduced in this pr:
#4908 .
We move the `do_async_exponential` from `model_runner_v1.py` to
`sampler.py`.
Now we are using `additional_config` to enable async exponential:
Add `"enable_async_exponential": 1` in `addition_config`.
Now we **ONLY** support default exponential/AI-CPU exponential, the old
`"enable_async_exponential": 2` option has been aborted to keep
consistency.
### Does this PR introduce _any_ user-facing change?
**YES**, added a new `additional_config` : `"enable_async_exponential":
1`.
When `enable_async_exponential` is set to 1, we enable the async
exponential and overlap with model runner.
When `enable_async_exponential` is set to 0 (default is 0), we disable
the async exponential, but exponential will still running on a different
stream using stream introduced in #4908.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: YuhanBai <yuhan.bai0830@gmail.com>
Signed-off-by: YuhanBai yuhan.bai0830@gmail.com
This commit is contained in:
@@ -42,6 +42,7 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
|
||||
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
|
||||
| `dump_config` | str | `None` | Configuration file path for msprobe dump(eager mode). |
|
||||
| `enable_async_exponential` | int | `0` | Whether to enable async exponential overlap. To enable async exponential, set this config to 1. |
|
||||
|
||||
The details of each configuration option are as follows:
|
||||
|
||||
|
||||
@@ -47,3 +47,21 @@ def test_models_prompt_logprobs() -> None:
|
||||
runner.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens=5,
|
||||
num_logprobs=1)
|
||||
|
||||
|
||||
def test_exponential_overlap() -> None:
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=0.9)
|
||||
|
||||
with VllmRunner("Qwen/Qwen3-0.6B",
|
||||
max_model_len=8192,
|
||||
gpu_memory_utilization=0.7,
|
||||
additional_config={
|
||||
"enable_async_exponential": 1,
|
||||
}) as runner:
|
||||
runner.generate(example_prompts, sampling_params)
|
||||
|
||||
@@ -161,6 +161,11 @@ class AscendConfig:
|
||||
False):
|
||||
kv_cfg.engine_id = f"{kv_cfg.engine_id}-{uuid4().hex}"
|
||||
kv_cfg._engine_id_patched = True
|
||||
self.enable_async_exponential = additional_config.get(
|
||||
"enable_async_exponential", 0)
|
||||
if self.enable_async_exponential not in (0, 1):
|
||||
raise AssertionError(
|
||||
"Enable async exponential can only be set to 0 or 1.")
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch_npu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
global_stream, npu_stream_switch)
|
||||
|
||||
@@ -41,10 +42,35 @@ class AscendSampler(Sampler):
|
||||
# TODO: support logprobs_mode in vllm-ascend
|
||||
super().__init__(logprobs_mode=logprobs_mode)
|
||||
self.topk_topp_sampler = AscendTopKTopPSampler()
|
||||
self.async_exponential_event = torch.npu.Event()
|
||||
|
||||
def set_q_event(self, q, event):
|
||||
self.topk_topp_sampler.set_q_event(q, event)
|
||||
|
||||
def do_async_exponential(self, b_s, head_dim, generators):
|
||||
# Calculating exponential randoms in a different stream
|
||||
# and overlapping with model executing.
|
||||
with torch.npu.stream(global_stream()):
|
||||
global_stream().wait_stream(torch.npu.current_stream())
|
||||
q = torch.empty((b_s, head_dim), device="npu", dtype=torch.float32)
|
||||
# Goes to async exponential with AI-CPU exponential or default exponential.
|
||||
if len(generators) != q.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
self.async_exponential_event.record()
|
||||
self.set_q_event(q, self.async_exponential_event)
|
||||
|
||||
|
||||
class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
|
||||
def set_q_event(self, q, event):
|
||||
# Pass in async exponential results.
|
||||
# Also pass in event to prevent synchronize errors.
|
||||
self.q = q
|
||||
self.async_event = event
|
||||
|
||||
def _apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@@ -99,4 +125,8 @@ class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if get_ascend_config().enable_async_exponential == 1:
|
||||
# Add synchronize to prevent synchronize error.
|
||||
self.async_event.synchronize()
|
||||
return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
@@ -1385,6 +1385,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||
|
||||
if self.ascend_config.enable_async_exponential != 0:
|
||||
self.sampler.do_async_exponential(
|
||||
b_s=logits_indices.shape[0],
|
||||
head_dim=self.model_config.get_vocab_size(),
|
||||
generators=self.input_batch.sampling_metadata.generators)
|
||||
|
||||
# Run forward pass
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
with set_ascend_forward_context(
|
||||
|
||||
Reference in New Issue
Block a user