[Refactor] AttentionBuilder inherit from base class in vllm (#5916)

### What this PR does / why we need it?

This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder`
properly inherit from the base class `MLACommonMetadataBuilder` in vllm
by adding `super().__init__()` calls.

**Changes:**
- Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()`
- Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()`
- Extract `ascend_chunked_prefill_workspace_size()` to
`vllm_ascend/attention/utils.py` to avoid code duplication
- Override `determine_chunked_prefill_workspace_size()` to support
Ascend-specific 128k tokens workspace size (vs 64k in parent class)
- Update unit tests to mock parent class `__init__` for proper isolation

**Why we need it:**
- Follow proper Python inheritance patterns by calling
`super().__init__()`
- Reduce code duplication by reusing parent class initialization logic
- Better maintainability as parent class changes will be automatically
inherited

Part of issue #5463 item 10

### Does this PR introduce _any_ user-facing change?

No, this is an internal refactoring that does not change any user-facing
behavior.

Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
LICO67373
2026-01-21 10:45:45 +08:00
committed by GitHub
parent 839e03cbc9
commit 12a668b1d9
5 changed files with 158 additions and 40 deletions

View File

@@ -208,11 +208,38 @@ class TestAscendMLAMetadata(TestBase):
class TestAscendMLAMetadataBuilder(TestBase): class TestAscendMLAMetadataBuilder(TestBase):
def setUp(self):
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)
self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()
def tearDown(self):
self.parent_init_patcher.stop()
def test_ascend_mla_metadata_builder_default(self): def test_ascend_mla_metadata_builder_default(self):
mock_vllm_config = MagicMock() mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
@@ -238,6 +265,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
@@ -274,10 +302,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu' mock_device = 'cpu'
torch.Tensor.pin_memory = lambda x: x # noqa torch.Tensor.pin_memory = lambda x: x # noqa
@@ -314,6 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config = MagicMock() mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
@@ -352,10 +385,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu' mock_device = 'cpu'
mock_vllm_config.speculative_config = None mock_vllm_config.speculative_config = None
@@ -374,10 +409,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu' mock_device = 'cpu'
mock_vllm_config.speculative_config = None mock_vllm_config.speculative_config = None
@@ -398,11 +435,34 @@ class TestAscendMLAMetadataBuilder(TestBase):
class TestAscendMLAMetadataBuilderBuild(TestBase): class TestAscendMLAMetadataBuilderBuild(TestBase):
def setUp(self): def setUp(self):
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)
self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()
self.mock_vllm_config = MagicMock(spec=VllmConfig) self.mock_vllm_config = MagicMock(spec=VllmConfig)
self.mock_vllm_config.cache_config = CacheConfig(block_size=32) self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
mock_scheduler_config = MagicMock(spec=SchedulerConfig) mock_scheduler_config = MagicMock(spec=SchedulerConfig)
mock_scheduler_config.max_num_seqs = 8 mock_scheduler_config.max_num_seqs = 8
mock_scheduler_config.chunked_prefill_enabled = True mock_scheduler_config.chunked_prefill_enabled = True
mock_scheduler_config.enable_chunked_prefill = True
self.mock_vllm_config.scheduler_config = mock_scheduler_config self.mock_vllm_config.scheduler_config = mock_scheduler_config
self.mock_vllm_config.speculative_config = None self.mock_vllm_config.speculative_config = None
self.mock_device = torch.device("cpu") self.mock_device = torch.device("cpu")
@@ -423,6 +483,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.kv_cache_spec.head_size = 64 self.kv_cache_spec.head_size = 64
self.kv_cache_spec.num_heads = 32 self.kv_cache_spec.num_heads = 32
def tearDown(self):
self.parent_init_patcher.stop()
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group') @patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group')

View File

@@ -99,13 +99,44 @@ class TestAscendSFAMetadataBuilder(TestBase):
return_value=self.mock_cfg) return_value=self.mock_cfg)
self.patcher.start() self.patcher.start()
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)
self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()
if hasattr(enable_dsa_cp, "cache_clear"): if hasattr(enable_dsa_cp, "cache_clear"):
enable_dsa_cp.cache_clear() enable_dsa_cp.cache_clear()
def tearDown(self):
self.patcher.stop()
self.parent_init_patcher.stop()
def test_ascend_sfa_metadata_builder_default(self): def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock() kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"] layer_names = ["layer1", "layer2"]
vllm_config = MagicMock() vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock() speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4 speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config vllm_config.speculative_config = speculative_config
@@ -138,6 +169,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
kv_cache_spec = MagicMock() kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"] layer_names = ["layer1", "layer2"]
vllm_config = MagicMock() vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock() speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4 speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config vllm_config.speculative_config = speculative_config
@@ -190,6 +226,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
kv_cache_spec = MagicMock() kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"] layer_names = ["layer1", "layer2"]
vllm_config = MagicMock() vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock() speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4 speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config vllm_config.speculative_config = speculative_config

View File

@@ -19,12 +19,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import ( from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata) AscendPCPMetadata, CPChunkedContextMetadata)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (
enable_cp, AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size,
maybe_save_kv_layer_to_connector, enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills,
split_decodes_and_prefills, trans_rope_weight, transdata, wait_for_kv_layer_from_connector)
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import ( from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params, get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces) update_draft_graph_params_workspaces, update_graph_params_workspaces)
@@ -215,11 +213,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
metadata_cls: type[AscendMLAMetadata] | None = None, metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False, supports_dcp_with_varlen: bool = False,
): ):
self.metadata_cls = (metadata_cls if metadata_cls is not None else super().__init__(
AscendMLAMetadata) kv_cache_spec, layer_names, vllm_config, device,
self.vllm_config = vllm_config metadata_cls if metadata_cls is not None else AscendMLAMetadata,
self.model_config = vllm_config.model_config supports_dcp_with_varlen)
self.device = device
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len + self.max_blocks = (vllm_config.model_config.max_model_len +
@@ -236,29 +234,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
got {self.decode_threshold}" got {self.decode_threshold}"
self.reorder_batch_threshold = self.decode_threshold self.reorder_batch_threshold = self.decode_threshold
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None self.cos_cache = None
self.sin_cache = None self.sin_cache = None
@@ -280,6 +256,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.seq_lens: torch.Tensor = None self.seq_lens: torch.Tensor = None
self.attn_mask_builder = AttentionMaskBuilder(self.device) self.attn_mask_builder = AttentionMaskBuilder(self.device)
@staticmethod
def determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@classmethod @classmethod
def get_cudagraph_support( def get_cudagraph_support(
cls: type["AscendMLAMetadataBuilder"], cls: type["AscendMLAMetadataBuilder"],

View File

@@ -20,6 +20,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
ascend_chunked_prefill_workspace_size,
maybe_save_kv_layer_to_connector, maybe_save_kv_layer_to_connector,
trans_rope_weight, transdata, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
@@ -131,7 +132,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
understand this class understand this class
""" """
# _attn_mask_builder = None
def __init__( def __init__(
self, self,
kv_cache_spec, kv_cache_spec,
@@ -141,11 +141,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
metadata_cls: type[AscendSFAMetadata] | None = None, metadata_cls: type[AscendSFAMetadata] | None = None,
supports_dcp_with_varlen: bool = False, supports_dcp_with_varlen: bool = False,
): ):
self.metadata_cls = (metadata_cls if metadata_cls is not None else super().__init__(
AscendSFAMetadata) kv_cache_spec, layer_names, vllm_config, device,
self.vllm_config = vllm_config metadata_cls if metadata_cls is not None else AscendSFAMetadata,
self.model_config = vllm_config.model_config supports_dcp_with_varlen)
self.device = device
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len + self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size self.block_size - 1) // self.block_size
@@ -169,6 +169,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device) self.attn_mask_builder = AttentionMaskBuilder(self.device)
@staticmethod
def determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@classmethod @classmethod
def get_cudagraph_support( def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"], cls: type["AscendSFAMetadataBuilder"],

View File

@@ -12,6 +12,34 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
def ascend_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
chunked_prefill_workspace_size = min(
# Make sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 128k tokens,
# which would result in the workspace being:
# 2*(576)*(128*1024) = 288mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(128*1024) = 6gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024,
)
chunked_prefill_workspace_size = max(
chunked_prefill_workspace_size,
scheduler_config.max_num_seqs * cache_config.block_size,
)
return chunked_prefill_workspace_size
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
if vllm_config.speculative_config is not None: if vllm_config.speculative_config is not None:
return False return False