[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)
Co-authored-by: ShenAo1111 <1377693092@qq.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -49,7 +49,7 @@ ALL_OTHER_LORA_MODELS = [
|
||||
LoRAModelCase(
|
||||
base="meta-llama/Llama-2-7b-hf",
|
||||
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
|
||||
max_loras_per_batch=1,
|
||||
max_loras_per_batch=2,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -96,6 +96,7 @@ class TestLoRABackend(unittest.TestCase):
|
||||
disable_cuda_graph=True,
|
||||
disable_radix_cache=True,
|
||||
mem_fraction_static=0.88,
|
||||
disable_custom_all_reduce=False,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(
|
||||
[prompt], max_new_tokens=max_new_tokens, lora_paths=[adaptor.name]
|
||||
@@ -114,6 +115,7 @@ class TestLoRABackend(unittest.TestCase):
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
mem_fraction_static=0.88,
|
||||
disable_custom_all_reduce=False,
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(
|
||||
[prompt], max_new_tokens=max_new_tokens
|
||||
|
||||
Reference in New Issue
Block a user