[Triton]support swiglu_quant triton in w4a8 (#5161)

### What this PR does / why we need it?
support swiglu_quant triton in w4a8
### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-12-22 16:01:58 +08:00
committed by GitHub
parent 60d9398f6d
commit 12d581605b
3 changed files with 132 additions and 3 deletions

View File

@@ -20,6 +20,7 @@ import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
@@ -243,9 +244,17 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
if HAS_TRITON:
from vllm_ascend.ops.triton.activation.swiglu_quant import \
swiglu_quant
hidden_states, swiglu_out_scale = swiglu_quant(
hidden_states,
group_list=group_list,
group_list_type=group_list_type)
else:
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],