[bugfix] A2 Environment Pooling for Memcache Compatibility (#5601)
### What this PR does / why we need it?
When running memcache in the A2 environment, the logic for registering
memory needs to be added. Additionally, there is a link establishment
conflict between memcache and HCCS during initialization in A2, so the
link should be established in advance.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
7157596103
---------
Signed-off-by: fangjianwei <f30058701@china.huawei.com>
Co-authored-by: fangjianwei <f30058701@china.huawei.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
class MmcDirect(Enum):
|
||||
@@ -26,10 +27,28 @@ class MemcacheBackend(Backend):
|
||||
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
||||
"to run vLLM with MemcacheConnector.") from e
|
||||
try:
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
assert res == 0
|
||||
soc_version = get_ascend_device_type()
|
||||
if soc_version in {AscendDeviceType.A2}:
|
||||
import torch
|
||||
from vllm.distributed import get_world_group
|
||||
tmp_tensor = torch.zeros(1, device="npu")
|
||||
output_tensor_list = [
|
||||
torch.empty_like(tmp_tensor)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
output_tensor_list,
|
||||
tmp_tensor,
|
||||
group=get_world_group().device_group)
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
assert res == 0
|
||||
else:
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
assert res == 0
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
@@ -43,7 +62,12 @@ class MemcacheBackend(Backend):
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
pass
|
||||
soc_version = get_ascend_device_type()
|
||||
if soc_version in {AscendDeviceType.A2}:
|
||||
for ptr, size in zip(ptrs, sizes):
|
||||
self.store.register_buffer(ptr, size)
|
||||
else:
|
||||
pass
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
@@ -82,7 +82,10 @@ class KVPoolScheduler:
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
if num_external_hit_tokens < num_computed_tokens:
|
||||
need_to_allocate = 0
|
||||
else:
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
|
||||
logger.info(
|
||||
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
|
||||
|
||||
@@ -87,7 +87,7 @@ class KVPoolWorker:
|
||||
self.put_step = 1
|
||||
|
||||
self.metadata = KeyMetadata(
|
||||
model_config.model.split('/')[-1],
|
||||
model_config.model.rstrip('/').split('/')[-1],
|
||||
self.head_or_tp_rank,
|
||||
self.pcp_rank,
|
||||
self.dcp_rank,
|
||||
|
||||
Reference in New Issue
Block a user