[Bugfix]KV pool rank 0 consumes more HBM (#6113)

### What this PR does / why we need it?

before add_set_deivce
<img width="2354" height="674" alt="image"
src="https://github.com/user-attachments/assets/8b81ab5f-b9ba-4fd2-8546-8f36ac15d32b"
/>
after
<img width="1044" height="156" alt="image"
src="https://github.com/user-attachments/assets/996d845a-8abd-4aae-b894-4a9832b1f742"
/>

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
baxingpiaochong
2026-01-23 19:47:33 +08:00
committed by GitHub
parent bdf65e6bd3
commit 8786412f5c

View File

@@ -2,6 +2,8 @@
import json import json
import os import os
import re import re
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
@@ -32,6 +34,7 @@ class MooncakeBackend(Backend):
"to run vLLM with MooncakeConnector.") from e "to run vLLM with MooncakeConnector.") from e
self.config = MooncakeStoreConfig.load_from_env() self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore() self.store = MooncakeDistributedStore()
self.rank = parallel_config.rank
if self.config.protocol == "ascend": if self.config.protocol == "ascend":
local_hostname = get_ip() local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname, transfer_engine = global_te.get_transfer_engine(local_hostname,
@@ -50,6 +53,10 @@ class MooncakeBackend(Backend):
logger.error(msg) logger.error(msg)
raise RuntimeError(msg) raise RuntimeError(msg)
def set_device(self):
device = torch.device(f"npu:{self.rank}")
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], lengths: list[int]): def register_buffer(self, ptrs: list[int], lengths: list[int]):
global_te.register_buffer(ptrs, lengths) global_te.register_buffer(ptrs, lengths)