Add a test case to test retract (#1662)

This commit is contained in:
Lianmin Zheng
2024-10-13 20:32:37 -07:00
committed by GitHub
parent 2725f8da61
commit 869f1c02c4
4 changed files with 50 additions and 1 deletions

View File

@@ -590,9 +590,11 @@ class ScheduleBatch:
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
@@ -601,6 +603,7 @@ class ScheduleBatch:
), "No space left for only one request"
break
first_iter = False
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)

View File

@@ -77,6 +77,9 @@ logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker."""
@@ -611,10 +614,11 @@ class Scheduler:
return new_batch
def update_running_batch(self):
global test_retract
batch = self.running_batch
# Check if decode out of memory
if not batch.check_decode_mem():
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode()