[Bugfix]KV pool rank 0 consumes more HBM (#6113)
### What this PR does / why we need it?
before add_set_deivce
<img width="2354" height="674" alt="image"
src="https://github.com/user-attachments/assets/8b81ab5f-b9ba-4fd2-8546-8f36ac15d32b"
/>
after
<img width="1044" height="156" alt="image"
src="https://github.com/user-attachments/assets/996d845a-8abd-4aae-b894-4a9832b1f742"
/>
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
@@ -32,6 +34,7 @@ class MooncakeBackend(Backend):
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.rank = parallel_config.rank
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
||||
@@ -50,6 +53,10 @@ class MooncakeBackend(Backend):
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def set_device(self):
|
||||
device = torch.device(f"npu:{self.rank}")
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user