[Bugfix] Improve Triton stability on Ascend for large grids (#6301)

### What this PR does / why we need it?
Improve Triton stability on Ascend for large grids
set `TRITON_ALL_BLOCKS_PARALLEL=1` when grids > 65535

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
zhangxinyuehfad
2026-02-03 10:32:27 +08:00
committed by GitHub
parent 05cc03d785
commit 26b83f8bde

View File

@@ -16,6 +16,7 @@
#
import math
import os
from typing import Optional, Tuple
import torch
@@ -547,6 +548,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
class AscendMRotaryEmbedding(MRotaryEmbedding):
# Empirical safety threshold for large Triton grids on Ascend NPU
_ASCEND_TRITON_GRID_LIMIT = 65535
def forward_triton(self,
positions: torch.Tensor,
query: torch.Tensor,
@@ -568,6 +572,12 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
assert self.mrope_section
# When the grid becomes large, enable TRITON_ALL_BLOCKS_PARALLEL
# to avoid scheduler/runtime failures.
if (query_shape[0] > self._ASCEND_TRITON_GRID_LIMIT and
os.environ.get("TRITON_ALL_BLOCKS_PARALLEL") != "1"):
os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1"
q, k = triton_mrope(
query,
key,