fix: Fix returned prefill logits and add output str test (#1046)

This commit is contained in:
Ying Sheng
2024-08-11 23:13:45 -07:00
committed by GitHub
parent fb1f28cbbb
commit 32f6144323
3 changed files with 33 additions and 14 deletions

View File

@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
all_logits /= self.config.final_logit_softcapping
all_logits = torch.tanh(all_logits)
all_logits *= self.config.final_logit_softcapping
all_logprobs = all_logits
del all_logits, hidden_states
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)

View File

@@ -26,9 +26,11 @@ from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
"The capital of France is",
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
"AI is a field of computer science focused on",
]
NUM_TOP_LOGPROBS = 5
@@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype):
@dataclass
class ModelOutput:
output_strs: str = None
top_input_logprobs: torch.Tensor = None
top_output_logprobs: torch.Tensor = None
embed_logits: torch.Tensor = None
output_strs: List[str] = None
output_ids: List[int] = None
top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None
embed_logits: List[torch.Tensor] = None
class HFRunner:
@@ -117,7 +120,9 @@ class HFRunner:
output_ids = self.model.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_strs.append(self.tokenizer.decode(output_ids[0]))
output_strs.append(
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
)
logits = self.model.forward(input_ids).logits[0]
logprobs = F.log_softmax(
@@ -145,7 +150,7 @@ class HFRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
max_new_tokens=8,
):
self.in_queue.put((prompts, max_new_tokens))
return self.out_queue.get()
@@ -184,7 +189,7 @@ class SRTRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
max_new_tokens=8,
):
if self.is_generation_model:
# the return value contains logprobs from prefill