Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
)
from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy
from sglang.srt.managers.tp_worker import ModelTpWorker
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -134,7 +134,7 @@ class Scheduler:
)
# Launch a tensor parallel worker
self.tp_worker = ModelTpWorker(
self.tp_worker = TpModelWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
@@ -179,7 +179,7 @@ class Scheduler:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache)
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status
self.waiting_queue: List[Req] = []
@@ -575,9 +575,9 @@ class Scheduler:
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
forward_batch, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
@@ -641,8 +641,8 @@ class Scheduler:
)
else:
assert batch.extend_num_tokens != 0
input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
forward_batch = batch.get_forward_batch()
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
# Check finish conditions
for i, req in enumerate(batch.reqs):
@@ -771,9 +771,9 @@ class Scheduler:
batch.prepare_for_decode()
# Forward and sample the next tokens
input_metadata = batch.get_input_metadata()
forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
forward_batch, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids