diff --git a/pyproject.toml b/pyproject.toml index 5e7e2de4..cad82e5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,20 @@ exclude = [ "vllm_ascend/kv_offload", "vllm_ascend/lora", "vllm_ascend/model_loader", - "vllm_ascend/ops", + "vllm_ascend/ops/fused_moe", + "vllm_ascend/ops/activation.py", + "vllm_ascend/ops/flashcomm2_oshard_manager.py", + "vllm_ascend/ops/layer_shard_linear.py", + "vllm_ascend/ops/layernorm.py", + "vllm_ascend/ops/linear_op.py", + "vllm_ascend/ops/linear.py", + "vllm_ascend/ops/mla.py", + "vllm_ascend/ops/mm_encoder_attention.py", + "vllm_ascend/ops/register_custom_ops.py", + "vllm_ascend/ops/rotary_embedding.py", + "vllm_ascend/ops/vocab_parallel_embedding.py", + "vllm_ascend/ops/weight_prefetch.py", + "vllm_ascend/ops/__init__.py", "vllm_ascend/patch", "vllm_ascend/quantization", "vllm_ascend/sample", diff --git a/vllm_ascend/ops/triton/activation/swiglu_quant.py b/vllm_ascend/ops/triton/activation/swiglu_quant.py index 7ec2cbaf..b0ef7813 100644 --- a/vllm_ascend/ops/triton/activation/swiglu_quant.py +++ b/vllm_ascend/ops/triton/activation/swiglu_quant.py @@ -26,8 +26,7 @@ def _swiglu_quant_kernel( else: gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN) gl_mask = gl_offsets < NUM_EXPERTS - group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, - other=0).to(tl.int32) + group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32) total_rows = tl.sum(group_list) block_size = (total_rows - 1) // NUM_CORES + 1 @@ -41,14 +40,8 @@ def _swiglu_quant_kernel( # swiglu x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS) cur_x = tl.load(x_ptr + x_offsets) - x1 = tl.extract_slice(cur_x, - offsets=(0, ), - sizes=(HALF_COLS, ), - strides=(1, )) - x2 = tl.extract_slice(cur_x, - offsets=(HALF_COLS, ), - sizes=(HALF_COLS, ), - strides=(1, )) + x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,)) + x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,)) out = x1 * tl.sigmoid(x1) * x2 # quant @@ -57,20 +50,13 @@ def _swiglu_quant_kernel( # store scale tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty)) for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE): - tmp_out = tl.extract_slice(out, - offsets=(col_blk_idx, ), - sizes=(COL_BLOCK_SIZE, ), - strides=(1, )) - tmp_out = (tmp_out.to(tl.float32) / scale).to( - x_ptr.dtype.element_ty) + tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,)) + tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty) tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate") - o_offsets = (row_idx * HALF_COLS + col_blk_idx + - tl.arange(0, COL_BLOCK_SIZE)) + o_offsets = row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE) mask = (col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)) < HALF_COLS - tl.store(out_ptr + o_offsets, - tmp_out.to(out_ptr.dtype.element_ty), - mask=mask) + tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty), mask=mask) else: # store out o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS) @@ -80,12 +66,11 @@ def _swiglu_quant_kernel( def swiglu_quant(x, group_list, group_list_type, need_quant=True): # group_list_type must be 0 cusum or 1 count if group_list_type not in [0, 1]: - raise ValueError( - f"group_list_type must be 0 or 1, but got {group_list_type}") + raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}") s, h = x.shape out_dtype = torch.int8 if need_quant else x.dtype out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device) - scale = torch.empty((s, ), dtype=torch.float32, device=x.device) + scale = torch.empty((s,), dtype=torch.float32, device=x.device) num_experts = group_list.shape[0] # ub must be 32-byte aligned on npu if group_list.dtype == torch.int64: @@ -93,12 +78,10 @@ def swiglu_quant(x, group_list, group_list_type, need_quant=True): elif group_list.dtype == torch.int32: num_experts_algin = (num_experts + 15) // 16 * 16 else: - raise ValueError( - f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}" - ) + raise ValueError(f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}") num_vectorcore = get_vectorcore_num() - _swiglu_quant_kernel[(num_vectorcore, )]( + _swiglu_quant_kernel[(num_vectorcore,)]( x, group_list, out, diff --git a/vllm_ascend/ops/triton/batch_invariant/matmul.py b/vllm_ascend/ops/triton/batch_invariant/matmul.py index 0b7934ad..06a4b412 100644 --- a/vllm_ascend/ops/triton/batch_invariant/matmul.py +++ b/vllm_ascend/ops/triton/batch_invariant/matmul.py @@ -67,11 +67,9 @@ def matmul_bias_persistent_kernel( for k in range(0, tl.cdiv(K, BLOCK_K)): k_start = k * BLOCK_K # Calculate pointer offsets for x (row-major) - x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + - k_start) * stride_xk + x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk # Calculate pointer offsets for y (row-major) - y_ptrs = y_ptr + (rk[:, None] + - k_start) * stride_yk + rn[None, :] * stride_yn + y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn # Create masks to prevent out-of-bounds access x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K) @@ -89,14 +87,12 @@ def matmul_bias_persistent_kernel( # Load bias values (broadcast to all rows) bias_ptrs = bias_ptr + rn * stride_bias bias_mask = rn < N - bias_vals = tl.load(bias_ptrs, mask=bias_mask, - other=0.0).to(tl.float32) + bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32) # Add bias to accumulator (automatic broadcasting) acc += bias_vals[None, :] # Calculate output pointer positions - out_ptrs = output_ptr + rm[:, - None] * stride_outm + rn[None, :] * stride_outn + out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn out_mask = (rm[:, None] < M) & (rn[None, :] < N) # Store result to global memory @@ -106,28 +102,28 @@ def matmul_bias_persistent_kernel( def matmul_persistent(x, y, bias=None): """ Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None) - + Parameters: x: torch.Tensor, shape [M, K] y: torch.Tensor, shape [K, N] bias: torch.Tensor, shape [N] or None - + Returns: output: torch.Tensor, shape [M, N] """ # Validate input shapes assert x.dim() == 2, "x must be a 2D tensor" assert y.dim() == 2, "y must be a 2D tensor" - assert x.shape[1] == y.shape[ - 0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}" + assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}" M, K = x.shape _, N = y.shape # Validate bias shape (if not None) if bias is not None: assert bias.dim() == 1, "bias must be a 1D tensor" - assert y.shape[1] == bias.shape[ - 0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}" + assert y.shape[1] == bias.shape[0], ( + f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}" + ) # Allocate output tensor (same data type as x) output = torch.empty((M, N), dtype=x.dtype, device=x.device) @@ -176,24 +172,24 @@ def matmul_persistent(x, y, bias=None): @triton.jit def linear_persistent_kernel( - a_ptr, # Pointer to tensor a, shape [M, K] - b_ptr, # Pointer to tensor b, shape [N, K] - c_ptr, # Pointer to output tensor c, shape [M, N] - M, # Number of rows in tensor a - N, # Number of rows in tensor b (number of columns in output c) - K, # Number of columns in both tensor a and tensor b - stride_am, # Stride of tensor a along dimension M (typically K) - stride_ak, # Stride of tensor a along dimension K (typically 1) - stride_bn, # Stride of tensor b along dimension N (typically K) - stride_bk, # Stride of tensor b along dimension K (typically 1) - stride_cm, # Stride of tensor c along dimension M (typically N) - stride_cn, # Stride of tensor c along dimension N (typically 1) - BLOCK_M: tl.constexpr, # Block size for M dimension - BLOCK_N: tl.constexpr, # Block size for N dimension - BLOCK_K: tl.constexpr, # Block size for K dimension - NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension - NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension - GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size + a_ptr, # Pointer to tensor a, shape [M, K] + b_ptr, # Pointer to tensor b, shape [N, K] + c_ptr, # Pointer to output tensor c, shape [M, N] + M, # Number of rows in tensor a + N, # Number of rows in tensor b (number of columns in output c) + K, # Number of columns in both tensor a and tensor b + stride_am, # Stride of tensor a along dimension M (typically K) + stride_ak, # Stride of tensor a along dimension K (typically 1) + stride_bn, # Stride of tensor b along dimension N (typically K) + stride_bk, # Stride of tensor b along dimension K (typically 1) + stride_cm, # Stride of tensor c along dimension M (typically N) + stride_cn, # Stride of tensor c along dimension N (typically 1) + BLOCK_M: tl.constexpr, # Block size for M dimension + BLOCK_N: tl.constexpr, # Block size for N dimension + BLOCK_K: tl.constexpr, # Block size for K dimension + NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension + NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension + GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size ): # Get current program's 1D index (1D grid) pid = tl.program_id(0) @@ -226,18 +222,12 @@ def linear_persistent_kernel( k_mask = k_indices < K # Load block of tensor a: shape [BLOCK_M, BLOCK_K] - a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[ - None, :] * stride_ak - a_vals = tl.load(a_ptrs, - mask=m_mask[:, None] & k_mask[None, :], - other=0.0) + a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak + a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0) # Load block of tensor b: shape [BLOCK_N, BLOCK_K] - b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[ - None, :] * stride_bk - b_vals = tl.load(b_ptrs, - mask=n_mask[:, None] & k_mask[None, :], - other=0.0) + b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk + b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0) # Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N] b_vals_transposed = tl.trans(b_vals) @@ -246,8 +236,7 @@ def linear_persistent_kernel( product = tl.dot(a_vals, b_vals_transposed) acc += product # Store result to output tensor c - c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[ - None, :] * stride_cn + c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :]) @@ -255,19 +244,18 @@ def linear_persistent(x, y): """ Implement matrix multiplication with Triton: x @ y^T Uses a fixed-size 1D grid - + Parameters: x: torch.Tensor, shape [M, K] y: torch.Tensor, shape [N, K] - + Returns: output: torch.Tensor, shape [M, N] """ # Validate input shapes assert x.dim() == 2, "x must be a 2D tensor" assert y.dim() == 2, "y must be a 2D tensor" - assert x.shape[1] == y.shape[ - 1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}" + assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}" M, K = x.shape N, _ = y.shape @@ -283,9 +271,8 @@ def linear_persistent(x, y): num_blocks_n = triton.cdiv(N, BLOCK_N) # Set fixed 1D grid size - grid_size = driver.active.utils.get_device_properties( - torch.npu.current_device())["num_vectorcore"] // 2 - grid = (grid_size, ) + grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2 + grid = (grid_size,) # Launch kernel linear_persistent_kernel[grid]( @@ -330,8 +317,7 @@ def bmm_batch_invariant(a, b, *, out=None): return out return result else: - raise ValueError(f"bmm_batch_invariant expects 3D tensors, " - f"got shapes {a.shape} and {b.shape}") + raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}") def addmm_batch_invariant(bias, a, b): @@ -392,7 +378,8 @@ def matmul_batch_invariant(a, b, *, out=None): raise ValueError( f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, " f"3D x 2D, 2D x 3D, and 4D x 4D, " - f"got shapes {a.shape} and {b.shape}") + f"got shapes {a.shape} and {b.shape}" + ) def linear_batch_invariant(input_, weight, bias=None): diff --git a/vllm_ascend/ops/triton/batch_invariant/mean.py b/vllm_ascend/ops/triton/batch_invariant/mean.py index 0a13f734..19594838 100644 --- a/vllm_ascend/ops/triton/batch_invariant/mean.py +++ b/vllm_ascend/ops/triton/batch_invariant/mean.py @@ -86,8 +86,7 @@ def mean_dim( Tensor with mean values along specified dimension """ # Validate inputs - assert -input_.ndim <= dim < input_.ndim, ( - f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions") + assert -input_.ndim <= dim < input_.ndim, f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions" # Handle negative dim if dim < 0: @@ -123,7 +122,7 @@ def mean_dim( output_shape = shape.copy() output_shape[dim] = 1 else: - output_shape = shape[:dim] + shape[dim + 1:] + output_shape = shape[:dim] + shape[dim + 1 :] # Create output tensor output = torch.empty(output_shape, dtype=dtype, device=input_.device) @@ -135,7 +134,7 @@ def mean_dim( output_2d = output.reshape(M, K) # Launch kernel - grid = (M * K, ) + grid = (M * K,) BLOCK_SIZE = 1024 mean_kernel[grid]( @@ -165,13 +164,10 @@ def mean_batch_invariant( if len(dim) == 1: return mean_dim(input_, dim[0], keepdim=keepdim) else: - assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32 - }, ("only float types supported for now") + assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32}, "only float types supported for now" if len(dim) == 0: dim = list(range(input_.ndim)) n_elems = 1 for d in dim: n_elems *= input_.shape[d] - return torch.sum(input_, dim=dim, keepdim=keepdim, - dtype=torch.float32).to(dtype - or input_.dtype) / n_elems + return torch.sum(input_, dim=dim, keepdim=keepdim, dtype=torch.float32).to(dtype or input_.dtype) / n_elems diff --git a/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py b/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py index f4aa78a3..767854c2 100644 --- a/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py +++ b/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py @@ -100,8 +100,8 @@ def rms_norm( """ assert weight.dim() == 1, "Weight must be 1-dimensional" assert input_.shape[-1] == weight.shape[0], ( - f"Input last dimension ({input_.shape[-1]}) must match " - f"weight dimension ({weight.shape[0]})") + f"Input last dimension ({input_.shape[-1]}) must match weight dimension ({weight.shape[0]})" + ) # Flatten all dimensions except the last one original_shape = input_.shape @@ -113,10 +113,9 @@ def rms_norm( output = torch.empty_like(input_2d, dtype=input_.dtype) BLOCK_SIZE = 1024 - max_grid_size = driver.active.utils.get_device_properties( - torch.npu.current_device())["num_vectorcore"] + max_grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] - grid = (min(n_rows, max_grid_size), ) + grid = (min(n_rows, max_grid_size),) _rms_norm_kernel[grid]( input_2d, diff --git a/vllm_ascend/ops/triton/fla/chunk.py b/vllm_ascend/ops/triton/fla/chunk.py index 03d2d6cd..58b5cc72 100644 --- a/vllm_ascend/ops/triton/fla/chunk.py +++ b/vllm_ascend/ops/triton/fla/chunk.py @@ -9,7 +9,6 @@ # ruff: noqa: E501 # mypy: ignore-errors import warnings -from typing import Optional import torch from einops import rearrange @@ -25,22 +24,20 @@ from .utils import input_guard from .wy_fast import recompute_w_u_fwd -def chunk_gated_delta_rule_fwd(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None): +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, +): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd(k=k, - beta=beta, - g_cumsum=g, - cu_seqlens=cu_seqlens, - output_dtype=torch.float32) + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( k=k, @@ -75,20 +72,21 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, class ChunkGatedDeltaRuleFunction(torch.autograd.Function): - @staticmethod @input_guard - def forward(ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - use_qk_l2norm_in_kernel: bool = False): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) @@ -110,17 +108,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @torch.compiler.disable -def chunk_gated_delta_rule(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False, - use_qk_l2norm_in_kernel: bool = False): +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): r""" Args: q (torch.Tensor): @@ -186,41 +186,39 @@ def chunk_gated_delta_rule(q: torch.Tensor, """ assert q.dtype == k.dtype == v.dtype assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." - assert len( - beta.shape - ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." if head_first: raise DeprecationWarning( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", - stacklevel=2) - q, k, v, beta, g = map( - lambda x: rearrange(x, 'b h t ... -> b t h ...'), - (q, k, v, beta, g)) + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) if not head_first and q.shape[1] < q.shape[2]: warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2) + stacklevel=2, + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") - if initial_state is not None and initial_state.shape[0] != len( - cu_seqlens) - 1: + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, - use_qk_l2norm_in_kernel) + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel + ) if head_first: - o = rearrange(o, 'b t h ... -> b h t ...') - return o, final_state \ No newline at end of file + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py index 846623ad..d08fe9aa 100644 --- a/vllm_ascend/ops/triton/fla/chunk_delta_h.py +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -8,23 +8,24 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 # mypy: ignore-errors -from typing import Optional import torch from vllm.triton_utils import tl, triton from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp -_CONDITIONS = ("seq7168", ) +_CONDITIONS = ("seq7168",) -@triton.heuristics({ - "USE_G": lambda args: args["g"] is not None, - "USE_INITIAL_STATE": lambda args: args["h0"] is not None, - "STORE_FINAL_STATE": lambda args: args["ht"] is not None, - "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, - "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, -}) +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) @triton.jit(do_not_specialize=["T"]) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, @@ -85,28 +86,20 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( if USE_INITIAL_STATE: h0_ptr = h0 + i_nh * K * V ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1 - b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, - other=0.0).to(tl.float32) + b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, other=0.0).to(tl.float32) ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1 - b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, - other=0.0).to(tl.float32) + b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, other=0.0).to(tl.float32) # main recurrence for i_t in range(NT): h_base = h + (boh + i_t) * H * K * V + i_h * K * V - p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), - (128, 64), (1, 0)) - tl.store(p_h1_bv1, - b_h1_bv1.to(p_h1_bv1.dtype.element_ty), - boundary_check=(0, 1)) + p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0)) + tl.store(p_h1_bv1, b_h1_bv1.to(p_h1_bv1.dtype.element_ty), boundary_check=(0, 1)) - p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), - (128, 64), (1, 0)) - tl.store(p_h1_bv2, - b_h1_bv2.to(p_h1_bv2.dtype.element_ty), - boundary_check=(0, 1)) + p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0)) + tl.store(p_h1_bv2, b_h1_bv2.to(p_h1_bv2.dtype.element_ty), boundary_check=(0, 1)) offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None] offs_k_wv = tl.arange(0, 128)[None, :] @@ -117,8 +110,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_w = tl.load(ptr_w, mask=mask_w, other=0.0) k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K - p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), - (128, BT), (0, 1)) + p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), (128, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) v_new_base = v_new + bos * H * V + i_h * V @@ -144,12 +136,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype)) if SAVE_NEW_VALUE: - p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), - (i_t * BT, v_start1), (BT, 64), - (1, 0)) - tl.store(p_v_new1, - b_v_new1.to(p_v_new1.dtype.element_ty), - boundary_check=(0, 1)) + p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start1), (BT, 64), (1, 0)) + tl.store(p_v_new1, b_v_new1.to(p_v_new1.dtype.element_ty), boundary_check=(0, 1)) if USE_G: b_v_new1 = b_v_new1 * b_g[:, None] @@ -165,12 +153,8 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype)) if SAVE_NEW_VALUE: - p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), - (i_t * BT, v_start2), (BT, 64), - (1, 0)) - tl.store(p_v_new2, - b_v_new2.to(p_v_new2.dtype.element_ty), - boundary_check=(0, 1)) + p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start2), (BT, 64), (1, 0)) + tl.store(p_v_new2, b_v_new2.to(p_v_new2.dtype.element_ty), boundary_check=(0, 1)) if USE_G: b_v_new2 = b_v_new2 * b_g[:, None] @@ -183,29 +167,23 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( if STORE_FINAL_STATE: ht_ptr = ht + i_nh * K * V - p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), - (128, 64), (1, 0)) - tl.store(p_ht1_bv1, - b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), - boundary_check=(0, 1)) + p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0)) + tl.store(p_ht1_bv1, b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), boundary_check=(0, 1)) - p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), - (128, 64), (1, 0)) - tl.store(p_ht1_bv2, - b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), - boundary_check=(0, 1)) + p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0)) + tl.store(p_ht1_bv2, b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), boundary_check=(0, 1)) def chunk_gated_delta_rule_fwd_h( k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, - g: Optional[torch.Tensor] = None, - initial_state: Optional[torch.Tensor] = None, + g: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, output_final_state: bool = False, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # This kernel is slightly different from fla to support Q/K with different head numbers. # In fla, Q/K always have the same head number, so Hg is always equal to H. @@ -213,8 +191,7 @@ def chunk_gated_delta_rule_fwd_h( H = u.shape[-2] BT = chunk_size - chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size) - if cu_seqlens is not None else None) + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None # N: the actual number of sequences in the batch with either equal or variable lengths if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None @@ -227,8 +204,7 @@ def chunk_gated_delta_rule_fwd_h( assert K <= 256, "current kernel does not support head dimension larger than 256." h = k.new_empty(B, NT, H, K, V) - final_state = (k.new_empty(N, H, K, V, dtype=torch.float32) - if output_final_state else None) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None v_new = torch.empty_like(u) if save_new_value else None g = g.transpose(1, 2).contiguous() diff --git a/vllm_ascend/ops/triton/fla/chunk_o.py b/vllm_ascend/ops/triton/fla/chunk_o.py index 5a3578a8..258b40eb 100644 --- a/vllm_ascend/ops/triton/fla/chunk_o.py +++ b/vllm_ascend/ops/triton/fla/chunk_o.py @@ -9,7 +9,6 @@ # ruff: noqa: E501 # mypy: ignore-errors -from typing import Optional import torch from vllm.triton_utils import tl, triton @@ -17,11 +16,13 @@ from vllm.triton_utils import tl, triton from .utils import prepare_chunk_offsets, safe_exp -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, -}) -@triton.jit(do_not_specialize=['T']) +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) def chunk_fwd_kernel_o( q, k, @@ -48,8 +49,7 @@ def chunk_fwd_kernel_o( T_max = T if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int64) @@ -71,12 +71,9 @@ def chunk_fwd_kernel_o( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), - (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), - (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), - (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BT] @@ -102,10 +99,8 @@ def chunk_fwd_kernel_o( m_A = o_i[:, None] >= o_i[None, :] b_A = tl.where(m_A, b_A, 0) - p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) # to fix mma -> mma layout conversion @@ -119,9 +114,9 @@ def chunk_fwd_o( k: torch.Tensor, v: torch.Tensor, h: torch.Tensor, - g: Optional[torch.Tensor] = None, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, + g: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] @@ -129,7 +124,7 @@ def chunk_fwd_o( BT = chunk_size if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 o = torch.empty_like(v) if cu_seqlens is None: @@ -141,7 +136,7 @@ def chunk_fwd_o( ) def grid(meta): - return (triton.cdiv(V, meta['BV']), N * H) + return (triton.cdiv(V, meta["BV"]), N * H) g = g.transpose(1, 2).contiguous() chunk_fwd_kernel_o[grid]( diff --git a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py index aa183149..1ad1aead 100644 --- a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py +++ b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 # mypy: ignore-errors -from typing import Optional import torch from vllm.triton_utils import tl, triton @@ -16,11 +15,13 @@ from vllm.triton_utils import tl, triton from .utils import prepare_chunk_indices, safe_exp -@triton.heuristics({ - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, - 'USE_G': lambda args: args['g_cumsum'] is not None, -}) -@triton.jit(do_not_specialize=['T']) +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "USE_G": lambda args: args["g_cumsum"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) def chunk_scaled_dot_kkt_fwd_kernel( k, beta, # [H, B, T] @@ -44,10 +45,11 @@ def chunk_scaled_dot_kkt_fwd_kernel( for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to( - tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t_i * 2).to(tl.int32), + tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32), + ) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -55,39 +57,37 @@ def chunk_scaled_dot_kkt_fwd_kernel( o_t = tl.arange(0, BT) o_t_fp32 = o_t.to(tl.float32) - p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ), - (i_t * BT, ), (BT, ), (0, )) - b_beta = tl.load(p_beta, boundary_check=(0, )) + p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, - (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), - (BT, BK), (1, 0)) + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_A += tl.dot(b_k, tl.trans(b_k)) if USE_G: - p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ), - (1, ), (i_t * BT, ), (BT, ), (0, )) - b_g = tl.load(p_g, boundary_check=(0, )) + p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) b_g_diff = b_g[:, None] - b_g[None, :] b_A *= safe_exp(b_g_diff) b_A *= b_beta[:, None] b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0) - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), - (i_t * BT, 0), (BT, BT), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) def chunk_scaled_dot_kkt_fwd( - k: torch.Tensor, - beta: torch.Tensor, - g_cumsum: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, - output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: r""" Compute beta * K * K^T. @@ -117,8 +117,7 @@ def chunk_scaled_dot_kkt_fwd( BT = chunk_size if cu_seqlens is not None: cu_seqlens = cu_seqlens.cpu() - chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) - if cu_seqlens is not None else None) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None chunk_indices = chunk_indices.npu() cu_seqlens = cu_seqlens.npu() else: diff --git a/vllm_ascend/ops/triton/fla/cumsum.py b/vllm_ascend/ops/triton/fla/cumsum.py index e93a2438..da7bf8c9 100644 --- a/vllm_ascend/ops/triton/fla/cumsum.py +++ b/vllm_ascend/ops/triton/fla/cumsum.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 # mypy: ignore-errors -from typing import Optional import torch from vllm.triton_utils import tl, triton @@ -16,11 +15,10 @@ from vllm.triton_utils import tl, triton from .utils import prepare_chunk_indices -@triton.heuristics({ - 'HAS_SCALE': lambda args: args['scale'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None -}) -@triton.jit(do_not_specialize=['T']) +@triton.heuristics( + {"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} +) +@triton.jit(do_not_specialize=["T"]) def chunk_local_cumsum_scalar_kernel( s, o, @@ -41,20 +39,19 @@ def chunk_local_cumsum_scalar_kernel( N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE if IS_VARLEN: - i_s, i_block = tl.load(chunk_indices + i_block * 2).to( - tl.int32), tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_s).to( - tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32) + i_s, i_block = ( + tl.load(chunk_indices + i_block * 2).to(tl.int32), + tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32), + ) + bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: - ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), - (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) - ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), - (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) - b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32) + ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32) b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE)) b_s = tl.trans(b_s, (2, 0, 1)) b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) @@ -63,11 +60,9 @@ def chunk_local_cumsum_scalar_kernel( b_o = tl.trans(b_o, (2, 0, 1)) b_o = tl.reshape(b_o, (H, BLOCK_T)) else: - ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), - (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) - ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), - (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) - b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32) + ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32) b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H)) b_s = tl.trans(b_s, (1, 0, 2)) b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) @@ -76,7 +71,7 @@ def chunk_local_cumsum_scalar_kernel( b_o = tl.trans(b_o, (1, 0, 2)) b_o = tl.reshape(b_o, (BLOCK_T, H)) - tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, )) + tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0,)) return @@ -85,61 +80,64 @@ def chunk_local_cumsum_scalar( chunk_size, reverse: bool = False, scale: float = None, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.Tensor] = torch.float, + output_dtype: torch.Tensor | None = torch.float, ): if head_first: B, H, T = g.shape else: B, T, H = g.shape - assert chunk_size == 2**(chunk_size.bit_length() - - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size)) - block_indices = prepare_chunk_indices( - cu_seqlens, - chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None - num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv( - T, OPTIM_BLOCK_SIZE) + block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None + num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (num_blocks, B) - chunk_local_cumsum_scalar_kernel[grid](s=g_org, - o=g, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_indices=block_indices, - T=T, - B=B, - H=H, - BLOCK_T=OPTIM_BLOCK_SIZE, - CHUNK_SIZE=chunk_size, - HEAD_FIRST=head_first, - REVERSE=reverse, - num_warps=8, - num_stages=3) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=block_indices, + T=T, + B=B, + H=H, + BLOCK_T=OPTIM_BLOCK_SIZE, + CHUNK_SIZE=chunk_size, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=8, + num_stages=3, + ) return g -def chunk_local_cumsum(g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - scale: float = None, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, - **kwargs) -> torch.Tensor: +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + **kwargs, +) -> torch.Tensor: if cu_seqlens is not None: - assert g.shape[ - 0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" if len(g.shape) == 3: - return chunk_local_cumsum_scalar(g=g, - chunk_size=chunk_size, - reverse=reverse, - scale=scale, - cu_seqlens=cu_seqlens, - head_first=head_first, - output_dtype=output_dtype) + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + ) else: - raise ValueError(f"Unsupported input shape {g.shape}, " - f"which should be (B, T, H, D) if `head_first=False` " - f"or (B, H, T, D) otherwise") + raise ValueError( + f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py b/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py index d809dcd4..e58e79bb 100644 --- a/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py +++ b/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py @@ -31,29 +31,30 @@ def fused_qkvzba_split_reshape_cat_kernel( BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V q_end: tl.constexpr = HEAD_QK - blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + - i_qk * QKVZ_DIM_T + tl.arange(0, q_end)) + blk_q_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(0, q_end) k_end: tl.constexpr = q_end + HEAD_QK - blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + - i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end)) + blk_k_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end) v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + - i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end)) + blk_v_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end) z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + - i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end)) - blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + - i_qk * HEAD_QK + tl.arange(0, HEAD_QK)) - blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + - NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + - tl.arange(0, HEAD_QK)) - blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + - NUM_HEADS_QK * HEAD_QK * 2 + - i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + - tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)) - blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V + - i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + - tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK)) + blk_z_ptr = mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end) + blk_q_st_ptr = mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + i_qk * HEAD_QK + tl.arange(0, HEAD_QK) + blk_k_st_ptr = ( + mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK + tl.arange(0, HEAD_QK) + ) + blk_v_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK * 2 + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + blk_z_st_ptr = ( + z + + i_bs * NUM_HEADS_V * HEAD_V + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) @@ -66,8 +67,7 @@ def fused_qkvzba_split_reshape_cat_kernel( tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) for i in tl.static_range(b_end, a_end): blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i - blk_a_st_ptr = (a + i_bs * NUM_HEADS_V + - i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end)) + blk_a_st_ptr = a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) diff --git a/vllm_ascend/ops/triton/fla/l2norm.py b/vllm_ascend/ops/triton/fla/l2norm.py index 82c83247..9ba89faa 100644 --- a/vllm_ascend/ops/triton/fla/l2norm.py +++ b/vllm_ascend/ops/triton/fla/l2norm.py @@ -15,8 +15,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @triton.jit -def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, - MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): +def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) rindex = tl.arange(0, N)[None, :] @@ -24,8 +23,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None] xmask = row_idx < M - xs = tl.load(X + (rindex + N * row_idx), mask=xmask, - other=0.0).to(tl.float32) + xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32) square = xs * xs square_sum = tl.sum(square, 1)[:, None] rsqrt = tl.rsqrt(square_sum + eps) @@ -33,9 +31,7 @@ def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) -def l2norm_fwd(x: torch.Tensor, - eps: float = 1e-6, - output_dtype: torch.dtype | None = None): +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): x_shape_og = x.shape x = x.reshape(-1, x.shape[-1]) # allocate output @@ -56,7 +52,7 @@ def l2norm_fwd(x: torch.Tensor, num_core = get_vectorcore_num() main_bs = triton.cdiv(T, num_core) num_sub_blocks = triton.cdiv(main_bs, MBLOCK) - grid = (num_core, ) + grid = (num_core,) l2norm_fwd_kernel2_loop[grid]( X=x, Y=y, diff --git a/vllm_ascend/ops/triton/fla/layernorm_guard.py b/vllm_ascend/ops/triton/fla/layernorm_guard.py index c99f9e08..c9decfac 100644 --- a/vllm_ascend/ops/triton/fla/layernorm_guard.py +++ b/vllm_ascend/ops/triton/fla/layernorm_guard.py @@ -12,10 +12,12 @@ from vllm.triton_utils import tl, triton MAX_CORES = 65535 -@triton.heuristics({ - "HAS_BIAS": lambda args: args["B"] is not None, - "HAS_Z": lambda args: args["Z"] is not None, -}) +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + } +) @triton.jit def layer_norm_fwd_kernel( X, # pointer to the input @@ -49,13 +51,10 @@ def layer_norm_fwd_kernel( n_iters = n_iters + 1 for i in tl.range(n_iters): - X_base = X + (i * BLOCK_ROWS * - stride_x_row) + row * stride_x_row + group * N - Y_base = Y + (i * BLOCK_ROWS * - stride_y_row) + row * stride_y_row + group * N + X_base = X + (i * BLOCK_ROWS * stride_x_row) + row * stride_x_row + group * N + Y_base = Y + (i * BLOCK_ROWS * stride_y_row) + row * stride_y_row + group * N if HAS_Z: - Z_base = Z + (i * BLOCK_ROWS * - stride_z_row) + row * stride_z_row + group * N + Z_base = Z + (i * BLOCK_ROWS * stride_z_row) + row * stride_z_row + group * N if not IS_RMS_NORM: Mean_base = Mean + (i * BLOCK_ROWS) + group * M Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M @@ -64,17 +63,17 @@ def layer_norm_fwd_kernel( B_base = B + group * N # Compute mean and variance cols = tl.arange(0, BLOCK_N) - x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X_base + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32) x *= z * tl.sigmoid(z) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean_base + row, mean) - xbar = tl.where(cols < N, x - mean, 0.) + xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: - xbar = tl.where(cols < N, x, 0.) + xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd_base + row, rstd) @@ -112,26 +111,24 @@ def _layer_norm_fwd( if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) - assert weight.shape == (N, ) + assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 - assert bias.shape == (N, ) + assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 - mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) - if not is_rms_norm else None) - rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M if M < MAX_CORES else MAX_CORES, ngroups) @@ -160,7 +157,6 @@ def _layer_norm_fwd( class LayerNormFn(torch.autograd.Function): - @staticmethod def forward( ctx, diff --git a/vllm_ascend/ops/triton/fla/sigmoid_gating.py b/vllm_ascend/ops/triton/fla/sigmoid_gating.py index b4c063d2..e5512c6d 100644 --- a/vllm_ascend/ops/triton/fla/sigmoid_gating.py +++ b/vllm_ascend/ops/triton/fla/sigmoid_gating.py @@ -14,7 +14,7 @@ import os import torch from vllm.triton_utils import tl, tldevice, triton -if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": div = tldevice.fast_dividef exp = tldevice.fast_expf log = tldevice.fast_logf @@ -31,17 +31,15 @@ else: log2 = tl.log2 -@triton.heuristics({ - 'USE_INITIAL_STATE': - lambda args: args['h0'] is not None, - 'IS_VARLEN': - lambda args: args['cu_seqlens'] is not None, - "IS_CONTINUOUS_BATCHING": - lambda args: args['ssm_state_indices'] is not None, - "IS_SPEC_DECODING": - lambda args: args['num_accepted_tokens'] is not None, -}) -@triton.jit(do_not_specialize=['N', 'T']) +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) def fused_recurrent_gated_delta_rule_fwd_kernel( q, k, @@ -70,8 +68,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( stride_indices_tok: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, # whether to use initial state INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace - IS_BETA_HEADWISE: tl. - constexpr, # whether beta is headwise vector or scalar, + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, USE_QK_L2NORM_IN_KERNEL: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, @@ -82,8 +79,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( i_n, i_hv = i_nh // HV, i_nh % HV i_h = i_hv // (HV // H) if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) all = T T = eos - bos else: @@ -108,8 +104,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: i_t = 0 - p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_init_state_token + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) else: p_h0 = h0 + bos * HV * K * V p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] @@ -164,18 +161,21 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( # keep the states for multi-query tokens if INPLACE_FINAL_STATE: - p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_final_state_token + p_ht = ( + ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token + ) else: p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) -@triton.heuristics({ - "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, - "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, -}) +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) @triton.jit(do_not_specialize=["T"]) def fused_sigmoid_gating_delta_rule_update_kernel( A_log, @@ -245,8 +245,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel( idx = tl.load(h0_indices + i_n) # if idx >= 0: tmp0 = tl.where(idx < 0, 0, idx) - p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V + - o_k[:, None] * V + o_v[None, :]) + p_h0 = h0_source + tmp0 * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :] temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) temp2 = tl.zeros_like(temp1) value0 = tl.where(idx < 0, temp2, temp1) @@ -314,8 +313,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel( if USE_INITIAL_STATE: idx = tl.load(h0_indices + i_n) if idx >= 0: - p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + - o_k[:, None] * V + o_v[None, :]) + p_h0 = h0_source + idx * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) @@ -350,7 +348,7 @@ def fused_sigmoid_gating_delta_rule_update( num_warps = 1 if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 else: assert scale > 0, "scale must be positive" diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index a8000320..62a943fb 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 # mypy: ignore-errors -from typing import Optional import torch from vllm.triton_utils import tl, triton @@ -66,9 +65,13 @@ def solve_tril_16x16_kernel( offs_cols_in_block = tl.arange(0, 16) # 2 Calculate the pointer of each element - ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o + - offs_rows_in_block[:, None] * H * BT + - offs_cols_in_block[None, :]) + ptr_A_subrec16 = ( + A + + row_start_o * H * BT + + col_start_o + + offs_rows_in_block[:, None] * H * BT + + offs_cols_in_block[None, :] + ) # 3 Create a mask to prevent out-of-bounds access global_rows = row_start_o + offs_rows_in_block[:, None] @@ -76,14 +79,14 @@ def solve_tril_16x16_kernel( load_mask = (global_rows < T) & (global_cols < BT) # 4 Use mask to safely load data - b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, - other=0.0).to(tl.float32) + b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32) b_A = tl.insert_slice( ful=b_A, sub=b_A_subrec16[None, :, :], # (1, 16, 16) offsets=[blkid, 0, 0], sizes=[1, 16, 16], - strides=[1, 1, 1]) + strides=[1, 1, 1], + ) local_ori_A = tl.trans(b_A, (1, 0, 2)) local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS)) @@ -97,9 +100,7 @@ def solve_tril_16x16_kernel( # for loop to update N_BLOCKS row vector for i in range(1, 16): - nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), - (1, 16 * N_BLOCKS), - (16 * N_BLOCKS, 1)) + nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1)) b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) @@ -107,34 +108,27 @@ def solve_tril_16x16_kernel( b_a = b_a + dot_product b_a_new_expanded = b_a[:, None, :] - b_A = tl.insert_slice(ful=b_A, - sub=b_a_new_expanded, - offsets=[0, i, 0], - sizes=[N_BLOCKS, 1, 16], - strides=[1, 1, 1]) + b_A = tl.insert_slice( + ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1] + ) - on_diagonal = (rows == cols) + on_diagonal = rows == cols b_A = tl.where(on_diagonal, b_A + 1.0, b_A) b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16)) - p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), - (N_BLOCKS * 16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0)) # 1 Create in-block offset offs_rows_to_store = tl.arange(0, N_BLOCKS * 16) offs_cols_to_store = tl.arange(0, 16) # 2 Calculate the pointer of each element - p_Ai = (Ad + base_t * H * 16 + 0 + - offs_rows_to_store[:, None] * H * 16 + - offs_cols_to_store[None, :]) + p_Ai = Ad + base_t * H * 16 + 0 + offs_rows_to_store[:, None] * H * 16 + offs_cols_to_store[None, :] # 3 Create a mask to prevent out-of-bounds access, only check rows global_store_rows = base_t + offs_rows_to_store[:, None] store_mask = global_store_rows < T # 4 use mask to save data safely - tl.store(p_Ai, - b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), - mask=store_mask) + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=store_mask) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @@ -169,18 +163,12 @@ def merge_16x16_to_32x32_inverse_kernel( Ad += (bos * H + i_h) * 16 Ai += (bos * H + i_h) * 32 - p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), - (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), - (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), - (16, 16), (1, 0)) - p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), - (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), - (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), - (16, 16), (1, 0)) + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) @@ -313,26 +301,20 @@ def merge_16x16_to_64x64_inverse_kernel( offs_n = tl.arange(0, 32) mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] - tl.store(ptr_Ai, - Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), - mask=mask_store) + tl.store(ptr_Ai, Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store) # store Ai_22_32 to (i_t * 64 + 32, 32) offs_m = i_t * 64 + 32 + tl.arange(0, 32) offs_n = 32 + tl.arange(0, 32) mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] - tl.store(ptr_Ai, - Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), - mask=mask_store) + tl.store(ptr_Ai, Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store) # store Ai_21_32 to (i_t * 64 + 32, 32) offs_n = tl.arange(0, 32) mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] - tl.store(ptr_Ai, - Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), - mask=mask_store) + tl.store(ptr_Ai, Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store) # zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63) offs_m = i_t * 64 + tl.arange(0, 32) @@ -345,7 +327,7 @@ def merge_16x16_to_64x64_inverse_kernel( def solve_tril( A: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float, ) -> torch.Tensor: """ @@ -367,19 +349,12 @@ def solve_tril( assert A.shape[-1] in [16, 32, 64] B, T, H, BT = A.shape - Ad = torch.empty(B, - T, - H, - 16, - device=A.device, - dtype=torch.float if BT != 16 else output_dtype) + Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) LARGE_BLOCK_T = 608 * 2 - chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) - if cu_seqlens is not None else None) - NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv( - T, LARGE_BLOCK_T) + chunk_indices = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T) solve_tril_16x16_kernel[NT, B * H]( A=A, @@ -398,10 +373,8 @@ def solve_tril( return Ad Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) - merge_fn = (merge_16x16_to_32x32_inverse_kernel - if BT == 32 else merge_16x16_to_64x64_inverse_kernel) - chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) - if cu_seqlens is not None else None) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) merge_fn[NT, B * H]( diff --git a/vllm_ascend/ops/triton/fla/utils.py b/vllm_ascend/ops/triton/fla/utils.py index 4d2cd135..fa23c2af 100644 --- a/vllm_ascend/ops/triton/fla/utils.py +++ b/vllm_ascend/ops/triton/fla/utils.py @@ -9,7 +9,7 @@ # ruff: noqa: E501 import contextlib import functools -from typing import Callable +from collections.abc import Callable import torch from vllm.triton_utils import tl, triton @@ -19,38 +19,24 @@ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: return cu_seqlens[1:] - cu_seqlens[:-1] -def prepare_chunk_indices(cu_seqlens: torch.LongTensor, - chunk_size: int) -> torch.LongTensor: - indices = torch.cat([ - torch.arange(n) - for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() - ]) - return torch.stack([indices.eq(0).cumsum(0) - 1, indices], - 1).to(cu_seqlens) +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) -def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, - chunk_size: int) -> torch.LongTensor: - return torch.cat([ - cu_seqlens.new_tensor([0]), - triton.cdiv(prepare_lens(cu_seqlens), chunk_size) - ]).cumsum(-1) +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) -def input_guard( - fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ A decorator to make sure all input tensors are contiguous and set the device based on input tensors. """ @functools.wraps(fn) def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else - i.contiguous() for i in args) - contiguous_kwargs = { - k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) - for k, v in kwargs.items() - } + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} tensor = None for arg in args: diff --git a/vllm_ascend/ops/triton/fla/wy_fast.py b/vllm_ascend/ops/triton/fla/wy_fast.py index 1d4c2955..d6e24075 100644 --- a/vllm_ascend/ops/triton/fla/wy_fast.py +++ b/vllm_ascend/ops/triton/fla/wy_fast.py @@ -9,7 +9,6 @@ # ruff: noqa: E501 # mypy: ignore-errors -from typing import Optional, Tuple import torch from vllm.triton_utils import tl, triton @@ -17,23 +16,39 @@ from vllm.triton_utils import tl, triton from .utils import prepare_chunk_indices -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) -@triton.jit(do_not_specialize=['T']) -def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, - T, H: tl.constexpr, Hg: tl.constexpr, - K: tl.constexpr, V: tl.constexpr, - BT: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, IS_VARLEN: tl.constexpr): +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): T_max = T i_t_o = tl.program_id(0) for i_bh in range(H): i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t_o * 2).to( - tl.int32), tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t_o * 2).to(tl.int32), + tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32), + ) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -44,7 +59,7 @@ def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, offs_t_2d = global_offs_t[:, None] offs_bt = tl.arange(0, BT)[None, :] - ptr_A = (A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1) + ptr_A = A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1 mask_A = mask_t[:, None] b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) @@ -58,29 +73,25 @@ def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, offs_v = i_v * BV + tl.arange(0, BV)[None, :] mask_v = (mask_t[:, None]) & (offs_v < V) - ptr_v = (v + (bos * H + i_h) * V + offs_t_2d * (H * V) + - offs_v * 1) + ptr_v = v + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1 b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32) - b_vb = (b_v * b_beta[:, None]) + b_vb = b_v * b_beta[:, None] b_u = tl.dot(b_A, b_vb, allow_tf32=False) - ptr_u = (u + (bos * H + i_h) * V + offs_t_2d * (H * V) + - offs_v * 1) + ptr_u = u + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1 tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v) for i_k in range(tl.cdiv(K, BK)): offs_k = i_k * BK + tl.arange(0, BK)[None, :] mask_k = (mask_t[:, None]) & (offs_k < K) - ptr_k = (k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * - (Hg * K) + offs_k * 1) + ptr_k = k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * (Hg * K) + offs_k * 1 b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32) - b_kb = (b_k * b_beta[:, None] * b_g[:, None]) + b_kb = b_k * b_beta[:, None] * b_g[:, None] b_w = tl.dot(b_A, b_kb) - ptr_w = (w + (bos * H + i_h) * K + offs_t_2d * (H * K) + - offs_k * 1) + ptr_w = w + (bos * H + i_h) * K + offs_t_2d * (H * K) + offs_k * 1 tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k) @@ -90,14 +101,13 @@ def recompute_w_u_fwd( beta: torch.Tensor, g_cumsum: torch.Tensor, A: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + cu_seqlens: torch.LongTensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] BT = A.shape[-1] - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) \ - if cu_seqlens is not None else None + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BK = 64 diff --git a/vllm_ascend/ops/triton/fused_gdn_gating.py b/vllm_ascend/ops/triton/fused_gdn_gating.py index dfd5dde2..b3b05706 100644 --- a/vllm_ascend/ops/triton/fused_gdn_gating.py +++ b/vllm_ascend/ops/triton/fused_gdn_gating.py @@ -30,15 +30,12 @@ def fused_gdn_gating_kernel( ): i_b, i_s = tl.program_id(0), tl.program_id(1) for row_idx in range(0, ROW_ITER): - batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange( - 0, BLK_BATCHES) + batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES) for col_idx in range(0, COL_ITER): head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS) - off = batch_off[:, - None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[ - None, :] + off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :] head_mask = head_off < NUM_HEADS mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES) @@ -48,17 +45,14 @@ def fused_gdn_gating_kernel( blk_bias = tl.load(dt_bias + head_off, mask=head_mask) x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :] - softplus_x = tl.where(beta * x <= threshold, - (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) # compute beta_output = sigmoid(b) blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) - tl.store(beta_output + off, - blk_beta_output.to(beta_output.dtype.element_ty), - mask=mask) + tl.store(beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask) def fused_gdn_gating_patch( @@ -85,17 +79,13 @@ def fused_gdn_gating_patch( progs = num_cores FACTOR = 8 * num_heads row_per_core = triton.cdiv(batch, num_cores) - BLK_BATCHES = triton.next_power_of_2( - triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // - a.element_size()) // 2 + BLK_BATCHES = ( + triton.next_power_of_2(triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // a.element_size()) // 2 + ) ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) - beta_output = torch.empty(1, - batch, - num_heads, - dtype=b.dtype, - device=b.device) + beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) grid = (progs, seq_len) fused_gdn_gating_kernel[grid]( diff --git a/vllm_ascend/ops/triton/layernorm_gated.py b/vllm_ascend/ops/triton/layernorm_gated.py index 48ca46f5..76c418f2 100644 --- a/vllm_ascend/ops/triton/layernorm_gated.py +++ b/vllm_ascend/ops/triton/layernorm_gated.py @@ -9,6 +9,7 @@ import torch from vllm.triton_utils import tl, triton + @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) @triton.jit @@ -131,11 +132,7 @@ def layer_norm_fwd_npu( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - mean = ( - torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - if not is_rms_norm - else None - ) + mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) MAX_FUSED_SIZE = 65536 // x.element_size() @@ -168,4 +165,4 @@ def layer_norm_fwd_npu( IS_RMS_NORM=is_rms_norm, # Remove multibuffer if not needed ) - return out, mean, rstd \ No newline at end of file + return out, mean, rstd diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 6bc2d373..14bae3b7 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -14,7 +14,6 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from typing import Optional import torch import triton # type: ignore @@ -61,22 +60,19 @@ def split_qkv_rmsnorm_rope_kernel( for row_idx in tl.range(row_pid, batch_size, row_step): col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) valid_mask = col_indices < q_hidden_size - input_values = (tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0).to(tl.float32).reshape( - Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)) + input_values = ( + tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) + .to(tl.float32) + .reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + ) squares = input_values * input_values variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - Q_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = (input_values * reciprocal_std - ) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(Q_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) if BIAS: - normalized_values = (normalized_values * weight_values + - bias_values).to(tl.bfloat16) + normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) else: - normalized_values = (normalized_values * weight_values).to( - tl.bfloat16) + normalized_values = (normalized_values * weight_values).to(tl.bfloat16) sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) @@ -93,8 +89,7 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), - dtype=tl.bfloat16) + cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) cat_x = tl.insert_slice( cat_x, -x2, @@ -127,22 +122,19 @@ def split_qkv_rmsnorm_rope_kernel( for row_idx in tl.range(row_pid, batch_size, row_step): col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) valid_mask = col_indices < kv_hidden_size - input_values = (tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0).to(tl.float32).reshape( - KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)) + input_values = ( + tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) + .to(tl.float32) + .reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + ) squares = input_values * input_values variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - KV_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = (input_values * reciprocal_std - ) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(KV_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) if BIAS: - normalized_values = (normalized_values * weight_values + - bias_values).to(tl.bfloat16) + normalized_values = (normalized_values * weight_values + bias_values).to(tl.bfloat16) else: - normalized_values = (normalized_values * weight_values).to( - tl.bfloat16) + normalized_values = (normalized_values * weight_values).to(tl.bfloat16) sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) @@ -158,8 +150,7 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), - dtype=tl.bfloat16) + cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) cat_x = tl.insert_slice( cat_x, -x2, @@ -189,12 +180,8 @@ def split_qkv_rmsnorm_rope_kernel( for _ in tl.range(row_pid, batch_size, row_step): col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0) - tl.store(v_ptr + output_offset + col_indices, - input_values, - mask=valid_mask) + input_values = tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) + tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask) input_offset += input_offset_step output_offset += output_offset_step @@ -209,27 +196,18 @@ def split_qkv_rmsnorm_rope_impl( kv_hidden_size: int, head_dim: int, eps: float, - q_bias: Optional[torch.Tensor] = None, - k_bias: Optional[torch.Tensor] = None, + q_bias: torch.Tensor | None = None, + k_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) - assert KV_BLOCK_SIZE == head_dim + assert head_dim == KV_BLOCK_SIZE assert q_hidden_size % kv_hidden_size == 0 Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim batch_size = input.shape[0] total_hidden_size = q_hidden_size + kv_hidden_size * 2 - q_output = torch.empty(batch_size, - q_hidden_size, - device=input.device, - dtype=input.dtype) - k_output = torch.empty(batch_size, - kv_hidden_size, - device=input.device, - dtype=input.dtype) - v_output = torch.empty(batch_size, - kv_hidden_size, - device=input.device, - dtype=input.dtype) + q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype) + k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) + v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) n_cols = kv_hidden_size // KV_BLOCK_SIZE num_vectorcore = get_vectorcore_num() assert num_vectorcore % n_cols == 0 @@ -271,8 +249,8 @@ def split_qkv_rmsnorm_rope_impl_fake( kv_hidden_size: int, head_dim: int, eps: float, - q_bias: Optional[torch.Tensor] = None, - k_bias: Optional[torch.Tensor] = None, + q_bias: torch.Tensor | None = None, + k_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Fake implementation for shape inference during Dynamo/AOT tracing. # Note: sin and cos are not used in shape computation, but must be present in signature. @@ -298,8 +276,10 @@ def split_qkv_rmsnorm_rope_impl_fake( return q_output, k_output, v_output -direct_register_custom_op(op_name="qkv_rmsnorm_rope", - op_func=split_qkv_rmsnorm_rope_impl, - fake_impl=split_qkv_rmsnorm_rope_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") +direct_register_custom_op( + op_name="qkv_rmsnorm_rope", + op_func=split_qkv_rmsnorm_rope_impl, + fake_impl=split_qkv_rmsnorm_rope_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index 4a304d99..4080d6c3 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -7,24 +7,23 @@ # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # mypy: ignore-errors -from typing import Any, Optional +from typing import Any import torch import torch.nn.functional as F import triton import triton.language as tl - from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore def causal_conv1d_ref( x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, + initial_states: torch.Tensor | None = None, return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", + final_states_out: torch.Tensor | None = None, + activation: str | None = "silu", ): """ x: (batch, dim, seqlen) @@ -42,19 +41,14 @@ def causal_conv1d_ref( dim, width = weight.shape if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: @@ -66,13 +60,13 @@ def causal_conv1d_ref( def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - conv_states: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - metadata: Optional[Any] = None, + bias: torch.Tensor | None = None, + activation: str | None = "silu", + conv_states: torch.Tensor | None = None, + has_initial_state: torch.Tensor | None = None, + cache_indices: torch.Tensor | None = None, + query_start_loc: torch.Tensor | None = None, + metadata: Any | None = None, pad_slot_id: int = PAD_SLOT_ID, ): """ @@ -126,10 +120,10 @@ def causal_conv1d_fn( bias, activation=activation, return_final_states=True, - final_states_out=conv_states[cache_indices[i]][..., :( - width - 1)].unsqueeze(0), - initial_states=conv_states[cache_indices[i]][..., :(width - 1)] - if has_initial_state[i] else None)) + final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0), + initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None, + ) + ) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) out_ref_tensor = torch.cat(out_ref, dim=0) return out_ref_tensor @@ -137,54 +131,50 @@ def causal_conv1d_fn( @triton.jit def _causal_conv1d_update_kernel_npu_tiled( - # Pointers - x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, # (num_cache_lines, dim, state_len) - conv_state_indices_ptr, - num_accepted_tokens_ptr, - query_start_loc_ptr, # (batch + 1) - block_idx_last_scheduled_token, # (batch,) - initial_state_idx, # (batch,) - o_ptr, # same shape as x_ptr - batch: tl.int32, - dim: tl.constexpr, - seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen - state_len: tl.constexpr, # effective state_len computed in wrapper - num_cache_lines: tl.constexpr, - - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - - # others - pad_slot_id: tl.constexpr, - - # Meta - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, # <= 6 - SILU_ACTIVATION: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_APC_ENABLED: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - - # tiling - BLOCK_N: tl.constexpr, # channel tile (C_TILE) - B_TILE: tl.constexpr, # batch tile - T_CHUNK: tl.constexpr, # token chunk for state update + # Pointers + x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, # (num_cache_lines, dim, state_len) + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + o_ptr, # same shape as x_ptr + batch: tl.int32, + dim: tl.constexpr, + seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen + state_len: tl.constexpr, # effective state_len computed in wrapper + num_cache_lines: tl.constexpr, + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, # <= 6 + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + # tiling + BLOCK_N: tl.constexpr, # channel tile (C_TILE) + B_TILE: tl.constexpr, # batch tile + T_CHUNK: tl.constexpr, # token chunk for state update ): # program ids pid_b = tl.program_id(0) # batch-tile id @@ -197,37 +187,30 @@ def _causal_conv1d_update_kernel_npu_tiled( # preload weights once per program (shared by B_TILE sequences) w_base = w_ptr + idx_feats * stride_w_dim # define to avoid "undefined" in branches - w_col0 = tl.zeros((BLOCK_N, ), dtype=tl.float32) - w_col1 = tl.zeros((BLOCK_N, ), dtype=tl.float32) - w_col2 = tl.zeros((BLOCK_N, ), dtype=tl.float32) - w_col3 = tl.zeros((BLOCK_N, ), dtype=tl.float32) - w_col4 = tl.zeros((BLOCK_N, ), dtype=tl.float32) - w_col5 = tl.zeros((BLOCK_N, ), dtype=tl.float32) + w_col0 = tl.zeros((BLOCK_N,), dtype=tl.float32) + w_col1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + w_col2 = tl.zeros((BLOCK_N,), dtype=tl.float32) + w_col3 = tl.zeros((BLOCK_N,), dtype=tl.float32) + w_col4 = tl.zeros((BLOCK_N,), dtype=tl.float32) + w_col5 = tl.zeros((BLOCK_N,), dtype=tl.float32) if KERNEL_WIDTH >= 1: - w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, - other=0.0).to(tl.float32) + w_col0 = tl.load(w_base + 0 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 2: - w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, - other=0.0).to(tl.float32) + w_col1 = tl.load(w_base + 1 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 3: - w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, - other=0.0).to(tl.float32) + w_col2 = tl.load(w_base + 2 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 4: - w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, - other=0.0).to(tl.float32) + w_col3 = tl.load(w_base + 3 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 5: - w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, - other=0.0).to(tl.float32) + w_col4 = tl.load(w_base + 4 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) if KERNEL_WIDTH >= 6: - w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, - other=0.0).to(tl.float32) + w_col5 = tl.load(w_base + 5 * stride_w_width, mask=mask_w, other=0.0).to(tl.float32) # bias vector once per program if HAS_BIAS: - acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, - other=0.0).to(tl.float32) + acc_bias = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32) else: - acc_bias = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_bias = tl.zeros((BLOCK_N,), dtype=tl.float32) # token index vector for chunked copy tok_vec = tl.arange(0, T_CHUNK) # [T_CHUNK] @@ -241,36 +224,26 @@ def _causal_conv1d_update_kernel_npu_tiled( # APC mapping (optional) # ------------------------- if IS_APC_ENABLED: - conv_state_init = tl.load(initial_state_idx + b, - mask=lane_active, - other=0).to(tl.int32) - current_last_index = tl.load(block_idx_last_scheduled_token + b, - mask=lane_active, - other=0).to(tl.int32) + conv_state_init = tl.load(initial_state_idx + b, mask=lane_active, other=0).to(tl.int32) + current_last_index = tl.load(block_idx_last_scheduled_token + b, mask=lane_active, other=0).to(tl.int32) else: conv_state_init = tl.full((), 0, tl.int32) current_last_index = tl.full((), 0, tl.int32) # input cache line - conv_states_input_coord = tl.load(conv_state_indices_ptr + - b * stride_state_indices + - conv_state_init, - mask=lane_active, - other=0).to(tl.int64) + conv_states_input_coord = tl.load( + conv_state_indices_ptr + b * stride_state_indices + conv_state_init, mask=lane_active, other=0 + ).to(tl.int64) if USE_PAD_SLOT: - lane_active = lane_active & (conv_states_input_coord - != pad_slot_id) + lane_active = lane_active & (conv_states_input_coord != pad_slot_id) # ------------------------- # varlen (optional): revise seqlen_run and state_len_run like original kernel does # ------------------------- if IS_VARLEN: - qs = tl.load(query_start_loc_ptr + b, mask=lane_active, - other=0).to(tl.int64) - qe = tl.load(query_start_loc_ptr + (b + 1), - mask=lane_active, - other=0).to(tl.int64) + qs = tl.load(query_start_loc_ptr + b, mask=lane_active, other=0).to(tl.int64) + qe = tl.load(query_start_loc_ptr + (b + 1), mask=lane_active, other=0).to(tl.int64) seqlen_run = (qe - qs).to(tl.int32) # revise effective state_len for shorter sequences (same formula as original) state_len_run = (state_len - (seqlen - seqlen_run)).to(tl.int32) @@ -289,9 +262,7 @@ def _causal_conv1d_update_kernel_npu_tiled( # spec decoding offset (optional) # ------------------------- if IS_SPEC_DECODING: - conv_state_token_offset = ( - tl.load(num_accepted_tokens_ptr + b, mask=lane_active, - other=1).to(tl.int64) - 1) + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to(tl.int64) - 1 shift = tl.full((), 1, tl.int32) # sliding by 1 in spec mode else: conv_state_token_offset = tl.full((), 0, tl.int64) @@ -300,37 +271,37 @@ def _causal_conv1d_update_kernel_npu_tiled( # ------------------------- # STEP 1: read initial history cols BEFORE state update (out==x safe) # ------------------------- - conv_states_base = (conv_state_ptr + - conv_states_input_coord * stride_conv_state_seq + - idx_feats * stride_conv_state_dim) + conv_states_base = ( + conv_state_ptr + conv_states_input_coord * stride_conv_state_seq + idx_feats * stride_conv_state_dim + ) prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok # define history vectors as zeros then load conditionally - col0 = tl.zeros((BLOCK_N, ), dtype=tl.float16) - col1 = tl.zeros((BLOCK_N, ), dtype=tl.float16) - col2 = tl.zeros((BLOCK_N, ), dtype=tl.float16) - col3 = tl.zeros((BLOCK_N, ), dtype=tl.float16) - col4 = tl.zeros((BLOCK_N, ), dtype=tl.float16) + col0 = tl.zeros((BLOCK_N,), dtype=tl.float16) + col1 = tl.zeros((BLOCK_N,), dtype=tl.float16) + col2 = tl.zeros((BLOCK_N,), dtype=tl.float16) + col3 = tl.zeros((BLOCK_N,), dtype=tl.float16) + col4 = tl.zeros((BLOCK_N,), dtype=tl.float16) if KERNEL_WIDTH >= 2: - col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + col0 = tl.load(prior_tokens + 0 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to( + tl.float16 + ) if KERNEL_WIDTH >= 3: - col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to( + tl.float16 + ) if KERNEL_WIDTH >= 4: - col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to( + tl.float16 + ) if KERNEL_WIDTH >= 5: - col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + col3 = tl.load(prior_tokens + 3 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to( + tl.float16 + ) if KERNEL_WIDTH >= 6: - col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + col4 = tl.load(prior_tokens + 4 * stride_conv_state_tok, mask=lane_active & mask_w, other=0.0).to( + tl.float16 + ) # ------------------------- # STEP 2: chunked state update (replaces original NP2_STATELEN x BLOCK_N big block) @@ -340,29 +311,25 @@ def _causal_conv1d_update_kernel_npu_tiled( # dst[0:keep] = src[shift : shift+keep], dst[keep:keep+seqlen_run] = x[0:seqlen_run] # ------------------------- # output cache line - conv_states_offset = tl.load(conv_state_indices_ptr + - b * stride_state_indices + - current_last_index, - mask=lane_active, - other=0).to(tl.int64) + conv_states_offset = tl.load( + conv_state_indices_ptr + b * stride_state_indices + current_last_index, mask=lane_active, other=0 + ).to(tl.int64) - use_shift = (seqlen_run < state_len_run) - use_tail = (seqlen_run >= state_len_run) + use_shift = seqlen_run < state_len_run + use_tail = seqlen_run >= state_len_run zero_i32 = tl.full((), 0, tl.int32) - keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), - zero_i32).to(tl.int32) - tail_start = tl.where(use_tail, (seqlen_run - state_len_run), - zero_i32).to(tl.int32) + keep_shift = tl.where(use_shift, (state_len_run - seqlen_run), zero_i32).to(tl.int32) + tail_start = tl.where(use_tail, (seqlen_run - state_len_run), zero_i32).to(tl.int32) # base pointers - state_src_base = (conv_state_ptr + - conv_states_input_coord * stride_conv_state_seq + - conv_state_token_offset * stride_conv_state_tok + - idx_feats * stride_conv_state_dim) - state_dst_base = (conv_state_ptr + - conv_states_offset * stride_conv_state_seq + - idx_feats * stride_conv_state_dim) + state_src_base = ( + conv_state_ptr + + conv_states_input_coord * stride_conv_state_seq + + conv_state_token_offset * stride_conv_state_tok + + idx_feats * stride_conv_state_dim + ) + state_dst_base = conv_state_ptr + conv_states_offset * stride_conv_state_seq + idx_feats * stride_conv_state_dim x_base = x_ptr + x_offset + idx_feats * stride_x_dim @@ -370,16 +337,16 @@ def _causal_conv1d_update_kernel_npu_tiled( for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] src_tok = (dst_tok + shift).to(tl.int32) # [T_CHUNK] - m_tok = use_shift & (dst_tok < keep_shift) & ( - src_tok < state_len_run) & (dst_tok < state_len_run) - m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( - conv_states_input_coord - < num_cache_lines) & (conv_states_offset < num_cache_lines) + m_tok = use_shift & (dst_tok < keep_shift) & (src_tok < state_len_run) & (dst_tok < state_len_run) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_input_coord < num_cache_lines) + & (conv_states_offset < num_cache_lines) + ) - src_ptrs = state_src_base[ - None, :] + src_tok[:, None] * stride_conv_state_tok - dst_ptrs = state_dst_base[ - None, :] + dst_tok[:, None] * stride_conv_state_tok + src_ptrs = state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok + dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok vals = tl.load(src_ptrs, mask=m, other=0.0) tl.store(dst_ptrs, vals, mask=m) @@ -387,14 +354,11 @@ def _causal_conv1d_update_kernel_npu_tiled( for t0 in tl.static_range(0, seqlen, T_CHUNK): x_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] dst_tok = (keep_shift + x_tok).to(tl.int32) # [T_CHUNK] - m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok - < state_len_run) - m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( - conv_states_offset < num_cache_lines) + m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines) x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token - dst_ptrs = state_dst_base[ - None, :] + dst_tok[:, None] * stride_conv_state_tok + dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok x_vals = tl.load(x_ptrs, mask=m, other=0.0) tl.store(dst_ptrs, x_vals, mask=m) @@ -403,12 +367,10 @@ def _causal_conv1d_update_kernel_npu_tiled( dst_tok = (t0 + tok_vec).to(tl.int32) # [T_CHUNK] x_tok = (tail_start + dst_tok).to(tl.int32) # [T_CHUNK] m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) - m = (lane_active & m_tok)[:, None] & mask_w[None, :] & ( - conv_states_offset < num_cache_lines) + m = (lane_active & m_tok)[:, None] & mask_w[None, :] & (conv_states_offset < num_cache_lines) x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token - dst_ptrs = state_dst_base[ - None, :] + dst_tok[:, None] * stride_conv_state_tok + dst_ptrs = state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok x_vals = tl.load(x_ptrs, mask=m, other=0.0) tl.store(dst_ptrs, x_vals, mask=m) @@ -433,17 +395,13 @@ def _causal_conv1d_update_kernel_npu_tiled( if KERNEL_WIDTH == 1: # only x[t] * w0 x_ptrs_1d = x_base_1d + idx_token * stride_x_token - matrix_x = tl.load(x_ptrs_1d, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) matrix_w = w_col0 elif KERNEL_WIDTH == 2: if j == 1: matrix_w = w_col1 x_ptrs_1d = x_base_1d + idx_token * stride_x_token - matrix_x = tl.load(x_ptrs_1d, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 3: if j == 1: matrix_w = w_col1 @@ -451,9 +409,7 @@ def _causal_conv1d_update_kernel_npu_tiled( elif j == 2: matrix_w = w_col2 x_ptrs_1d = x_base_1d + idx_token * stride_x_token - matrix_x = tl.load(x_ptrs_1d, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 4: if j == 1: matrix_w = w_col1 @@ -464,9 +420,7 @@ def _causal_conv1d_update_kernel_npu_tiled( elif j == 3: matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token - matrix_x = tl.load(x_ptrs_1d, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 5: if j == 1: matrix_w = w_col1 @@ -480,9 +434,7 @@ def _causal_conv1d_update_kernel_npu_tiled( elif j == 4: matrix_w = w_col4 x_ptrs_1d = x_base_1d + idx_token * stride_x_token - matrix_x = tl.load(x_ptrs_1d, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) elif KERNEL_WIDTH == 6: if j == 1: matrix_w = w_col1 @@ -499,9 +451,7 @@ def _causal_conv1d_update_kernel_npu_tiled( elif j == 5: matrix_w = w_col5 x_ptrs_1d = x_base_1d + idx_token * stride_x_token - matrix_x = tl.load(x_ptrs_1d, - mask=lane_active & mask_w, - other=0.0).to(tl.float16) + matrix_x = tl.load(x_ptrs_1d, mask=lane_active & mask_w, other=0.0).to(tl.float16) acc += matrix_x.to(tl.float32) * matrix_w # [BLOCK_N] @@ -606,7 +556,7 @@ def causal_conv1d_update_npu( x = x.unsqueeze(1) if query_start_loc is None: - batch, seqlen, dim = x.shape + batch, seqlen, dim = x.shape else: assert conv_state_indices is not None batch = conv_state_indices.size(0) @@ -614,14 +564,14 @@ def causal_conv1d_update_npu( seqlen = max_query_len width, _ = weight.shape - num_cache_lines, state_len_total,_ = conv_state.size() + num_cache_lines, state_len_total, _ = conv_state.size() # overwrite-on-x strategy same as original out = x stride_w_width, stride_w_dim = weight.stride() if query_start_loc is None: - stride_x_seq, stride_x_token,stride_x_dim = x.stride() + stride_x_seq, stride_x_token, stride_x_dim = x.stride() stride_o_seq, stride_o_token, stride_o_dim = out.stride() else: stride_x_token, stride_x_dim = x.stride() @@ -629,10 +579,8 @@ def causal_conv1d_update_npu( stride_o_token, stride_o_dim = out.stride() stride_o_seq = 0 - stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride( - ) - stride_state_indices = conv_state_indices.stride( - 0) if conv_state_indices is not None else 0 + stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 # effective state_len exactly as original if num_accepted_tokens is not None: @@ -642,10 +590,10 @@ def causal_conv1d_update_npu( np2_statelen = triton.next_power_of_2(eff_state_len) # -------- tiling heuristic-------- - #keep program count around ~[80..160] + # keep program count around ~[80..160] # vector core 40 # TODO: use driver to get the vector core num - CORE_HINT = 40 + CORE_HINT = 40 # channel tile: 512 when dim large (reduce tasks), else 256 block_n = 512 if dim >= 512 else 256 g = triton.cdiv(dim, block_n) @@ -669,6 +617,7 @@ def causal_conv1d_update_npu( triton.cdiv(batch, META["B_TILE"]), triton.cdiv(dim, META["BLOCK_N"]), ) + _causal_conv1d_update_kernel_npu_tiled[grid]( x, weight, diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 14281557..65f81b74 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton( tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) for pos in tl.range(0, BLOCK_SIZE): - draft_token_id1 = tl.get_element(draft_token_id, (pos, )) - target_argmax1 = tl.get_element(target_argmax_id, (pos, )) + draft_token_id1 = tl.get_element(draft_token_id, (pos,)) + target_argmax1 = tl.get_element(target_argmax_id, (pos,)) position = block_idx * BLOCK_SIZE + pos if draft_token_id1 == target_argmax1: bonus_renew_1( @@ -79,9 +79,7 @@ def bonus_renew( num_tokens1, ): bonus_token_id = tl.load(bonus_token_ids_ptr + position) - tl.store( - output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, - bonus_token_id) + tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id) @triton.jit(do_not_specialize=["max_spec_len"]) @@ -106,17 +104,15 @@ def rejection_greedy_sample_triton( is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) is_greedy_mask = mask & (is_greedy != 0) - start_idx = tl.where( - offset == 0, 0, - tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) num_draft_tokens = end_idx - start_idx for pos in tl.range(0, BLOCK_SIZE): - num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) + num_tokens1 = tl.get_element(num_draft_tokens, (pos,)) rejected = False - start_idx1 = tl.get_element(start_idx, (pos, )) - is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, )) + start_idx1 = tl.get_element(start_idx, (pos,)) + is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,)) position = block_idx * BLOCK_SIZE + pos for i in range(num_tokens1): if not rejected: @@ -142,50 +138,44 @@ def rejection_greedy_sample_triton( @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( - output_token_ids_ptr, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - bonus_token_ids_ptr, # [batch_size] - recovered_token_ids_ptr, # [num_tokens] - uniform_probs_ptr, # [num_tokens] - is_greedy_ptr, # [batch_size] - max_spec_len, - vocab_size, - vec_len, - NO_DRAFT_PROBS: tl.constexpr, - BLOCK_SIZE: tl.constexpr): + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + vec_len, + NO_DRAFT_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): block_idx = tl.program_id(0) offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < vec_len is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) not_greedy_mask = is_greedy == 0 - start_idxs = tl.where( - offsets == 0, 0, - tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): - not_greedy = tl.get_element(not_greedy_mask, (req_i, )) + not_greedy = tl.get_element(not_greedy_mask, (req_i,)) if not_greedy: rejected = False - start_idx = tl.get_element(start_idxs, (req_i, )) + start_idx = tl.get_element(start_idxs, (req_i,)) req_idx = block_idx * BLOCK_SIZE + req_i - num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, )) + num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,)) for pos in range(num_draft_tokens): if not rejected: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + - pos) + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) if NO_DRAFT_PROBS: draft_prob = 1 else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. @@ -195,17 +185,13 @@ def rejection_random_sample_kernel( else: # Reject. Use recovered token. rejected = True - token_id = tl.load(recovered_token_ids_ptr + - start_idx + pos) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - pos, token_id) + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id, ) @@ -225,8 +211,7 @@ def expand_kernel( offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) len_mask = offset < vec_len - start_idx = tl.where(offset == 0, 0, - tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) num_tokens = end_idx - start_idx @@ -234,13 +219,11 @@ def expand_kernel( src_val = tl.where(src_val == replace_from, replace_to, src_val) for i in tl.range(0, BLOCK_SIZE): - num_tokens1 = tl.get_element(num_tokens, (i, )) - start_idx1 = tl.get_element(start_idx, (i, )) - src_val1 = tl.get_element(src_val, (i, )) + num_tokens1 = tl.get_element(num_tokens, (i,)) + start_idx1 = tl.get_element(start_idx, (i,)) + src_val1 = tl.get_element(src_val, (i,)) offset1 = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx1 + offset1, - src_val1, - mask=offset1 < num_tokens1) + tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1) @triton.jit @@ -257,8 +240,7 @@ def sample_recovered_tokens_kernel( SUB_BLOCK: tl.constexpr, ): req_idx = tl.program_id(0) - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + - req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -272,27 +254,25 @@ def sample_recovered_tokens_kernel( global_max_p = -1.0 if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - draft_token_id) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) # Temporarily zero out the probability of the draft token. # This is essentially the same as target_prob - draft_prob, except that # n-gram does not have draft_prob. We regard it as 1. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - 0) + tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) for loop_i in range(loop): vocab_start = loop_i * SUB_BLOCK vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf")) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf") + ) new_p = prob / q recovered_id = tl.argmax(new_p, axis=-1) - max_p = tl.get_element(new_p, (recovered_id, )) + max_p = tl.get_element(new_p, (recovered_id,)) if max_p > global_max_p: global_max_p = max_p global_recovered_id = vocab_start + recovered_id @@ -300,25 +280,24 @@ def sample_recovered_tokens_kernel( for loop_i in range(loop): vocab_start = loop_i * SUB_BLOCK vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0 + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf")) + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf") + ) new_p = prob / q recovered_id = tl.argmax(new_p, axis=-1) - max_p = tl.get_element(new_p, (recovered_id, )) + max_p = tl.get_element(new_p, (recovered_id,)) if max_p > global_max_p: global_max_p = max_p global_recovered_id = vocab_start + recovered_id @@ -327,21 +306,25 @@ def sample_recovered_tokens_kernel( if NO_DRAFT_PROBS: # Restore the original probability. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - orig_prob) + tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) -def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens, - cu_num_draft_tokens, draft_token_ids, - target_argmax, bonus_token_ids, - is_greedy, max_spec_len, grid, - block_size): +def rejection_greedy_sample_with_triton( + output_token_ids, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + grid, + block_size, +): vec_len = output_token_ids.shape[0] - if min(num_draft_tokens) == 1 and max( - num_draft_tokens) == 1 and is_greedy is None: - rejection_greedy_sample_spec_len_1_triton[(grid, )]( + if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and is_greedy is None: + rejection_greedy_sample_spec_len_1_triton[(grid,)]( output_token_ids, draft_token_ids, target_argmax, @@ -350,7 +333,7 @@ def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens, BLOCK_SIZE=block_size, ) else: - rejection_greedy_sample_triton[(grid, )]( + rejection_greedy_sample_triton[(grid,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -363,12 +346,11 @@ def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens, ) -def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, - replace_to, max_num_tokens): +def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens): vec_len = batch_size grid, block_size = cal_grid_and_block_size(batch_size) - expand_kernel[(grid, )]( + expand_kernel[(grid,)]( expanded_x, x, cu_num_tokens, @@ -382,56 +364,50 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_block_verify_kernel( - output_token_ids_ptr, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - bonus_token_ids_ptr, # [batch_size] - recovered_token_ids_ptr, # [num_tokens] - uniform_probs_ptr, # [num_tokens] - is_greedy_ptr, # [batch_size] - max_spec_len, - vocab_size, - vec_len, - NO_DRAFT_PROBS: tl.constexpr, - BLOCK_SIZE: tl.constexpr): + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + vec_len, + NO_DRAFT_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): block_idx = tl.program_id(0) offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < vec_len is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) not_greedy_mask = is_greedy == 0 - start_idxs = tl.where( - offsets == 0, 0, - tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): - not_greedy = tl.get_element(not_greedy_mask, (req_i, )) + not_greedy = tl.get_element(not_greedy_mask, (req_i,)) if not_greedy: - rejected = False pi = 1.0 uniform_prob = 1.0 last_accepted_token_pos = -1 - start_idx = tl.get_element(start_idxs, (req_i, )) + start_idx = tl.get_element(start_idxs, (req_i,)) req_idx = block_idx * BLOCK_SIZE + req_i - num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, )) + num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,)) for pos in range(num_draft_tokens): draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) + target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) uniform_prob = uniform_prob * tmp_uniform_prob if NO_DRAFT_PROBS: draft_prob = 1 else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) pi = min(pi * target_prob / draft_prob, 1.0) if draft_prob > 0 and pi >= uniform_prob: @@ -443,19 +419,14 @@ def rejection_random_sample_block_verify_kernel( if last_accepted_token_pos > -1: for pos in range(last_accepted_token_pos + 1): token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - pos, token_id) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if rejected: - recovered_token_id = tl.load(recovered_token_ids_ptr + - start_idx + - last_accepted_token_pos + 1) + recovered_token_id = tl.load(recovered_token_ids_ptr + start_idx + last_accepted_token_pos + 1) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - last_accepted_token_pos + 1, recovered_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + last_accepted_token_pos + 1, + recovered_token_id, + ) else: bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id) diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index 8eecac8f..9e61e051 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -14,9 +14,10 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from vllm.triton_utils import tl, triton + import torch -from typing import Tuple +from vllm.triton_utils import tl, triton + from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @@ -48,16 +49,16 @@ def _triton_rope( This triton kernel applies rotary embedding on q and k. It supports rope_dim != head_dim scenario. It supports both neox style and non-neox style rope computation. - + Input tensor layout assumptions: - + q size: (num_tokens, num_q_heads, head_dim) q stride: (num_q_heads * head_dim, head_dim, 1) k size: (num_tokens, num_kv_heads, head_dim) k stride: (num_kv_heads * head_dim, head_dim, 1) cos/sin size: (num_tokens, rope_dim/2) cos/sin stride: (rope_dim/2, 1) - + Different compute pattern of IS_NEOX_STYLE: if IS_NEOX_STYLE: @@ -88,10 +89,8 @@ def _triton_rope( cos_offsets = tl.arange(0, pad_rope_dim // 2) cos_mask = cos_offsets < (rope_dim // 2) - cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, - other=0).to(tl.float32) - sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, - other=0).to(tl.float32) + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) # #################################################################### # Load the left and right half of q and k for the current @@ -99,28 +98,20 @@ def _triton_rope( # #################################################################### # left half of the head if IS_NEOX_STYLE: - first_half_q_offsets = tl.arange( - 0, pad_n_qh)[:, None] * hd + tl.arange( - 0, pad_rope_dim // 2)[None, :] - first_half_k_offsets = tl.arange( - 0, pad_n_kh)[:, None] * hd + tl.arange( - 0, pad_rope_dim // 2)[None, :] + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :] else: - first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + ( - 2 * tl.arange(0, pad_rope_dim // 2)[None, :]) - first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + ( - 2 * tl.arange(0, pad_rope_dim // 2)[None, :]) + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :]) + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :]) - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( - 0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( - 0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)) - q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets, - mask=first_q_mask, - other=0).to(sin_row.dtype) - k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets, - mask=first_k_mask, - other=0).to(sin_row.dtype) + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2) + ) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2) + ) + q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) # right half of the head if IS_NEOX_STYLE: @@ -131,41 +122,29 @@ def _triton_rope( second_half_k_offsets = first_half_k_offsets + 1 second_q_mask = first_q_mask second_k_mask = first_k_mask - q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets, - mask=second_q_mask, - other=0).to(sin_row.dtype) - k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets, - mask=second_k_mask, - other=0).to(sin_row.dtype) + q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row - tl.store(q_start_ptr + first_half_q_offsets, - new_q_tile_1, - mask=first_q_mask) + tl.store(q_start_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row - tl.store(q_start_ptr + second_half_q_offsets, - new_q_tile_2, - mask=second_q_mask) + tl.store(q_start_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row - tl.store(k_start_ptr + first_half_k_offsets, - new_k_tile_1, - mask=first_k_mask) + tl.store(k_start_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row - tl.store(k_start_ptr + second_half_k_offsets, - new_k_tile_2, - mask=second_k_mask) + tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) def rope_forward_triton( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - rope_dim: int = -1, - is_neox_style: bool = True - ) -> Tuple[torch.Tensor, torch.Tensor]: + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + rope_dim: int = -1, + is_neox_style: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: if not q.is_contiguous(): q = q.contiguous() if not k.is_contiguous(): @@ -187,7 +166,7 @@ def rope_forward_triton( num_vectorcore = get_vectorcore_num() n_row = min(num_tokens, num_vectorcore) - _triton_rope[(n_row, )]( + _triton_rope[(n_row,)]( q, q.stride(0), k, diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index c0588502..a66566da 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -49,20 +49,15 @@ def prepare_inputs_padded_kernel( other=0, ) - num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, - cu_draft_curr) + num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, cu_draft_curr) - valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, - mask=mask) + valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, mask=mask) num_rejected = num_draft_tokens + 1 - valid_count num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0) # query_start_loc[req_idx + 1] is the start position of the next request, # which is one past the last token of this request. - q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, - mask=mask) - 1 + q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, mask=mask) - 1 index_to_sample = q_last_tok_idx - num_rejected - tl.store(token_indices_to_sample_ptr + offsets, - index_to_sample, - mask=mask) + tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask) diff --git a/vllm_ascend/ops/triton/triton_utils.py b/vllm_ascend/ops/triton/triton_utils.py index 6b0ac964..72bdde4a 100644 --- a/vllm_ascend/ops/triton/triton_utils.py +++ b/vllm_ascend/ops/triton/triton_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import torch from vllm.triton_utils import HAS_TRITON, triton @@ -10,9 +10,9 @@ _NUM_VECTORCORE = -1 def init_device_properties_triton(): global _NUM_AICORE, _NUM_VECTORCORE if _NUM_AICORE == -1 and HAS_TRITON: - device_properties: Dict[str, Any] = ( - triton.runtime.driver.active.utils.get_device_properties( - torch.npu.current_device())) + device_properties: dict[str, Any] = triton.runtime.driver.active.utils.get_device_properties( + torch.npu.current_device() + ) _NUM_AICORE = device_properties.get("num_aicore", -1) _NUM_VECTORCORE = device_properties.get("num_vectorcore", -1) assert _NUM_AICORE > 0 and _NUM_VECTORCORE > 0, "Failed to detect device properties."