[New model] Qwen3-next support (#2917)
### What this PR does / why we need it?
Add Qwen3-next support.
### Does this PR introduce _any_ user-facing change?
Yes, users can use Qwen3 next.
Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the
tutorial will be ready in
[here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html)
### How was this patch tested?
Doc CI passed
Related: https://github.com/vllm-project/vllm-ascend/issues/2884
Co-Authored-By: Angazenn <supperccell@163.com>
Co-Authored-By: zzzzwwjj <1183291235@qq.com>
Co-Authored-By: MengqingCao <cmq0113@163.com>
Co-Authored-By: linfeng-yuan <1102311262@qq.com>
Co-Authored-By: hust17yixuan <303660421@qq.com>
Co-Authored-By: SunnyLee219 <3294305115@qq.com>
Co-Authored-By: maoxx241 <maoxx241@umn.edu>
- vLLM version: v0.10.2
- vLLM main:
b834b4cbf1
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: hust17yixuan <303660421@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -350,7 +350,8 @@ class EagleProposer(Proposer):
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
attn_metadata_i = self.runner.attn_metadata_builder.build(
|
||||
common_attn_metadata, self.runner.get_model())
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
@@ -436,7 +437,8 @@ class EagleProposer(Proposer):
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
common_attn_metadata, self.runner.model)
|
||||
|
||||
@@ -91,7 +91,7 @@ class MtpProposer(Proposer):
|
||||
target_attn_layer_names)
|
||||
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
||||
self.attn_layer_name = list(draft_attn_layer_names)
|
||||
|
||||
self.model.load_weights(
|
||||
loader.get_all_weights(
|
||||
@@ -186,6 +186,8 @@ class MtpProposer(Proposer):
|
||||
hidden_states: torch.Tensor = None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
@@ -379,9 +381,21 @@ class MtpProposer(Proposer):
|
||||
attn_state=self.runner.attn_state,
|
||||
graph_pad_size=self.runner.graph_pad_size,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
common_attn_metadata, self.runner.get_model())
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
if not self.torchair_graph_enabled:
|
||||
builder = self.runner.attn_groups[0][0].metadata_builder
|
||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
|
||||
else:
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
@@ -392,7 +406,6 @@ class MtpProposer(Proposer):
|
||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._sync_metadata_across_dp(
|
||||
num_tokens, self.runner.with_prefill, False)
|
||||
attn_metadata.slot_mapping = target_slot_mapping
|
||||
else:
|
||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
@@ -466,18 +479,23 @@ class MtpProposer(Proposer):
|
||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||
break
|
||||
|
||||
if not self.torchair_graph_enabled:
|
||||
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
|
||||
else:
|
||||
attn_metadata_i = attn_metadata
|
||||
|
||||
if step == 0:
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
slot_mapping = attn_metadata.slot_mapping[last_token_indices]
|
||||
attn_metadata.slot_mapping.fill_(-1)
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
|
||||
attn_metadata_i.slot_mapping.fill_(-1)
|
||||
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if attn_metadata.num_decode_tokens != 0:
|
||||
attn_metadata.num_decode_tokens = batch_size
|
||||
if attn_metadata_i.num_decode_tokens != 0:
|
||||
attn_metadata_i.num_decode_tokens = batch_size
|
||||
if is_running_torchair:
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.query_lens = [1] * batch_size
|
||||
attn_metadata_i.num_actual_tokens = batch_size
|
||||
attn_metadata_i.query_lens = [1] * batch_size
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
@@ -494,12 +512,12 @@ class MtpProposer(Proposer):
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.seq_lens[:batch_size] += 1
|
||||
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
attn_metadata.seq_lens.device, non_blocking=True)
|
||||
attn_metadata.seq_lens[:batch_size].masked_fill_(
|
||||
attn_metadata_i.seq_lens.device, non_blocking=True)
|
||||
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len_cpu, 1)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
@@ -511,24 +529,24 @@ class MtpProposer(Proposer):
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
attn_metadata.slot_mapping[:batch_size] = slot_mapping
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
attn_metadata.prefill.seq_lens = attn_metadata.seq_lens
|
||||
attn_metadata.prefill.context_lens = attn_metadata.seq_lens
|
||||
attn_metadata.prefill.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata.prefill.max_seq_lens += 1
|
||||
attn_metadata.prefill.max_seq_lens = min(
|
||||
attn_metadata.prefill.max_seq_lens,
|
||||
if attn_metadata_i.prefill is not None:
|
||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.prefill.max_seq_lens += 1
|
||||
attn_metadata_i.prefill.max_seq_lens = min(
|
||||
attn_metadata_i.prefill.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
if attn_metadata.decode is not None:
|
||||
attn_metadata.decode.seq_lens = attn_metadata.seq_lens
|
||||
attn_metadata.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata.decode.max_seq_lens += 1
|
||||
attn_metadata.decode.max_seq_lens = min(
|
||||
attn_metadata.decode.max_seq_lens,
|
||||
if attn_metadata_i.decode is not None:
|
||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.decode.max_seq_lens += 1
|
||||
attn_metadata_i.decode.max_seq_lens = min(
|
||||
attn_metadata_i.decode.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
|
||||
Reference in New Issue
Block a user