Add a test case to test retract (#1662)
This commit is contained in:
@@ -590,9 +590,11 @@ class ScheduleBatch:
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
@@ -601,6 +603,7 @@ class ScheduleBatch:
|
||||
), "No space left for only one request"
|
||||
break
|
||||
|
||||
first_iter = False
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
@@ -77,6 +77,9 @@ logger = logging.getLogger(__name__)
|
||||
# Crash on warning if we are running CI tests
|
||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
# Test retract decode
|
||||
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
@@ -611,10 +614,11 @@ class Scheduler:
|
||||
return new_batch
|
||||
|
||||
def update_running_batch(self):
|
||||
global test_retract
|
||||
batch = self.running_batch
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
||||
old_ratio = self.new_token_ratio
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||
|
||||
Reference in New Issue
Block a user