Add grouped free operations (#1706)
This commit is contained in:
@@ -834,6 +834,8 @@ class Scheduler:
|
|||||||
|
|
||||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||||
|
|
||||||
|
self.token_to_kv_pool.free_group_begin()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
if self.server_args.enable_overlap_schedule and req.finished():
|
if self.server_args.enable_overlap_schedule and req.finished():
|
||||||
@@ -860,6 +862,8 @@ class Scheduler:
|
|||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
|
|
||||||
|
self.token_to_kv_pool.free_group_end()
|
||||||
|
|
||||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||||
self.print_decode_stats()
|
self.print_decode_stats()
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
|
|||||||
self.store_dtype = dtype
|
self.store_dtype = dtype
|
||||||
|
|
||||||
self.free_slots = None
|
self.free_slots = None
|
||||||
|
self.is_not_in_free_group = True
|
||||||
|
self.free_group = []
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
|
|||||||
select_index = self.free_slots[:need_size]
|
select_index = self.free_slots[:need_size]
|
||||||
self.free_slots = self.free_slots[need_size:]
|
self.free_slots = self.free_slots[need_size:]
|
||||||
|
|
||||||
return torch.tensor(select_index, dtype=torch.int32, device=self.device)
|
return select_index.to(self.device)
|
||||||
|
|
||||||
def free(self, free_index: torch.Tensor):
|
def free(self, free_index: torch.Tensor):
|
||||||
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
|
if self.is_not_in_free_group:
|
||||||
|
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
||||||
|
else:
|
||||||
|
self.free_group.append(free_index)
|
||||||
|
|
||||||
|
def free_group_begin(self):
|
||||||
|
self.is_not_in_free_group = False
|
||||||
|
self.free_group = []
|
||||||
|
|
||||||
|
def free_group_end(self):
|
||||||
|
self.is_not_in_free_group = True
|
||||||
|
if self.free_group:
|
||||||
|
self.free(torch.concat(self.free_group))
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.free_slots = np.arange(1, self.size + 1)
|
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
|
||||||
|
self.is_in_free_group = False
|
||||||
|
self.free_group = []
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
Reference in New Issue
Block a user