diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 5584d01ad..cf5045fda 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module): all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() + if hasattr(self.config, "final_logit_softcapping"): + all_logits /= self.config.final_logit_softcapping + all_logits = torch.tanh(all_logits) + all_logits *= self.config.final_logit_softcapping + all_logprobs = all_logits del all_logits, hidden_states all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e619d58ca..e5ad3ea9d 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -26,9 +26,11 @@ from sglang.srt.server import Runtime from sglang.srt.utils import is_generation_model DEFAULT_PROMPTS = [ - "The capital of France is", + # the output of gemma-2-2b from SRT is unstable on the commented prompt + # "The capital of France is", "The capital of the United Kindom is", "Today is a sunny day and I like", + "AI is a field of computer science focused on", ] NUM_TOP_LOGPROBS = 5 @@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype): @dataclass class ModelOutput: - output_strs: str = None - top_input_logprobs: torch.Tensor = None - top_output_logprobs: torch.Tensor = None - embed_logits: torch.Tensor = None + output_strs: List[str] = None + output_ids: List[int] = None + top_input_logprobs: List[torch.Tensor] = None + top_output_logprobs: List[torch.Tensor] = None + embed_logits: List[torch.Tensor] = None class HFRunner: @@ -117,7 +120,9 @@ class HFRunner: output_ids = self.model.generate( input_ids, do_sample=False, max_new_tokens=max_new_tokens ) - output_strs.append(self.tokenizer.decode(output_ids[0])) + output_strs.append( + self.tokenizer.decode(output_ids[0][len(input_ids[0]) :]) + ) logits = self.model.forward(input_ids).logits[0] logprobs = F.log_softmax( @@ -145,7 +150,7 @@ class HFRunner: def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, - max_new_tokens=64, + max_new_tokens=8, ): self.in_queue.put((prompts, max_new_tokens)) return self.out_queue.get() @@ -184,7 +189,7 @@ class SRTRunner: def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, - max_new_tokens=64, + max_new_tokens=8, ): if self.is_generation_model: # the return value contains logprobs from prefill diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index f05764802..ca4f096e3 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner MODELS = [ ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1), + ("google/gemma-2-2b", 1), ] TORCH_DTYPES = [torch.float16] -class TestCausalModels(unittest.TestCase): +class TestGenerationModels(unittest.TestCase): - def assert_close_prefill_logits( + def assert_close_prefill_logits_and_output_strs( self, prompts, model_path, tp_size, torch_dtype, + max_new_tokens, ) -> None: with HFRunner( model_path, torch_dtype=torch_dtype, is_generation_model=True ) as hf_runner: - hf_outputs = hf_runner.forward(prompts) + hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) with SRTRunner( model_path, @@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase): torch_dtype=torch_dtype, is_generation_model=True, ) as srt_runner: - srt_outputs = srt_runner.forward(prompts) + srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) for i in range(len(prompts)): hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) @@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase): abs(hf_logprobs - srt_logprobs) < tolerance ), f"prefill logprobs not all close" + assert hf_outputs.output_strs == srt_outputs.output_strs + def test_prefill_logits(self): for model, tp_size in MODELS: for torch_dtype in TORCH_DTYPES: - self.assert_close_prefill_logits( - DEFAULT_PROMPTS, model, tp_size, torch_dtype + max_new_tokens = 8 + self.assert_close_prefill_logits_and_output_strs( + DEFAULT_PROMPTS, + model, + tp_size, + torch_dtype, + max_new_tokens, )