Fix retraction + overlap (#1860)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 4
|
python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
|
||||||
|
|
||||||
unit-test-backend-part-2:
|
unit-test-backend-part-2:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 4 --range-end 14
|
python3 run_suite.py --suite minimal --range-begin 6 --range-end 14
|
||||||
|
|
||||||
unit-test-backend-part-3:
|
unit-test-backend-part-3:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
|
|||||||
@@ -211,9 +211,6 @@ class Req:
|
|||||||
# this does not include the jump forward tokens.
|
# this does not include the jump forward tokens.
|
||||||
self.completion_tokens_wo_jump_forward = 0
|
self.completion_tokens_wo_jump_forward = 0
|
||||||
|
|
||||||
# The number of cached tokens, that were already cached in the KV store
|
|
||||||
self.cached_tokens = 0
|
|
||||||
|
|
||||||
# For vision inputs
|
# For vision inputs
|
||||||
self.image_inputs: Optional[ImageInputs] = None
|
self.image_inputs: Optional[ImageInputs] = None
|
||||||
|
|
||||||
@@ -223,6 +220,9 @@ class Req:
|
|||||||
self.last_node = None
|
self.last_node = None
|
||||||
self.is_being_chunked = 0
|
self.is_being_chunked = 0
|
||||||
|
|
||||||
|
# For retraction
|
||||||
|
self.is_retracted = False
|
||||||
|
|
||||||
# Logprobs (arguments)
|
# Logprobs (arguments)
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
@@ -242,12 +242,15 @@ class Req:
|
|||||||
# The relative logprob_start_len in an extend batch
|
# The relative logprob_start_len in an extend batch
|
||||||
self.extend_logprob_start_len = 0
|
self.extend_logprob_start_len = 0
|
||||||
|
|
||||||
# Embedding
|
# Embedding (return values)
|
||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.grammar: Optional[Grammar] = None
|
self.grammar: Optional[Grammar] = None
|
||||||
|
|
||||||
|
# The number of cached tokens, that were already cached in the KV cache
|
||||||
|
self.cached_tokens = 0
|
||||||
|
|
||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
self.mrope_position_delta = [] # use mutable object
|
self.mrope_position_delta = [] # use mutable object
|
||||||
|
|
||||||
@@ -561,7 +564,7 @@ class ScheduleBatch:
|
|||||||
seq_lens[i] -= encoder_len
|
seq_lens[i] -= encoder_len
|
||||||
|
|
||||||
if len(req.prefix_indices) < encoder_len:
|
if len(req.prefix_indices) < encoder_len:
|
||||||
# NOTE: the encoder part should considered as a whole
|
# NOTE: the encoder part should be considered as a whole
|
||||||
assert len(req.prefix_indices) == 0
|
assert len(req.prefix_indices) == 0
|
||||||
input_ids[i] = input_ids[i][encoder_len:]
|
input_ids[i] = input_ids[i][encoder_len:]
|
||||||
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
||||||
@@ -648,6 +651,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
req.extend_logprob_start_len = extend_logprob_start_len
|
req.extend_logprob_start_len = extend_logprob_start_len
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
req.is_retracted = False
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||||
@@ -780,6 +784,7 @@ class ScheduleBatch:
|
|||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.extend_input_len = 0
|
req.extend_input_len = 0
|
||||||
|
req.is_retracted = True
|
||||||
|
|
||||||
# For incremental logprobs
|
# For incremental logprobs
|
||||||
req.last_update_decode_tokens = 0
|
req.last_update_decode_tokens = 0
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from sglang.utils import get_exception_traceback
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Crash on warning if we are running CI tests
|
# Crash on warning if we are running CI tests
|
||||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||||
|
|
||||||
@@ -831,9 +832,10 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req.is_being_chunked > 0:
|
if req.is_retracted:
|
||||||
req.is_being_chunked -= 1
|
continue
|
||||||
else:
|
|
||||||
|
if req.is_being_chunked <= 0:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_ids[i])
|
req.output_ids.append(next_token_ids[i])
|
||||||
@@ -851,12 +853,18 @@ class Scheduler:
|
|||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings, bid = result
|
embeddings, bid = result
|
||||||
embeddings = embeddings.tolist()
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
|
if req.is_retracted:
|
||||||
|
continue
|
||||||
|
|
||||||
req.embedding = embeddings[i]
|
req.embedding = embeddings[i]
|
||||||
if req.is_being_chunked > 0:
|
if req.is_being_chunked > 0:
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
@@ -893,7 +901,12 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
if self.server_args.enable_overlap_schedule and req.finished():
|
if req.is_retracted:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.server_args.enable_overlap_schedule and (
|
||||||
|
req.finished()
|
||||||
|
):
|
||||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1015,6 +1028,7 @@ class Scheduler:
|
|||||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
|
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -107,6 +107,27 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--enable-overlap-schedule",
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"20000",
|
||||||
|
"--schedule-policy",
|
||||||
|
"lpm",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
os.environ["SGLANG_TEST_RETRACT"] = "true"
|
os.environ["SGLANG_TEST_RETRACT"] = "true"
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user