[Bugfix] Use hf_text_config instead of hf_config to support multimodal PD-Disaggregated (#5205)
### What this PR does / why we need it?
In code files such as`mooncake_connector.py`,
`vllm_config.model_config.hf_config` is used to get the LLM configs.
This approach works for LLMs, but not for multi-modal models. For
multi-modal models, `vllm_config.model_config.hf_text_config` must be
used instead to get the LLM configs.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UT
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: ApsarasX <apsarax@outlook.com>
This commit is contained in:
@@ -1195,7 +1195,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
"prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size},
|
"prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size},
|
||||||
"decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1}
|
"decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1}
|
||||||
}.get(k, d)):
|
}.get(k, d)):
|
||||||
self.vllm_config.model_config.hf_config.num_key_value_heads = num_kv_heads
|
self.vllm_config.model_config.hf_text_config.num_key_value_heads = num_kv_heads
|
||||||
self.vllm_config.model_config.is_deepseek_mla = is_deepseek_mla
|
self.vllm_config.model_config.is_deepseek_mla = is_deepseek_mla
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config,
|
worker = MooncakeConnectorWorker(self.vllm_config,
|
||||||
self.engine_id)
|
self.engine_id)
|
||||||
|
|||||||
@@ -319,14 +319,14 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
self.block_size = self.vllm_config.cache_config.block_size
|
self.block_size = self.vllm_config.cache_config.block_size
|
||||||
if not is_vl_model(vllm_config):
|
if not is_vl_model(vllm_config):
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
|
self.k_head_dim = self.model_config.hf_text_config.kv_lora_rank
|
||||||
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
|
self.v_head_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.num_kv_heads = 1
|
self.num_kv_heads = 1
|
||||||
else:
|
else:
|
||||||
self.k_head_dim = self.model_config.hf_config.head_dim
|
self.k_head_dim = self.model_config.hf_text_config.head_dim
|
||||||
self.v_head_dim = self.model_config.hf_config.head_dim
|
self.v_head_dim = self.model_config.hf_text_config.head_dim
|
||||||
self.num_kv_heads = max(
|
self.num_kv_heads = max(
|
||||||
self.model_config.hf_config.num_key_value_heads //
|
self.model_config.hf_text_config.num_key_value_heads //
|
||||||
self.tp_size, 1)
|
self.tp_size, 1)
|
||||||
|
|
||||||
def add_request(self, request_id: str, local_block_ids: list[int],
|
def add_request(self, request_id: str, local_block_ids: list[int],
|
||||||
@@ -438,7 +438,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
|
|
||||||
remote_kv_caches_base_addrs = \
|
remote_kv_caches_base_addrs = \
|
||||||
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
|
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
|
||||||
num_layers = self.model_config.hf_config.num_hidden_layers
|
num_layers = self.model_config.hf_text_config.num_hidden_layers
|
||||||
first_layer_index, end_layer_index = get_pp_indices(
|
first_layer_index, end_layer_index = get_pp_indices(
|
||||||
num_layers, prefill_pp_rank, self._prefill_pp_size)
|
num_layers, prefill_pp_rank, self._prefill_pp_size)
|
||||||
num_cache_per_layer = len(list(
|
num_cache_per_layer = len(list(
|
||||||
@@ -503,10 +503,11 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
k_cache = list(self.kv_caches.values())[0][0]
|
k_cache = list(self.kv_caches.values())[0][0]
|
||||||
dtype = k_cache.dtype
|
dtype = k_cache.dtype
|
||||||
device = k_cache.device
|
device = k_cache.device
|
||||||
head_dim = self.model_config.hf_config.head_dim
|
head_dim = self.model_config.hf_text_config.head_dim
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
num_kv_head = max(
|
num_kv_head = max(
|
||||||
self.model_config.hf_config.num_key_value_heads // self.tp_size, 1)
|
self.model_config.hf_text_config.num_key_value_heads //
|
||||||
|
self.tp_size, 1)
|
||||||
|
|
||||||
flat_block_ids = [item for sublist in block_ids for item in sublist]
|
flat_block_ids = [item for sublist in block_ids for item in sublist]
|
||||||
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
|
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
|
||||||
@@ -1010,7 +1011,7 @@ class MooncakeConnectorWorker:
|
|||||||
|
|
||||||
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size
|
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
self.num_key_value_heads = self.vllm_config.model_config.hf_text_config.num_key_value_heads
|
||||||
|
|
||||||
# Handshake base port
|
# Handshake base port
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
|
|||||||
Reference in New Issue
Block a user