[Bugfix] Improve Triton stability on Ascend for large grids (#6301)
### What this PR does / why we need it?
Improve Triton stability on Ascend for large grids
set `TRITON_ALL_BLOCKS_PARALLEL=1` when grids > 65535
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -547,6 +548,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|||||||
|
|
||||||
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||||
|
|
||||||
|
# Empirical safety threshold for large Triton grids on Ascend NPU
|
||||||
|
_ASCEND_TRITON_GRID_LIMIT = 65535
|
||||||
|
|
||||||
def forward_triton(self,
|
def forward_triton(self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -568,6 +572,12 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
|
|||||||
|
|
||||||
assert self.mrope_section
|
assert self.mrope_section
|
||||||
|
|
||||||
|
# When the grid becomes large, enable TRITON_ALL_BLOCKS_PARALLEL
|
||||||
|
# to avoid scheduler/runtime failures.
|
||||||
|
if (query_shape[0] > self._ASCEND_TRITON_GRID_LIMIT and
|
||||||
|
os.environ.get("TRITON_ALL_BLOCKS_PARALLEL") != "1"):
|
||||||
|
os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1"
|
||||||
|
|
||||||
q, k = triton_mrope(
|
q, k = triton_mrope(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
Reference in New Issue
Block a user