From 7221045777bd4c3d77037fcb20d9bddd7b4dba3c Mon Sep 17 00:00:00 2001 From: "jiahao.quan" <143778363+mikequan0425@users.noreply.github.com> Date: Thu, 12 Feb 2026 10:55:34 +0800 Subject: [PATCH] [Attention] add gpt-oss support (#5901) ### What this PR does / why we need it? Please refer to the following link for the historical conversation https://github.com/vllm-project/vllm-ascend/pull/4467. We have made updates in light of the comments from the prior PR review. Given the refactoring of the attention_v1 component, we have carried out necessary adjustments to fit the newly revised code. ### Does this PR introduce _any_ user-facing change? 1. Modified the code in the Attention section to adapt to the SWA and Sink features required by gpt-oss. 2. Modified the code in the MoE section to add support for bias and swigluoai. ### How was this patch tested? Please refer to the https://github.com/vllm-project/vllm-ascend/pull/4467 for performance tests, on the basis of which the accuracy tests from AIME2024 have been newly added. ![img_v3_02tu_501e88e3-2217-4565-8edf-b9acf4f43f2g](https://github.com/user-attachments/assets/024f8283-18ab-4d4d-ab12-27917b5d7d06) - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: wangxiyuan Signed-off-by: mikequan0425 Signed-off-by: hfadzxy Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Signed-off-by: jiangyunfan1 Signed-off-by: pu-zhe Signed-off-by: liziyu Signed-off-by: wangxiaoteng Signed-off-by: luomin2005 Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: SlightwindSec Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: wangxiyuan Co-authored-by: leon_tao Co-authored-by: nurxat <738457498@qq.com> Co-authored-by: hfadzxy Co-authored-by: mikequan <199741451@qq.com> Co-authored-by: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Co-authored-by: jiangyunfan1 Co-authored-by: pu-zhe Co-authored-by: luomin2005 Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com> Co-authored-by: wangxiaoteng Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: Cao Yi Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: SILONG ZENG <2609716663@qq.com> --- .github/workflows/scripts/config.yaml | 2 + .../test_offline_inference_distributed.py | 18 +++++- vllm_ascend/attention/attention_v1.py | 64 ++++++++++++++----- vllm_ascend/ops/activation.py | 13 +++- vllm_ascend/ops/fused_moe/fused_moe.py | 4 ++ vllm_ascend/ops/fused_moe/moe_comm_method.py | 7 ++ vllm_ascend/ops/fused_moe/moe_mlp.py | 19 +++++- vllm_ascend/ops/rotary_embedding.py | 3 + 8 files changed, 111 insertions(+), 19 deletions(-) diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index a7f0cf49..8dcc20de 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -68,6 +68,8 @@ e2e-2card-light: estimated_time: 220 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek3_2_w8a8_pruning_mtp_tp2_ep estimated_time: 90 + - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_gpt_oss_distributed_tp2 + estimated_time: 180 e2e-multicard-2-cards: # TODO: recover skipped tests diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index 8226237a..d006f5db 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -22,7 +22,6 @@ Run `pytest tests/test_offline_inference.py`. """ import os from unittest.mock import patch - import pytest from vllm import SamplingParams @@ -48,6 +47,9 @@ DEEPSEEK_W4A8_MODELS = [ "vllm-ascend/DeepSeek-V3.1-W4A8-puring", ] +GPT_OSS_MODELS = [ + "unsloth/gpt-oss-20b-BF16", +] def test_deepseek_multistream_moe_tp2(): example_prompts = [ @@ -289,3 +291,17 @@ def test_qwen3_w4a4_distributed_tp2(model): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +@pytest.mark.parametrize("model", GPT_OSS_MODELS) +def test_gpt_oss_distributed_tp2(model): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + model, + tensor_parallel_size=2, + enforce_eager=True, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6955ccea..988f40d8 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -350,6 +350,7 @@ class AscendAttentionBackendImpl(AttentionImpl): logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, + sinks: torch.Tensor = None, **kwargs, ) -> None: self.vllm_config = get_current_vllm_config() @@ -372,6 +373,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) + self.sinks = sinks @staticmethod def update_graph_params( @@ -766,6 +768,7 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata.attn_state == AscendAttentionState.DecodeOnly and self.sliding_window is not None and attn_metadata.seq_lens.shape[0] == query.size(0) + and self.sinks is None ): return self._forward_fia_slidingwindow(query, attn_metadata, output) key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata) @@ -778,23 +781,52 @@ class AscendAttentionBackendImpl(AttentionImpl): key = key[:num_tokens] value = value[:num_tokens] # Get workspace from cache or calculate it if not present. - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=attn_metadata.attn_mask, - block_table=block_table, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=actual_seq_lengths_kv, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) + if self.sinks is not None: + actual_seq_qlen = attn_metadata.actual_seq_lengths_q + if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + actual_seq_qlen = torch.tensor([1] * len(attn_metadata.seq_lens_list), dtype=torch.int32).cumsum(dim=0) + if self.sliding_window is not None: + atten_mask = attn_metadata.swa_mask + sparse_mode = 4 + else: + atten_mask = attn_metadata.attn_mask + sparse_mode = 3 + attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2( + query, + key, + value, + num_query_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="TND", + pre_tokens=self.sliding_window if self.sliding_window is not None else SWA_INT_MAX, + next_tokens=0, + atten_mask=atten_mask, + sparse_mode=sparse_mode, + softmax_scale=self.scale, + block_table=block_table, + block_size=block_size, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_lengths_kv, + learnable_sink=self.sinks, + ) + else: + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) - attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) + attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) output[:num_tokens] = attn_output[:num_tokens] return output diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index e8b685c3..30b6b3c1 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -16,7 +16,7 @@ # import torch -from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul +from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul, SwigluOAIAndMul from vllm_ascend.utils import get_weight_prefetch_method @@ -38,3 +38,14 @@ class AscendSiluAndMul(SiluAndMul): out = torch_npu.npu_swiglu(x) weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out) return out + + +class AscendSwigluOAIAndMul: + def swiglu_oai_forward(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor: + class MinimalSwigluOAIAndMul: + def __init__(self): + self.alpha = alpha + self.limit = limit + + layer = MinimalSwigluOAIAndMul() + return SwigluOAIAndMul.forward_native(layer, x) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 26c59953..b1853b4a 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -94,6 +94,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, + activation: str = "silu", enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, **kwargs, @@ -137,6 +138,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, + w1_bias=layer.w13_bias if self.moe.has_bias else None, + w2_bias=layer.w2_bias if self.moe.has_bias else None, + activation=activation, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 59721d1c..7ca34e74 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -110,6 +110,8 @@ class MoECommMethod(ABC): topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, apply_router_weight_on_input: bool = False, use_int8_w8a8: bool = False, use_int4_w4a8: bool = False, @@ -158,6 +160,9 @@ class MoECommMethod(ABC): w1_scale=w1_scale, w2=w2, w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + activation=activation, group_list=dispatch_results.group_list, dynamic_scale=dispatch_results.dynamic_scale, group_list_type=dispatch_results.group_list_type, @@ -286,6 +291,8 @@ class FusedMC2CommImpl(MoECommMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, apply_router_weight_on_input: bool = False, use_int8_w8a8: bool = False, use_int4_w4a8: bool = False, diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index a3d3af6b..7b086b46 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -22,6 +22,7 @@ from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.activation import AscendSwigluOAIAndMul from vllm_ascend.utils import ( dispose_tensor, enable_custom_op, @@ -270,6 +271,9 @@ def unquant_apply_mlp( w1: torch.Tensor, w2: torch.Tensor, group_list: torch.Tensor, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + activation: str | None = None, group_list_type: int = 1, topk_scales: torch.Tensor | None = None, need_trans: bool = True, @@ -281,12 +285,18 @@ def unquant_apply_mlp( gate_up_out = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], + bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None, split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, )[0] - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + if activation == "swigluoai": + num_experts, _, hidden_size = w1.shape + gate_up_out = AscendSwigluOAIAndMul.swiglu_oai_forward(gate_up_out.view(-1, hidden_size)) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) if topk_scales is not None: gate_up_out *= topk_scales @@ -294,6 +304,7 @@ def unquant_apply_mlp( hidden_states = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], + bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None, split_item=2, group_list_type=group_list_type, group_type=0, @@ -309,6 +320,9 @@ def unified_apply_mlp( group_list: torch.Tensor, w1_scale: list[torch.Tensor] | None = None, w2_scale: list[torch.Tensor] | None = None, + activation: str | None = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, dynamic_scale: torch.Tensor = None, group_list_type: int = 1, w1_scale_bias: torch.Tensor = None, @@ -344,6 +358,9 @@ def unified_apply_mlp( hidden_states=hidden_states, w1=w1, w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + activation=activation, group_list=group_list, group_list_type=group_list_type, topk_scales=topk_scales, diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ec5aa665..042bc57e 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -256,12 +256,15 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding): attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, + truncate: bool = False, ) -> None: extra_kwargs = { "extrapolation_factor": extrapolation_factor, "attn_factor": attn_factor, "beta_fast": beta_fast, "beta_slow": beta_slow, + # TODO: current not support actual truncate,adaptation for extra parameters to be compatible with vllm + "truncate": truncate, } super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, scaling_factor, dtype, **extra_kwargs