Unify forward mode (#1360)

This commit is contained in:
Liangsheng Yin
2024-09-09 13:49:29 -07:00
committed by GitHub
parent 689ff588ec
commit 69b3bb9ae1
9 changed files with 54 additions and 58 deletions

View File

@@ -60,7 +60,6 @@ import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits