Fix linear.py and improve weight loading (#2851)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
@@ -35,6 +36,8 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
@@ -193,6 +196,11 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
@@ -217,6 +225,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
assert hasattr(self, "k_buffer")
|
||||
assert hasattr(self, "v_buffer")
|
||||
k_size_bytes = 0
|
||||
for k_cache in self.k_buffer:
|
||||
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
||||
v_size_bytes = 0
|
||||
for v_cache in self.v_buffer:
|
||||
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
||||
return k_size_bytes, v_size_bytes
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
|
||||
Reference in New Issue
Block a user