[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (#6874)
### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.
### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.
- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.
- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.
### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**
fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**
replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -5,10 +5,12 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
||||
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton_siso
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
@@ -299,42 +301,33 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
q: torch.Tensor | None,
|
||||
k: torch.Tensor,
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
actual_seq_lengths_query: torch.Tensor,
|
||||
actual_seq_lengths_key: torch.Tensor,
|
||||
need_gather_q_kv: bool = False,
|
||||
):
|
||||
if q is None:
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
cos_q, sin_q = cos, sin
|
||||
weights, _ = self.weights_proj(x)
|
||||
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
|
||||
q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
if HAS_TRITON:
|
||||
q_li = rope_forward_triton_siso(
|
||||
q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
|
||||
)
|
||||
else:
|
||||
q_li_pe, q_li_nope = torch.split(
|
||||
q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
|
||||
) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
q_li_pe = q_li_pe.unsqueeze(2)
|
||||
q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin)
|
||||
q_li_pe = q_li_pe.squeeze(2)
|
||||
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
if kv_cache is not None:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
|
||||
) # b, s, n, d
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
|
||||
q = q_li
|
||||
|
||||
key = kv_cache[2]
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -146,6 +146,79 @@ def _triton_rope(
|
||||
tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _triton_rope_siso(
|
||||
qk_ptr,
|
||||
qk_row_stride,
|
||||
cos_ptr,
|
||||
cos_row_stride,
|
||||
sin_ptr,
|
||||
sin_row_stride,
|
||||
cos_sin_ptr,
|
||||
cos_sin_row_stride,
|
||||
pos_ptr,
|
||||
num_tokens,
|
||||
n_h: tl.constexpr,
|
||||
hd: tl.constexpr,
|
||||
rope_dim: tl.constexpr,
|
||||
pad_n_h: tl.constexpr,
|
||||
pad_rope_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_NEOX_STYLE: tl.constexpr,
|
||||
USE_COS_SIN: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0).to(tl.int64)
|
||||
row_block_size = tl.num_programs(0)
|
||||
|
||||
for row_idx in tl.range(pid, num_tokens, row_block_size):
|
||||
qk_start_ptr = qk_ptr + row_idx * qk_row_stride
|
||||
|
||||
# ####################################################################
|
||||
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
||||
# m of this program instance
|
||||
# ####################################################################
|
||||
cos_offsets = tl.arange(0, pad_rope_dim // 2)
|
||||
sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim)
|
||||
cos_mask = cos_offsets < (rope_dim // 2)
|
||||
if USE_COS_SIN:
|
||||
pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64)
|
||||
cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
else:
|
||||
cos_start_ptr = cos_ptr + row_idx * cos_row_stride
|
||||
sin_start_ptr = sin_ptr + row_idx * sin_row_stride
|
||||
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
|
||||
|
||||
# ####################################################################
|
||||
# Load the left and right half of q and k for the current
|
||||
# program instance (i.e. for the current token) separately
|
||||
# ####################################################################
|
||||
# left half of the head
|
||||
if IS_NEOX_STYLE:
|
||||
first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :]
|
||||
else:
|
||||
first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :])
|
||||
|
||||
first_mask = (tl.arange(0, pad_n_h)[:, None] < n_h) & (
|
||||
tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)
|
||||
)
|
||||
qk_tile_1 = tl.load(qk_start_ptr + first_half_offsets, mask=first_mask, other=0).to(sin_row.dtype)
|
||||
|
||||
# right half of the head
|
||||
if IS_NEOX_STYLE:
|
||||
second_half_offsets = first_half_offsets + (rope_dim // 2)
|
||||
else:
|
||||
second_half_offsets = first_half_offsets + 1
|
||||
second_mask = first_mask
|
||||
qk_tile_2 = tl.load(qk_start_ptr + second_half_offsets, mask=second_mask, other=0).to(sin_row.dtype)
|
||||
|
||||
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
||||
new_qk_tile_1 = qk_tile_1 * cos_row - qk_tile_2 * sin_row
|
||||
tl.store(qk_start_ptr + first_half_offsets, new_qk_tile_1, mask=first_mask)
|
||||
|
||||
|
||||
def rope_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -237,3 +310,83 @@ def rope_forward_triton(
|
||||
"Please check whether you call rope_forward_triton correctly."
|
||||
)
|
||||
return q, k
|
||||
|
||||
|
||||
def rope_forward_triton_siso(
|
||||
qk: torch.Tensor,
|
||||
cos: torch.Tensor = None,
|
||||
sin: torch.Tensor = None,
|
||||
cos_sin_cache: torch.Tensor = None,
|
||||
positions: torch.Tensor = None,
|
||||
rope_dim: int = -1,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if not qk.is_contiguous():
|
||||
qk = qk.contiguous()
|
||||
|
||||
num_tokens, n_head, head_dim = qk.shape
|
||||
assert rope_dim <= head_dim
|
||||
pad_rope_dim = triton.next_power_of_2(rope_dim)
|
||||
pad_n_head = triton.next_power_of_2(n_head)
|
||||
BLOCK_SIZE = pad_n_head
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
n_row = min(num_tokens, num_vectorcore)
|
||||
|
||||
if cos_sin_cache is not None and positions is not None:
|
||||
assert positions.shape[0] == num_tokens
|
||||
_triton_rope_siso[(n_row,)](
|
||||
qk,
|
||||
qk.stride(0),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
positions,
|
||||
num_tokens,
|
||||
n_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
USE_COS_SIN=True,
|
||||
)
|
||||
elif cos is not None and sin is not None:
|
||||
assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens
|
||||
cos = cos.view(num_tokens, -1)
|
||||
sin = sin.view(num_tokens, -1)
|
||||
if rope_dim == -1:
|
||||
# If rope_dim is not specified, we assume that input cos/sin is not
|
||||
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
|
||||
rope_dim = cos.shape[-1] * 2
|
||||
_triton_rope_siso[(n_row,)](
|
||||
qk,
|
||||
qk.stride(0),
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
num_tokens,
|
||||
n_head,
|
||||
head_dim,
|
||||
rope_dim,
|
||||
pad_n_head,
|
||||
pad_rope_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
IS_NEOX_STYLE=is_neox_style,
|
||||
USE_COS_SIN=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, rope_forward_triton supports passing:\n"
|
||||
"1. positions and original cos_sin_cache.\n"
|
||||
"2. cos and sin which are already selected by positions\n"
|
||||
"Please check whether you call rope_forward_triton correctly."
|
||||
)
|
||||
return qk
|
||||
|
||||
@@ -1138,10 +1138,24 @@ def enable_dsa_cp_with_layer_shard() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be
|
||||
# effectively hidden, it is enabled only during the prefill stage.
|
||||
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
|
||||
return is_prefill_instance
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp_with_o_proj_tp() -> bool:
|
||||
if not enable_dsa_cp():
|
||||
return False
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# if is PD mix stage, using original TP o_proj weight, and also need to
|
||||
# full gather for o_proj weight for prefill stage.
|
||||
return vllm_config.kv_transfer_config is None
|
||||
|
||||
|
||||
def check_gdn_layer(vllm_config) -> bool:
|
||||
"""
|
||||
gdn layer is marked with `linear_attention`.
|
||||
|
||||
Reference in New Issue
Block a user