[Main][Ops] Make triton rope support index_selecting from cos_sin_cache (#5450)
### What this PR does / why we need it?
This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:
1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.
In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).
**Highlights**
- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Additional test on Qwen3-235b accuracy:
| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
MAX_POSITION_EMBEDDINGS = [262144]
|
||||
HEAD_SIZES = [64, 128]
|
||||
ROTARY_DIMS = [32, 64]
|
||||
NUM_Q_HEADS = [64]
|
||||
@@ -139,3 +140,83 @@ def test_rotary_embedding_triton_kernel(
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS)
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
|
||||
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_triton_kernel_with_cos_sin_cache(
|
||||
max_position_embeddings: int,
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_k_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
init_device_properties_triton()
|
||||
if rotary_dim == -1:
|
||||
rotary_dim = head_size
|
||||
cos_sin_cache = torch.randn(max_position_embeddings, rotary_dim, dtype=dtype, device=device)
|
||||
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
|
||||
q_trt = torch.randn(num_tokens,
|
||||
num_q_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_trt = torch.randn(num_tokens,
|
||||
num_k_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_gold = torch.randn(num_tokens,
|
||||
num_q_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_gold = torch.randn(num_tokens,
|
||||
num_k_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_trt.copy_(q_gold)
|
||||
k_trt.copy_(k_gold)
|
||||
q_trt, k_trt = rope_forward_triton(q_trt,
|
||||
k_trt,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
rope_dim=rotary_dim,
|
||||
is_neox_style=is_neox_style)
|
||||
cos, sin = cos_sin_cache.index_select(0, positions).chunk(2, dim=-1)
|
||||
q_gold, k_gold = _rope_pytorch_native(q_gold,
|
||||
k_gold,
|
||||
cos,
|
||||
sin,
|
||||
rope_dim=rotary_dim,
|
||||
is_neox_style=is_neox_style)
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(q_trt.view(q_gold.size()),
|
||||
q_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(k_trt.view(k_gold.size()),
|
||||
k_gold,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
MAX_POSITION_EMBEDDINGS = [262144]
|
||||
NUM_TOKENS = [1, 4, 8, 16, 1024]
|
||||
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
|
||||
HEAD_SIZES = [128]
|
||||
@@ -55,6 +56,7 @@ def rms_norm(
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -63,7 +65,7 @@ def rms_norm(
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
|
||||
def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads,
|
||||
head_size, eps, dtype, seed, device):
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
@@ -77,12 +79,10 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
|
||||
device=device)
|
||||
q_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
k_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
sin = torch.from_numpy(
|
||||
cos_sin_cache = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
cos = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
[max_position_embeddings, head_size])).to(dtype).npu()
|
||||
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
|
||||
# fused kernel
|
||||
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -91,8 +91,12 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
|
||||
kv_hidden_size=kv_hidden_size,
|
||||
head_dim=head_size,
|
||||
eps=eps,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions)
|
||||
|
||||
cos, sin = cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
# split
|
||||
_q, _k, v_gold = qkv.cpu().split(
|
||||
@@ -129,6 +133,7 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -137,7 +142,7 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
|
||||
def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads,
|
||||
num_kv_heads, head_size, eps, dtype,
|
||||
seed, device):
|
||||
torch.manual_seed(seed)
|
||||
@@ -154,12 +159,10 @@ def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
|
||||
k_weight = torch.randn(head_size, dtype=dtype, device=device)
|
||||
q_bias = torch.randn(head_size, dtype=dtype, device=device)
|
||||
k_bias = torch.randn(head_size, dtype=dtype, device=device)
|
||||
sin = torch.from_numpy(
|
||||
cos_sin_cache = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
cos = torch.from_numpy(
|
||||
np.random.uniform(0, 1,
|
||||
[num_tokens, 1, 1, head_size])).to(dtype).npu()
|
||||
[max_position_embeddings, head_size])).to(dtype).npu()
|
||||
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
|
||||
# fused kernel
|
||||
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -170,8 +173,12 @@ def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
|
||||
eps=eps,
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions)
|
||||
|
||||
cos, sin = cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
# split
|
||||
_q, _k, v_gold = qkv.cpu().split(
|
||||
|
||||
@@ -60,7 +60,7 @@ class ModelQKNormRopeWithoutBias(nn.Module):
|
||||
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, qkv, cos, sin):
|
||||
def forward(self, qkv, cos_sin_cache, positions):
|
||||
"""
|
||||
Args:
|
||||
qkv: [T, q_size + 2*kv_size]
|
||||
@@ -82,13 +82,12 @@ class ModelQKNormRopeWithoutBias(nn.Module):
|
||||
|
||||
# Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim]
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -116,7 +115,7 @@ class ModelQKNormRopeWithBias(nn.Module):
|
||||
self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, qkv, cos, sin):
|
||||
def forward(self, qkv, cos_sin_cache, positions):
|
||||
# Split QKV
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -132,13 +131,12 @@ class ModelQKNormRopeWithBias(nn.Module):
|
||||
|
||||
# Reshape for RoPE
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -147,7 +145,7 @@ def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False):
|
||||
check_rules = [
|
||||
(torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused),
|
||||
(torch.ops.npu.npu_rms_norm.default, not expect_fused),
|
||||
(torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused),
|
||||
(torch.ops.vllm.npu_rotary_embedding.default, not expect_fused),
|
||||
]
|
||||
if use_bias:
|
||||
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
|
||||
|
||||
@@ -25,10 +25,10 @@ CASE_QWEN_ACLGRAPH = LLMTestCase(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
prompts=PROMPTS_SHORT,
|
||||
golden_answers=[
|
||||
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any",
|
||||
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the",
|
||||
' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president',
|
||||
' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of',
|
||||
' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
|
||||
' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
|
||||
],
|
||||
)
|
||||
|
||||
@@ -48,10 +48,10 @@ CASE_QWEN_FULL = LLMTestCase(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
prompts=PROMPTS_SHORT,
|
||||
golden_answers=[
|
||||
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any",
|
||||
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the",
|
||||
' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president',
|
||||
' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of',
|
||||
' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
|
||||
' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
|
||||
],
|
||||
)
|
||||
|
||||
@@ -72,8 +72,8 @@ CASE_QWEN_FULL_DECODE_ONLY = LLMTestCase(
|
||||
prompts=PROMPTS_LONG,
|
||||
golden_answers=[
|
||||
' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the',
|
||||
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
|
||||
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +'
|
||||
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
|
||||
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can'
|
||||
])
|
||||
|
||||
CASE_DS_FULL_DECODE_ONLY = LLMTestCase(
|
||||
@@ -91,8 +91,8 @@ CASE_QWEN_EX = LLMTestCase(
|
||||
prompts=PROMPTS_LONG,
|
||||
golden_answers=[
|
||||
' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the',
|
||||
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
|
||||
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +'
|
||||
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
|
||||
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can'
|
||||
])
|
||||
|
||||
CASE_DS_EX = LLMTestCase(model="vllm-ascend/DeepSeek-V2-Lite-W8A8",
|
||||
|
||||
@@ -164,8 +164,6 @@ def set_ascend_forward_context(
|
||||
|
||||
_mc2_tokens_capacity: int | None = None
|
||||
_reserved_mc2_mask: torch.Tensor | None = None
|
||||
_sin: torch.Tensor | None = None
|
||||
_cos: torch.Tensor | None = None
|
||||
|
||||
|
||||
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
|
||||
|
||||
@@ -47,17 +47,21 @@ class GraphEXQKNormRopeFusionPattern:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
max_position_embeddings = 16384
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos, sin]
|
||||
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
positions = torch.ones(T, dtype=torch.int64, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos_sin_cache, positions]
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def register(self):
|
||||
def pattern(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -68,17 +72,19 @@ class GraphEXQKNormRopeFusionPattern:
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
@@ -90,8 +96,8 @@ class GraphEXQKNormRopeFusionPattern:
|
||||
eps=self.eps,
|
||||
q_bias=None,
|
||||
k_bias=None,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -117,16 +123,17 @@ class GraphEXQKNormRopeFusionPatternWithBias:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
max_position_embeddings = 16384
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
|
||||
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
positions = torch.ones(T, dtype=torch.int64, device="npu")
|
||||
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions]
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def register(self):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
@@ -134,8 +141,8 @@ class GraphEXQKNormRopeFusionPatternWithBias:
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -148,12 +155,10 @@ class GraphEXQKNormRopeFusionPatternWithBias:
|
||||
k_normed = k_norm_out + k_bias
|
||||
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -163,8 +168,8 @@ class GraphEXQKNormRopeFusionPatternWithBias:
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
@@ -176,10 +181,9 @@ class GraphEXQKNormRopeFusionPatternWithBias:
|
||||
eps=self.eps,
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
torchair.register_replacement(
|
||||
@@ -197,7 +201,7 @@ class GraphEXQKNormRopeFusionPass:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
if dtype not in (torch.bfloat16,):
|
||||
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
|
||||
|
||||
@@ -45,16 +45,21 @@ class QKNormRopeFusionPattern:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
max_position_embeddings = 16384
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos, sin]
|
||||
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
positions = torch.ones(T, dtype=torch.int64, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos_sin_cache, positions]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -65,17 +70,19 @@ class QKNormRopeFusionPattern:
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
@@ -87,8 +94,8 @@ class QKNormRopeFusionPattern:
|
||||
eps=self.eps,
|
||||
q_bias=None,
|
||||
k_bias=None,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -109,15 +116,16 @@ class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
max_position_embeddings = 16384
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
positions = torch.ones(T, dtype=torch.int64, device="npu")
|
||||
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
@@ -126,8 +134,8 @@ class QKNormRopeFusionPatternWithBias:
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -140,12 +148,10 @@ class QKNormRopeFusionPatternWithBias:
|
||||
k_normed = k_norm_out + k_bias
|
||||
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -155,8 +161,8 @@ class QKNormRopeFusionPatternWithBias:
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
@@ -168,8 +174,8 @@ class QKNormRopeFusionPatternWithBias:
|
||||
eps=self.eps,
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -186,7 +192,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
if dtype not in (torch.bfloat16,):
|
||||
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
@@ -188,15 +188,16 @@ def _quantize_impl_fake(
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
def _rope_forward_triton_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
rope_dim: int = -1,
|
||||
def _rope_forward_oot_impl_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_dim: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
return query, key
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -262,10 +263,11 @@ direct_register_custom_op(
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rope_forward_triton",
|
||||
op_func=rope_forward_triton,
|
||||
fake_impl=_rope_forward_triton_fake,
|
||||
op_name="npu_rotary_embedding",
|
||||
op_func=rope_forward_oot,
|
||||
fake_impl=_rope_forward_oot_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
@@ -29,11 +29,13 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
|
||||
# Currently, rope ops used on npu requires detached cos && sin as inputs.
|
||||
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
|
||||
@@ -146,54 +148,38 @@ def get_cos_and_sin_slice():
|
||||
return _cos_slice, _sin_slice
|
||||
|
||||
|
||||
def _rope_forward_oot(
|
||||
self,
|
||||
def rope_forward_oot(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
if offsets is not None:
|
||||
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
|
||||
if (
|
||||
is_neox_style
|
||||
and self.head_size == 128
|
||||
and self.cos_sin_cache.shape[-1] == 128
|
||||
and cos is not None
|
||||
and sin is not None
|
||||
):
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
|
||||
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||
# Although this function modifies in-place, please retain the function's return value.
|
||||
# Otherwise, the graph fusion operation may fail.
|
||||
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
if HAS_TRITON:
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
q = query.contiguous().view(query.shape[0], -1, self.head_size)
|
||||
k = key.contiguous().view(key.shape[0], -1, self.head_size)
|
||||
query, key = torch.ops.vllm.rope_forward_triton(
|
||||
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
else:
|
||||
if HAS_TRITON:
|
||||
num_tokens = query.shape[0]
|
||||
query, key = rope_forward_triton(
|
||||
query.view(num_tokens, -1, head_size),
|
||||
key.view(num_tokens, -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
positions=positions,
|
||||
rope_dim=rotary_dim,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
else:
|
||||
if rotary_dim < head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
q_rot = query[..., : self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim :]
|
||||
k_rot = key[..., : self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim :]
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
q_rot = query[..., :rotary_dim]
|
||||
q_pass = query[..., rotary_dim:]
|
||||
k_rot = key[..., :rotary_dim]
|
||||
k_pass = key[..., rotary_dim:]
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
# only the rotary part is processed here,
|
||||
@@ -202,27 +188,26 @@ def _rope_forward_oot(
|
||||
positions,
|
||||
q_rot,
|
||||
k_rot,
|
||||
self.rotary_dim,
|
||||
self.cos_sin_cache,
|
||||
rotary_dim,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, rotary_dim)
|
||||
query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
@@ -251,7 +236,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
|
||||
return torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
|
||||
)
|
||||
|
||||
|
||||
class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
@@ -460,7 +447,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3, 2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
|
||||
q_pe, k_pe = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style
|
||||
)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
@triton.jit
|
||||
def split_qkv_rmsnorm_rope_kernel(
|
||||
input_ptr,
|
||||
sin_ptr,
|
||||
cos_ptr,
|
||||
cos_sin_ptr,
|
||||
pos_ptr,
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
@@ -74,9 +74,11 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
@@ -89,22 +91,24 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
roped_q1 = x1 * cos - x2 * sin
|
||||
roped_q2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q,
|
||||
roped_q1,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q,
|
||||
roped_q2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_q = cat_x * sin + normalized_values * cos
|
||||
tl.store(
|
||||
q_ptr + output_offset + col_indices,
|
||||
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
|
||||
@@ -135,9 +139,12 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM)
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
@@ -150,23 +157,24 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
roped_k1 = x1 * cos - x2 * sin
|
||||
roped_k2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k,
|
||||
roped_k1,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k,
|
||||
roped_k2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = cat_x * sin + normalized_values * cos
|
||||
|
||||
tl.store(
|
||||
k_ptr + output_offset + col_indices,
|
||||
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
|
||||
@@ -188,8 +196,8 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl(
|
||||
input: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_hidden_size: int,
|
||||
@@ -216,8 +224,8 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
|
||||
input,
|
||||
sin,
|
||||
cos,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
q_output,
|
||||
k_output,
|
||||
v_output,
|
||||
@@ -241,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl(
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl_fake(
|
||||
input: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_hidden_size: int,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
@@ -30,10 +29,13 @@ def _triton_rope(
|
||||
q_row_stride,
|
||||
k_ptr,
|
||||
k_row_stride,
|
||||
cos,
|
||||
cos_ptr,
|
||||
cos_row_stride,
|
||||
sin,
|
||||
sin_ptr,
|
||||
sin_row_stride,
|
||||
cos_sin_ptr,
|
||||
cos_sin_row_stride,
|
||||
pos_ptr,
|
||||
num_tokens,
|
||||
n_qh: tl.constexpr,
|
||||
n_kh: tl.constexpr,
|
||||
@@ -44,6 +46,7 @@ def _triton_rope(
|
||||
pad_rope_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_NEOX_STYLE: tl.constexpr,
|
||||
USE_COS_SIN: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
This triton kernel applies rotary embedding on q and k.
|
||||
@@ -84,13 +87,19 @@ def _triton_rope(
|
||||
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
||||
# m of this program instance
|
||||
# ####################################################################
|
||||
cos_start_ptr = cos + row_idx * cos_row_stride
|
||||
sin_start_ptr = sin + row_idx * sin_row_stride
|
||||
|
||||
cos_offsets = tl.arange(0, pad_rope_dim // 2)
|
||||
sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim)
|
||||
cos_mask = cos_offsets < (rope_dim // 2)
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
if USE_COS_SIN:
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
else:
|
||||
cos_start_ptr = cos_ptr + row_idx * cos_row_stride
|
||||
sin_start_ptr = sin_ptr + row_idx * sin_row_stride
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
|
||||
# ####################################################################
|
||||
# Load the left and right half of q and k for the current
|
||||
@@ -140,8 +149,10 @@ def _triton_rope(
|
||||
def rope_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor = None,
|
||||
sin: torch.Tensor = None,
|
||||
cos_sin_cache: torch.Tensor = None,
|
||||
positions: torch.Tensor = None,
|
||||
rope_dim: int = -1,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -152,12 +163,6 @@ def rope_forward_triton(
|
||||
|
||||
num_tokens, n_q_head, head_dim = q.shape
|
||||
n_kv_head = k.shape[1]
|
||||
cos = cos.view(num_tokens, -1)
|
||||
sin = sin.view(num_tokens, -1)
|
||||
if rope_dim == -1:
|
||||
# If rope_dim is not specified, we assume that input cos/sin is not
|
||||
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
|
||||
rope_dim = cos.shape[-1] * 2
|
||||
assert rope_dim <= head_dim
|
||||
pad_rope_dim = triton.next_power_of_2(rope_dim)
|
||||
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
||||
@@ -166,24 +171,69 @@ def rope_forward_triton(
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
n_row = min(num_tokens, num_vectorcore)
|
||||
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
k.stride(0),
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
num_tokens,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
)
|
||||
if cos_sin_cache is not None and positions is not None:
|
||||
assert positions.shape[0] == num_tokens
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
k.stride(0),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
positions,
|
||||
num_tokens,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
USE_COS_SIN=True,
|
||||
)
|
||||
elif cos is not None and sin is not None:
|
||||
assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens
|
||||
cos = cos.view(num_tokens, -1)
|
||||
sin = sin.view(num_tokens, -1)
|
||||
if rope_dim == -1:
|
||||
# If rope_dim is not specified, we assume that input cos/sin is not
|
||||
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
|
||||
rope_dim = cos.shape[-1] * 2
|
||||
_triton_rope[(n_row,)](
|
||||
q,
|
||||
q.stride(0),
|
||||
k,
|
||||
k.stride(0),
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
num_tokens,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
USE_COS_SIN=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, rope_forward_triton supports passing:\n"
|
||||
"1. positions and original cos_sin_cache.\n"
|
||||
"2. cos and sin which are already selected by positions\n"
|
||||
"Please check whether you call rope_forward_triton correctly."
|
||||
)
|
||||
return q, k
|
||||
|
||||
@@ -39,7 +39,6 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
|
||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
|
||||
@@ -299,9 +298,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
_,
|
||||
) = self.runner._sync_metadata_across_dp(num_tokens, is_draft_model=True)
|
||||
|
||||
# update global cos, sin
|
||||
update_cos_sin(self._get_positions(num_tokens))
|
||||
|
||||
multi_steps_attn_metadata = []
|
||||
if not self.use_cuda_graph:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
@@ -471,9 +467,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
# update global cos, sin
|
||||
update_cos_sin(self._get_positions(num_input_tokens))
|
||||
|
||||
if self.uses_mrope:
|
||||
used_update_positions = target_positions[:, last_token_indices]
|
||||
else:
|
||||
@@ -651,9 +644,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
input_ids = self.input_ids[:input_batch_size]
|
||||
inputs_embeds = None
|
||||
|
||||
# update global cos, sin
|
||||
update_cos_sin(self._get_positions(input_batch_size))
|
||||
|
||||
# Run the model.
|
||||
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
|
||||
|
||||
@@ -1197,8 +1197,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors)
|
||||
|
||||
# update global cos, sin
|
||||
update_cos_sin(positions)
|
||||
|
||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||
# KV scales calculation involves dynamic operations that are incompatible
|
||||
# with CUDA graph capture.
|
||||
|
||||
Reference in New Issue
Block a user