[feat]dcp pcp support aclgraph (#3731)
### What this PR does / why we need it?
dcp pcp support full aclgraph, including mla attention_v1
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -176,16 +176,30 @@ class TestAscendMLAMetadata(TestBase):
|
|||||||
|
|
||||||
class TestAscendMLAMetadataBuilder(TestBase):
|
class TestAscendMLAMetadataBuilder(TestBase):
|
||||||
|
|
||||||
def test_ascend_mla_metadata_builder_default(self):
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
|
return_value=1)
|
||||||
|
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
|
||||||
|
mock_dcp, mock_get_dcp_group):
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_dcp.world_size = 1
|
||||||
|
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||||
|
dcp_group.rank_in_group = 0
|
||||||
|
dcp_group.world_size = 1
|
||||||
|
dcp_group.device_group = MagicMock()
|
||||||
|
mock_get_dcp_group.return_value = dcp_group
|
||||||
|
|
||||||
mock_vllm_config.speculative_config = None
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
ascend_config = MagicMock()
|
ascend_config = MagicMock()
|
||||||
@@ -200,16 +214,31 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
builder.chunked_prefill_enabled,
|
builder.chunked_prefill_enabled,
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||||
|
|
||||||
def test_ascend_mla_metadata_builder_spec_decode(self):
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
|
return_value=1)
|
||||||
|
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
|
||||||
|
mock_dcp,
|
||||||
|
mock_get_dcp_group):
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_dcp.world_size = 1
|
||||||
|
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||||
|
dcp_group.rank_in_group = 0
|
||||||
|
dcp_group.world_size = 1
|
||||||
|
dcp_group.device_group = MagicMock()
|
||||||
|
mock_get_dcp_group.return_value = dcp_group
|
||||||
|
|
||||||
mock_spec_config = MagicMock()
|
mock_spec_config = MagicMock()
|
||||||
mock_spec_config.num_speculative_tokens = 3
|
mock_spec_config.num_speculative_tokens = 3
|
||||||
mock_vllm_config.speculative_config = mock_spec_config
|
mock_vllm_config.speculative_config = mock_spec_config
|
||||||
@@ -226,16 +255,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
builder.chunked_prefill_enabled,
|
builder.chunked_prefill_enabled,
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||||
|
|
||||||
def test_reorder_batch(self):
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
|
return_value=1)
|
||||||
|
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
|
||||||
|
mock_get_dcp_group):
|
||||||
ascend_config = MagicMock()
|
ascend_config = MagicMock()
|
||||||
|
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_dcp.world_size = 1
|
||||||
|
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||||
|
dcp_group.rank_in_group = 0
|
||||||
|
dcp_group.world_size = 1
|
||||||
|
dcp_group.device_group = MagicMock()
|
||||||
|
mock_get_dcp_group.return_value = dcp_group
|
||||||
|
|
||||||
mock_vllm_config.speculative_config = None
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||||
|
|||||||
@@ -865,26 +865,81 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
|
|
||||||
# 1. Compute out&lse by "npu_fused_infer_attention_score"
|
# 1. Compute out&lse by "npu_fused_infer_attention_score"
|
||||||
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
|
||||||
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
|
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
|
||||||
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
|
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
||||||
self.key_cache.view(self.key_cache.shape[0],
|
self.key_cache.shape[1], -1)
|
||||||
self.key_cache.shape[1], -1),
|
value = self.value_cache.view(self.key_cache.shape[0],
|
||||||
self.value_cache.view(self.key_cache.shape[0],
|
self.key_cache.shape[1], -1)
|
||||||
self.key_cache.shape[1], -1),
|
common_kwargs = {
|
||||||
num_heads=num_heads,
|
'num_heads':
|
||||||
num_key_value_heads=self.num_kv_heads,
|
num_heads,
|
||||||
input_layout="BSND",
|
'num_key_value_heads':
|
||||||
atten_mask=None,
|
self.num_kv_heads,
|
||||||
scale=self.scale,
|
'input_layout':
|
||||||
antiquant_mode=0,
|
"BSND",
|
||||||
antiquant_scale=None,
|
'atten_mask':
|
||||||
softmax_lse_flag=True,
|
None,
|
||||||
block_table=attn_metadata.block_tables,
|
'scale':
|
||||||
block_size=self.key_cache.shape[1],
|
self.scale,
|
||||||
actual_seq_lengths_kv=attn_metadata.decode_meta.
|
'antiquant_mode':
|
||||||
|
0,
|
||||||
|
'antiquant_scale':
|
||||||
|
None,
|
||||||
|
'softmax_lse_flag':
|
||||||
|
True,
|
||||||
|
'block_table':
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
'block_size':
|
||||||
|
self.key_cache.shape[1],
|
||||||
|
"actual_seq_lengths_kv":
|
||||||
|
attn_metadata.decode_meta.
|
||||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||||
)
|
}
|
||||||
|
graph_params = get_graph_params()
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
if forward_context.capturing:
|
||||||
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
|
event = torch.npu.ExternalEvent()
|
||||||
|
event.wait(stream)
|
||||||
|
event.reset(stream)
|
||||||
|
graph_params.events[num_tokens].append(event)
|
||||||
|
|
||||||
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
|
if workspace is None:
|
||||||
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
|
q_nope, k_nope, value, **common_kwargs)
|
||||||
|
update_graph_params_workspaces(num_tokens,
|
||||||
|
weak_ref_tensors(workspace))
|
||||||
|
attn_out = torch.empty_like(q_nope)
|
||||||
|
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
||||||
|
dtype=torch.float,
|
||||||
|
device=q_nope.device)
|
||||||
|
|
||||||
|
graph_params.attn_params[num_tokens].append(
|
||||||
|
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||||
|
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||||
|
self.scale, attn_metadata.block_tables,
|
||||||
|
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||||
|
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
|
||||||
|
self.dcp_rank],
|
||||||
|
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
|
||||||
|
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
||||||
|
torch.npu.graph_task_group_begin(stream)
|
||||||
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
|
q_nope,
|
||||||
|
k_nope,
|
||||||
|
value,
|
||||||
|
**common_kwargs,
|
||||||
|
workspace=workspace,
|
||||||
|
out=[attn_out, attn_lse])
|
||||||
|
handle = torch.npu.graph_task_group_end(stream)
|
||||||
|
graph_params.handles[num_tokens].append(handle)
|
||||||
|
else:
|
||||||
|
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||||
|
q_nope, k_nope, value, **common_kwargs)
|
||||||
|
|
||||||
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||||
attn_out.shape[3])
|
attn_out.shape[3])
|
||||||
|
|||||||
@@ -140,6 +140,9 @@ class AscendMLADecodeMetadata:
|
|||||||
cos: torch.Tensor = None
|
cos: torch.Tensor = None
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
||||||
list[int]]]]]] = None
|
list[int]]]]]] = None
|
||||||
|
seq_mask_pcp: torch.Tensor = None
|
||||||
|
seq_mask_dcp: torch.Tensor = None
|
||||||
|
cp_seq_len: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -259,6 +262,24 @@ class AscendMLAMetadataBuilder:
|
|||||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.cos_cache = None
|
self.cos_cache = None
|
||||||
self.sin_cache = None
|
self.sin_cache = None
|
||||||
|
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||||
|
) if prefill_context_parallel_enable() else 1
|
||||||
|
self.cp_rank = get_prefill_context_model_parallel_rank(
|
||||||
|
) if self.pcp_size > 1 else 0
|
||||||
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
|
) if self.dcp_size > 1 else 0
|
||||||
|
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||||
|
0)
|
||||||
|
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||||
|
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
|
||||||
|
self.pcp_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=device)
|
||||||
|
self.seq_mask_dcp_buf = torch.empty(max_num_seqs,
|
||||||
|
self.dcp_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=device)
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
@@ -463,6 +484,41 @@ class AscendMLAMetadataBuilder:
|
|||||||
block_table = block_table[:num_decodes, ...]
|
block_table = block_table[:num_decodes, ...]
|
||||||
seq_lens_list = seq_lens.tolist()
|
seq_lens_list = seq_lens.tolist()
|
||||||
|
|
||||||
|
if num_computed_tokens_of_pcp_dcp is not None:
|
||||||
|
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||||
|
num_computed_tokens_of_pcp_dcp
|
||||||
|
)[:num_decodes] # [bs, pcp_size, dcp_size]
|
||||||
|
seq_mask_pcp = torch.where(
|
||||||
|
torch.tensor(
|
||||||
|
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
|
||||||
|
1).to(torch.uint8)
|
||||||
|
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
|
||||||
|
shape[1]].copy_(seq_mask_pcp,
|
||||||
|
non_blocking=True)
|
||||||
|
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
|
||||||
|
seq_mask_pcp.shape[1])
|
||||||
|
|
||||||
|
seq_mask_dcp = torch.where(
|
||||||
|
torch.tensor(
|
||||||
|
num_computed_tokens_of_cp_dcp_array[:,
|
||||||
|
self.cp_rank, :])
|
||||||
|
== 0, 0, 1).to(torch.uint8)
|
||||||
|
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
|
||||||
|
shape[1]].copy_(seq_mask_dcp,
|
||||||
|
non_blocking=True)
|
||||||
|
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
|
||||||
|
seq_mask_dcp.shape[1])
|
||||||
|
|
||||||
|
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
|
||||||
|
self.cp_rank,
|
||||||
|
self.dcp_rank]
|
||||||
|
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
|
||||||
|
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||||
|
else:
|
||||||
|
seq_mask_pcp_shape = (0, 0)
|
||||||
|
seq_mask_dcp_shape = (0, 0)
|
||||||
|
cp_seq_len = None
|
||||||
|
|
||||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||||
assert self.cos_cache is not None
|
assert self.cos_cache is not None
|
||||||
assert self.sin_cache is not None
|
assert self.sin_cache is not None
|
||||||
@@ -485,7 +541,14 @@ class AscendMLAMetadataBuilder:
|
|||||||
sin=sin,
|
sin=sin,
|
||||||
cos=cos,
|
cos=cos,
|
||||||
num_computed_tokens_of_pcp_dcp=
|
num_computed_tokens_of_pcp_dcp=
|
||||||
num_computed_tokens_of_pcp_dcp)
|
num_computed_tokens_of_pcp_dcp,
|
||||||
|
seq_mask_pcp=self.
|
||||||
|
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
||||||
|
seq_mask_pcp_shape[1]],
|
||||||
|
seq_mask_dcp=self.
|
||||||
|
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
||||||
|
seq_mask_dcp_shape[1]],
|
||||||
|
cp_seq_len=cp_seq_len)
|
||||||
else:
|
else:
|
||||||
cos[:num_decode_tokens,
|
cos[:num_decode_tokens,
|
||||||
...] = self.cos_cache[input_positions].unsqueeze(
|
...] = self.cos_cache[input_positions].unsqueeze(
|
||||||
@@ -505,7 +568,14 @@ class AscendMLAMetadataBuilder:
|
|||||||
sin=sin[:num_decode_tokens, ...],
|
sin=sin[:num_decode_tokens, ...],
|
||||||
cos=cos[:num_decode_tokens, ...],
|
cos=cos[:num_decode_tokens, ...],
|
||||||
num_computed_tokens_of_pcp_dcp=
|
num_computed_tokens_of_pcp_dcp=
|
||||||
num_computed_tokens_of_pcp_dcp)
|
num_computed_tokens_of_pcp_dcp,
|
||||||
|
seq_mask_pcp=self.
|
||||||
|
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
||||||
|
seq_mask_pcp_shape[1]],
|
||||||
|
seq_mask_dcp=self.
|
||||||
|
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
||||||
|
seq_mask_dcp_shape[1]],
|
||||||
|
cp_seq_len=cp_seq_len)
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||||
@@ -1590,36 +1660,63 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
||||||
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
||||||
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
||||||
num_computed_tokens_of_pcp_dcp = np.array(
|
seq_mask_pcp = decode_meta.seq_mask_pcp
|
||||||
decode_meta.num_computed_tokens_of_pcp_dcp
|
seq_mask_dcp = decode_meta.seq_mask_dcp
|
||||||
)[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size]
|
seq_len = decode_meta.cp_seq_len
|
||||||
seq_mask_pcp = torch.where(
|
|
||||||
torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0,
|
|
||||||
1).to(torch.uint8).to(q_pe.device)
|
|
||||||
seq_mask_dcp = torch.where(
|
|
||||||
torch.tensor(
|
|
||||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0,
|
|
||||||
1).to(torch.uint8).to(q_pe.device)
|
|
||||||
seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
|
|
||||||
self.dcp_rank]
|
|
||||||
seq_len = torch.tensor(seq_len, dtype=torch.int32)
|
|
||||||
# npu_multi_head_latent_attention does not support seq_len = 0,
|
|
||||||
# update where seq_len == 0 to 1.
|
|
||||||
# This will not influence result, since we will use seq_mask to update lse.
|
|
||||||
seq_len = torch.where(seq_len == 0, 1, seq_len)
|
|
||||||
|
|
||||||
if torch.sum(seq_len).item() == 0:
|
common_kwargs = {
|
||||||
# Case that no kv_cache has been stored on this rank, no need to do following computation.
|
"return_lse": True,
|
||||||
attn_output = torch.zeros(
|
"calc_type": "calc_type_ring",
|
||||||
[num_tokens, num_heads, self.kv_lora_rank],
|
}
|
||||||
dtype=q_nope.dtype,
|
graph_params = get_graph_params()
|
||||||
device=q_nope.device)
|
forward_context: ForwardContext = get_forward_context()
|
||||||
softmax_lse = torch.full((num_tokens, num_heads, 1),
|
if forward_context.capturing:
|
||||||
float('-inf'),
|
stream = torch_npu.npu.current_stream()
|
||||||
dtype=q_nope.dtype,
|
event = torch.npu.ExternalEvent()
|
||||||
device=q_nope.device)
|
event.wait(stream)
|
||||||
|
event.reset(stream)
|
||||||
|
graph_params.events[num_tokens].append(event)
|
||||||
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
|
if workspace is None:
|
||||||
|
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
|
||||||
|
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
|
||||||
|
seq_len, num_heads, self.scale, self.num_kv_heads,
|
||||||
|
**common_kwargs)
|
||||||
|
update_graph_params_workspaces(num_tokens,
|
||||||
|
weak_ref_tensors(workspace))
|
||||||
|
attn_output = torch.empty_like(q_nope)
|
||||||
|
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
graph_params.attn_params[num_tokens].append(
|
||||||
|
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
|
||||||
|
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
|
||||||
|
decode_meta.block_table, seq_len, num_heads, self.scale,
|
||||||
|
self.num_kv_heads, weak_ref_tensors(attn_output),
|
||||||
|
weak_ref_tensors(softmax_lse)))
|
||||||
|
torch.npu.graph_task_group_begin(stream)
|
||||||
|
torch_npu.atb.npu_multi_head_latent_attention(
|
||||||
|
q_nope,
|
||||||
|
q_pe,
|
||||||
|
k_nope,
|
||||||
|
k_pe,
|
||||||
|
decode_meta.block_table,
|
||||||
|
seq_len,
|
||||||
|
num_heads,
|
||||||
|
self.scale,
|
||||||
|
self.num_kv_heads,
|
||||||
|
**common_kwargs,
|
||||||
|
workspace=workspace,
|
||||||
|
output=attn_output,
|
||||||
|
lse=softmax_lse)
|
||||||
|
handle = torch.npu.graph_task_group_end(stream)
|
||||||
|
graph_params.handles[num_tokens].append(handle)
|
||||||
else:
|
else:
|
||||||
attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention(
|
attn_output = torch.empty_like(q_nope)
|
||||||
|
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
torch_npu.atb.npu_multi_head_latent_attention(
|
||||||
q_nope,
|
q_nope,
|
||||||
q_pe,
|
q_pe,
|
||||||
k_nope,
|
k_nope,
|
||||||
@@ -1630,7 +1727,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.scale,
|
self.scale,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
calc_type="calc_type_ring")
|
calc_type="calc_type_ring",
|
||||||
|
output=attn_output,
|
||||||
|
lse=softmax_lse)
|
||||||
|
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@@ -300,6 +301,105 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
|||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
|
|
||||||
|
|
||||||
|
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||||
|
graph_params = get_graph_params()
|
||||||
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||||
|
# for each layer's attention op in the graph.
|
||||||
|
for key, param, handle, event in zip(
|
||||||
|
forward_context.attn_metadata,
|
||||||
|
graph_params.attn_params[runtime_shape],
|
||||||
|
graph_params.handles[runtime_shape],
|
||||||
|
graph_params.events[runtime_shape],
|
||||||
|
):
|
||||||
|
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
|
||||||
|
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
|
||||||
|
dcp_rank, dcp_size) = param
|
||||||
|
actual_seq_lengths_kv = forward_context.attn_metadata[
|
||||||
|
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
|
||||||
|
dcp_rank]
|
||||||
|
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||||
|
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
|
||||||
|
actual_seq_lengths_kv = np.concatenate(
|
||||||
|
[actual_seq_lengths_kv, pad_tensor])
|
||||||
|
if dcp_size > 1:
|
||||||
|
num_heads = num_heads * dcp_size
|
||||||
|
|
||||||
|
with torch.npu.stream(update_stream):
|
||||||
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
|
|
||||||
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
|
q_nope,
|
||||||
|
k_nope,
|
||||||
|
value,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_kv_heads,
|
||||||
|
input_layout="BSND",
|
||||||
|
atten_mask=None,
|
||||||
|
scale=scale,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
softmax_lse_flag=True,
|
||||||
|
block_table=block_table,
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
|
workspace=graph_params.workspaces.get(runtime_shape),
|
||||||
|
out=[attn_output, softmax_lse])
|
||||||
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
|
event.record(update_stream)
|
||||||
|
|
||||||
|
|
||||||
|
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||||
|
runtime_shape, speculative_config):
|
||||||
|
graph_params = get_graph_params()
|
||||||
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||||
|
# for each layer's attention op in the graph.
|
||||||
|
for key, param, handle, event in zip(
|
||||||
|
forward_context.attn_metadata,
|
||||||
|
graph_params.attn_params[runtime_shape],
|
||||||
|
graph_params.handles[runtime_shape],
|
||||||
|
graph_params.events[runtime_shape],
|
||||||
|
):
|
||||||
|
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
|
||||||
|
num_kv_heads, attn_output, softmax_lse) = param
|
||||||
|
|
||||||
|
decode_meta = forward_context.attn_metadata[key].decode
|
||||||
|
seq_len = decode_meta.cp_seq_len
|
||||||
|
|
||||||
|
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||||
|
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||||
|
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
|
||||||
|
len(seq_len))
|
||||||
|
else:
|
||||||
|
pad_length = runtime_shape - len(seq_len)
|
||||||
|
pad_tensor = torch.zeros(pad_length,
|
||||||
|
dtype=seq_len.dtype,
|
||||||
|
device=seq_len.device)
|
||||||
|
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
|
||||||
|
|
||||||
|
with torch.npu.stream(update_stream):
|
||||||
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
|
|
||||||
|
torch_npu.atb.npu_multi_head_latent_attention(
|
||||||
|
q_nope,
|
||||||
|
q_pe,
|
||||||
|
k_nope,
|
||||||
|
k_pe,
|
||||||
|
block_table,
|
||||||
|
seq_len,
|
||||||
|
num_heads,
|
||||||
|
scale,
|
||||||
|
num_kv_heads,
|
||||||
|
return_lse=True,
|
||||||
|
calc_type="calc_type_ring",
|
||||||
|
workspace=graph_params.workspaces.get(runtime_shape),
|
||||||
|
output=attn_output,
|
||||||
|
lse=softmax_lse)
|
||||||
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
|
event.record(update_stream)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphParams:
|
class GraphParams:
|
||||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||||
|
|||||||
@@ -110,10 +110,14 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
AscendPrefillContextParallelMetadata)
|
AscendPrefillContextParallelMetadata)
|
||||||
|
# yapf: disable
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
set_graph_params,
|
set_graph_params,
|
||||||
|
update_attn_dcp_pcp_params,
|
||||||
update_attn_params,
|
update_attn_params,
|
||||||
|
update_mla_attn_dcp_pcp_params,
|
||||||
update_mla_attn_params)
|
update_mla_attn_params)
|
||||||
|
# yapf: enable
|
||||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
||||||
D2DExpertWeightLoader
|
D2DExpertWeightLoader
|
||||||
@@ -1649,6 +1653,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
|
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
|
||||||
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
|
slot_mapping_for_pcp = blk_table.slot_mapping[:
|
||||||
|
long_seq_metadata
|
||||||
|
.
|
||||||
|
num_actual_tokens_pcp_padded]
|
||||||
|
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
|
||||||
assert pcp_unpad_mask is not None
|
assert pcp_unpad_mask is not None
|
||||||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
||||||
pcp_unpad_mask
|
pcp_unpad_mask
|
||||||
@@ -1657,10 +1666,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
0]]
|
0]]
|
||||||
pcp_padded_slot_mapping.fill_(-1)
|
pcp_padded_slot_mapping.fill_(-1)
|
||||||
pcp_padded_slot_mapping[
|
pcp_padded_slot_mapping[
|
||||||
pcp_unpad_mask] = blk_table.slot_mapping[:
|
pcp_unpad_mask] = slot_mapping_for_pcp[:
|
||||||
slot_mapping_size]
|
slot_mapping_size]
|
||||||
blk_table.slot_mapping[:long_seq_metadata.
|
slot_mapping_for_pcp[:long_seq_metadata.
|
||||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||||
|
slot_mapping = slot_mapping_for_pcp
|
||||||
|
|
||||||
# Make AscendCommonAttentionMetadata
|
# Make AscendCommonAttentionMetadata
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
@@ -1749,13 +1759,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
# FIXME: Try using `auto_dispatch_capture=True`
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
update_mla_attn_params(self.update_stream, forward_context,
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
maybe_padded_num_tokens,
|
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||||
self.speculative_config)
|
forward_context,
|
||||||
|
maybe_padded_num_tokens,
|
||||||
|
self.speculative_config)
|
||||||
|
else:
|
||||||
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
|
update_mla_attn_params(self.update_stream, forward_context,
|
||||||
|
maybe_padded_num_tokens,
|
||||||
|
self.speculative_config)
|
||||||
else:
|
else:
|
||||||
update_attn_params(self.update_stream, forward_context,
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
maybe_padded_num_tokens)
|
update_attn_dcp_pcp_params(self.update_stream,
|
||||||
|
forward_context,
|
||||||
|
maybe_padded_num_tokens)
|
||||||
|
else:
|
||||||
|
update_attn_params(self.update_stream, forward_context,
|
||||||
|
maybe_padded_num_tokens)
|
||||||
|
|
||||||
if get_forward_context().sp_enabled:
|
if get_forward_context().sp_enabled:
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||||
@@ -2488,6 +2510,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache_group_id].get_device_tensor()
|
kv_cache_group_id].get_device_tensor()
|
||||||
slot_mapping = self.input_batch.block_table[
|
slot_mapping = self.input_batch.block_table[
|
||||||
kv_cache_group_id].slot_mapping
|
kv_cache_group_id].slot_mapping
|
||||||
|
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
long_seq_metadata = self._generate_pcp_metadata(
|
||||||
|
num_tokens, self.seq_lens_cpu)
|
||||||
|
if long_seq_metadata is not None:
|
||||||
|
pcp_world_size = get_pcp_group(
|
||||||
|
).world_size if prefill_context_parallel_enable() else 1
|
||||||
|
dcp_world_size = get_dcp_group().world_size
|
||||||
|
num_computed_tokens_of_pcp_dcp = [[
|
||||||
|
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||||||
|
] for _ in range(num_tokens)]
|
||||||
|
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=torch.tensor(
|
query_start_loc=torch.tensor(
|
||||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||||
@@ -2511,6 +2546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
cos=self.cos,
|
cos=self.cos,
|
||||||
sin=self.sin,
|
sin=self.sin,
|
||||||
|
prefill_context_parallel_metadata=long_seq_metadata,
|
||||||
)
|
)
|
||||||
attn_state = AscendAttentionState.DecodeOnly
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
if self.speculative_config and \
|
if self.speculative_config and \
|
||||||
@@ -2540,12 +2576,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
not forward_context.capturing:
|
not forward_context.capturing:
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
# FIXME: Try using `auto_dispatch_capture=True`
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
update_mla_attn_params(self.update_stream, forward_context,
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
positions.shape[0],
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
self.speculative_config)
|
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||||
|
forward_context,
|
||||||
|
positions.shape[0],
|
||||||
|
self.speculative_config)
|
||||||
|
else:
|
||||||
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
|
update_mla_attn_params(self.update_stream, forward_context,
|
||||||
|
positions.shape[0],
|
||||||
|
self.speculative_config)
|
||||||
else:
|
else:
|
||||||
update_attn_params(self.update_stream, forward_context,
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
positions.shape[0])
|
update_attn_dcp_pcp_params(self.update_stream,
|
||||||
|
forward_context,
|
||||||
|
positions.shape[0])
|
||||||
|
else:
|
||||||
|
update_attn_params(self.update_stream, forward_context,
|
||||||
|
positions.shape[0])
|
||||||
|
|
||||||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||||||
hidden_states, _ = hidden_states
|
hidden_states, _ = hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user