Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
)
|
||||
from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy
|
||||
from sglang.srt.managers.tp_worker import ModelTpWorker
|
||||
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -134,7 +134,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
self.tp_worker = ModelTpWorker(
|
||||
self.tp_worker = TpModelWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
@@ -179,7 +179,7 @@ class Scheduler:
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache)
|
||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
@@ -575,9 +575,9 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
input_metadata = batch.get_input_metadata()
|
||||
forward_batch = batch.get_forward_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
input_metadata, batch
|
||||
forward_batch, batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
@@ -641,8 +641,8 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
input_metadata = batch.get_input_metadata()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
|
||||
forward_batch = batch.get_forward_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -771,9 +771,9 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
input_metadata = batch.get_input_metadata()
|
||||
forward_batch = batch.get_forward_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
input_metadata, batch
|
||||
forward_batch, batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
|
||||
Reference in New Issue
Block a user