Fix linear.py and improve weight loading (#2851)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-01-13 01:39:14 -08:00
committed by GitHub
parent 4093aa4660
commit 72c7776355
12 changed files with 113 additions and 125 deletions

View File

@@ -27,6 +27,7 @@ from enum import IntEnum
from functools import wraps
from typing import List, Tuple, Union
import numpy as np
import psutil
import torch
@@ -35,6 +36,8 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
@@ -193,6 +196,11 @@ class MHATokenToKVPool(BaseTokenToKVPool):
self.layer_num = layer_num
self._create_buffers()
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
@@ -217,6 +225,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
del self.k_buffer
del self.v_buffer
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
return k_size_bytes, v_size_bytes
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer