fix: Fix returned prefill logits and add output str test (#1046)
This commit is contained in:
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
all_logits /= self.config.final_logit_softcapping
|
||||
all_logits = torch.tanh(all_logits)
|
||||
all_logits *= self.config.final_logit_softcapping
|
||||
|
||||
all_logprobs = all_logits
|
||||
del all_logits, hidden_states
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
@@ -26,9 +26,11 @@ from sglang.srt.server import Runtime
|
||||
from sglang.srt.utils import is_generation_model
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
"The capital of France is",
|
||||
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
||||
# "The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
"AI is a field of computer science focused on",
|
||||
]
|
||||
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
@@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype):
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
output_strs: str = None
|
||||
top_input_logprobs: torch.Tensor = None
|
||||
top_output_logprobs: torch.Tensor = None
|
||||
embed_logits: torch.Tensor = None
|
||||
output_strs: List[str] = None
|
||||
output_ids: List[int] = None
|
||||
top_input_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprobs: List[torch.Tensor] = None
|
||||
embed_logits: List[torch.Tensor] = None
|
||||
|
||||
|
||||
class HFRunner:
|
||||
@@ -117,7 +120,9 @@ class HFRunner:
|
||||
output_ids = self.model.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_strs.append(self.tokenizer.decode(output_ids[0]))
|
||||
output_strs.append(
|
||||
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
|
||||
)
|
||||
|
||||
logits = self.model.forward(input_ids).logits[0]
|
||||
logprobs = F.log_softmax(
|
||||
@@ -145,7 +150,7 @@ class HFRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=64,
|
||||
max_new_tokens=8,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens))
|
||||
return self.out_queue.get()
|
||||
@@ -184,7 +189,7 @@ class SRTRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=64,
|
||||
max_new_tokens=8,
|
||||
):
|
||||
if self.is_generation_model:
|
||||
# the return value contains logprobs from prefill
|
||||
|
||||
@@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
|
||||
MODELS = [
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
|
||||
("google/gemma-2-2b", 1),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestCausalModels(unittest.TestCase):
|
||||
class TestGenerationModels(unittest.TestCase):
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
def assert_close_prefill_logits_and_output_strs(
|
||||
self,
|
||||
prompts,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
max_new_tokens,
|
||||
) -> None:
|
||||
with HFRunner(
|
||||
model_path, torch_dtype=torch_dtype, is_generation_model=True
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(prompts)
|
||||
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
@@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase):
|
||||
torch_dtype=torch_dtype,
|
||||
is_generation_model=True,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(prompts)
|
||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
@@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase):
|
||||
abs(hf_logprobs - srt_logprobs) < tolerance
|
||||
), f"prefill logprobs not all close"
|
||||
|
||||
assert hf_outputs.output_strs == srt_outputs.output_strs
|
||||
|
||||
def test_prefill_logits(self):
|
||||
for model, tp_size in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype
|
||||
max_new_tokens = 8
|
||||
self.assert_close_prefill_logits_and_output_strs(
|
||||
DEFAULT_PROMPTS,
|
||||
model,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
max_new_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user