[Bugfix] Improve Triton stability on Ascend for large grids (#6301)
### What this PR does / why we need it?
Improve Triton stability on Ascend for large grids
set `TRITON_ALL_BLOCKS_PARALLEL=1` when grids > 65535
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -547,6 +548,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
|
||||
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
|
||||
# Empirical safety threshold for large Triton grids on Ascend NPU
|
||||
_ASCEND_TRITON_GRID_LIMIT = 65535
|
||||
|
||||
def forward_triton(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -568,6 +572,12 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
|
||||
assert self.mrope_section
|
||||
|
||||
# When the grid becomes large, enable TRITON_ALL_BLOCKS_PARALLEL
|
||||
# to avoid scheduler/runtime failures.
|
||||
if (query_shape[0] > self._ASCEND_TRITON_GRID_LIMIT and
|
||||
os.environ.get("TRITON_ALL_BLOCKS_PARALLEL") != "1"):
|
||||
os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1"
|
||||
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
|
||||
Reference in New Issue
Block a user