Split the overlapped version of TpModelWorkerClient into a separate file (#1726)

This commit is contained in:
Lianmin Zheng
2024-10-20 00:29:29 -07:00
committed by GitHub
parent 593b19f29d
commit b48edff67f
7 changed files with 217 additions and 131 deletions

View File

@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Memory pool."""
"""
Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
"""
import logging
from typing import List, Tuple, Union
@@ -26,7 +32,7 @@ logger = logging.getLogger(__name__)
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int, device: str):
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
self.size = size
self.max_context_len = max_context_len
self.device = device
@@ -34,6 +40,13 @@ class ReqToTokenPool:
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
self.write_records = []
if use_records:
# records all write operations
self.write = self.write_with_records
else:
self.write = self.write_without_records
def available_size(self):
return len(self.free_slots)
@@ -55,16 +68,27 @@ class ReqToTokenPool:
def clear(self):
self.free_slots = list(range(self.size))
self.write_records = []
def write(self, indices, values):
def write_without_records(self, indices, values):
self.req_to_token[indices] = values
def write_with_records(self, indices, values):
self.req_to_token[indices] = values
self.write_records.append((indices, values))
def get_write_records(self):
return None
ret = self.write_records
self.write_records = []
return ret
def apply_write_records(self, write_records: List[Tuple]):
for indices, values in write_records:
self.req_to_token[indices] = values
class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
"""A memory pool that maps a token location to its kv cache data."""
def __init__(
self,