Fix the lora adapter when lora path is none (#4799)
Co-authored-by: Beichen Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -133,10 +133,6 @@ class LoRAManager:
|
|||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
||||||
|
|
||||||
# FIXME: Handle lora uid with None more safely
|
|
||||||
if cur_uids == set([None]):
|
|
||||||
return
|
|
||||||
|
|
||||||
# set up batch info shared by all lora moruldes
|
# set up batch info shared by all lora moruldes
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
seg_lens = (
|
seg_lens = (
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ class LoRAMemoryPool:
|
|||||||
if uid is None:
|
if uid is None:
|
||||||
for i in range(self.num_layer):
|
for i in range(self.num_layer):
|
||||||
for k in self.A_buffer.keys():
|
for k in self.A_buffer.keys():
|
||||||
self.A_buffer[k][i][buffer_id] *= 0
|
self.A_buffer[k][i][buffer_id] = 0
|
||||||
return
|
return
|
||||||
|
|
||||||
assert lora_adapter is not None
|
assert lora_adapter is not None
|
||||||
|
|||||||
@@ -96,6 +96,11 @@ class TestLoRA(CustomTestCase):
|
|||||||
srt_outputs = srt_runner.forward(
|
srt_outputs = srt_runner.forward(
|
||||||
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
|
||||||
)
|
)
|
||||||
|
srt_outputs_lora_path_none = srt_runner.forward(
|
||||||
|
prompts,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
lora_paths=[None] * len(prompts),
|
||||||
|
)
|
||||||
|
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||||
@@ -169,18 +174,20 @@ class TestLoRA(CustomTestCase):
|
|||||||
print(f"{srt_outputs.output_strs=}")
|
print(f"{srt_outputs.output_strs=}")
|
||||||
print(f"{hf_no_lora_outputs.output_strs=}")
|
print(f"{hf_no_lora_outputs.output_strs=}")
|
||||||
print(f"{srt_no_lora_outputs.output_strs=}")
|
print(f"{srt_no_lora_outputs.output_strs=}")
|
||||||
|
print(f"{srt_outputs_lora_path_none.output_strs=}")
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
|
||||||
srt_outputs.output_strs[i].strip(" "),
|
srt_outputs.output_strs[i].strip(" "),
|
||||||
hf_outputs.output_strs[i],
|
hf_outputs.output_strs[i],
|
||||||
)
|
)
|
||||||
# assert (
|
assert (
|
||||||
# srt_no_lora_outputs.output_strs[i].strip(" ")
|
srt_no_lora_outputs.output_strs[i].strip(" ")
|
||||||
# == hf_no_lora_outputs.output_strs[i]
|
== hf_no_lora_outputs.output_strs[i]
|
||||||
# ), (
|
), (
|
||||||
# srt_no_lora_outputs.output_strs[i].strip(" "),
|
srt_no_lora_outputs.output_strs[i].strip(" "),
|
||||||
# hf_no_lora_outputs.output_strs[i],
|
hf_no_lora_outputs.output_strs[i],
|
||||||
# )
|
)
|
||||||
|
assert srt_outputs_lora_path_none == srt_no_lora_outputs
|
||||||
|
|
||||||
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
|
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
|
||||||
print("=================== testing serving =======================")
|
print("=================== testing serving =======================")
|
||||||
@@ -257,7 +264,7 @@ class TestLoRA(CustomTestCase):
|
|||||||
srt_no_lora_logprobs = torch.Tensor(
|
srt_no_lora_logprobs = torch.Tensor(
|
||||||
srt_no_lora_outputs.top_input_logprobs[i]
|
srt_no_lora_outputs.top_input_logprobs[i]
|
||||||
)
|
)
|
||||||
srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i])
|
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||||
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
|
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
|
||||||
|
|
||||||
print(f"{srt_no_lora_outputs.output_strs=}")
|
print(f"{srt_no_lora_outputs.output_strs=}")
|
||||||
@@ -280,7 +287,7 @@ class TestLoRA(CustomTestCase):
|
|||||||
tp_size = 1
|
tp_size = 1
|
||||||
max_new_tokens = 32
|
max_new_tokens = 32
|
||||||
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
|
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
|
||||||
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
|
self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
|
||||||
# self.base_inference(
|
# self.base_inference(
|
||||||
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
|
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
|
||||||
# )
|
# )
|
||||||
|
|||||||
Reference in New Issue
Block a user