Split the overlapped version of TpModelWorkerClient into a separate file (#1726)
This commit is contained in:
@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Memory pool."""
|
||||
"""
|
||||
Memory pool.
|
||||
|
||||
SGLang has two levels of memory pool.
|
||||
ReqToTokenPool maps a a request to its token locations.
|
||||
BaseTokenToKVPool maps a token location to its KV cache data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, Union
|
||||
@@ -26,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
class ReqToTokenPool:
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(self, size: int, max_context_len: int, device: str):
|
||||
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
@@ -34,6 +40,13 @@ class ReqToTokenPool:
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
self.write_records = []
|
||||
|
||||
if use_records:
|
||||
# records all write operations
|
||||
self.write = self.write_with_records
|
||||
else:
|
||||
self.write = self.write_without_records
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
@@ -55,16 +68,27 @@ class ReqToTokenPool:
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
self.write_records = []
|
||||
|
||||
def write(self, indices, values):
|
||||
def write_without_records(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def write_with_records(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
self.write_records.append((indices, values))
|
||||
|
||||
def get_write_records(self):
|
||||
return None
|
||||
ret = self.write_records
|
||||
self.write_records = []
|
||||
return ret
|
||||
|
||||
def apply_write_records(self, write_records: List[Tuple]):
|
||||
for indices, values in write_records:
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
"""A memory pool that maps a token location to its kv cache data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user