CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)

This commit is contained in:
fzyzcjy
2025-01-14 03:38:51 +08:00
committed by GitHub
parent d08c77c434
commit 923f518337
12 changed files with 406 additions and 60 deletions

View File

@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
Memory pool.
@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
use_records: bool,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size
self.max_context_len = max_context_len
self.device = device
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
with memory_saver_adapter.region():
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
self.write_records = []
self.use_records = use_records
@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def _clear_buffers(self):
del self.k_buffer
@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
layer_num: int,
device: str,
heavy_channel_num: int,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
with memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]