[Feat] Adapted mtp function to Qwen3-next (#3918)
### What this PR does / why we need it?
Adapts mtp function to Qwen3-next.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -260,6 +260,24 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
|
||||
# 2.1: process the mutli-query part
|
||||
if spec_sequence_masks is not None:
|
||||
mixed_qkv_spec = mixed_qkv_spec.view(
|
||||
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
|
||||
mixed_qkv_spec = causal_conv1d_update(
|
||||
mixed_qkv_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0]
|
||||
[:attn_metadata.num_spec_decodes],
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
validate_data=False,
|
||||
)
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
|
||||
Reference in New Issue
Block a user