[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -53,6 +53,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
self.mock_vllm_config = MagicMock()
|
self.mock_vllm_config = MagicMock()
|
||||||
self.mock_vllm_config.speculative_config = None
|
self.mock_vllm_config.speculative_config = None
|
||||||
self.mock_vllm_config.model_config.max_model_len = 640
|
self.mock_vllm_config.model_config.max_model_len = 640
|
||||||
|
self.mock_vllm_config.model_config.hf_text_config.sliding_window = None
|
||||||
self.mock_vllm_config.cache_config.block_size = 64
|
self.mock_vllm_config.cache_config.block_size = 64
|
||||||
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||||
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
|
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
|
||||||
@@ -89,8 +90,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
slot_mapping=torch.tensor(range(20)),
|
slot_mapping=torch.tensor(range(20)),
|
||||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||||
positions=torch.tensor([10, 10]),
|
positions=torch.tensor([10, 10]),
|
||||||
attn_mask=torch.ones((15, 15)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
|
|||||||
@@ -1004,8 +1004,6 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
||||||
attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens = kv_with_q_head_nomask_seqlens
|
attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens = kv_with_q_head_nomask_seqlens
|
||||||
attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens = kv_with_q_tail_nomask_seqlens
|
attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens = kv_with_q_tail_nomask_seqlens
|
||||||
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
|
|
||||||
torch.ones(10, 10, dtype=torch.float16), 1)
|
|
||||||
|
|
||||||
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
|
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
|
||||||
value, kv_c_and_k_pe_cache,
|
value, kv_c_and_k_pe_cache,
|
||||||
|
|||||||
@@ -244,8 +244,15 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||||
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||||
self, mock_get_cos_and_sin_mla):
|
self, mock_get_pcp_group, mock_get_pcp_group_mask,
|
||||||
|
mock_get_cos_and_sin_mla):
|
||||||
|
pcp_group = MagicMock()
|
||||||
|
pcp_group.world_size = 1
|
||||||
|
mock_get_pcp_group.return_value = pcp_group
|
||||||
|
mock_get_pcp_group_mask.return_value = pcp_group
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
@@ -400,14 +407,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
self.kv_cache_spec.num_heads = 32
|
self.kv_cache_spec.num_heads = 32
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||||
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||||
@patch("torch.npu.is_available")
|
@patch("torch.npu.is_available")
|
||||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||||
mock_zeros,
|
mock_zeros, mock_get_pcp_group,
|
||||||
|
mock_get_pcp_group_mask,
|
||||||
mock_get_cos_and_sin_mla):
|
mock_get_cos_and_sin_mla):
|
||||||
mock_npu_available.return_value = False
|
mock_npu_available.return_value = False
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
|
pcp_group = MagicMock()
|
||||||
|
pcp_group.world_size = 1
|
||||||
|
mock_get_pcp_group.return_value = pcp_group
|
||||||
|
mock_get_pcp_group_mask.return_value = pcp_group
|
||||||
|
|
||||||
def zeros_override(*args, **kwargs):
|
def zeros_override(*args, **kwargs):
|
||||||
kwargs.pop('pin_memory', None)
|
kwargs.pop('pin_memory', None)
|
||||||
@@ -426,8 +440,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
slot_mapping=torch.tensor(range(20)),
|
slot_mapping=torch.tensor(range(20)),
|
||||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||||
positions=torch.tensor([10, 10]),
|
positions=torch.tensor([10, 10]),
|
||||||
attn_mask=torch.ones((10, 10)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.PrefillNoCache,
|
attn_state=AscendAttentionState.PrefillNoCache,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
@@ -458,14 +470,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||||
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||||
@patch("torch.npu.is_available")
|
@patch("torch.npu.is_available")
|
||||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||||
mock_zeros,
|
mock_zeros, mock_get_pcp_group,
|
||||||
|
mock_get_pcp_group_mask,
|
||||||
mock_get_cos_and_sin_mla):
|
mock_get_cos_and_sin_mla):
|
||||||
mock_npu_available.return_value = False
|
mock_npu_available.return_value = False
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
|
pcp_group = MagicMock()
|
||||||
|
pcp_group.world_size = 1
|
||||||
|
mock_get_pcp_group.return_value = pcp_group
|
||||||
|
mock_get_pcp_group_mask.return_value = pcp_group
|
||||||
|
|
||||||
def zeros_override(*args, **kwargs):
|
def zeros_override(*args, **kwargs):
|
||||||
kwargs.pop('pin_memory', None)
|
kwargs.pop('pin_memory', None)
|
||||||
@@ -485,8 +504,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
slot_mapping=torch.tensor(range(20)),
|
slot_mapping=torch.tensor(range(20)),
|
||||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||||
positions=torch.tensor([10, 10]),
|
positions=torch.tensor([10, 10]),
|
||||||
attn_mask=torch.ones((15, 15)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
@@ -517,8 +534,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||||
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
|
def test_build_decode_only_metadata(self, mock_get_pcp_group,
|
||||||
|
mock_get_pcp_group_mask,
|
||||||
|
mock_get_cos_and_sin_mla):
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
|
pcp_group = MagicMock()
|
||||||
|
pcp_group.world_size = 1
|
||||||
|
mock_get_pcp_group.return_value = pcp_group
|
||||||
|
mock_get_pcp_group_mask.return_value = pcp_group
|
||||||
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||||
@@ -532,8 +557,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||||
positions=torch.tensor([10, 10]),
|
positions=torch.tensor([10, 10]),
|
||||||
attn_mask=torch.ones((3, 3)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.DecodeOnly,
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
@@ -563,9 +586,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||||
def test_build_for_graph_capture_decode_only(self,
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
|
def test_build_for_graph_capture_decode_only(self, mock_get_pcp_group,
|
||||||
|
mock_get_pcp_group_mask,
|
||||||
mock_get_cos_and_sin_mla):
|
mock_get_cos_and_sin_mla):
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
|
pcp_group = MagicMock()
|
||||||
|
pcp_group.world_size = 1
|
||||||
|
mock_get_pcp_group.return_value = pcp_group
|
||||||
|
mock_get_pcp_group_mask.return_value = pcp_group
|
||||||
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||||
@@ -579,8 +609,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||||
positions=torch.tensor([10, 10]),
|
positions=torch.tensor([10, 10]),
|
||||||
attn_mask=torch.ones((3, 3)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.DecodeOnly,
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
@@ -625,8 +653,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
slot_mapping=torch.tensor(range(20)),
|
slot_mapping=torch.tensor(range(20)),
|
||||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||||
positions=torch.tensor([10, 10]),
|
positions=torch.tensor([10, 10]),
|
||||||
attn_mask=torch.ones((10, 10)),
|
|
||||||
spec_attn_mask=None,
|
|
||||||
attn_state=AscendAttentionState.PrefillNoCache,
|
attn_state=AscendAttentionState.PrefillNoCache,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
|
|||||||
@@ -291,8 +291,6 @@ class TestMtpProposer:
|
|||||||
|
|
||||||
mock_runner = MagicMock()
|
mock_runner = MagicMock()
|
||||||
mock_runner.actual_seq_lengths_q = MagicMock()
|
mock_runner.actual_seq_lengths_q = MagicMock()
|
||||||
mock_runner.attn_mask = MagicMock()
|
|
||||||
mock_runner.spec_attn_mask = MagicMock()
|
|
||||||
mock_runner.attn_state = MagicMock()
|
mock_runner.attn_state = MagicMock()
|
||||||
mock_runner.graph_pad_size = 0
|
mock_runner.graph_pad_size = 0
|
||||||
mock_runner.decode_token_per_req = MagicMock()
|
mock_runner.decode_token_per_req = MagicMock()
|
||||||
@@ -334,5 +332,3 @@ class TestMtpProposer:
|
|||||||
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
|
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
|
||||||
assert spec_common_attn_metadata.max_query_len == 8
|
assert spec_common_attn_metadata.max_query_len == 8
|
||||||
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q
|
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q
|
||||||
assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask
|
|
||||||
assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
|||||||
input_batch.num_tokens = torch.tensor(num_tokens)
|
input_batch.num_tokens = torch.tensor(num_tokens)
|
||||||
|
|
||||||
query_lens = torch.tensor(query_lens)
|
query_lens = torch.tensor(query_lens)
|
||||||
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None,
|
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
||||||
input_batch)
|
input_batch)
|
||||||
|
|
||||||
if not expect_not_none:
|
if not expect_not_none:
|
||||||
@@ -97,21 +97,6 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
|||||||
assert hasattr(result, 'head_attn_nomask_seqlens')
|
assert hasattr(result, 'head_attn_nomask_seqlens')
|
||||||
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
||||||
|
|
||||||
if hasattr(result, 'pcp_prefill_mask'
|
|
||||||
) and result.pcp_prefill_mask is not None:
|
|
||||||
if use_mla:
|
|
||||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
|
||||||
else:
|
|
||||||
assert result.pcp_prefill_mask.shape == (2048, 2048)
|
|
||||||
else:
|
|
||||||
if hasattr(result, 'pcp_prefill_mask'):
|
|
||||||
if result.pcp_prefill_mask is not None:
|
|
||||||
if use_mla:
|
|
||||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
|
||||||
else:
|
|
||||||
assert result.pcp_prefill_mask.shape == (2048,
|
|
||||||
2048)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
||||||
|
|||||||
@@ -13,6 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.distributed import get_pcp_group
|
||||||
|
|
||||||
|
from vllm_ascend.platform import ModelConfig
|
||||||
|
from vllm_ascend.utils import singleton
|
||||||
|
|
||||||
|
|
||||||
def _generate_attn_mask(max_seq_len, dtype):
|
def _generate_attn_mask(max_seq_len, dtype):
|
||||||
@@ -29,6 +33,7 @@ def _generate_attn_mask(max_seq_len, dtype):
|
|||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
class AttentionMaskBuilder:
|
class AttentionMaskBuilder:
|
||||||
|
|
||||||
def __init__(self, device: torch.device):
|
def __init__(self, device: torch.device):
|
||||||
@@ -83,3 +88,15 @@ class AttentionMaskBuilder:
|
|||||||
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
||||||
self.swa_mask = triu_mask + tril_mask
|
self.swa_mask = triu_mask + tril_mask
|
||||||
return self.swa_mask
|
return self.swa_mask
|
||||||
|
|
||||||
|
def get_attention_mask(self, model_config: ModelConfig):
|
||||||
|
if model_config.runner_type == "pooling":
|
||||||
|
return self.get_attn_mask(2048, torch.bool)
|
||||||
|
|
||||||
|
return self.get_splitfuse_attn_mask()
|
||||||
|
|
||||||
|
def get_final_mla_mask(self, model_config: ModelConfig):
|
||||||
|
if get_pcp_group().world_size > 1:
|
||||||
|
return self.get_pcp_mla_mask(model_config.dtype)
|
||||||
|
# Prefill stages use 512x512 mask with appropriate dtype
|
||||||
|
return self.get_mla_mask(model_config.dtype)
|
||||||
@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
AscendMetadataForDecode, AscendMetadataForPrefill)
|
AscendMetadataForDecode, AscendMetadataForPrefill)
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
@@ -219,6 +220,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|||||||
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
||||||
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
@@ -253,10 +255,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||||
attn_mask = common_attn_metadata.attn_mask
|
|
||||||
swa_mask = common_attn_metadata.swa_mask
|
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
|
|
||||||
|
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||||
|
attn_mask = self.attn_mask_builder.get_attention_mask(
|
||||||
|
self.model_config)
|
||||||
|
|
||||||
|
swa_mask = None
|
||||||
|
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
|
||||||
|
if self.model_config is not None and is_swa:
|
||||||
|
swa_mask = self.attn_mask_builder.get_swa_mask(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.model_config.hf_text_config.sliding_window)
|
||||||
|
|
||||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|||||||
@@ -121,7 +121,8 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
attn_mask = common_attn_metadata.attn_mask
|
attn_mask = self.attn_mask_builder.get_attention_mask(
|
||||||
|
self.model_config)
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||||
|
|
||||||
@@ -212,7 +213,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
|
||||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||||
pcp_allgather_restore_idx)
|
pcp_allgather_restore_idx)
|
||||||
|
|
||||||
@@ -433,13 +433,12 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||||
nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \
|
nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \
|
||||||
if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
|
||||||
output, lse = self._attention_with_nomask_and_mask(
|
output, lse = self._attention_with_nomask_and_mask(
|
||||||
**data,
|
**data,
|
||||||
q_seqlens=attn_mask_seqlens,
|
q_seqlens=attn_mask_seqlens,
|
||||||
kv_seqlens_nomask=nomask_seqlens,
|
kv_seqlens_nomask=nomask_seqlens,
|
||||||
kv_seqlens_mask=attn_mask_seqlens,
|
kv_seqlens_mask=attn_mask_seqlens,
|
||||||
mask=mask,
|
mask=attn_metadata.attn_mask,
|
||||||
attn_metadata=attn_metadata)
|
attn_metadata=attn_metadata)
|
||||||
return output, lse
|
return output, lse
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class AscendPCPMetadata:
|
|||||||
head_attn_nomask_seqlens: torch.Tensor = None
|
head_attn_nomask_seqlens: torch.Tensor = None
|
||||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||||
q_full_idx: torch.Tensor = None
|
q_full_idx: torch.Tensor = None
|
||||||
pcp_prefill_mask: torch.Tensor = None
|
|
||||||
pcp_allgather_restore_idx: Optional[list[int]] = None
|
pcp_allgather_restore_idx: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||||
tail_attn_nomask_seqlens,
|
tail_attn_nomask_seqlens,
|
||||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
|
||||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||||
pcp_allgather_restore_idx)
|
pcp_allgather_restore_idx)
|
||||||
|
|
||||||
@@ -195,7 +194,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
).item()
|
).item()
|
||||||
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
||||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||||
# to avoid irregular spec_attn_mask shape
|
# to avoid irregular attn_mask shape
|
||||||
return self.num_decodes_flatten + self.num_prefills
|
return self.num_decodes_flatten + self.num_prefills
|
||||||
else:
|
else:
|
||||||
return self.num_decodes_flatten
|
return self.num_decodes_flatten
|
||||||
@@ -420,7 +419,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
|
||||||
output_head, lse_head = self._attention_with_mask_and_nomask(
|
output_head, lse_head = self._attention_with_mask_and_nomask(
|
||||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||||
@@ -431,7 +429,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||||
attn_mask_seqlens=attn_mask_seqlens,
|
attn_mask_seqlens=attn_mask_seqlens,
|
||||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||||
mask=mask)
|
mask=attn_metadata.attn_mask)
|
||||||
|
|
||||||
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
||||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||||
@@ -443,7 +441,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||||
attn_mask_seqlens=attn_mask_seqlens,
|
attn_mask_seqlens=attn_mask_seqlens,
|
||||||
attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||||
mask=mask)
|
mask=attn_metadata.attn_mask)
|
||||||
|
|
||||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||||
attn_output = torch.index_select(
|
attn_output = torch.index_select(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
AscendPCPMetadata, CPChunkedContextMetadata)
|
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||||
@@ -263,6 +264,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
self.graph_pad_size = 0
|
self.graph_pad_size = 0
|
||||||
self.query_lens: torch.Tensor = None
|
self.query_lens: torch.Tensor = None
|
||||||
self.seq_lens: torch.Tensor = None
|
self.seq_lens: torch.Tensor = None
|
||||||
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
@@ -448,7 +450,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
num_decodes=self.num_decodes,
|
num_decodes=self.num_decodes,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=self.attn_mask_builder.get_final_mla_mask(
|
||||||
|
self.model_config),
|
||||||
attn_state=common_attn_metadata.attn_state,
|
attn_state=common_attn_metadata.attn_state,
|
||||||
prefill=prefill_metadata,
|
prefill=prefill_metadata,
|
||||||
decode=decode_metadata,
|
decode=decode_metadata,
|
||||||
@@ -542,7 +545,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
prefill_input_positions = input_positions[tokens_start:]
|
prefill_input_positions = input_positions[tokens_start:]
|
||||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||||
return AscendMLAPrefillMetadata(
|
return AscendMLAPrefillMetadata(
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=self.attn_mask_builder.get_final_mla_mask(
|
||||||
|
self.model_config),
|
||||||
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
context_lens=self.seq_lens[reqs_start:],
|
context_lens=self.seq_lens[reqs_start:],
|
||||||
@@ -643,7 +647,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
seq_lens_list=seq_lens_list,
|
seq_lens_list=seq_lens_list,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
sin=sin[:self.num_decode_tokens, ...],
|
sin=sin[:self.num_decode_tokens, ...],
|
||||||
cos=cos[:self.num_decode_tokens, ...],
|
cos=cos[:self.num_decode_tokens, ...],
|
||||||
@@ -1197,7 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# Output shape: [num_heads, num_tokens, dim]
|
# Output shape: [num_heads, num_tokens, dim]
|
||||||
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
|
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
|
||||||
sparse_mode = 3
|
sparse_mode = 3
|
||||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||||
else:
|
else:
|
||||||
# The output layout is set to NBSD to eliminate the need for a
|
# The output layout is set to NBSD to eliminate the need for a
|
||||||
@@ -1218,7 +1222,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
attn_output_shape = (self.num_heads, num_tokens, 1,
|
attn_output_shape = (self.num_heads, num_tokens, 1,
|
||||||
self.kv_lora_rank)
|
self.kv_lora_rank)
|
||||||
sparse_mode = 0
|
sparse_mode = 0
|
||||||
spec_attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
common_kwargs = {
|
common_kwargs = {
|
||||||
'query_rope': q_pe,
|
'query_rope': q_pe,
|
||||||
@@ -1226,7 +1230,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
'num_heads': self.num_heads,
|
'num_heads': self.num_heads,
|
||||||
'num_key_value_heads': self.num_kv_heads,
|
'num_key_value_heads': self.num_kv_heads,
|
||||||
'input_layout': input_layout,
|
'input_layout': input_layout,
|
||||||
'atten_mask': spec_attn_mask,
|
'atten_mask': attn_mask,
|
||||||
'sparse_mode': sparse_mode,
|
'sparse_mode': sparse_mode,
|
||||||
'scale': self.scale,
|
'scale': self.scale,
|
||||||
'antiquant_mode': 0,
|
'antiquant_mode': 0,
|
||||||
@@ -1269,8 +1273,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||||
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
|
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
|
||||||
self.num_heads, self.num_kv_heads, input_layout,
|
self.num_heads, self.num_kv_heads, input_layout,
|
||||||
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
|
weak_ref_tensors(attn_mask) if attn_mask is not None else
|
||||||
else None, sparse_mode, self.scale, decode_meta.block_table,
|
None, sparse_mode, self.scale, decode_meta.block_table,
|
||||||
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
|
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
|
||||||
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
|
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
@@ -156,6 +157,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
and self.vllm_config.compilation_config.cudagraph_mode
|
and self.vllm_config.compilation_config.cudagraph_mode
|
||||||
== CUDAGraphMode.FULL_DECODE_ONLY
|
== CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
||||||
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
@@ -280,7 +282,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
head_dim=self.model_config.get_head_size(),
|
head_dim=self.model_config.get_head_size(),
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=self.attn_mask_builder.get_attention_mask(
|
||||||
|
self.model_config),
|
||||||
attn_state=common_attn_metadata.attn_state,
|
attn_state=common_attn_metadata.attn_state,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
sin=sin[:num_input_tokens],
|
sin=sin[:num_input_tokens],
|
||||||
|
|||||||
@@ -66,8 +66,6 @@ class AscendPrefillContextParallelMetadata:
|
|||||||
|
|
||||||
q_full_idx: torch.Tensor = None
|
q_full_idx: torch.Tensor = None
|
||||||
|
|
||||||
pcp_prefill_mask: torch.Tensor = None
|
|
||||||
|
|
||||||
# original query_lens before pcp split
|
# original query_lens before pcp split
|
||||||
query_lens_pcp_full_cpu: torch.Tensor = None
|
query_lens_pcp_full_cpu: torch.Tensor = None
|
||||||
|
|
||||||
@@ -93,12 +91,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
|||||||
|
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
attn_mask: torch.Tensor = None
|
|
||||||
|
|
||||||
spec_attn_mask: torch.Tensor = None
|
|
||||||
|
|
||||||
swa_mask: torch.Tensor = None
|
|
||||||
|
|
||||||
attn_state: Any = None
|
attn_state: Any = None
|
||||||
|
|
||||||
graph_pad_size: int = -1
|
graph_pad_size: int = -1
|
||||||
@@ -130,9 +122,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
|||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||||
positions=self.positions[:num_actual_tokens],
|
positions=self.positions[:num_actual_tokens],
|
||||||
attn_mask=self.attn_mask,
|
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
|
||||||
swa_mask=self.swa_mask,
|
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||||
num_input_tokens=num_actual_tokens,
|
num_input_tokens=num_actual_tokens,
|
||||||
|
|||||||
@@ -340,7 +340,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
|||||||
graph_params.events[runtime_shape],
|
graph_params.events[runtime_shape],
|
||||||
):
|
):
|
||||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||||
spec_attn_mask, sparse_mode, scale, block_table, block_size,
|
attn_mask, sparse_mode, scale, block_table, block_size,
|
||||||
seq_lens_list, actual_seq_lengths, attn_output,
|
seq_lens_list, actual_seq_lengths, attn_output,
|
||||||
softmax_lse) = param
|
softmax_lse) = param
|
||||||
seq_lens_list = forward_context.attn_metadata[
|
seq_lens_list = forward_context.attn_metadata[
|
||||||
@@ -380,7 +380,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_key_value_heads=num_kv_heads,
|
num_key_value_heads=num_kv_heads,
|
||||||
input_layout=input_layout,
|
input_layout=input_layout,
|
||||||
atten_mask=spec_attn_mask,
|
atten_mask=attn_mask,
|
||||||
sparse_mode=sparse_mode,
|
sparse_mode=sparse_mode,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
antiquant_mode=0,
|
antiquant_mode=0,
|
||||||
@@ -480,7 +480,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
|||||||
seq_len = decode_meta.cp_seq_len
|
seq_len = decode_meta.cp_seq_len
|
||||||
|
|
||||||
# For pcp + spec decode, we flatten seq_lens
|
# For pcp + spec decode, we flatten seq_lens
|
||||||
# to avoid irregular spec_attn_mask shape,
|
# to avoid irregular attn_mask shape,
|
||||||
# so there's no need to divide runtime_shape by spec_multiple
|
# so there's no need to divide runtime_shape by spec_multiple
|
||||||
pad_length = runtime_shape - len(seq_len)
|
pad_length = runtime_shape - len(seq_len)
|
||||||
pad_tensor = torch.zeros(pad_length,
|
pad_tensor = torch.zeros(pad_length,
|
||||||
|
|||||||
@@ -222,8 +222,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
slot_mapping=self.runner.input_batch.block_table[0].
|
slot_mapping=self.runner.input_batch.block_table[0].
|
||||||
slot_mapping.gpu,
|
slot_mapping.gpu,
|
||||||
positions=self.runner.positions.gpu,
|
positions=self.runner.positions.gpu,
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
max_seq_len=0,
|
max_seq_len=0,
|
||||||
@@ -672,8 +670,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
positions=common_attn_metadata.positions[token_indices],
|
positions=common_attn_metadata.positions[token_indices],
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
max_seq_len=0)
|
max_seq_len=0)
|
||||||
@@ -762,8 +758,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
positions=common_attn_metadata.positions,
|
positions=common_attn_metadata.positions,
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
num_computed_tokens_cpu=common_attn_metadata.
|
num_computed_tokens_cpu=common_attn_metadata.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -73,8 +73,6 @@ class MtpProposer(EagleProposer):
|
|||||||
slot_mapping=self.runner.input_batch.block_table[0].
|
slot_mapping=self.runner.input_batch.block_table[0].
|
||||||
slot_mapping.gpu,
|
slot_mapping.gpu,
|
||||||
positions=self.runner.positions.gpu,
|
positions=self.runner.positions.gpu,
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
max_seq_len=0)
|
max_seq_len=0)
|
||||||
|
|||||||
@@ -1150,3 +1150,14 @@ def check_kv_extra_config(vllm_config):
|
|||||||
_check(
|
_check(
|
||||||
"decode",
|
"decode",
|
||||||
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
||||||
|
|
||||||
|
|
||||||
|
def singleton(cls):
|
||||||
|
instances = {}
|
||||||
|
|
||||||
|
def get_instance(*args, **kwargs):
|
||||||
|
if cls not in instances:
|
||||||
|
instances[cls] = cls(*args, **kwargs)
|
||||||
|
return instances[cls]
|
||||||
|
|
||||||
|
return get_instance
|
||||||
@@ -77,7 +77,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
|||||||
from vllm.v1.worker.utils import AttentionGroup
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@@ -230,7 +229,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.positions = self._make_buffer(max_buffer_num_tokens,
|
self.positions = self._make_buffer(max_buffer_num_tokens,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
self.sampler = AscendSampler()
|
self.sampler = AscendSampler()
|
||||||
self.attn_mask = None
|
|
||||||
self.attn_state = None
|
self.attn_state = None
|
||||||
|
|
||||||
# Ascend-specific configurations
|
# Ascend-specific configurations
|
||||||
@@ -264,19 +262,9 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
use_sparse=self.use_sparse,
|
use_sparse=self.use_sparse,
|
||||||
use_mm_prefix=self.model_config is not None
|
use_mm_prefix=self.model_config is not None
|
||||||
and self.model_config.is_mm_prefix_lm)
|
and self.model_config.is_mm_prefix_lm)
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
|
||||||
|
|
||||||
self._set_up_drafter()
|
self._set_up_drafter()
|
||||||
|
|
||||||
# sliding window attn mask
|
|
||||||
self.swa_mask = None
|
|
||||||
is_swa = hasattr(self.vllm_config.model_config.hf_text_config,
|
|
||||||
"sliding_window")
|
|
||||||
if self.model_config is not None and is_swa:
|
|
||||||
self.swa_mask = self.attn_mask_builder.get_swa_mask(
|
|
||||||
self.dtype,
|
|
||||||
self.vllm_config.model_config.hf_text_config.sliding_window)
|
|
||||||
|
|
||||||
# kv role
|
# kv role
|
||||||
self.is_kv_producer = False
|
self.is_kv_producer = False
|
||||||
self.is_kv_consumer = False
|
self.is_kv_consumer = False
|
||||||
@@ -370,7 +358,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
def _set_up_drafter(self):
|
def _set_up_drafter(self):
|
||||||
# Set up speculative decoding.
|
# Set up speculative decoding.
|
||||||
self.spec_attn_mask = None
|
|
||||||
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
||||||
SuffixDecodingProposer]] = None
|
SuffixDecodingProposer]] = None
|
||||||
self.actual_seq_lengths_q: list[int] = []
|
self.actual_seq_lengths_q: list[int] = []
|
||||||
@@ -379,8 +366,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||||
assert spec_token_num > 0
|
assert spec_token_num > 0
|
||||||
self.decode_token_per_req = 1 + spec_token_num
|
self.decode_token_per_req = 1 + spec_token_num
|
||||||
self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
|
|
||||||
)
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.drafter = self._get_drafter()
|
self.drafter = self._get_drafter()
|
||||||
if self.speculative_config.method == "eagle3":
|
if self.speculative_config.method == "eagle3":
|
||||||
@@ -494,22 +479,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
return self.model.unwrap()
|
return self.model.unwrap()
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def _make_attention_mask(self, attn_state) -> torch.Tensor:
|
|
||||||
# pcp situation.
|
|
||||||
if self.attn_mask_builder is None:
|
|
||||||
raise ValueError("Attn mask builder is None")
|
|
||||||
# Pooling situation.
|
|
||||||
if self.model_config.runner_type == "pooling":
|
|
||||||
return self.attn_mask_builder.get_attn_mask(2048, torch.bool)
|
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
return self.attn_mask_builder.get_pcp_mla_mask(self.dtype)
|
|
||||||
# mla prefill
|
|
||||||
if attn_state != AscendAttentionState.DecodeOnly:
|
|
||||||
return self.attn_mask_builder.get_mla_mask(self.dtype)
|
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@@ -551,7 +520,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
with_prefill = attn_state not in [
|
with_prefill = attn_state not in [
|
||||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
]
|
]
|
||||||
self.attn_mask = self._make_attention_mask(attn_state)
|
|
||||||
|
|
||||||
# Get positions.
|
# Get positions.
|
||||||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||||||
@@ -941,7 +909,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
|
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
|
||||||
total_num_scheduled_tokens, self.query_lens,
|
total_num_scheduled_tokens, self.query_lens,
|
||||||
self.attn_mask, self.input_batch)
|
self.input_batch)
|
||||||
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
|
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
|
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
|
||||||
@@ -997,9 +965,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_computed_tokens_cpu=self.input_batch.
|
num_computed_tokens_cpu=self.input_batch.
|
||||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||||
positions=self.positions.gpu,
|
positions=self.positions.gpu,
|
||||||
attn_mask=self.attn_mask,
|
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
|
||||||
swa_mask=self.swa_mask,
|
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
@@ -1009,7 +974,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||||
# For pcp + spec decode, we flatten block_table
|
# For pcp + spec decode, we flatten block_table
|
||||||
# to avoid irregular spec_attn_mask shape, e.g.,
|
# to avoid irregular attn_mask shape, e.g.,
|
||||||
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||||||
# ori block_table: # [d0, d1, p0, p1, p2]
|
# ori block_table: # [d0, d1, p0, p1, p2]
|
||||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||||
@@ -1918,7 +1883,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.query_start_loc.cpu[1:num_reqs +
|
self.query_start_loc.cpu[1:num_reqs +
|
||||||
1] = torch.Tensor(cu_num_tokens)
|
1] = torch.Tensor(cu_num_tokens)
|
||||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||||
self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
||||||
|
|
||||||
num_computed_tokens_cpu = (
|
num_computed_tokens_cpu = (
|
||||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
@@ -1930,8 +1894,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
slot_mapping = self.input_batch.block_table[
|
slot_mapping = self.input_batch.block_table[
|
||||||
kv_cache_group_id].slot_mapping
|
kv_cache_group_id].slot_mapping
|
||||||
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
|
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
|
||||||
num_tokens, self.query_lens, self.attn_mask,
|
num_tokens, self.query_lens, self.input_batch)
|
||||||
self.input_batch)
|
|
||||||
if long_seq_metadata is not None:
|
if long_seq_metadata is not None:
|
||||||
pcp_world_size = get_pcp_group().world_size
|
pcp_world_size = get_pcp_group().world_size
|
||||||
dcp_world_size = get_dcp_group().world_size
|
dcp_world_size = get_dcp_group().world_size
|
||||||
@@ -1954,9 +1917,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
slot_mapping=slot_mapping.gpu,
|
slot_mapping=slot_mapping.gpu,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions.gpu,
|
positions=self.positions.gpu,
|
||||||
attn_mask=self.attn_mask,
|
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
|
||||||
swa_mask=self.swa_mask,
|
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
|
|||||||
@@ -498,7 +498,7 @@ class PCPManager:
|
|||||||
torch.float32).argsort().to(torch.int32)
|
torch.float32).argsort().to(torch.int32)
|
||||||
|
|
||||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
||||||
attn_mask, input_batch):
|
input_batch):
|
||||||
from vllm_ascend.attention.utils import \
|
from vllm_ascend.attention.utils import \
|
||||||
AscendPrefillContextParallelMetadata
|
AscendPrefillContextParallelMetadata
|
||||||
num_reqs = input_batch.num_reqs or query_lens.size(0)
|
num_reqs = input_batch.num_reqs or query_lens.size(0)
|
||||||
@@ -523,7 +523,7 @@ class PCPManager:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
# For pcp + spec decode, we flatten seq_lens
|
# For pcp + spec decode, we flatten seq_lens
|
||||||
# to avoid irregular spec_attn_mask shape.
|
# to avoid irregular attn_mask shape.
|
||||||
# Same as block_table, we flatten decode seq_lens to query_lens,
|
# Same as block_table, we flatten decode seq_lens to query_lens,
|
||||||
# and keep prefill seq_lens unchanged.
|
# and keep prefill seq_lens unchanged.
|
||||||
for decode_idx in range(self.decode_threshold):
|
for decode_idx in range(self.decode_threshold):
|
||||||
@@ -657,13 +657,11 @@ class PCPManager:
|
|||||||
split_with_q_head_nomask_idx_reqs,
|
split_with_q_head_nomask_idx_reqs,
|
||||||
split_kv_with_q_tail_nomask_idx_reqs,
|
split_kv_with_q_tail_nomask_idx_reqs,
|
||||||
head_attn_nomask_seqlens, chunk_seqlens)
|
head_attn_nomask_seqlens, chunk_seqlens)
|
||||||
pcp_prefill_mask = attn_mask
|
|
||||||
|
|
||||||
self.extra_long_seq_kwargs = {
|
self.extra_long_seq_kwargs = {
|
||||||
'attn_mask_seqlens': attn_mask_seqlens,
|
'attn_mask_seqlens': attn_mask_seqlens,
|
||||||
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
||||||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
|
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens
|
||||||
'pcp_prefill_mask': pcp_prefill_mask
|
|
||||||
}
|
}
|
||||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
@@ -685,8 +683,6 @@ class PCPManager:
|
|||||||
'head_attn_nomask_seqlens']
|
'head_attn_nomask_seqlens']
|
||||||
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||||
'tail_attn_nomask_seqlens']
|
'tail_attn_nomask_seqlens']
|
||||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
|
||||||
'pcp_prefill_mask']
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
|
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
|
||||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
||||||
|
|||||||
@@ -58,9 +58,6 @@ def build_attn_metadata(
|
|||||||
decode_token_per_req: int,
|
decode_token_per_req: int,
|
||||||
actual_seq_lengths_q: list[int],
|
actual_seq_lengths_q: list[int],
|
||||||
positions: torch.Tensor | None = None,
|
positions: torch.Tensor | None = None,
|
||||||
attn_mask: torch.Tensor
|
|
||||||
| None = None,
|
|
||||||
spec_attn_mask: torch.Tensor | None = None,
|
|
||||||
attn_state: Any | None = None,
|
attn_state: Any | None = None,
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
num_input_tokens: int = 0,
|
num_input_tokens: int = 0,
|
||||||
@@ -92,8 +89,6 @@ def build_attn_metadata(
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
attn_mask=attn_mask,
|
|
||||||
spec_attn_mask=spec_attn_mask,
|
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
graph_pad_size=graph_pad_size,
|
graph_pad_size=graph_pad_size,
|
||||||
num_input_tokens=num_input_tokens,
|
num_input_tokens=num_input_tokens,
|
||||||
|
|||||||
@@ -32,8 +32,7 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
|||||||
|
|
||||||
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
||||||
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
||||||
build_attn_state,
|
build_attn_state)
|
||||||
make_attention_mask)
|
|
||||||
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
||||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||||
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
||||||
@@ -155,12 +154,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
num_valid_tokens,
|
num_valid_tokens,
|
||||||
)
|
)
|
||||||
attn_mask = make_attention_mask(
|
|
||||||
self.vllm_config,
|
|
||||||
attn_state,
|
|
||||||
self.dtype,
|
|
||||||
self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
idx_mapping_list = [
|
idx_mapping_list = [
|
||||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||||
@@ -284,7 +277,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
slot_mappings=slot_mappings.to(torch.int32),
|
slot_mappings=slot_mappings.to(torch.int32),
|
||||||
kv_cache_config=self.kv_cache_config,
|
kv_cache_config=self.kv_cache_config,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
attn_mask=attn_mask,
|
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user