### What this PR does / why we need it? This reverts commit7ed9e9de69, which introduces an issue that the patch doesn't work with recompute scheduler enabled. - vLLM version: v0.17.0 - vLLM main:4034c3d32e--------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import scipy # type: ignore
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
@@ -356,9 +355,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
||||
o_proj_full_pool: torch.Tensor | None = None
|
||||
|
||||
# qk_hadamard tensor shared when dsa c8 enabled
|
||||
qk_hadamard: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -429,12 +425,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.is_rope_neox_style = False
|
||||
self.use_torch_npu_lightning_indexer = True
|
||||
|
||||
# dsa c8
|
||||
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
|
||||
if self.use_sparse_c8_indexer:
|
||||
self.c8_k_cache_dtype = torch.int8
|
||||
self.c8_k_scale_cache_dtype = torch.float16
|
||||
|
||||
# Effective in SFA when FlashComm is enabled.
|
||||
self.enable_dsa_cp = enable_dsa_cp()
|
||||
|
||||
@@ -525,11 +515,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
|
||||
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
||||
128**0.5
|
||||
)
|
||||
|
||||
# Processing the input parameters for MLAPO by reordering and transposing
|
||||
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
|
||||
# and handling quantization parameters.
|
||||
@@ -889,15 +874,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if self.use_sparse_c8_indexer:
|
||||
k_li = k_li @ AscendSFAImpl.qk_hadamard
|
||||
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
|
||||
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
|
||||
else:
|
||||
k_li_scale = None
|
||||
|
||||
return k_li, k_li_scale
|
||||
return k_li
|
||||
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
@@ -928,35 +905,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
q_li_pe = q_li_pe.squeeze(2)
|
||||
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
if self.use_sparse_c8_indexer:
|
||||
q_li_shape_ori = q_li.shape
|
||||
q_li = q_li @ AscendSFAImpl.qk_hadamard
|
||||
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
|
||||
|
||||
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
|
||||
# So two branches are maintained temporarily.
|
||||
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
|
||||
if self.use_sparse_c8_indexer:
|
||||
assert len(kv_cache) == 4
|
||||
weights = weights.to(torch.float16)
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
|
||||
query=q_li.view(q_li_shape_ori),
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
|
||||
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=attn_metadata.block_table,
|
||||
query_quant_mode=0,
|
||||
key_quant_mode=0,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3,
|
||||
)
|
||||
elif self.use_torch_npu_lightning_indexer:
|
||||
if self.use_torch_npu_lightning_indexer:
|
||||
topk_indices, _ = torch_npu.npu_lightning_indexer(
|
||||
query=q_li,
|
||||
key=kv_cache[2],
|
||||
@@ -1079,7 +1031,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
|
||||
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
||||
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
@@ -1092,46 +1044,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_dsa_cp:
|
||||
assert k_pe is not None
|
||||
assert k_nope is not None
|
||||
assert k_li is not None
|
||||
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
|
||||
# support all_gather kv async for communication calculation overlap
|
||||
if not self.use_sparse_c8_indexer:
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k_li.view(-1, k_li.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
else:
|
||||
# due to different dtypes, we have to split commu pass
|
||||
assert k_li_scale is not None
|
||||
fused_kv_no_split, _ = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
k_li, _ = all_gather_async(
|
||||
k_li,
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
k_li_scale, kv_ag_handle = all_gather_async(
|
||||
k_li_scale,
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||
torch.cat(
|
||||
[
|
||||
k_pe.view(-1, k_pe.shape[-1]),
|
||||
k_nope.view(-1, k_nope.shape[-1]),
|
||||
k_li.view(-1, k_li.shape[-1]),
|
||||
],
|
||||
dim=1,
|
||||
),
|
||||
get_tp_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
@@ -1151,12 +1077,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
if kv_cache is not None:
|
||||
assert fused_kv_no_split is not None
|
||||
if not self.use_sparse_c8_indexer:
|
||||
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
else:
|
||||
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
|
||||
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
||||
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
||||
)
|
||||
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
|
||||
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
|
||||
DeviceOperator.reshape_and_cache(
|
||||
@@ -1175,13 +1098,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
|
||||
) # b, s, n, d
|
||||
if self.use_sparse_c8_indexer:
|
||||
assert len(kv_cache) == 4
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[3].view(-1, k_li_scale.shape[-1]),
|
||||
slot_mapping.view(-1, 1),
|
||||
k_li_scale.view(-1, k_li_scale.shape[-1]),
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user