[Bugfix]KV pool rank 0 consumes more HBM (#6113)
### What this PR does / why we need it?
before add_set_deivce
<img width="2354" height="674" alt="image"
src="https://github.com/user-attachments/assets/8b81ab5f-b9ba-4fd2-8546-8f36ac15d32b"
/>
after
<img width="1044" height="156" alt="image"
src="https://github.com/user-attachments/assets/996d845a-8abd-4aae-b894-4a9832b1f742"
/>
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -32,6 +34,7 @@ class MooncakeBackend(Backend):
|
|||||||
"to run vLLM with MooncakeConnector.") from e
|
"to run vLLM with MooncakeConnector.") from e
|
||||||
self.config = MooncakeStoreConfig.load_from_env()
|
self.config = MooncakeStoreConfig.load_from_env()
|
||||||
self.store = MooncakeDistributedStore()
|
self.store = MooncakeDistributedStore()
|
||||||
|
self.rank = parallel_config.rank
|
||||||
if self.config.protocol == "ascend":
|
if self.config.protocol == "ascend":
|
||||||
local_hostname = get_ip()
|
local_hostname = get_ip()
|
||||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
||||||
@@ -50,6 +53,10 @@ class MooncakeBackend(Backend):
|
|||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
def set_device(self):
|
||||||
|
device = torch.device(f"npu:{self.rank}")
|
||||||
|
torch.npu.set_device(device)
|
||||||
|
|
||||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||||
global_te.register_buffer(ptrs, lengths)
|
global_te.register_buffer(ptrs, lengths)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user