[OPS]add split_qkv_tp_rmsnorm_rope ops (#7376)
### What this PR does / why we need it?
This PR introduces a new fused Triton kernel,
`split_qkv_tp_rmsnorm_rope` for Minimax-m2.5.
The implementation includes two Triton kernels:
1. `_split_qkv_and_compute_local_qk_var_kernel`: Splits the QKV input
and computes the local variance for RMSNorm.
2. `_apply_global_rmsnorm_kernel`: Applies global RMSNorm (considering
TP all-reduce for variance) and Neox-style RoPE.
### Does this PR introduce _any_ user-facing change?
Does not.
### How was this patch tested?
```python
pytest tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_tp_rmsnorm_rope.py
```
### Test Data
A3 TP16
基线
| data | TTFT(ms) | TPOT(ms) | TPS |
|------------|---------:|---------:|-------:|
| 4k/1k@bs1 | 267.55 | 25.5 | 38.85 |
| 4k/1k@bs4 | 542.4 | 26.51 | 148.06 |
测试线
| data | TTFT(ms) | TPOT(ms) | TPS |
|------------|---------:|---------:|-------:|
| 4k/1k@bs1 | 234.64 | 20.96 | 47.24 |
| 4k/1k@bs4 | 508.36 | 22.16 | 176.69 |
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: xutianyi <xutianyi5@huawei.com>
Co-authored-by: xutianyi <xutianyi5@huawei.com>
This commit is contained in:
@@ -28,6 +28,8 @@ from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice
|
||||
|
||||
FP8_DTYPES = tuple(
|
||||
getattr(torch, dtype_name)
|
||||
for dtype_name in (
|
||||
@@ -172,3 +174,31 @@ def _patched_load_weights(
|
||||
|
||||
|
||||
MiniMaxM2Model.load_weights = _patched_load_weights
|
||||
|
||||
|
||||
def _patch_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=self.q_norm.weight,
|
||||
k_weight=self.k_norm.weight,
|
||||
q_hidden_size=self.q_size,
|
||||
kv_hidden_size=self.kv_size,
|
||||
head_dim=self.head_dim,
|
||||
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
|
||||
eps=self.q_norm.variance_epsilon,
|
||||
tp_world=self.q_norm.tp_world,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
MiniMaxM2Attention.forward = _patch_forward
|
||||
|
||||
Reference in New Issue
Block a user