### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/ops/layer_shard_linear.py`|
|`vllm_ascend/ops/linear.py`|
|`vllm_ascend/ops/linear_op.py`|
|`vllm_ascend/worker/worker.py`|
| ` vllm_ascend/patch/worker/patch_bert.py` |
| ` vllm_ascend/patch/worker/patch_deepseek.py` |
| ` vllm_ascend/patch/worker/patch_distributed.py` |
| ` vllm_ascend/patch/worker/patch_module.py` |
| ` vllm_ascend/patch/worker/patch_multimodal_merge.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` |
| ` vllm_ascend/patch/worker/patch_rejection_sampler.py` |
| ` vllm_ascend/patch/worker/patch_rope.py` |
| ` vllm_ascend/patch/worker/patch_triton.py` |
| ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` |
| ` vllm_ascend/patch/worker/patch_v2_egale.py` |
|` vllm_ascend/worker/npu_input_batch.py`|
|` vllm_ascend/worker/v2/aclgraph_utils.py`|
|` vllm_ascend/worker/v2/attn_utils.py`|
|` vllm_ascend/worker/v2/model_runner.py`|
|` vllm_ascend/worker/v2/sample/gumbel.py`|
|` vllm_ascend/worker/v2/sample/penalties.py`|
|` vllm_ascend/worker/v2/sample/sampler.py`|
|` vllm_ascend/worker/v2/spec_decode/__init__.py`|
|` vllm_ascend/worker/v2/spec_decode/eagle.py`|
|` vllm_ascend/worker/v2/states.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -13,33 +13,27 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#from collections.abc import Iterable
|
||||
# from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.models.qwen3_next import (Qwen3NextGatedDeltaNet,
|
||||
fused_gdn_gating)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import \
|
||||
fused_qkvzba_split_reshape_cat
|
||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
|
||||
fused_sigmoid_gating_delta_rule_update
|
||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||
|
||||
|
||||
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -61,10 +55,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
forward_context = get_forward_context()
|
||||
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
# triton grid should be less than 66536
|
||||
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(
|
||||
self.num_k_heads, self.tp_size)
|
||||
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and \
|
||||
is_cuda_graph and divide_grid < 65536:
|
||||
divide_grid = projected_states_qkvz.shape[0] * triton.cdiv(self.num_k_heads, self.tp_size)
|
||||
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536:
|
||||
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
|
||||
projected_states_qkvz,
|
||||
projected_states_ba,
|
||||
@@ -74,10 +66,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
self.head_v_dim,
|
||||
)
|
||||
else:
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||
(query, key, value))
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value))
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
# ============================================================
|
||||
@@ -150,16 +140,14 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
a = a[:num_actual_tokens]
|
||||
|
||||
# 1. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
mixed_qkv_non_spec = None
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
|
||||
mixed_qkv_non_spec = mixed_qkv.index_select(
|
||||
0, non_spec_token_indx)
|
||||
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
|
||||
else:
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
@@ -172,8 +160,7 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0]
|
||||
[:attn_metadata.num_spec_decodes],
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes],
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
query_start_loc=spec_query_start_loc,
|
||||
max_query_len=spec_state_indices_tensor.size(-1),
|
||||
@@ -204,21 +191,16 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=
|
||||
non_spec_state_indices_tensor[:attn_metadata.
|
||||
num_actual_tokens],
|
||||
conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens],
|
||||
validate_data=True,
|
||||
)
|
||||
else:
|
||||
mixed_qkv_non_spec = None
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_non_spec)
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)
|
||||
|
||||
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
|
||||
self.dt_bias)
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
@@ -248,8 +230,7 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[:attn_metadata.
|
||||
num_spec_decodes + 1],
|
||||
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
@@ -259,8 +240,7 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
# 2.2: Process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[
|
||||
non_spec_state_indices_tensor].contiguous()
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
@@ -278,24 +258,20 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
# Init cache
|
||||
ssm_state[
|
||||
non_spec_state_indices_tensor] = last_recurrent_state.to(
|
||||
ssm_state.dtype)
|
||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[:attn_metadata.
|
||||
num_decodes + 1],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
))
|
||||
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
core_attn_out_non_spec, last_recurrent_state = None, None
|
||||
|
||||
@@ -324,14 +300,12 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
device=core_attn_out_non_spec.device,
|
||||
)
|
||||
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
|
||||
merged_out.index_copy_(1, non_spec_token_indx,
|
||||
core_attn_out_non_spec)
|
||||
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
|
||||
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
|
||||
elif spec_sequence_masks is not None:
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
|
||||
else:
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(
|
||||
0)
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
||||
|
||||
|
||||
Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward
|
||||
|
||||
Reference in New Issue
Block a user