[Feat]enable sfa cp for dsv3.2 (#4702)
### What this PR does / why we need it?
RFC: https://github.com/vllm-project/vllm/issues/30055
### How was this patch tested?
1. enable flashcommon1
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
2. enable sfa-cp
--additional-config '{ "enable_sfa_cp": true }' \
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: AlvisGong <gwly0401@163.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: hwhaokun <haokun0405@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -5,9 +5,9 @@ import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -17,10 +17,15 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
_round_up, dispose_layer, enable_sp,
|
||||
is_enable_nz, replace_layer)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -49,6 +54,20 @@ class AscendSFABackend(AttentionBackend):
|
||||
return AscendSFAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class SfaCpContext:
|
||||
num_tokens: int
|
||||
num_tokens_pad: int
|
||||
local_start: int
|
||||
local_end: int
|
||||
local_end_with_pad: int
|
||||
pad_size: int
|
||||
local_pad_size: int
|
||||
slot_mapping_cp: torch.Tensor
|
||||
actual_seq_lengths_query: torch.Tensor
|
||||
actual_seq_lengths_key: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
@@ -79,6 +98,7 @@ class AscendSFAMetadata:
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
sfa_cp_context: Optional[SfaCpContext] = None
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
@@ -122,6 +142,9 @@ class AscendSFAMetadataBuilder:
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
self.enable_sfa_cp = enable_sp() and \
|
||||
hasattr(self.model_config.hf_config, "index_topk")
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# No need to reorder for Ascend SFA
|
||||
@@ -171,6 +194,64 @@ class AscendSFAMetadataBuilder:
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
sfa_cp_context = None
|
||||
if self.enable_sfa_cp:
|
||||
global_tp_size = get_tp_group().world_size
|
||||
num_tokens = num_actual_tokens
|
||||
num_tokens_pad = _round_up(num_actual_tokens, global_tp_size)
|
||||
num_tokens_per_device = num_tokens_pad // global_tp_size
|
||||
pad_size = num_tokens_pad - num_tokens
|
||||
local_start = get_tp_group().rank_in_group * num_tokens_per_device
|
||||
local_end_with_pad = local_start + num_tokens_per_device
|
||||
local_end = min(local_end_with_pad, num_actual_tokens)
|
||||
local_pad_size = local_end_with_pad - local_end
|
||||
|
||||
if pad_size > 0:
|
||||
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
||||
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size),
|
||||
value=-1)
|
||||
cos = cos[local_start:local_end_with_pad]
|
||||
sin = sin[local_start:local_end_with_pad]
|
||||
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
||||
|
||||
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
|
||||
actual_seq_lengths_key = torch.empty_like(seq_lens)
|
||||
num_segs = cum_query_lens.shape[0]
|
||||
last_token = 0
|
||||
cum = 0
|
||||
for i in range(0, num_segs):
|
||||
global_start = last_token
|
||||
global_end = cum_query_lens[i].item()
|
||||
last_token = global_end
|
||||
|
||||
local_start = max(global_start, local_start)
|
||||
local_end = min(global_end, local_end_with_pad)
|
||||
num_local_tokens = local_end - local_start
|
||||
|
||||
if num_local_tokens > 0:
|
||||
cum += num_local_tokens
|
||||
actual_seq_lengths_query[i] = cum
|
||||
|
||||
offset = global_end - local_end
|
||||
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
|
||||
else:
|
||||
actual_seq_lengths_query[i] = cum
|
||||
actual_seq_lengths_key[i] = 0
|
||||
|
||||
sfa_cp_context = SfaCpContext(
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_pad=num_tokens_pad,
|
||||
local_start=local_start,
|
||||
local_end=local_end,
|
||||
local_end_with_pad=local_end_with_pad,
|
||||
pad_size=pad_size,
|
||||
local_pad_size=local_pad_size,
|
||||
slot_mapping_cp=slot_mapping_cp,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
has_prefill=has_prefill,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
@@ -183,7 +264,8 @@ class AscendSFAMetadataBuilder:
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
block_tables=block_table,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
cos=cos,
|
||||
sfa_cp_context=sfa_cp_context)
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
@@ -251,6 +333,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.num_heads_per_rank = self.num_heads // self.tp_size
|
||||
self.q_b_proj = kwargs['q_b_proj']
|
||||
|
||||
@@ -258,8 +341,32 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
assert self.indexer is not None, "Indexer is required for DSA."
|
||||
|
||||
self.enable_sfa_cp = enable_sp()
|
||||
self.local_num_heads = self.num_heads
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
|
||||
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
||||
self._replace_linear_class_for_sfa_cp()
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="q_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.q_proj,
|
||||
prefetch_step=1)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
|
||||
# indexer param
|
||||
self.n_head: int = self.indexer.n_head # 64
|
||||
self.head_dim: int = self.indexer.head_dim # 128
|
||||
@@ -306,16 +413,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
self.kv_lora_rank, self.local_num_heads *
|
||||
(self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.local_num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.local_num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
@@ -336,29 +443,42 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
||||
dispose_layer(self.kv_b_proj)
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.o_proj)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||
self.W_UV,
|
||||
perm_x1=[1, 0, 2],
|
||||
perm_x2=[0, 1, 2],
|
||||
perm_y=[1, 0, 2])
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
x = x.reshape(-1, self.local_num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
x = x.view(-1, self.local_num_heads,
|
||||
self.kv_lora_rank).transpose(0, 1)
|
||||
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# # Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
x = x.transpose(0,
|
||||
1).reshape(-1,
|
||||
self.local_num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.view(-1, self.local_num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
@@ -375,6 +495,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
slots_cp: Optional[torch.Tensor],
|
||||
):
|
||||
B = kv_no_split.shape[0]
|
||||
N = self.num_kv_heads
|
||||
@@ -383,18 +504,44 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
return k_pe, k_nope
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
assert slots_cp is not None
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots_cp.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
is_output_kv=True,
|
||||
)
|
||||
#TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
||||
k_pe = get_tp_group().all_gather(k_pe, 0)
|
||||
k_nope = get_tp_group().all_gather(k_nope, 0)
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1),
|
||||
k_nope.view(-1, k_nope.shape[-1]))
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1),
|
||||
k_pe.view(-1, k_pe.shape[-1]))
|
||||
else:
|
||||
torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
|
||||
def rope_single(
|
||||
self,
|
||||
@@ -420,10 +567,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.enable_sfa_cp:
|
||||
from vllm.forward_context import get_forward_context
|
||||
if not get_forward_context().in_profile_run:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
return output.fill_(0)
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
hidden_states = hidden_states[:num_actual_tokens]
|
||||
if self.enable_sfa_cp:
|
||||
need_gather_q_kv = False
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens]
|
||||
@@ -439,38 +596,61 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
|
||||
# Process for Flash Comm V1
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c.contiguous(), need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
if need_gather_q_kv:
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c.contiguous(), need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
ql_nope, q_pe = \
|
||||
self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, attn_metadata.cos, attn_metadata.sin)
|
||||
k_pe, k_nope = self.exec_kv(kv_no_split, attn_metadata.cos,
|
||||
attn_metadata.sin, kv_cache, slot_mapping)
|
||||
slot_mapping_cp = None
|
||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||
actual_seq_lengths_key = attn_metadata.seq_lens
|
||||
if self.enable_sfa_cp:
|
||||
assert attn_metadata.sfa_cp_context is not None
|
||||
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
|
||||
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
||||
|
||||
topk_indices = self.indexer_select(x=hidden_states,
|
||||
qr=q_c,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
|
||||
slot_mapping_cp)
|
||||
|
||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
|
||||
topk_indices = self.indexer_select(
|
||||
x=hidden_states,
|
||||
qr=q_c,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
key=kv_cache[0],
|
||||
value=kv_cache[0],
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths_query=attn_metadata.cum_query_lens,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_key,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
key_rope=kv_cache[1],
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
@@ -489,11 +669,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
qr: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
actual_seq_lengths_query: torch.Tensor,
|
||||
actual_seq_lengths_key: torch.Tensor,
|
||||
need_gather_q_kv: bool = False,
|
||||
):
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
|
||||
# q process in new stream
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
@@ -539,6 +720,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
k = get_tp_group().all_gather(k, 0)
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
@@ -551,18 +735,55 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
weights, need_gather_q_kv)
|
||||
|
||||
block_table = attn_metadata.block_tables
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
cum_query_lens = attn_metadata.cum_query_lens
|
||||
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=cum_query_lens,
|
||||
actual_seq_lengths_key=seq_lens,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3)
|
||||
return topk_indices
|
||||
|
||||
def _replace_linear_class_for_sfa_cp(self):
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# Dispose tensor from the original q_proj
|
||||
dispose_layer(self.q_proj)
|
||||
# Construct the new q_proj using ReplicatedLinear
|
||||
new_q_proj = ReplicatedLinear(self.q_lora_rank,
|
||||
self.local_num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=self.q_proj.prefix)
|
||||
# Replace the q_proj with the new one
|
||||
replace_layer(self.q_proj, new_q_proj)
|
||||
|
||||
# Dispose tensor from the original kv_b_proj
|
||||
dispose_layer(self.kv_b_proj)
|
||||
# Construct the new kv_b_proj using ReplicatedLinear
|
||||
new_kv_b_proj = ReplicatedLinear(
|
||||
self.kv_lora_rank,
|
||||
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=self.kv_b_proj.prefix)
|
||||
# Replace the kv_b_proj with the new one
|
||||
replace_layer(self.kv_b_proj, new_kv_b_proj)
|
||||
|
||||
# Dispose tensor from the original o_proj
|
||||
dispose_layer(self.o_proj)
|
||||
# Construct the new o_proj using ReplicatedLinear
|
||||
config = vllm_config.model_config.hf_config
|
||||
new_o_proj = ReplicatedLinear(config.num_attention_heads *
|
||||
config.v_head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=self.o_proj.prefix)
|
||||
# Replace the o_proj with the new one
|
||||
replace_layer(self.o_proj, new_o_proj)
|
||||
|
||||
@@ -9,7 +9,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import flashcomm2_enable
|
||||
from vllm_ascend.utils import enable_sp, flashcomm2_enable
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
@@ -19,6 +19,7 @@ _LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
@@ -48,6 +49,13 @@ def get_flashcomm2_odp_group() -> GroupCoordinator:
|
||||
return _FLASHCOMM2_ODP
|
||||
|
||||
|
||||
def get_shared_weight_group() -> GroupCoordinator:
|
||||
assert _SHARED_WEIGHT is not None, (
|
||||
"output shared weight parallel group for flashcomm2 is not initialized"
|
||||
)
|
||||
return _SHARED_WEIGHT
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
@@ -226,6 +234,18 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="flashcomm2_odp")
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if enable_sp() and is_ds_v32:
|
||||
global _SHARED_WEIGHT
|
||||
group_ranks = [list(range(torch.distributed.get_world_size()))]
|
||||
_SHARED_WEIGHT = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="CP_shared_weight")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
@@ -274,3 +294,8 @@ def destroy_ascend_model_parallel():
|
||||
).flashcomm2_oproj_tensor_parallel_size != 1:
|
||||
_FLASHCOMM2_ODP.destroy()
|
||||
_FLASHCOMM2_ODP = None
|
||||
|
||||
global _SHARED_WEIGHT
|
||||
if _SHARED_WEIGHT:
|
||||
_SHARED_WEIGHT.destroy()
|
||||
_SHARED_WEIGHT = None
|
||||
|
||||
252
vllm_ascend/ops/shared_weight_layer.py
Normal file
252
vllm_ascend/ops/shared_weight_layer.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerMetadata:
|
||||
"""Metadata for a layer.
|
||||
"""
|
||||
layer_idx: int # The index of the layer.
|
||||
layer: LinearBase # The layer object.
|
||||
post_method: Callable[[
|
||||
torch.nn.Module
|
||||
], None] # The `process_weights_after_loading` method from the quant method.
|
||||
weight: torch.Tensor # The weight tensor.
|
||||
window_idx: int # The index of the window.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
num_layers: int
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
return layer_idx % self.group.world_size == self.group.rank_in_group
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
return
|
||||
|
||||
self.layers.sort(key=lambda x: x.layer_idx)
|
||||
self.num_layers = len(self.layers)
|
||||
assert self.num_layers > 0, "No layers in the series"
|
||||
assert self.prefetch_step >= 0 and self.prefetch_step <= max(
|
||||
0, self.num_layers -
|
||||
2), "prefetch_step must be in [0, num_layers - 2]"
|
||||
self.start_layer = self.layers[0].layer_idx
|
||||
self.end_layer = self.layers[-1].layer_idx + 1
|
||||
|
||||
for layer_idx in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[layer_idx - self.start_layer]
|
||||
assert layer.layer_idx == layer_idx, "layer_idx must be consecutive"
|
||||
is_source = self.is_source(layer_idx)
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
|
||||
if not is_source:
|
||||
layer.weight.set_(torch.empty_like(self.dummy_weight))
|
||||
# Broadcast to get the true weight.
|
||||
dist.broadcast(layer.weight,
|
||||
src=self.group.ranks[layer_idx %
|
||||
self.group.world_size],
|
||||
group=self.group.device_group)
|
||||
# Call `process_weights_after_loading` from the quant method.
|
||||
layer.post_method(layer.layer)
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
))
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
))
|
||||
# When the layer not intended to be stored in this device, dispose the tensor.
|
||||
if not is_source:
|
||||
dispose_tensor(layer.weight)
|
||||
# Dispose the dummy tensor since it's no longer needed.
|
||||
dispose_tensor(self.dummy_weight)
|
||||
|
||||
def reach_layer(self, layer_idx: int):
|
||||
# The index of the layer to be prefetched.
|
||||
next_layer_idx = (layer_idx + self.prefetch_step
|
||||
) % self.num_layers + self.start_layer
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
# Update `window_offset` by rolling one step.
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
|
||||
1)
|
||||
assert window.data_layer_idx != next_layer_idx
|
||||
window.data_layer_idx = next_layer_idx
|
||||
# Start asynchronous broadcast work.
|
||||
window.work = dist.broadcast(
|
||||
next_layer.weight,
|
||||
src=self.group.ranks[next_layer_idx % self.group.world_size],
|
||||
group=self.group.device_group,
|
||||
async_op=True)
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
window.work = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerExternalMetadata:
|
||||
"""External metadata for a layer.
|
||||
"""
|
||||
series: SeriesMetadata
|
||||
layer_idx: int
|
||||
|
||||
|
||||
_series_dict: dict[str, SeriesMetadata] = {}
|
||||
|
||||
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
|
||||
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
layer_idx: int) -> Callable:
|
||||
|
||||
def wrapped_forward(*args, **kwargs):
|
||||
# Wait for the weight.
|
||||
series.wait_weight(layer_idx)
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
return wrapped_forward
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- start_layer is the index of the first layer in the series (inclusive).
|
||||
- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
|
||||
layer: The linear layer object to register.
|
||||
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
layer: LinearBase,
|
||||
prefetch_step: int = 1,
|
||||
):
|
||||
global _series_dict
|
||||
if series_name not in _series_dict:
|
||||
_series_dict[series_name] = SeriesMetadata(
|
||||
group=group,
|
||||
start_layer=0,
|
||||
end_layer=0,
|
||||
num_layers=0,
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[],
|
||||
shared_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
assert layer.quant_method is not None
|
||||
layer_idx = extract_layer_index(layer.prefix)
|
||||
series.layers.append(
|
||||
LayerMetadata(
|
||||
layer_idx=layer_idx,
|
||||
layer=layer,
|
||||
post_method=layer.quant_method.process_weights_after_loading,
|
||||
weight=layer.weight,
|
||||
window_idx=-1,
|
||||
))
|
||||
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
|
||||
layer.quant_method.process_weights_after_loading = lambda layer: None
|
||||
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
|
||||
if not series.is_source(layer_idx):
|
||||
dispose_tensor(layer.weight)
|
||||
layer.weight.weight_loader = lambda *args, **kwargs: None
|
||||
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
|
||||
global _layer_external_dict
|
||||
_layer_external_dict[id(layer)] = LayerExternalMetadata(
|
||||
series=series,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
|
||||
|
||||
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
|
||||
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
layer_idx = extract_layer_index(layer.prefix)
|
||||
return layer_idx < num_hidden_layers
|
||||
@@ -1067,3 +1067,15 @@ def refresh_block_size(vllm_config):
|
||||
"Block size is set to 128 if prefix cache or chunked prefill is enabled."
|
||||
)
|
||||
cache_config.block_size = 128
|
||||
|
||||
|
||||
def dispose_layer(layer: Any):
|
||||
for attr_name in dir(layer):
|
||||
attr_value = getattr(layer, attr_name)
|
||||
if isinstance(attr_value, torch.Tensor):
|
||||
dispose_tensor(attr_value)
|
||||
|
||||
|
||||
def replace_layer(original_layer: Any, new_layer: Any):
|
||||
original_layer.__class__ = new_layer.__class__
|
||||
original_layer.__dict__ = new_layer.__dict__
|
||||
|
||||
Reference in New Issue
Block a user