[V1] MTP supports torchair (#2145)
### What this PR does / why we need it?
Support MTP with:
- [x] V0 Scheduler
- [x] TorchAir
- [x] Single DP
- [x] Multi DP
- [x] Disaggregate PD
Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
```
To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
--model="/weights/DeepSeek-R1_w8a8/" \
--trust-remote-code \
--max-model-len 40000 \
--tensor-parallel-size 4 \
--data_parallel_size 4 \
--max-num-seqs 16 \
--no-enable-prefix-caching \
--enable_expert_parallel \
--served-model-name deepseekr1 \
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
--quantization ascend \
--host 0.0.0.0 \
--port 1234 \
--additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
--gpu_memory_utilization 0.9
```
offline example with torchair enabled
```
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
model="/home/data/DeepSeek-R1_w8a8/",
tensor_parallel_size=16,
max_num_seqs=16,
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=False,
max_model_len=2000,
additional_config = {
'torchair_graph_config': {
'enabled': True,
"graph_batch_sizes": [16],
'enable_multistream_shared_expert': False,
},
"ascend_scheduler_config": {
"enabled": True
},
# 'expert_tensor_parallel_size': 16,
}
)
# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.0
- vLLM main:
302962e806
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -206,9 +206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
|
||||
|
||||
# Set up Attention
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
@@ -231,8 +228,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_eagle = False
|
||||
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
||||
MtpProposer]] = None
|
||||
self.actual_seq_lengths_q = []
|
||||
self.spec_token_num = 0
|
||||
self.decode_token_per_req = 1
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
self.decode_token_per_req = 1 + self.spec_token_num
|
||||
self.actual_seq_lengths_q = [
|
||||
len for len in
|
||||
range(self.decode_token_per_req, self.max_num_tokens +
|
||||
1, self.decode_token_per_req)
|
||||
]
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
@@ -253,6 +261,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
f"{self.speculative_config.method}")
|
||||
self.rejection_sampler = AscendRejectionSampler()
|
||||
|
||||
# Persistent batch.
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
@@ -338,9 +347,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
|
||||
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||
|
||||
self.check_torchair_graph_batch_sizes()
|
||||
|
||||
# graph_block_tables shape: [num_request, cell(max_model_len / block_size)]
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req,
|
||||
(self.model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
@@ -558,17 +573,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_token_index:end_token_index] = new_token_ids
|
||||
self.input_batch.num_tokens_no_spec[
|
||||
req_index] = end_token_index
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, ())
|
||||
if spec_token_ids:
|
||||
start_index = end_token_index
|
||||
end_token_index += len(spec_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index,
|
||||
start_index:end_token_index] = spec_token_ids
|
||||
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, ())
|
||||
if spec_token_ids:
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
||||
end_token_index = start_index + num_spec_tokens
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index, start_index:end_token_index] = spec_token_ids
|
||||
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
||||
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
||||
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
@@ -586,6 +601,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, ())
|
||||
if spec_token_ids:
|
||||
req_index = self.input_batch.num_reqs - 1
|
||||
start_index = len(req_state.prompt_token_ids) + len(
|
||||
req_state.output_token_ids)
|
||||
end_token_index = start_index + len(spec_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index, start_index:end_token_index] = spec_token_ids
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
@@ -615,6 +640,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
if self.dp_size == 1:
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
|
||||
@@ -1108,6 +1137,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
# SpecDecoding now supports seq_len=1 and seq_len=2
|
||||
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
if self.use_eagle:
|
||||
@@ -1154,10 +1187,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
|
||||
total_num_scheduled_tokens, with_prefill, enable_dbo)
|
||||
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
|
||||
self.with_prefill = with_prefill
|
||||
self.num_tokens_across_dp = num_tokens_across_dp
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
self.graph_pad_size = padded_num_tokens_across_dp
|
||||
extra_builder_kwargs[
|
||||
'graph_pad_size'] = self.graph_pad_size # type: ignore
|
||||
else:
|
||||
self.graph_pad_size = -1
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
extra_builder_kwargs[
|
||||
@@ -1837,10 +1874,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
if with_prefill:
|
||||
num_reqs = num_tokens
|
||||
else:
|
||||
num_reqs = (num_tokens + self.decode_token_per_req -
|
||||
1) // self.decode_token_per_req
|
||||
num_reqs = min(num_reqs, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
@@ -1852,7 +1896,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
num_reqs=num_reqs, num_actual_tokens=1)
|
||||
elif skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
@@ -1913,6 +1957,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
if self.speculative_config:
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.attn_mask)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(
|
||||
kv, tuple), "kv_cache must be a tuple"
|
||||
@@ -1949,6 +1996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.use_spec_decode and isinstance(
|
||||
self.drafter, EagleProposer):
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
self.drafter.dummy_run(
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
skip_attn=skip_attn,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@contextmanager
|
||||
@@ -2071,9 +2127,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.max_num_reqs:
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
@@ -2537,7 +2593,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
accepted_token_indices = None
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
@@ -2557,14 +2613,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
cu_num_tokens, accepted_token_indices, target_token_ids, \
|
||||
target_positions, target_hidden_states, target_slot_mapping = self.drafter.prepare_inputs(
|
||||
attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
self.input_ids[:num_scheduled_tokens],
|
||||
positions[:num_scheduled_tokens],
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
||||
is_torchair_graph=self.torchair_graph_enabled,
|
||||
)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -2575,7 +2633,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
token_indices=accepted_token_indices)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
@@ -2686,11 +2744,56 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
selected_batch_size = self.max_num_reqs
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size < selected_batch_size:
|
||||
selected_batch_size = padded_batch_size
|
||||
return selected_batch_size
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def check_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def get_supported_pooling_tasks(self):
|
||||
model = self.get_model()
|
||||
|
||||
Reference in New Issue
Block a user