CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)
This commit is contained in:
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
"""
|
||||
Memory pool.
|
||||
|
||||
@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
|
||||
class ReqToTokenPool:
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
use_records: bool,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
with memory_saver_adapter.region():
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
self.write_records = []
|
||||
self.use_records = use_records
|
||||
@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.layer_num = layer_num
|
||||
@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
with self.memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
with memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
layer_num: int,
|
||||
device: str,
|
||||
heavy_channel_num: int,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
with memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
Reference in New Issue
Block a user