[BugFix] Fix Qwen3-next break (#3428)
### What this PR does / why we need it? Fix Qwen3NextGatedDeltaNet, caused by https://github.com/vllm-project/vllm/pull/26437 ### How was this patch tested? ``` def main(): prompts = [ "窗前明月光,", "The president of the United States is Mr.", "The capital of France is", "The future of AI is", "感时花溅泪,", "家书抵万金啥意思?", "plz tell me a story: ", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64 ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -51,6 +51,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
from vllm.model_executor.models.qwen3_next import ( # isort: skip
|
||||
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
|
||||
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
|
||||
@@ -201,7 +203,11 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
spec_token_masks = attn_metadata.spec_token_masks
|
||||
if vllm_version_is("0.11.0"):
|
||||
spec_token_masks = attn_metadata.spec_token_masks
|
||||
else:
|
||||
spec_token_indx = attn_metadata.spec_token_indx
|
||||
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
@@ -216,8 +222,9 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
if vllm_version_is("0.11.0"):
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
projected_states_qkvz, projected_states_ba = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
@@ -242,8 +249,13 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
mixed_qkv_non_spec = None
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||
if vllm_version_is("0.11.0"):
|
||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
|
||||
mixed_qkv_non_spec = mixed_qkv.index_select(
|
||||
0, non_spec_token_indx)
|
||||
else:
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
@@ -293,10 +305,16 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g[:, spec_token_masks]
|
||||
beta_spec = beta[:, spec_token_masks]
|
||||
g_non_spec = g[:, ~spec_token_masks]
|
||||
beta_non_spec = beta[:, ~spec_token_masks]
|
||||
if vllm_version_is("0.11.0"):
|
||||
g_spec = g[:, spec_token_masks]
|
||||
beta_spec = beta[:, spec_token_masks]
|
||||
g_non_spec = g[:, ~spec_token_masks]
|
||||
beta_non_spec = beta[:, ~spec_token_masks]
|
||||
else:
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
@@ -404,8 +422,14 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
dtype=core_attn_out_non_spec.dtype,
|
||||
device=core_attn_out_non_spec.device,
|
||||
)
|
||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||
if vllm_version_is("0.11.0"):
|
||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||
else:
|
||||
core_attn_out.index_copy_(1, spec_token_indx,
|
||||
core_attn_out_spec)
|
||||
core_attn_out.index_copy_(1, non_spec_token_indx,
|
||||
core_attn_out_non_spec)
|
||||
elif spec_sequence_masks is not None:
|
||||
core_attn_out = core_attn_out_spec
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user