Unify forward output datastructure (#11124)
This commit is contained in:
@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatchOutput,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -175,7 +179,6 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
get_zmq_socket,
|
||||
is_cpu,
|
||||
kill_itself_when_parent_died,
|
||||
numa_bind_to_node,
|
||||
point_to_point_pyobj,
|
||||
@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
|
||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: Optional[LogitsProcessorOutput]
|
||||
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
||||
next_token_ids: Optional[List[int]]
|
||||
can_run_cuda_graph: bool
|
||||
|
||||
# For output processing
|
||||
extend_input_len_per_req: List[int]
|
||||
extend_logprob_start_len_per_req: List[int]
|
||||
bid: int
|
||||
can_run_cuda_graph: bool
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch_output(
|
||||
cls,
|
||||
forward_batch_output: ForwardBatchOutput,
|
||||
extend_input_len_per_req: List[int],
|
||||
extend_logprob_start_len_per_req: List[int],
|
||||
):
|
||||
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
||||
|
||||
return cls(
|
||||
logits_output=forward_batch_output.logits_output,
|
||||
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
||||
next_token_ids=forward_batch_output.next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pp_proxy(
|
||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||
):
|
||||
# TODO(lsyin): also simplify this logic
|
||||
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
||||
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
||||
proxy_dict = next_pp_outputs.tensors
|
||||
return cls(
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
||||
extend_logprob_start_len_per_req=proxy_dict.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingBatchResult:
|
||||
embeddings: torch.Tensor
|
||||
bid: int
|
||||
|
||||
|
||||
class Scheduler(
|
||||
@@ -403,6 +441,12 @@ class Scheduler(
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
# Dispatch the model worker
|
||||
if self.spec_algorithm.is_none():
|
||||
self.model_worker = self.tp_worker
|
||||
else:
|
||||
self.model_worker = self.draft_worker
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
self.max_total_num_tokens,
|
||||
@@ -959,7 +1003,6 @@ class Scheduler(
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
bids = [None] * self.pp_size
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
while True:
|
||||
server_is_idle = True
|
||||
@@ -980,10 +1023,7 @@ class Scheduler(
|
||||
# (last rank) send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
next_token_ids = result.next_token_ids
|
||||
if self.cur_batch.return_logprob:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
@@ -1031,17 +1071,10 @@ class Scheduler(
|
||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||
else:
|
||||
logits_output = None
|
||||
output_result = GenerationBatchResult(
|
||||
|
||||
output_result = GenerationBatchResult.from_pp_proxy(
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
||||
"extend_input_len_per_req", None
|
||||
),
|
||||
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
bid=bids[next_mb_id],
|
||||
next_pp_outputs=next_pp_outputs,
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
@@ -1049,8 +1082,6 @@ class Scheduler(
|
||||
|
||||
# (not last rank)
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
# carry the outputs to the next stage
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
if pp_outputs:
|
||||
@@ -1072,8 +1103,10 @@ class Scheduler(
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors,
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
@@ -2016,33 +2049,25 @@ class Scheduler(
|
||||
|
||||
# Run forward
|
||||
if self.is_generation:
|
||||
|
||||
batch_or_worker_batch = batch
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
else:
|
||||
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
bid = model_worker_batch.bid
|
||||
else:
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
can_run_cuda_graph,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
bs = batch.batch_size()
|
||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||
self.spec_num_total_forward_ct += bs
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
forward_batch_output = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
batch.output_ids = next_token_ids
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
self.udpate_spec_metrics(
|
||||
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
||||
)
|
||||
|
||||
# update batch's output ids
|
||||
batch.output_ids = forward_batch_output.next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -2051,6 +2076,7 @@ class Scheduler(
|
||||
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
||||
else:
|
||||
extend_input_len_per_req = None
|
||||
|
||||
if batch.return_logprob:
|
||||
extend_logprob_start_len_per_req = [
|
||||
req.extend_logprob_start_len for req in batch.reqs
|
||||
@@ -2058,25 +2084,15 @@ class Scheduler(
|
||||
else:
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
||||
pp_hidden_states_proxy_tensors=(
|
||||
pp_hidden_states_proxy_tensors
|
||||
if not self.pp_group.is_last_rank
|
||||
else None
|
||||
),
|
||||
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
||||
return GenerationBatchResult.from_forward_batch_output(
|
||||
forward_batch_output=forward_batch_output,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
bid=bid,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = EmbeddingBatchResult(
|
||||
embeddings=embeddings, bid=model_worker_batch.bid
|
||||
)
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
return ret
|
||||
|
||||
def process_batch_result(
|
||||
|
||||
Reference in New Issue
Block a user