[CI] Return output logprobs in unit test (#1361)
This commit is contained in:
@@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_top_logprobs(logits, k):
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
|
||||
return logprobs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
output_strs: List[str] = None
|
||||
@@ -108,7 +114,8 @@ class HFRunner:
|
||||
if prompts is not None:
|
||||
if self.is_generation:
|
||||
output_strs = []
|
||||
prefill_logprobs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
for p in prompts:
|
||||
if isinstance(p, str):
|
||||
input_ids = self.tokenizer.encode(
|
||||
@@ -117,32 +124,43 @@ class HFRunner:
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
output_ids = self.model.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
output_strs.append(
|
||||
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
|
||||
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
||||
)
|
||||
# outputs.scores: (num_token, 1, vocab_size)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
logits = self.model.forward(input_ids).logits[0]
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
logprobs, top_indices = torch.topk(
|
||||
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
|
||||
input_logits = self.model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
# print("index", top_indices)
|
||||
prefill_logprobs.append(logprobs.tolist())
|
||||
del logits
|
||||
del logprobs
|
||||
del input_logits
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=prefill_logprobs
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
def forward(
|
||||
@@ -194,6 +212,7 @@ class SRTRunner:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
response = self.runtime.generate(
|
||||
@@ -219,9 +238,17 @@ class SRTRunner:
|
||||
]
|
||||
]
|
||||
)
|
||||
top_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
|
||||
Reference in New Issue
Block a user