[New model] Qwen3-next support (#2917)

### What this PR does / why we need it?
Add Qwen3-next support.

### Does this PR introduce _any_ user-facing change?
Yes, users can use Qwen3 next.
Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the
tutorial will be ready in
[here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html)

### How was this patch tested?
Doc CI passed

Related: https://github.com/vllm-project/vllm-ascend/issues/2884

Co-Authored-By: Angazenn <supperccell@163.com>
Co-Authored-By: zzzzwwjj <1183291235@qq.com>
Co-Authored-By: MengqingCao <cmq0113@163.com>
Co-Authored-By: linfeng-yuan <1102311262@qq.com>
Co-Authored-By: hust17yixuan <303660421@qq.com>
Co-Authored-By: SunnyLee219 <3294305115@qq.com>
Co-Authored-By: maoxx241 <maoxx241@umn.edu>


- vLLM version: v0.10.2
- vLLM main:
b834b4cbf1

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: hust17yixuan <303660421@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
wangxiyuan
2025-09-16 01:17:42 +08:00
committed by GitHub
parent b5ccef6115
commit c556038ef0
26 changed files with 3959 additions and 258 deletions

View File

@@ -350,7 +350,8 @@ class EagleProposer(Proposer):
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
num_computed_tokens_cpu=None,
seq_lens=None)
attn_metadata_i = self.runner.attn_metadata_builder.build(
common_attn_metadata, self.runner.get_model())
for layer_name in kv_cache_group_spec.layer_names:
@@ -436,7 +437,8 @@ class EagleProposer(Proposer):
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
num_computed_tokens_cpu=None,
seq_lens=None)
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
attn_metadata = self.runner.attn_metadata_builder.build(
common_attn_metadata, self.runner.model)

View File

@@ -91,7 +91,7 @@ class MtpProposer(Proposer):
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names))
self.attn_layer_name = list(draft_attn_layer_names)
self.model.load_weights(
loader.get_all_weights(
@@ -186,6 +186,8 @@ class MtpProposer(Proposer):
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
@@ -379,9 +381,21 @@ class MtpProposer(Proposer):
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
)
attn_metadata = self.runner.attn_metadata_builder.build(
common_attn_metadata, self.runner.get_model())
num_computed_tokens_cpu=None,
seq_lens=None)
if not self.torchair_graph_enabled:
builder = self.runner.attn_groups[0][0].metadata_builder
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
else:
attn_metadata = self.runner.attn_metadata_builder.build(
0, common_attn_metadata, self.runner.get_model())
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
@@ -392,7 +406,6 @@ class MtpProposer(Proposer):
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(
num_tokens, self.runner.with_prefill, False)
attn_metadata.slot_mapping = target_slot_mapping
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
@@ -466,18 +479,23 @@ class MtpProposer(Proposer):
if step == self.num_speculative_tokens - 1 or with_prefill:
break
if not self.torchair_graph_enabled:
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
else:
attn_metadata_i = attn_metadata
if step == 0:
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
slot_mapping = attn_metadata.slot_mapping[last_token_indices]
attn_metadata.slot_mapping.fill_(-1)
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
attn_metadata_i.slot_mapping.fill_(-1)
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata.num_decode_tokens != 0:
attn_metadata.num_decode_tokens = batch_size
if attn_metadata_i.num_decode_tokens != 0:
attn_metadata_i.num_decode_tokens = batch_size
if is_running_torchair:
attn_metadata.num_actual_tokens = batch_size
attn_metadata.query_lens = [1] * batch_size
attn_metadata_i.num_actual_tokens = batch_size
attn_metadata_i.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
@@ -494,12 +512,12 @@ class MtpProposer(Proposer):
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata.seq_lens[:batch_size] += 1
attn_metadata_i.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata.seq_lens.device, non_blocking=True)
attn_metadata.seq_lens[:batch_size].masked_fill_(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
@@ -511,24 +529,24 @@ class MtpProposer(Proposer):
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
attn_metadata.slot_mapping[:batch_size] = slot_mapping
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if attn_metadata.prefill is not None:
attn_metadata.prefill.seq_lens = attn_metadata.seq_lens
attn_metadata.prefill.context_lens = attn_metadata.seq_lens
attn_metadata.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata.prefill.max_seq_lens += 1
attn_metadata.prefill.max_seq_lens = min(
attn_metadata.prefill.max_seq_lens,
if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.prefill.max_seq_lens += 1
attn_metadata_i.prefill.max_seq_lens = min(
attn_metadata_i.prefill.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata.decode is not None:
attn_metadata.decode.seq_lens = attn_metadata.seq_lens
attn_metadata.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata.decode.max_seq_lens += 1
attn_metadata.decode.max_seq_lens = min(
attn_metadata.decode.max_seq_lens,
if attn_metadata_i.decode is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]