[Disaggregated Prefill] P2P Disaggregated Prefill based on llm_datadist (#694)
### What this PR does / why we need it? - This PR proposes a P2P version of Disaggregated Prefill based on llm_datadist which manages data transfer. - This solution reconstructs previous offline single-node Disaggregated Prefill solution, and supports multi-node and online serveing now. - Currently this solution supports 1P1D situation of Deepseek hybrid parallelism (P: TP+EP, D: DP+EP). Note that xPyD situation is considered in the solution design, and will be supported soon within v1 engine. --------- Signed-off-by: hw_whx <wanghexiang7@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: hw_whx <wanghexiang7@huawei.com> Co-authored-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -1,6 +1,27 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
|
||||
"LLMDataDistConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendSimpleConnector",
|
||||
"vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector")
|
||||
|
||||
0
vllm_ascend/distributed/kv_transfer/__init__.py
Normal file
0
vllm_ascend/distributed/kv_transfer/__init__.py
Normal file
209
vllm_ascend/distributed/kv_transfer/simple_buffer.py
Normal file
209
vllm_ascend/distributed/kv_transfer/simple_buffer.py
Normal file
@@ -0,0 +1,209 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import zlib
|
||||
from typing import List, Optional
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
|
||||
KVLookupBufferBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
from vllm_ascend.distributed.kv_transfer.utils import TORCH_DTYPE_TO_NPU_DTYPE
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Hash a string into a int32 value.
|
||||
def int32_hash(data):
|
||||
assert isinstance(data, str)
|
||||
data = data.encode("utf-8")
|
||||
return zlib.adler32(data)
|
||||
|
||||
|
||||
class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
def __init__(self, data_pipe: SimplePipe):
|
||||
self.data_pipe = data_pipe
|
||||
# Consumer buffer need these information to construct receiving buffer.
|
||||
self.num_layers = None
|
||||
self.num_heads = None
|
||||
self.head_size = None
|
||||
self.dtype = None
|
||||
self.hidden_size = None
|
||||
self.key_buffer = None
|
||||
self.value_buffer = None
|
||||
self.hidden_buffer = None
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
req_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
seq_len: num_tokens of current request.
|
||||
input_tokens: [seq_len]
|
||||
roi: [seq_len]
|
||||
key: [num_layers, seq_len, num_kv_heads, head_size]
|
||||
value: [num_layers, seq_len, num_kv_heads, head_size]
|
||||
hidden: [seq_len, hidden_size]
|
||||
"""
|
||||
orig_k_shape = key.shape
|
||||
num_layers = orig_k_shape[0]
|
||||
|
||||
# unsequeeze all tensors to make first dim to 1.
|
||||
# This is because D node can only pull one batch data from P.
|
||||
# So we make first dim to 1 here in order to pull full data.
|
||||
key = key.view(num_layers, -1).unsqueeze(0)
|
||||
value = value.view(num_layers, -1).unsqueeze(0)
|
||||
hidden = hidden.unsqueeze(0)
|
||||
|
||||
hidden_dtype = key.dtype
|
||||
# initialize LLMDatadist data structure
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
key.shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
|
||||
seq_len_dim_index=1,
|
||||
)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
value.shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
|
||||
seq_len_dim_index=1,
|
||||
)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden.shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
|
||||
req_id = int32_hash(req_id)
|
||||
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 1)
|
||||
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 2)
|
||||
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 3)
|
||||
|
||||
# Currently we use hash value of request id as key, so no need to send input_tokens
|
||||
self.key_buffer = self.data_pipe.send_tensor(key, key_desc,
|
||||
key_cache_key)
|
||||
self.value_buffer = self.data_pipe.send_tensor(value, value_desc,
|
||||
value_cache_key)
|
||||
self.hidden_buffer = self.data_pipe.send_tensor(
|
||||
hidden, hidden_desc, hidden_cache_key)
|
||||
|
||||
def drop_select(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: Optional[torch.Tensor],
|
||||
req_id: str,
|
||||
) -> List[Optional[torch.Tensor]]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
```
|
||||
ret = buffer.pop(input_tokens, roi)
|
||||
return ret
|
||||
```
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
A list of tensors including:
|
||||
key: [num_layers, num_tokens, num_heads, head_size]
|
||||
value: [num_layers, num_tokens, num_heads, head_size]
|
||||
hidden_or_intermediate_states: [num_tokens, hidden_size]
|
||||
roi: None (Currently we don't supported roi)
|
||||
"""
|
||||
orig_req_id = req_id
|
||||
req_id = int32_hash(req_id)
|
||||
num_tokens = input_tokens.shape[0]
|
||||
kv_shape = (
|
||||
1,
|
||||
self.num_layers,
|
||||
num_tokens * self.num_heads * self.head_size,
|
||||
)
|
||||
hidden_shape = (1, num_tokens, self.hidden_size)
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
|
||||
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 1)
|
||||
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 2)
|
||||
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 3)
|
||||
|
||||
# Deallocate buffer allocated in last round.
|
||||
if self.key_buffer:
|
||||
try:
|
||||
self.data_pipe.deallocate_buffer(self.key_buffer)
|
||||
self.data_pipe.deallocate_buffer(self.value_buffer)
|
||||
self.data_pipe.deallocate_buffer(self.hidden_buffer)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to free kv cache buffer, Error code: {str(e)}")
|
||||
|
||||
try:
|
||||
self.key_buffer, key = self.data_pipe.recv_tensor(
|
||||
key_desc, key_cache_key)
|
||||
self.value_buffer, value = self.data_pipe.recv_tensor(
|
||||
value_desc, value_cache_key)
|
||||
self.hidden_buffer, hidden = self.data_pipe.recv_tensor(
|
||||
hidden_desc, hidden_cache_key)
|
||||
key = key.view(self.num_layers, num_tokens, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.view(self.num_layers, num_tokens, self.num_heads,
|
||||
self.head_size)
|
||||
hidden = hidden.view(num_tokens, self.hidden_size)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Faile to receive kv cache and hidden states of request: {orig_req_id} "
|
||||
f"Error is {str(e)}")
|
||||
return [None, None, None, None]
|
||||
|
||||
return [key, value, hidden, roi]
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
376
vllm_ascend/distributed/kv_transfer/simple_connector.py
Normal file
376
vllm_ascend/distributed/kv_transfer/simple_connector.py
Normal file
@@ -0,0 +1,376 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as vllm_envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class SimpleConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.model_config = config.model_config.hf_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE
|
||||
self.n_layer = self.config.model_config.get_num_layers(
|
||||
self.config.parallel_config)
|
||||
|
||||
self.producer_data_pipe: Optional[SimplePipe]
|
||||
self.consumer_data_pipe: Optional[SimplePipe]
|
||||
|
||||
self.producer_buffer: Optional[SimpleBuffer]
|
||||
self.consumer_buffer: Optional[SimpleBuffer]
|
||||
|
||||
if self.config.kv_transfer_config.is_kv_producer:
|
||||
self.producer_data_pipe = SimplePipe(
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
kv_transfer_config=config.kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=rank,
|
||||
)
|
||||
self.producer_buffer = SimpleBuffer(self.producer_data_pipe)
|
||||
else:
|
||||
self.consumer_data_pipe = SimplePipe(
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
kv_transfer_config=config.kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=rank,
|
||||
)
|
||||
self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe)
|
||||
|
||||
def select(
|
||||
self,
|
||||
input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor],
|
||||
req_id: str,
|
||||
) -> List[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.consumer_buffer is not None, (
|
||||
"Please initialize the "
|
||||
"consumer buffer before calling select.")
|
||||
return self.consumer_buffer.drop_select(input_tokens, roi, req_id)
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
req_id: str,
|
||||
) -> None:
|
||||
|
||||
assert self.producer_buffer is not None, (
|
||||
"Please initialize the "
|
||||
"producer buffer before calling insert.")
|
||||
self.producer_buffer.insert(input_tokens, roi, keys, values, hidden,
|
||||
req_id)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = self.model_config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = (model_config.kv_lora_rank +
|
||||
model_config.qk_rope_head_dim)
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = (model_config.qk_nope_head_dim +
|
||||
model_config.qk_rope_head_dim)
|
||||
else:
|
||||
head_size = getattr(
|
||||
model_config,
|
||||
"head_dim",
|
||||
int(hidden_size // num_attention_heads),
|
||||
)
|
||||
# Enumerate over all requests and insert them one by one.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
# shape: [num_layers, num_tokens, num_heads, head_size]
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
|
||||
# Currently we haven't considered situation of roi, pass None here.
|
||||
self.insert(
|
||||
current_tokens,
|
||||
None,
|
||||
keys,
|
||||
values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos],
|
||||
cur_req_id,
|
||||
)
|
||||
|
||||
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata", ]:
|
||||
bypass_model_exec = True
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
# get model config
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
num_heads, head_dim = kv_caches[0].shape[-2:]
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
num_layers = end_layer - start_layer
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = (model_config.kv_lora_rank +
|
||||
model_config.qk_rope_head_dim)
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = (model_config.qk_nope_head_dim +
|
||||
model_config.qk_rope_head_dim)
|
||||
else:
|
||||
head_size = getattr(
|
||||
model_config,
|
||||
"head_dim",
|
||||
int(hidden_size // num_attention_heads),
|
||||
)
|
||||
self.consumer_buffer.num_heads = num_heads # type: ignore
|
||||
self.consumer_buffer.num_layers = num_layers # type: ignore
|
||||
self.consumer_buffer.head_size = head_size # type: ignore
|
||||
self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore
|
||||
self.consumer_buffer.hidden_size = hidden_size # type: ignore
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
|
||||
total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# enumerate different requests
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
|
||||
|
||||
ret = self.select(
|
||||
current_tokens,
|
||||
torch.ones_like(current_tokens, dtype=bool),
|
||||
cur_req_id,
|
||||
)
|
||||
if ret[0] is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
num_computed_tokens_list.append(0)
|
||||
continue
|
||||
|
||||
keys: torch.Tensor = ret[0]
|
||||
values: torch.Tensor = ret[1]
|
||||
hidden: torch.Tensor = ret[2]
|
||||
|
||||
num_computed_tokens = keys.shape[1]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(
|
||||
model_executable.model.start_layer,
|
||||
model_executable.model.end_layer,
|
||||
):
|
||||
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
key_cache = kv_cache
|
||||
slots = slot_mapping[start_pos:end_pos]
|
||||
sliced_key = keys[i - model_executable.model.start_layer]
|
||||
torch_npu._npu_reshape_and_cache_siso(key=sliced_key,
|
||||
key_cache=key_cache,
|
||||
slot_indices=slots)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
sliced_key = keys[i - model_executable.model.start_layer]
|
||||
sliced_value = values[i -
|
||||
model_executable.model.start_layer]
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=sliced_key,
|
||||
value=sliced_value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slot_mapping[start_pos:end_pos],
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
if get_dp_group().world_size > 1:
|
||||
bypass_model_exec = True
|
||||
hidden_or_intermediate_states = torch.empty(
|
||||
[total_tokens, hidden_size],
|
||||
dtype=kv_caches[0].dtype,
|
||||
device=kv_caches[0].device)
|
||||
logger.warning(
|
||||
"[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.",
|
||||
torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.",
|
||||
torch.distributed.get_rank(),
|
||||
)
|
||||
# Can't directly concat here which might cause error when bs = 1.
|
||||
# hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device)
|
||||
if len(hidden_or_intermediate_states_for_one_req) == 1:
|
||||
hidden = hidden_or_intermediate_states_for_one_req[0]
|
||||
tmp_indice = torch.tensor([0] * hidden.shape[0],
|
||||
dtype=torch.int64).npu()
|
||||
hidden_or_intermediate_states = torch.empty_like(hidden)
|
||||
torch_npu.scatter_update_(
|
||||
hidden_or_intermediate_states,
|
||||
tmp_indice,
|
||||
hidden,
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.producer_data_pipe.close() # type: ignore
|
||||
self.consumer_data_pipe.close() # type: ignore
|
||||
self.producer_buffer.close() # type: ignore
|
||||
self.consumer_buffer.close() # type: ignore
|
||||
209
vllm_ascend/distributed/kv_transfer/simple_pipe.py
Normal file
209
vllm_ascend/distributed/kv_transfer/simple_pipe.py
Normal file
@@ -0,0 +1,209 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgpack # type: ignore
|
||||
import torch
|
||||
import torch_npu
|
||||
import torchair # type: ignore
|
||||
import zmq # type: ignore
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_ip
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.distributed.kv_transfer.utils import NPU_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimplePipe(KVPipeBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank,
|
||||
local_rank,
|
||||
kv_transfer_config,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0, # NPU offset in current P/D instance.
|
||||
):
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
# Currently for 1P1D situation, we use cluster_id=0 for both Prefill and Decode
|
||||
# Will change here in the future to support xPyD.
|
||||
self.cluster_id = 0
|
||||
self.config = kv_transfer_config
|
||||
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
kv_role = kv_transfer_config.kv_role
|
||||
if kv_role == "kv_producer":
|
||||
self.role = llm_datadist.LLMRole.PROMPT
|
||||
elif kv_role == "kv_consumer":
|
||||
self.role = llm_datadist.LLMRole.DECODER
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"kv_role should be inside [kv_producer, kv_consumer]")
|
||||
|
||||
prompt_device_ips = kv_connector_extra_config.get(
|
||||
"prompt_device_ips", None)
|
||||
decode_device_ips = kv_connector_extra_config.get(
|
||||
"decode_device_ips", None)
|
||||
if prompt_device_ips is None or decode_device_ips is None:
|
||||
raise ValueError(
|
||||
"Please specify prompt_device_ips and decode_device_ips"
|
||||
"in kv_transfer_config.kv_connector_extra_config")
|
||||
p_device_num = len(prompt_device_ips)
|
||||
d_device_num = len(decode_device_ips)
|
||||
# When number of devices in P and D is not equal,
|
||||
# we assume that device in D can be mapped to any device in P.
|
||||
self.p_device_rank = self.rank % p_device_num
|
||||
self.d_device_rank = self.rank % d_device_num
|
||||
|
||||
self.prompt_ip_list = prompt_device_ips
|
||||
self.decode_ip_list = decode_device_ips
|
||||
self.llmdatadist_comm_port = kv_connector_extra_config.get(
|
||||
"llmdatadist_comm_port", 26000)
|
||||
# LLMDataDist initializing.
|
||||
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
|
||||
self._prepare_data_dist()
|
||||
# Decoder needs to initialize and link cluster
|
||||
if self.role == llm_datadist.LLMRole.DECODER:
|
||||
self.cluster = self._make_cluster()
|
||||
_, ret = self.data_dist.link_clusters([self.cluster], 20000)
|
||||
logger.info(
|
||||
f"rank {self.rank}, local_rank {self.local_rank} link, ret={ret}"
|
||||
)
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
|
||||
self._register_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
# Initialize zmq socket and register to proxy.
|
||||
# Note that only NPU 0 of each P/D instance register to proxy.
|
||||
if not hostname:
|
||||
hostname = get_ip() # Get ip of current host.
|
||||
port = kv_transfer_config.kv_port + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
self.context = zmq.Context() # type: ignore
|
||||
self.router_socket = self.context.socket(
|
||||
zmq.ROUTER) # type: ignore
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
# The `http_port` must be consistent with the serving port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
self._register_thread = threading.Thread(
|
||||
target=self._register_to_proxy, daemon=True)
|
||||
self._register_thread.start()
|
||||
|
||||
def _prepare_data_dist(self):
|
||||
options = {
|
||||
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
|
||||
}
|
||||
if self.role == llm_datadist.LLMRole.PROMPT:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
options["llm.listenIpInfo"] = (
|
||||
f"{self.prompt_ip_list[self.p_device_rank]}:{self.llmdatadist_comm_port}"
|
||||
)
|
||||
else:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
print(f"prepare datadist, options: {options}")
|
||||
self.data_dist.init(options)
|
||||
self.kv_transfer = self.data_dist.kv_cache_manager
|
||||
print(f"{self.rank} rank data dist is ready")
|
||||
|
||||
def _make_cluster(self):
|
||||
cluster = llm_datadist.LLMClusterInfo()
|
||||
cluster.remote_cluster_id = self.cluster_id
|
||||
local_ip = self.decode_ip_list[self.d_device_rank]
|
||||
remote_ip = self.prompt_ip_list[self.p_device_rank]
|
||||
cluster.append_local_ip_info(local_ip, 0)
|
||||
cluster.append_remote_ip_info(remote_ip, self.llmdatadist_comm_port)
|
||||
return cluster
|
||||
|
||||
def _register_to_proxy(self):
|
||||
sock = self.context.socket(zmq.DEALER) # type: ignore
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) # type: ignore
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address,
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor: Optional[torch.Tensor],
|
||||
tensor_desc: llm_datadist.CacheDesc,
|
||||
tensor_key: llm_datadist.CacheKey,
|
||||
) -> llm_datadist.Cache:
|
||||
buffer = self.kv_transfer.allocate_cache(tensor_desc, [tensor_key])
|
||||
buffer_addr = buffer.per_device_tensor_addrs[0]
|
||||
data_tensor = torchair.llm_datadist.create_npu_tensors(
|
||||
tensor_desc.shape, tensor.dtype, buffer_addr)[0] # type: ignore
|
||||
update_indices = torch.tensor(
|
||||
[0] * tensor.shape[0], # type: ignore
|
||||
dtype=torch.int64).npu()
|
||||
torch_npu.scatter_update_(data_tensor, update_indices, tensor, axis=-1)
|
||||
# Free cache_id of buffer, actual deallocate will happen after consumer performing pull_cache.
|
||||
self.kv_transfer.deallocate_cache(buffer)
|
||||
return buffer
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_desc: llm_datadist.CacheDesc,
|
||||
tensor_key: llm_datadist.CacheKey,
|
||||
) -> llm_datadist.Cache:
|
||||
"""Note that this function only creates empty tensor on buffer addr and returns it."""
|
||||
tmp_buffer = self.kv_transfer.allocate_cache(tensor_desc)
|
||||
buffer_addr = tmp_buffer.per_device_tensor_addrs[0]
|
||||
data_tensor = torchair.llm_datadist.create_npu_tensors(
|
||||
tensor_desc.shape,
|
||||
NPU_DTYPE_TO_TORCH_DTYPE[tensor_desc.data_type],
|
||||
buffer_addr,
|
||||
)[0]
|
||||
self.kv_transfer.pull_cache(tensor_key, tmp_buffer, 0)
|
||||
# tmp_buffer is allocated without key and will be deallocated here immediately.
|
||||
# Free buffer here will cause accuracy problem.
|
||||
# self.kv_transfer.deallocate_cache(tmp_buffer)
|
||||
return tmp_buffer, data_tensor
|
||||
|
||||
def deallocate_buffer(self, buffer: llm_datadist.Cache):
|
||||
self.kv_transfer.deallocate_cache(buffer)
|
||||
|
||||
def close(self):
|
||||
self.data_dist.unlink_clusters([self.cluster], 5000)
|
||||
40
vllm_ascend/distributed/kv_transfer/utils.py
Normal file
40
vllm_ascend/distributed/kv_transfer/utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import llm_datadist # type: ignore
|
||||
import torch
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32,
|
||||
}
|
||||
|
||||
NPU_DTYPE_TO_TORCH_DTYPE = {
|
||||
llm_datadist.DataType.DT_FLOAT16: torch.half,
|
||||
llm_datadist.DataType.DT_FLOAT16: torch.float16,
|
||||
llm_datadist.DataType.DT_BF16: torch.bfloat16,
|
||||
llm_datadist.DataType.DT_FLOAT: torch.float,
|
||||
llm_datadist.DataType.DT_FLOAT: torch.float32,
|
||||
llm_datadist.DataType.DT_INT8: torch.int8,
|
||||
llm_datadist.DataType.DT_INT64: torch.int64,
|
||||
llm_datadist.DataType.DT_INT32: torch.int32,
|
||||
}
|
||||
@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed import get_dp_group, get_pp_group
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
@@ -1343,6 +1343,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
kv_caches=kv_caches
|
||||
)
|
||||
|
||||
if get_dp_group().world_size > 1:
|
||||
bypass_model_exec_tensor = torch.tensor(
|
||||
1, dtype=torch.int32) if bypass_model_exec else torch.tensor(
|
||||
0, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(bypass_model_exec_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=get_dp_group().cpu_group)
|
||||
# If there is any group have not receive the necessary hidden states or kv_cache, we force all the dp group execute.
|
||||
if bypass_model_exec_tensor.item() == 0:
|
||||
bypass_model_exec = False
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
@@ -1399,10 +1410,21 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
torch.tensor(model_forward_time +
|
||||
orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
# TODO: remove the synchronize here
|
||||
torch.npu.synchronize()
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Sending KV cache in distributed KV cache transfer setting
|
||||
if self.need_send_kv(model_input, kv_caches):
|
||||
get_kv_transfer_group().send_kv_caches_and_hidden_states(
|
||||
# model_executable is used to know which layer the current
|
||||
# worker is working on, so that we can send KV for only those
|
||||
# layers.
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
hidden_or_intermediate_states,
|
||||
)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
@@ -18,10 +18,13 @@
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import msgpack # type: ignore
|
||||
import torch
|
||||
import torch.distributed
|
||||
import zmq
|
||||
from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -37,7 +40,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import GiB_bytes, bind_kv_cache
|
||||
from vllm.utils import GiB_bytes, bind_kv_cache, get_ip
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
@@ -157,6 +160,33 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
self.enable_dummy_run = False
|
||||
if os.getenv("VLLM_DP_PROXY_IP", None):
|
||||
logger.warning("enable dummy run for the DP")
|
||||
self.enable_dummy_run = True
|
||||
# dp_rank = os.environ["VLLM_DP_RANK"]
|
||||
dp_master_ip = os.environ["VLLM_DP_PROXY_IP"]
|
||||
dp_proxy_listener_port = os.environ["VLLM_DP_PROXY_PORT"]
|
||||
dp_proxy_monitor_port = os.environ["VLLM_DP_MONITOR_PORT"]
|
||||
dp_proxy_listener_addr = f"{dp_master_ip}:{dp_proxy_listener_port}"
|
||||
self.dp_proxy_monitor_addr = f"{dp_master_ip}:{dp_proxy_monitor_port}"
|
||||
http_ip = get_ip()
|
||||
port = os.environ["VLLM_HTTP_PORT"]
|
||||
self.http_addr = f"{http_ip}:{port}"
|
||||
context = zmq.Context() # type: ignore
|
||||
sock = context.socket(zmq.DEALER) # type: ignore
|
||||
|
||||
logger.debug("ping dp proxy start, DP_RANK:%s", 0)
|
||||
# logger.debug("ping dp proxy start, DP_RANK:%s", dp_rank)
|
||||
|
||||
sock.connect(f"tcp://{dp_proxy_listener_addr}")
|
||||
data = {"type": "DP", "http_address": self.http_addr}
|
||||
for _ in range(10):
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
self.notify_socket = context.socket(zmq.PUSH) # type: ignore
|
||||
self.notify_socket.connect(f"tcp://{self.dp_proxy_monitor_addr}")
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
NPUPlatform.set_device(self.device)
|
||||
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
||||
@@ -375,6 +405,11 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
if self.enable_dummy_run:
|
||||
logger.debug(
|
||||
f"send notify to the dp proxy: {self.dp_proxy_monitor_addr}")
|
||||
data = {"info": "notify_step", "http_address": self.http_addr}
|
||||
self.notify_socket.send(msgpack.dumps(data))
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
|
||||
Reference in New Issue
Block a user