diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py index 6af45a68..7e68bf96 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py @@ -8,6 +8,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton IS_NEOX_STYLE = [True, False] DTYPES = [torch.bfloat16, torch.float16] +MAX_POSITION_EMBEDDINGS = [262144] HEAD_SIZES = [64, 128] ROTARY_DIMS = [32, 64] NUM_Q_HEADS = [64] @@ -139,3 +140,83 @@ def test_rotary_embedding_triton_kernel( gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() + + + +@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS) +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS) +@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_rotary_embedding_triton_kernel_with_cos_sin_cache( + max_position_embeddings: int, + is_neox_style: bool, + num_tokens: int, + num_q_heads: int, + num_k_heads: int, + head_size: int, + rotary_dim: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.manual_seed(seed) + torch.set_default_device(device) + init_device_properties_triton() + if rotary_dim == -1: + rotary_dim = head_size + cos_sin_cache = torch.randn(max_position_embeddings, rotary_dim, dtype=dtype, device=device) + positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) + q_trt = torch.randn(num_tokens, + num_q_heads, + head_size, + dtype=dtype, + device=device) + k_trt = torch.randn(num_tokens, + num_k_heads, + head_size, + dtype=dtype, + device=device) + q_gold = torch.randn(num_tokens, + num_q_heads, + head_size, + dtype=dtype, + device=device) + k_gold = torch.randn(num_tokens, + num_k_heads, + head_size, + dtype=dtype, + device=device) + q_trt.copy_(q_gold) + k_trt.copy_(k_gold) + q_trt, k_trt = rope_forward_triton(q_trt, + k_trt, + cos_sin_cache=cos_sin_cache, + positions=positions, + rope_dim=rotary_dim, + is_neox_style=is_neox_style) + cos, sin = cos_sin_cache.index_select(0, positions).chunk(2, dim=-1) + q_gold, k_gold = _rope_pytorch_native(q_gold, + k_gold, + cos, + sin, + rope_dim=rotary_dim, + is_neox_style=is_neox_style) + # Compare the results. + torch.testing.assert_close(q_trt.view(q_gold.size()), + q_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + torch.testing.assert_close(k_trt.view(k_gold.size()), + k_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py index 9a3b2524..7336e515 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_rmsnorm_rope.py @@ -7,6 +7,7 @@ import torch import vllm_ascend.ops.register_custom_ops # noqa from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton +MAX_POSITION_EMBEDDINGS = [262144] NUM_TOKENS = [1, 4, 8, 16, 1024] NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)] HEAD_SIZES = [128] @@ -55,6 +56,7 @@ def rms_norm( return out +@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -63,7 +65,7 @@ def rms_norm( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, +def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, head_size, eps, dtype, seed, device): torch.manual_seed(seed) torch.set_default_device(device) @@ -77,12 +79,10 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, device=device) q_weight = torch.randn(head_size, dtype=dtype, device=device) k_weight = torch.randn(head_size, dtype=dtype, device=device) - sin = torch.from_numpy( + cos_sin_cache = torch.from_numpy( np.random.uniform(0, 1, - [num_tokens, 1, 1, head_size])).to(dtype).npu() - cos = torch.from_numpy( - np.random.uniform(0, 1, - [num_tokens, 1, 1, head_size])).to(dtype).npu() + [max_position_embeddings, head_size])).to(dtype).npu() + positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) # fused kernel q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, q_weight=q_weight, @@ -91,8 +91,12 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, kv_hidden_size=kv_hidden_size, head_dim=head_size, eps=eps, - cos=cos, - sin=sin) + cos_sin_cache=cos_sin_cache, + positions=positions) + + cos, sin = cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2) + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) # split _q, _k, v_gold = qkv.cpu().split( @@ -129,6 +133,7 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, torch.npu.reset_peak_memory_stats() +@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -137,7 +142,7 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads, +def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads, head_size, eps, dtype, seed, device): torch.manual_seed(seed) @@ -154,12 +159,10 @@ def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads, k_weight = torch.randn(head_size, dtype=dtype, device=device) q_bias = torch.randn(head_size, dtype=dtype, device=device) k_bias = torch.randn(head_size, dtype=dtype, device=device) - sin = torch.from_numpy( + cos_sin_cache = torch.from_numpy( np.random.uniform(0, 1, - [num_tokens, 1, 1, head_size])).to(dtype).npu() - cos = torch.from_numpy( - np.random.uniform(0, 1, - [num_tokens, 1, 1, head_size])).to(dtype).npu() + [max_position_embeddings, head_size])).to(dtype).npu() + positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device) # fused kernel q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, q_weight=q_weight, @@ -170,8 +173,12 @@ def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads, eps=eps, q_bias=q_bias, k_bias=k_bias, - cos=cos, - sin=sin) + cos_sin_cache=cos_sin_cache, + positions=positions) + + cos, sin = cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2) + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) # split _q, _k, v_gold = qkv.cpu().split( diff --git a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py index f696ffbc..51cb76fa 100644 --- a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py +++ b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py @@ -60,7 +60,7 @@ class ModelQKNormRopeWithoutBias(nn.Module): self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) - def forward(self, qkv, cos, sin): + def forward(self, qkv, cos_sin_cache, positions): """ Args: qkv: [T, q_size + 2*kv_size] @@ -82,13 +82,12 @@ class ModelQKNormRopeWithoutBias(nn.Module): # Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim] q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) # Apply RoPE - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + ) return q_rope, k_rope, v @@ -116,7 +115,7 @@ class ModelQKNormRopeWithBias(nn.Module): self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) - def forward(self, qkv, cos, sin): + def forward(self, qkv, cos_sin_cache, positions): # Split QKV q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -132,13 +131,12 @@ class ModelQKNormRopeWithBias(nn.Module): # Reshape for RoPE q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) # Apply RoPE - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + ) return q_rope, k_rope, v @@ -147,7 +145,7 @@ def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False): check_rules = [ (torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused), (torch.ops.npu.npu_rms_norm.default, not expect_fused), - (torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused), + (torch.ops.vllm.npu_rotary_embedding.default, not expect_fused), ] if use_bias: check_rules.append((torch.ops.aten.add.Tensor, not expect_fused)) diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 90266933..1ebd7b21 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -25,10 +25,10 @@ CASE_QWEN_ACLGRAPH = LLMTestCase( model="Qwen/Qwen3-0.6B", prompts=PROMPTS_SHORT, golden_answers=[ - " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", + " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the", ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', - ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' + ' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' ], ) @@ -48,10 +48,10 @@ CASE_QWEN_FULL = LLMTestCase( model="Qwen/Qwen3-0.6B", prompts=PROMPTS_SHORT, golden_answers=[ - " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", + " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the", ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', - ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' + ' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' ], ) @@ -72,8 +72,8 @@ CASE_QWEN_FULL_DECODE_ONLY = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", - ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +' + " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' ]) CASE_DS_FULL_DECODE_ONLY = LLMTestCase( @@ -91,8 +91,8 @@ CASE_QWEN_EX = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", - ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +' + " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' ]) CASE_DS_EX = LLMTestCase(model="vllm-ascend/DeepSeek-V2-Lite-W8A8", diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index f28bb0f4..dcd53535 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -164,8 +164,6 @@ def set_ascend_forward_context( _mc2_tokens_capacity: int | None = None _reserved_mc2_mask: torch.Tensor | None = None -_sin: torch.Tensor | None = None -_cos: torch.Tensor | None = None def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len): diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py index 984a0579..bdafed32 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py @@ -47,17 +47,21 @@ class GraphEXQKNormRopeFusionPattern: def get_inputs(self): T = 5 + max_position_embeddings = 16384 qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - return [qkv, q_weight, k_weight, cos, sin] + cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu") + positions = torch.ones(T, dtype=torch.int64, device="npu") + return [qkv, q_weight, k_weight, cos_sin_cache, positions] - # The replacement registered here will be actually executed after AOT. def register(self): def pattern( - qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -68,17 +72,19 @@ class GraphEXQKNormRopeFusionPattern: k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + ) return q_rope, k_rope, v def replacement( - qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, @@ -90,8 +96,8 @@ class GraphEXQKNormRopeFusionPattern: eps=self.eps, q_bias=None, k_bias=None, - sin=sin, - cos=cos, + cos_sin_cache=cos_sin_cache, + positions=positions, ) return results @@ -117,16 +123,17 @@ class GraphEXQKNormRopeFusionPatternWithBias: def get_inputs(self): T = 5 + max_position_embeddings = 16384 qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] + cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu") + positions = torch.ones(T, dtype=torch.int64, device="npu") + + return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions] - # The replacement registered here will be actually executed after AOT. def register(self): def pattern( qkv: torch.Tensor, @@ -134,8 +141,8 @@ class GraphEXQKNormRopeFusionPatternWithBias: k_weight: torch.Tensor, q_bias: torch.Tensor, k_bias: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -148,12 +155,10 @@ class GraphEXQKNormRopeFusionPatternWithBias: k_normed = k_norm_out + k_bias q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + ) return q_rope, k_rope, v @@ -163,8 +168,8 @@ class GraphEXQKNormRopeFusionPatternWithBias: k_weight: torch.Tensor, q_bias: torch.Tensor, k_bias: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, @@ -176,10 +181,9 @@ class GraphEXQKNormRopeFusionPatternWithBias: eps=self.eps, q_bias=q_bias, k_bias=k_bias, - sin=sin, - cos=cos, + cos_sin_cache=cos_sin_cache, + positions=positions, ) - return results torchair.register_replacement( @@ -197,7 +201,7 @@ class GraphEXQKNormRopeFusionPass: def __init__(self, vllm_config: VllmConfig): dtype = vllm_config.model_config.dtype - if dtype not in (torch.bfloat16, torch.float16): + if dtype not in (torch.bfloat16,): logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype) return # use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index f03fc3c1..7eec4d92 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -45,16 +45,21 @@ class QKNormRopeFusionPattern: def get_inputs(self): T = 5 + max_position_embeddings = 16384 qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - return [qkv, q_weight, k_weight, cos, sin] + cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu") + positions = torch.ones(T, dtype=torch.int64, device="npu") + return [qkv, q_weight, k_weight, cos_sin_cache, positions] def register(self, pm_pass: PatternMatcherPass): def pattern( - qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -65,17 +70,19 @@ class QKNormRopeFusionPattern: k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + ) return q_rope, k_rope, v def replacement( - qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, @@ -87,8 +94,8 @@ class QKNormRopeFusionPattern: eps=self.eps, q_bias=None, k_bias=None, - sin=sin, - cos=cos, + cos_sin_cache=cos_sin_cache, + positions=positions, ) return results @@ -109,15 +116,16 @@ class QKNormRopeFusionPatternWithBias: def get_inputs(self): T = 5 + max_position_embeddings = 16384 qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") - sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + cos_sin_cache = torch.empty(max_position_embeddings, self.head_dim, dtype=torch.bfloat16, device="npu") + positions = torch.ones(T, dtype=torch.int64, device="npu") - return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] + return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions] def register(self, pm_pass: PatternMatcherPass): def pattern( @@ -126,8 +134,8 @@ class QKNormRopeFusionPatternWithBias: k_weight: torch.Tensor, q_bias: torch.Tensor, k_bias: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -140,12 +148,10 @@ class QKNormRopeFusionPatternWithBias: k_normed = k_norm_out + k_bias q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + ) return q_rope, k_rope, v @@ -155,8 +161,8 @@ class QKNormRopeFusionPatternWithBias: k_weight: torch.Tensor, q_bias: torch.Tensor, k_bias: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, @@ -168,8 +174,8 @@ class QKNormRopeFusionPatternWithBias: eps=self.eps, q_bias=q_bias, k_bias=k_bias, - cos=cos, - sin=sin, + cos_sin_cache=cos_sin_cache, + positions=positions, ) return results @@ -186,7 +192,7 @@ class QKNormRopeFusionPass(VllmInductorPass): self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass") dtype = vllm_config.model_config.dtype - if dtype not in (torch.bfloat16, torch.float16): + if dtype not in (torch.bfloat16,): logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype) return diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 7dcebc3b..f9404f0b 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -14,7 +14,7 @@ from vllm.forward_context import get_forward_context from vllm.utils.torch_utils import direct_register_custom_op from vllm_ascend.ascend_forward_context import MoECommType -from vllm_ascend.ops.triton.rope import rope_forward_triton +from vllm_ascend.ops.rotary_embedding import rope_forward_oot from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream @@ -188,15 +188,16 @@ def _quantize_impl_fake( return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False) -def _rope_forward_triton_fake( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - rope_dim: int = -1, +def _rope_forward_oot_impl_fake( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_dim: int, + rotary_dim: int, is_neox_style: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) + return query, key direct_register_custom_op( @@ -262,10 +263,11 @@ direct_register_custom_op( mutates_args=[], dispatch_key="PrivateUse1", ) + direct_register_custom_op( - op_name="rope_forward_triton", - op_func=rope_forward_triton, - fake_impl=_rope_forward_triton_fake, + op_name="npu_rotary_embedding", + op_func=rope_forward_oot, + fake_impl=_rope_forward_oot_impl_fake, mutates_args=[], dispatch_key="PrivateUse1", ) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 15bda163..ec5aa665 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -29,11 +29,13 @@ from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.triton_utils import HAS_TRITON +from vllm_ascend.platform import NPUPlatform +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model + if HAS_TRITON: from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope -from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model + from vllm_ascend.ops.triton.rope import rope_forward_triton # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. @@ -146,54 +148,38 @@ def get_cos_and_sin_slice(): return _cos_slice, _sin_slice -def _rope_forward_oot( - self, +def rope_forward_oot( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, is_neox_style: bool, offsets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: query_shape, key_shape = query.shape, key.shape - if self.cos_sin_cache.device != query.device: - self.cos_sin_cache = self.cos_sin_cache.to(query.device) - if self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) - cos, sin = get_cos_and_sin_slice() if offsets is not None: raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.") - if ( - is_neox_style - and self.head_size == 128 - and self.cos_sin_cache.shape[-1] == 128 - and cos is not None - and sin is not None - ): - # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. - # This method requires head_size and rotary_dim equal 128 and neox_style is True - query = query.contiguous().view(1, query.shape[0], -1, self.head_size) - key = key.contiguous().view(1, key.shape[0], -1, self.head_size) - # Although this function modifies in-place, please retain the function's return value. - # Otherwise, the graph fusion operation may fail. - query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) - elif self.rotary_dim < self.head_size: - if HAS_TRITON: - cos = cos.view(-1, self.rotary_dim) - sin = sin.view(-1, self.rotary_dim) - q = query.contiguous().view(query.shape[0], -1, self.head_size) - k = key.contiguous().view(key.shape[0], -1, self.head_size) - query, key = torch.ops.vllm.rope_forward_triton( - q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True - ) - return query.view(query_shape), key.view(key_shape) - else: + if HAS_TRITON: + num_tokens = query.shape[0] + query, key = rope_forward_triton( + query.view(num_tokens, -1, head_size), + key.view(num_tokens, -1, head_size), + cos_sin_cache=cos_sin_cache, + positions=positions, + rope_dim=rotary_dim, + is_neox_style=is_neox_style, + ) + else: + if rotary_dim < head_size: num_tokens = query.shape[0] - query = query.view(num_tokens, -1, self.head_size) - key = key.view(num_tokens, -1, self.head_size) - q_rot = query[..., : self.rotary_dim] - q_pass = query[..., self.rotary_dim :] - k_rot = key[..., : self.rotary_dim] - k_pass = key[..., self.rotary_dim :] + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + q_rot = query[..., :rotary_dim] + q_pass = query[..., rotary_dim:] + k_rot = key[..., :rotary_dim] + k_pass = key[..., rotary_dim:] q_rot = q_rot.contiguous().view(num_tokens, -1) k_rot = k_rot.contiguous().view(num_tokens, -1) # only the rotary part is processed here, @@ -202,27 +188,26 @@ def _rope_forward_oot( positions, q_rot, k_rot, - self.rotary_dim, - self.cos_sin_cache, + rotary_dim, + cos_sin_cache, + is_neox_style, + ) + q_rot = q_rot.view(num_tokens, -1, rotary_dim) + k_rot = k_rot.view(num_tokens, -1, rotary_dim) + query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + head_size, + cos_sin_cache, is_neox_style, ) - q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) - k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) - q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) - k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) - return q, k - else: - # TODO: Remove the contiguous in the future. - query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(key.shape[0], -1) - torch_npu._npu_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - is_neox_style, - ) return query.view(query_shape), key.view(key_shape) @@ -251,7 +236,9 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override - return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) + return torch.ops.vllm.npu_rotary_embedding( + positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style + ) class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding): @@ -460,7 +447,9 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): query = query.view(b, h_q, d // 2, 2).transpose(3, 2).reshape(b, h_q, d) b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) + q_pe, k_pe = torch.ops.vllm.npu_rotary_embedding( + positions, query, key, self.cos_sin_cache, self.head_size, self.rotary_dim, is_neox_style + ) return q_pe, k_pe diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 14bae3b7..18135aa7 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -26,8 +26,8 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @triton.jit def split_qkv_rmsnorm_rope_kernel( input_ptr, - sin_ptr, - cos_ptr, + cos_sin_ptr, + pos_ptr, q_ptr, k_ptr, v_ptr, @@ -74,9 +74,11 @@ def split_qkv_rmsnorm_rope_kernel( else: normalized_values = (normalized_values * weight_values).to(tl.bfloat16) - sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) - sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) - cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) + sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) + cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) + sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) x1 = tl.extract_slice( normalized_values, offsets=(0, 0), @@ -89,22 +91,24 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - cat_x = tl.insert_slice( - cat_x, - -x2, + roped_q1 = x1 * cos - x2 * sin + roped_q2 = x2 * cos + x1 * sin + + roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) + roped_q = tl.insert_slice( + roped_q, + roped_q1, offsets=(0, 0), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.insert_slice( - cat_x, - x1, + roped_q = tl.insert_slice( + roped_q, + roped_q2, offsets=(0, HALF_HEAD_DIM), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_q = cat_x * sin + normalized_values * cos tl.store( q_ptr + output_offset + col_indices, roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), @@ -135,9 +139,12 @@ def split_qkv_rmsnorm_rope_kernel( normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) else: normalized_values = (normalized_values * weight_values).to(tl.bfloat16) - sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) - sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) - cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) + sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) + cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) + sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) x1 = tl.extract_slice( normalized_values, offsets=(0, 0), @@ -150,23 +157,24 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - cat_x = tl.insert_slice( - cat_x, - -x2, + roped_k1 = x1 * cos - x2 * sin + roped_k2 = x2 * cos + x1 * sin + + roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) + roped_k = tl.insert_slice( + roped_k, + roped_k1, offsets=(0, 0), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.insert_slice( - cat_x, - x1, + roped_k = tl.insert_slice( + roped_k, + roped_k2, offsets=(0, HALF_HEAD_DIM), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_k = cat_x * sin + normalized_values * cos - tl.store( k_ptr + output_offset + col_indices, roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), @@ -188,8 +196,8 @@ def split_qkv_rmsnorm_rope_kernel( def split_qkv_rmsnorm_rope_impl( input: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_hidden_size: int, @@ -216,8 +224,8 @@ def split_qkv_rmsnorm_rope_impl( split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( input, - sin, - cos, + cos_sin_cache, + positions, q_output, k_output, v_output, @@ -241,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl( def split_qkv_rmsnorm_rope_impl_fake( input: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_hidden_size: int, diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index 9e61e051..ad863e40 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -14,7 +14,6 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # - import torch from vllm.triton_utils import tl, triton @@ -30,10 +29,13 @@ def _triton_rope( q_row_stride, k_ptr, k_row_stride, - cos, + cos_ptr, cos_row_stride, - sin, + sin_ptr, sin_row_stride, + cos_sin_ptr, + cos_sin_row_stride, + pos_ptr, num_tokens, n_qh: tl.constexpr, n_kh: tl.constexpr, @@ -44,6 +46,7 @@ def _triton_rope( pad_rope_dim: tl.constexpr, BLOCK_SIZE: tl.constexpr, IS_NEOX_STYLE: tl.constexpr, + USE_COS_SIN: tl.constexpr, ): """ This triton kernel applies rotary embedding on q and k. @@ -84,13 +87,19 @@ def _triton_rope( # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position # m of this program instance # #################################################################### - cos_start_ptr = cos + row_idx * cos_row_stride - sin_start_ptr = sin + row_idx * sin_row_stride - cos_offsets = tl.arange(0, pad_rope_dim // 2) + sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim) cos_mask = cos_offsets < (rope_dim // 2) - cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) - sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + if USE_COS_SIN: + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32) + else: + cos_start_ptr = cos_ptr + row_idx * cos_row_stride + sin_start_ptr = sin_ptr + row_idx * sin_row_stride + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) # #################################################################### # Load the left and right half of q and k for the current @@ -140,8 +149,10 @@ def _triton_rope( def rope_forward_triton( q: torch.Tensor, k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + cos_sin_cache: torch.Tensor = None, + positions: torch.Tensor = None, rope_dim: int = -1, is_neox_style: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -152,12 +163,6 @@ def rope_forward_triton( num_tokens, n_q_head, head_dim = q.shape n_kv_head = k.shape[1] - cos = cos.view(num_tokens, -1) - sin = sin.view(num_tokens, -1) - if rope_dim == -1: - # If rope_dim is not specified, we assume that input cos/sin is not - # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 - rope_dim = cos.shape[-1] * 2 assert rope_dim <= head_dim pad_rope_dim = triton.next_power_of_2(rope_dim) pad_n_q_head = triton.next_power_of_2(n_q_head) @@ -166,24 +171,69 @@ def rope_forward_triton( num_vectorcore = get_vectorcore_num() n_row = min(num_tokens, num_vectorcore) - _triton_rope[(n_row,)]( - q, - q.stride(0), - k, - k.stride(0), - cos, - cos.stride(0), - sin, - sin.stride(0), - num_tokens, - n_q_head, - n_kv_head, - head_dim, - rope_dim, - pad_n_q_head, - pad_n_kv_head, - pad_rope_dim, - BLOCK_SIZE=BLOCK_SIZE, - IS_NEOX_STYLE=is_neox_style, - ) + if cos_sin_cache is not None and positions is not None: + assert positions.shape[0] == num_tokens + _triton_rope[(n_row,)]( + q, + q.stride(0), + k, + k.stride(0), + None, + None, + None, + None, + cos_sin_cache, + cos_sin_cache.stride(0), + positions, + num_tokens, + n_q_head, + n_kv_head, + head_dim, + rope_dim, + pad_n_q_head, + pad_n_kv_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=True, + ) + elif cos is not None and sin is not None: + assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens + cos = cos.view(num_tokens, -1) + sin = sin.view(num_tokens, -1) + if rope_dim == -1: + # If rope_dim is not specified, we assume that input cos/sin is not + # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 + rope_dim = cos.shape[-1] * 2 + _triton_rope[(n_row,)]( + q, + q.stride(0), + k, + k.stride(0), + cos, + cos.stride(0), + sin, + sin.stride(0), + None, + None, + None, + num_tokens, + n_q_head, + n_kv_head, + head_dim, + rope_dim, + pad_n_q_head, + pad_n_kv_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=False, + ) + else: + raise ValueError( + "Currently, rope_forward_triton supports passing:\n" + "1. positions and original cos_sin_cache.\n" + "2. cos and sin which are already selected by positions\n" + "Please check whether you call rope_forward_triton correctly." + ) return q, k diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 054469c5..7ec749b2 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -39,7 +39,6 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params -from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is @@ -299,9 +298,6 @@ class EagleProposer(VllmEagleProposer): _, ) = self.runner._sync_metadata_across_dp(num_tokens, is_draft_model=True) - # update global cos, sin - update_cos_sin(self._get_positions(num_tokens)) - multi_steps_attn_metadata = [] if not self.use_cuda_graph: aclgraph_runtime_mode = CUDAGraphMode.NONE @@ -471,9 +467,6 @@ class EagleProposer(VllmEagleProposer): builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) - # update global cos, sin - update_cos_sin(self._get_positions(num_input_tokens)) - if self.uses_mrope: used_update_positions = target_positions[:, last_token_indices] else: @@ -651,9 +644,6 @@ class EagleProposer(VllmEagleProposer): input_ids = self.input_ids[:input_batch_size] inputs_embeds = None - # update global cos, sin - update_cos_sin(self._get_positions(input_batch_size)) - # Run the model. # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7da147e2..c8b8a94e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1197,8 +1197,10 @@ class NPUModelRunner(GPUModelRunner): model_kwargs, ec_connector_output, ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) + # update global cos, sin update_cos_sin(positions) + # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture.