diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 0c19b3e9..b4da71f3 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -16,6 +16,7 @@ # import math +import os from typing import Optional, Tuple import torch @@ -547,6 +548,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): class AscendMRotaryEmbedding(MRotaryEmbedding): + # Empirical safety threshold for large Triton grids on Ascend NPU + _ASCEND_TRITON_GRID_LIMIT = 65535 + def forward_triton(self, positions: torch.Tensor, query: torch.Tensor, @@ -568,6 +572,12 @@ class AscendMRotaryEmbedding(MRotaryEmbedding): assert self.mrope_section + # When the grid becomes large, enable TRITON_ALL_BLOCKS_PARALLEL + # to avoid scheduler/runtime failures. + if (query_shape[0] > self._ASCEND_TRITON_GRID_LIMIT and + os.environ.get("TRITON_ALL_BLOCKS_PARALLEL") != "1"): + os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1" + q, k = triton_mrope( query, key,