[Refactor] AttentionBuilder inherit from base class in vllm (#5916)
### What this PR does / why we need it? This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder` properly inherit from the base class `MLACommonMetadataBuilder` in vllm by adding `super().__init__()` calls. **Changes:** - Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()` - Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()` - Extract `ascend_chunked_prefill_workspace_size()` to `vllm_ascend/attention/utils.py` to avoid code duplication - Override `determine_chunked_prefill_workspace_size()` to support Ascend-specific 128k tokens workspace size (vs 64k in parent class) - Update unit tests to mock parent class `__init__` for proper isolation **Why we need it:** - Follow proper Python inheritance patterns by calling `super().__init__()` - Reduce code duplication by reusing parent class initialization logic - Better maintainability as parent class changes will be automatically inherited Part of issue #5463 item 10 ### Does this PR introduce _any_ user-facing change? No, this is an internal refactoring that does not change any user-facing behavior. Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -208,11 +208,38 @@ class TestAscendMLAMetadata(TestBase):
|
|||||||
|
|
||||||
class TestAscendMLAMetadataBuilder(TestBase):
|
class TestAscendMLAMetadataBuilder(TestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Mock parent class __init__ to avoid complex initialization,
|
||||||
|
# but still set the essential attributes that child class needs
|
||||||
|
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
||||||
|
device, metadata_cls, supports_dcp_with_varlen):
|
||||||
|
self.metadata_cls = metadata_cls
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.device = device
|
||||||
|
self.chunked_prefill_workspace_size = 128 * 1024
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size,
|
||||||
|
vllm_config.model_config.get_head_size()),
|
||||||
|
dtype=vllm_config.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent_init_patcher = patch(
|
||||||
|
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
||||||
|
mock_parent_init)
|
||||||
|
self.parent_init_patcher.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.parent_init_patcher.stop()
|
||||||
|
|
||||||
def test_ascend_mla_metadata_builder_default(self):
|
def test_ascend_mla_metadata_builder_default(self):
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
@@ -238,6 +265,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
@@ -274,10 +302,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
|
|
||||||
@@ -314,6 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
|
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
@@ -352,10 +385,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
mock_vllm_config.speculative_config = None
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
@@ -374,10 +409,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
mock_vllm_config.cache_config.block_size = 16
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
mock_vllm_config.speculative_config = None
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
@@ -398,11 +435,34 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
class TestAscendMLAMetadataBuilderBuild(TestBase):
|
class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
# Mock parent class __init__ to avoid complex initialization,
|
||||||
|
# but still set the essential attributes that child class needs
|
||||||
|
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
||||||
|
device, metadata_cls, supports_dcp_with_varlen):
|
||||||
|
self.metadata_cls = metadata_cls
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.device = device
|
||||||
|
self.chunked_prefill_workspace_size = 128 * 1024
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size,
|
||||||
|
vllm_config.model_config.get_head_size()),
|
||||||
|
dtype=vllm_config.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent_init_patcher = patch(
|
||||||
|
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
||||||
|
mock_parent_init)
|
||||||
|
self.parent_init_patcher.start()
|
||||||
|
|
||||||
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||||
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
|
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
|
||||||
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
|
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
|
||||||
mock_scheduler_config.max_num_seqs = 8
|
mock_scheduler_config.max_num_seqs = 8
|
||||||
mock_scheduler_config.chunked_prefill_enabled = True
|
mock_scheduler_config.chunked_prefill_enabled = True
|
||||||
|
mock_scheduler_config.enable_chunked_prefill = True
|
||||||
self.mock_vllm_config.scheduler_config = mock_scheduler_config
|
self.mock_vllm_config.scheduler_config = mock_scheduler_config
|
||||||
self.mock_vllm_config.speculative_config = None
|
self.mock_vllm_config.speculative_config = None
|
||||||
self.mock_device = torch.device("cpu")
|
self.mock_device = torch.device("cpu")
|
||||||
@@ -423,6 +483,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
self.kv_cache_spec.head_size = 64
|
self.kv_cache_spec.head_size = 64
|
||||||
self.kv_cache_spec.num_heads = 32
|
self.kv_cache_spec.num_heads = 32
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.parent_init_patcher.stop()
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
|
|||||||
@@ -99,13 +99,44 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
return_value=self.mock_cfg)
|
return_value=self.mock_cfg)
|
||||||
self.patcher.start()
|
self.patcher.start()
|
||||||
|
|
||||||
|
# Mock parent class __init__ to avoid complex initialization,
|
||||||
|
# but still set the essential attributes that child class needs
|
||||||
|
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
||||||
|
device, metadata_cls, supports_dcp_with_varlen):
|
||||||
|
self.metadata_cls = metadata_cls
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.device = device
|
||||||
|
self.chunked_prefill_workspace_size = 128 * 1024
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size,
|
||||||
|
vllm_config.model_config.get_head_size()),
|
||||||
|
dtype=vllm_config.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent_init_patcher = patch(
|
||||||
|
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
||||||
|
mock_parent_init)
|
||||||
|
self.parent_init_patcher.start()
|
||||||
|
|
||||||
if hasattr(enable_dsa_cp, "cache_clear"):
|
if hasattr(enable_dsa_cp, "cache_clear"):
|
||||||
enable_dsa_cp.cache_clear()
|
enable_dsa_cp.cache_clear()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.patcher.stop()
|
||||||
|
self.parent_init_patcher.stop()
|
||||||
|
|
||||||
def test_ascend_sfa_metadata_builder_default(self):
|
def test_ascend_sfa_metadata_builder_default(self):
|
||||||
kv_cache_spec = MagicMock()
|
kv_cache_spec = MagicMock()
|
||||||
layer_names = ["layer1", "layer2"]
|
layer_names = ["layer1", "layer2"]
|
||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
|
vllm_config.cache_config.block_size = 16
|
||||||
|
vllm_config.model_config.max_model_len = 1024
|
||||||
|
vllm_config.model_config.get_head_size.return_value = 64
|
||||||
|
vllm_config.model_config.dtype = torch.float16
|
||||||
|
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
speculative_config = MagicMock()
|
speculative_config = MagicMock()
|
||||||
speculative_config.num_speculative_tokens = 4
|
speculative_config.num_speculative_tokens = 4
|
||||||
vllm_config.speculative_config = speculative_config
|
vllm_config.speculative_config = speculative_config
|
||||||
@@ -138,6 +169,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
kv_cache_spec = MagicMock()
|
kv_cache_spec = MagicMock()
|
||||||
layer_names = ["layer1", "layer2"]
|
layer_names = ["layer1", "layer2"]
|
||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
|
vllm_config.cache_config.block_size = 16
|
||||||
|
vllm_config.model_config.max_model_len = 1024
|
||||||
|
vllm_config.model_config.get_head_size.return_value = 64
|
||||||
|
vllm_config.model_config.dtype = torch.float16
|
||||||
|
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
speculative_config = MagicMock()
|
speculative_config = MagicMock()
|
||||||
speculative_config.num_speculative_tokens = 4
|
speculative_config.num_speculative_tokens = 4
|
||||||
vllm_config.speculative_config = speculative_config
|
vllm_config.speculative_config = speculative_config
|
||||||
@@ -190,6 +226,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
kv_cache_spec = MagicMock()
|
kv_cache_spec = MagicMock()
|
||||||
layer_names = ["layer1", "layer2"]
|
layer_names = ["layer1", "layer2"]
|
||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
|
vllm_config.cache_config.block_size = 16
|
||||||
|
vllm_config.model_config.max_model_len = 1024
|
||||||
|
vllm_config.model_config.get_head_size.return_value = 64
|
||||||
|
vllm_config.model_config.dtype = torch.float16
|
||||||
|
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||||
speculative_config = MagicMock()
|
speculative_config = MagicMock()
|
||||||
speculative_config.num_speculative_tokens = 4
|
speculative_config.num_speculative_tokens = 4
|
||||||
vllm_config.speculative_config = speculative_config
|
vllm_config.speculative_config = speculative_config
|
||||||
|
|||||||
@@ -19,12 +19,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
AscendPCPMetadata, CPChunkedContextMetadata)
|
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (
|
||||||
enable_cp,
|
AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size,
|
||||||
maybe_save_kv_layer_to_connector,
|
enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills,
|
||||||
split_decodes_and_prefills,
|
trans_rope_weight, transdata, wait_for_kv_layer_from_connector)
|
||||||
trans_rope_weight, transdata,
|
|
||||||
wait_for_kv_layer_from_connector)
|
|
||||||
from vllm_ascend.compilation.acl_graph import (
|
from vllm_ascend.compilation.acl_graph import (
|
||||||
get_draft_graph_params, get_graph_params,
|
get_draft_graph_params, get_graph_params,
|
||||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||||
@@ -215,11 +213,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
metadata_cls: type[AscendMLAMetadata] | None = None,
|
metadata_cls: type[AscendMLAMetadata] | None = None,
|
||||||
supports_dcp_with_varlen: bool = False,
|
supports_dcp_with_varlen: bool = False,
|
||||||
):
|
):
|
||||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
super().__init__(
|
||||||
AscendMLAMetadata)
|
kv_cache_spec, layer_names, vllm_config, device,
|
||||||
self.vllm_config = vllm_config
|
metadata_cls if metadata_cls is not None else AscendMLAMetadata,
|
||||||
self.model_config = vllm_config.model_config
|
supports_dcp_with_varlen)
|
||||||
self.device = device
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||||
@@ -236,29 +234,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
got {self.decode_threshold}"
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
self.reorder_batch_threshold = self.decode_threshold
|
self.reorder_batch_threshold = self.decode_threshold
|
||||||
if self.chunked_prefill_enabled:
|
|
||||||
self.chunked_prefill_workspace_size = min(
|
|
||||||
# Max sure there is enough for 8 full length request or at least
|
|
||||||
# 4 pages of cache per request
|
|
||||||
max(8 * self.model_config.max_model_len,
|
|
||||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
|
||||||
# For long-context models try not to over-allocate limiting
|
|
||||||
# kv-cache space, limiting it to 64k tokens,
|
|
||||||
# which would result in the workspace being:
|
|
||||||
# 2*(576)*(64*1024) = 144mb
|
|
||||||
# (assuming 576 MLA head dim, and fp16)
|
|
||||||
# which would result in up-projected context being
|
|
||||||
# 2*(192*128)*(64*1024) = 3gb
|
|
||||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
|
||||||
128 * 1024)
|
|
||||||
assert self.chunked_prefill_workspace_size >= \
|
|
||||||
scheduler_config.max_num_seqs * self.block_size
|
|
||||||
self.chunked_prefill_workspace = torch.empty(
|
|
||||||
(self.chunked_prefill_workspace_size,
|
|
||||||
self.model_config.get_head_size()),
|
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.cos_cache = None
|
self.cos_cache = None
|
||||||
self.sin_cache = None
|
self.sin_cache = None
|
||||||
@@ -280,6 +256,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
self.seq_lens: torch.Tensor = None
|
self.seq_lens: torch.Tensor = None
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def determine_chunked_prefill_workspace_size(
|
||||||
|
vllm_config: VllmConfig) -> int:
|
||||||
|
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
cls: type["AscendMLAMetadataBuilder"],
|
cls: type["AscendMLAMetadataBuilder"],
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
ascend_chunked_prefill_workspace_size,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
trans_rope_weight, transdata,
|
trans_rope_weight, transdata,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
@@ -131,7 +132,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
understand this class
|
understand this class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# _attn_mask_builder = None
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
@@ -141,11 +141,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
metadata_cls: type[AscendSFAMetadata] | None = None,
|
metadata_cls: type[AscendSFAMetadata] | None = None,
|
||||||
supports_dcp_with_varlen: bool = False,
|
supports_dcp_with_varlen: bool = False,
|
||||||
):
|
):
|
||||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
super().__init__(
|
||||||
AscendSFAMetadata)
|
kv_cache_spec, layer_names, vllm_config, device,
|
||||||
self.vllm_config = vllm_config
|
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
|
||||||
self.model_config = vllm_config.model_config
|
supports_dcp_with_varlen)
|
||||||
self.device = device
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||||
self.block_size - 1) // self.block_size
|
self.block_size - 1) // self.block_size
|
||||||
@@ -169,6 +169,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def determine_chunked_prefill_workspace_size(
|
||||||
|
vllm_config: VllmConfig) -> int:
|
||||||
|
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
cls: type["AscendSFAMetadataBuilder"],
|
cls: type["AscendSFAMetadataBuilder"],
|
||||||
|
|||||||
@@ -12,6 +12,34 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
|
||||||
|
|
||||||
|
|
||||||
|
def ascend_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
|
chunked_prefill_workspace_size = min(
|
||||||
|
# Make sure there is enough for 8 full length request or at least
|
||||||
|
# 4 pages of cache per request
|
||||||
|
max(8 * model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size),
|
||||||
|
# For long-context models try not to over-allocate limiting
|
||||||
|
# kv-cache space, limiting it to 128k tokens,
|
||||||
|
# which would result in the workspace being:
|
||||||
|
# 2*(576)*(128*1024) = 288mb
|
||||||
|
# (assuming 576 MLA head dim, and fp16)
|
||||||
|
# which would result in up-projected context being
|
||||||
|
# 2*(192*128)*(128*1024) = 6gb
|
||||||
|
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||||
|
128 * 1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked_prefill_workspace_size = max(
|
||||||
|
chunked_prefill_workspace_size,
|
||||||
|
scheduler_config.max_num_seqs * cache_config.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunked_prefill_workspace_size
|
||||||
|
|
||||||
|
|
||||||
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
||||||
if vllm_config.speculative_config is not None:
|
if vllm_config.speculative_config is not None:
|
||||||
return False
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user