[Fix] Fix bugs and refactor codes in lora for better scalability. (#3652)

Co-authored-by: ShenAo1111 <1377693092@qq.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
aoshen524
2025-02-20 14:51:57 -05:00
committed by GitHub
parent ac05310098
commit e79f7420be
11 changed files with 459 additions and 200 deletions

View File

@@ -189,9 +189,17 @@ class HFRunner:
return_dict_in_generate=True,
output_scores=(not self.output_str_only),
)
output_strs.append(
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
text = self.tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
@@ -275,6 +283,7 @@ class SRTRunner:
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.65,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
@@ -283,7 +292,7 @@ class SRTRunner:
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.65,
mem_fraction_static=mem_fraction_static,
trust_remote_code=False,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
@@ -315,7 +324,15 @@ class SRTRunner:
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
output_strs.append(response["text"])
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]