[BugFix] Fix Qwen3-next break (#3428)
### What this PR does / why we need it? Fix Qwen3NextGatedDeltaNet, caused by https://github.com/vllm-project/vllm/pull/26437 ### How was this patch tested? ``` def main(): prompts = [ "窗前明月光,", "The president of the United States is Mr.", "The capital of France is", "The future of AI is", "感时花溅泪,", "家书抵万金啥意思?", "plz tell me a story: ", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64 ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -51,6 +51,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||||
|
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
from vllm.model_executor.models.qwen3_next import ( # isort: skip
|
from vllm.model_executor.models.qwen3_next import ( # isort: skip
|
||||||
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
|
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
|
||||||
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
|
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
|
||||||
@@ -201,7 +203,11 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
|||||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||||
spec_token_masks = attn_metadata.spec_token_masks
|
if vllm_version_is("0.11.0"):
|
||||||
|
spec_token_masks = attn_metadata.spec_token_masks
|
||||||
|
else:
|
||||||
|
spec_token_indx = attn_metadata.spec_token_indx
|
||||||
|
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
@@ -216,8 +222,9 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
|||||||
|
|
||||||
# 1. Set up dimensions for reshapes later
|
# 1. Set up dimensions for reshapes later
|
||||||
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
||||||
if spec_token_masks is not None:
|
if vllm_version_is("0.11.0"):
|
||||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
if spec_token_masks is not None:
|
||||||
|
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||||
projected_states_qkvz, projected_states_ba = torch.split(
|
projected_states_qkvz, projected_states_ba = torch.split(
|
||||||
projected_states,
|
projected_states,
|
||||||
[
|
[
|
||||||
@@ -242,8 +249,13 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
|||||||
mixed_qkv_spec = mixed_qkv
|
mixed_qkv_spec = mixed_qkv
|
||||||
mixed_qkv_non_spec = None
|
mixed_qkv_non_spec = None
|
||||||
else:
|
else:
|
||||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
if vllm_version_is("0.11.0"):
|
||||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||||
|
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||||
|
else:
|
||||||
|
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
|
||||||
|
mixed_qkv_non_spec = mixed_qkv.index_select(
|
||||||
|
0, non_spec_token_indx)
|
||||||
else:
|
else:
|
||||||
mixed_qkv_spec = None
|
mixed_qkv_spec = None
|
||||||
mixed_qkv_non_spec = mixed_qkv
|
mixed_qkv_non_spec = mixed_qkv
|
||||||
@@ -293,10 +305,16 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
|||||||
g_non_spec = None
|
g_non_spec = None
|
||||||
beta_non_spec = None
|
beta_non_spec = None
|
||||||
else:
|
else:
|
||||||
g_spec = g[:, spec_token_masks]
|
if vllm_version_is("0.11.0"):
|
||||||
beta_spec = beta[:, spec_token_masks]
|
g_spec = g[:, spec_token_masks]
|
||||||
g_non_spec = g[:, ~spec_token_masks]
|
beta_spec = beta[:, spec_token_masks]
|
||||||
beta_non_spec = beta[:, ~spec_token_masks]
|
g_non_spec = g[:, ~spec_token_masks]
|
||||||
|
beta_non_spec = beta[:, ~spec_token_masks]
|
||||||
|
else:
|
||||||
|
g_spec = g.index_select(1, spec_token_indx)
|
||||||
|
beta_spec = beta.index_select(1, spec_token_indx)
|
||||||
|
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||||
|
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||||
else:
|
else:
|
||||||
g_spec = None
|
g_spec = None
|
||||||
beta_spec = None
|
beta_spec = None
|
||||||
@@ -404,8 +422,14 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
|||||||
dtype=core_attn_out_non_spec.dtype,
|
dtype=core_attn_out_non_spec.dtype,
|
||||||
device=core_attn_out_non_spec.device,
|
device=core_attn_out_non_spec.device,
|
||||||
)
|
)
|
||||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
if vllm_version_is("0.11.0"):
|
||||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||||
|
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||||
|
else:
|
||||||
|
core_attn_out.index_copy_(1, spec_token_indx,
|
||||||
|
core_attn_out_spec)
|
||||||
|
core_attn_out.index_copy_(1, non_spec_token_indx,
|
||||||
|
core_attn_out_non_spec)
|
||||||
elif spec_sequence_masks is not None:
|
elif spec_sequence_masks is not None:
|
||||||
core_attn_out = core_attn_out_spec
|
core_attn_out = core_attn_out_spec
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user