Use ipc instead of tcp in zmq (#1566)

This commit is contained in:
Lianmin Zheng
2024-10-04 00:45:52 -07:00
committed by GitHub
parent 32eb6e96f2
commit 114bbc8651
9 changed files with 48 additions and 96 deletions

View File

@@ -16,7 +16,6 @@ limitations under the License.
"""Memory pool."""
import logging
from abc import ABC, abstractmethod
from typing import List, Tuple, Union
import numpy as np
@@ -62,9 +61,11 @@ class BaseTokenToKVPool:
self,
size: int,
dtype: torch.dtype,
device: str,
):
self.size = size
self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
@@ -84,7 +85,7 @@ class BaseTokenToKVPool:
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return torch.tensor(select_index, dtype=torch.int32, device="cuda")
return torch.tensor(select_index, dtype=torch.int32, device=self.device)
def free(self, free_index: torch.Tensor):
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
@@ -123,7 +124,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
layer_num: int,
device: str,
):
super().__init__(size, dtype)
super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
@@ -187,7 +188,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
layer_num: int,
device: str,
):
super().__init__(size, dtype)
super().__init__(size, dtype, device)
self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.