[Kernel] Optimize the selection and update OP of ssm state
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
"""Inference-only Qwen3Next model."""
|
"""Inference-only Qwen3Next model."""
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -69,10 +69,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer,
|
|||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
from vllm_kunlun.ops.activation import SiluAndMul
|
from vllm_kunlun.ops.activation import SiluAndMul
|
||||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
|
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
|
||||||
|
import xtorch_ops
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend="aot_eager")
|
@torch.compile(dynamic=True, backend="aot_eager")
|
||||||
@@ -613,13 +611,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
|
|
||||||
# 3.2: process the remaining part
|
# 3.2: process the remaining part
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
initial_state = ssm_state[
|
if non_spec_state_indices_tensor.shape[0] > 100:
|
||||||
non_spec_state_indices_tensor].contiguous()
|
initial_state = ssm_state[
|
||||||
initial_state[~has_initial_state, ...] = 0
|
non_spec_state_indices_tensor].contiguous()
|
||||||
|
else:
|
||||||
|
initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ]
|
||||||
|
initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device)
|
||||||
|
for i in range(non_spec_state_indices_tensor.shape[0]):
|
||||||
|
initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]]
|
||||||
|
|
||||||
|
initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1)
|
||||||
initial_state = initial_state.transpose(-1, -2).contiguous()
|
initial_state = initial_state.transpose(-1, -2).contiguous()
|
||||||
if self.num_v_heads // self.num_k_heads > 1:
|
|
||||||
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
|
||||||
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
|
||||||
(
|
(
|
||||||
core_attn_out_non_spec,
|
core_attn_out_non_spec,
|
||||||
last_recurrent_state,
|
last_recurrent_state,
|
||||||
@@ -635,9 +637,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
cu_seqlens=non_spec_query_start_loc,
|
cu_seqlens=non_spec_query_start_loc,
|
||||||
)
|
)
|
||||||
# Init cache
|
# Init cache
|
||||||
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous()
|
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
|
||||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
|
||||||
ssm_state.dtype)
|
cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
|
||||||
|
xtorch_ops.reshape_and_cache_flash(
|
||||||
|
last_recurrent_state,
|
||||||
|
last_recurrent_state,
|
||||||
|
cast_ssm_state,
|
||||||
|
cast_ssm_state,
|
||||||
|
non_spec_state_indices_tensor)
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
core_attn_out_non_spec, last_recurrent_state = (
|
core_attn_out_non_spec, last_recurrent_state = (
|
||||||
fused_recurrent_gated_delta_rule(
|
fused_recurrent_gated_delta_rule(
|
||||||
|
|||||||
Reference in New Issue
Block a user