Fix a bug in BatchTokenIDOut & Misc style and dependency updates (#7457)

This commit is contained in:
Lianmin Zheng
2025-06-23 06:20:39 -07:00
committed by GitHub
parent 8aa68ed5c4
commit 55e03b10c4
9 changed files with 37 additions and 32 deletions

View File

@@ -788,6 +788,7 @@ class Req:
self.multimodal_inputs = None
self.grammar = None
self.origin_input_ids = [0] # set it to one token to skip the long prefill
self.return_logprob = False
self.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)

View File

@@ -1374,7 +1374,14 @@ class Scheduler(
)
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
if self.disaggregation_mode == DisaggregationMode.DECODE:
req_total_size = (
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
)
else:
req_total_size = self.req_to_token_pool.size
if len(self.req_to_token_pool.free_slots) != req_total_size:
msg = (
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "

View File

@@ -1226,7 +1226,7 @@ class TokenizerManager:
state.last_output_offset = len(state.output_ids)
else:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids
output_token_ids = state.output_ids.copy()
out_dict = {
"output_ids": output_token_ids,

View File

@@ -1723,9 +1723,8 @@ class PortArgs:
dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1
if dp_rank is None:
scheduler_input_port = (
port_base + 3
) # TokenizerManager to DataParallelController
# TokenizerManager to DataParallelController
scheduler_input_port = port_base + 3
else:
scheduler_input_port = port_base + 3 + 1 + dp_rank

View File

@@ -1917,13 +1917,6 @@ def configure_ipv6(dist_init_addr):
return port, host
def rank0_log(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0:
logger.info(msg)
def rank0_print(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank
@@ -1931,6 +1924,9 @@ def rank0_print(msg: str):
print(msg, flush=True)
rank0_log = rank0_print
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))