Fix lora bench (#6302)
This commit is contained in:
@@ -170,6 +170,7 @@ async def benchmark(
|
|||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
lora_name="dummy", # the lora_name argument will not be used
|
lora_name="dummy", # the lora_name argument will not be used
|
||||||
|
image_data=None,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
test_output = await request_func(request_func_input=test_input)
|
test_output = await request_func(request_func_input=test_input)
|
||||||
@@ -194,6 +195,7 @@ async def benchmark(
|
|||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
output_len=output_len,
|
output_len=output_len,
|
||||||
lora_name="dummy",
|
lora_name="dummy",
|
||||||
|
image_data=None,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
|
|||||||
@@ -170,9 +170,7 @@ class LoRAManager:
|
|||||||
dim=0,
|
dim=0,
|
||||||
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
||||||
)
|
)
|
||||||
self.cuda_graph_batch_info.max_len = int(
|
self.cuda_graph_batch_info.max_len = 1
|
||||||
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||||
self.cuda_graph_batch_info.weight_indices[i] = (
|
self.cuda_graph_batch_info.weight_indices[i] = (
|
||||||
|
|||||||
Reference in New Issue
Block a user