diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 3a740abed..d702faab6 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -156,18 +156,15 @@ class LoRAManager: # set up batch info shared by all lora modules bs = forward_batch.batch_size - if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph: - # Do in-place updates when CUDA graph is enabled. Note that - # if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph - # will also use these preallocated buffers, no matter whether - # the batch can use CUDA graph or not. + if ( + hasattr(self, "max_bs_in_cuda_graph") + and bs <= self.max_bs_in_cuda_graph + and forward_batch.forward_mode.is_cuda_graph() + ): + # Do in-place updates when CUDA graph is enabled and the batch forward mode + # could use CUDA graph. self.cuda_graph_batch_info.bs = bs - if forward_batch.forward_mode.is_extend(): - self.cuda_graph_batch_info.seg_lens[:bs].copy_( - forward_batch.extend_seq_lens - ) - else: - self.cuda_graph_batch_info.seg_lens[:bs].fill_(1) + self.cuda_graph_batch_info.seg_lens[:bs].fill_(1) torch.cumsum( self.cuda_graph_batch_info.seg_lens[:bs], dim=0, @@ -201,10 +198,10 @@ class LoRAManager: max_len = int(torch.max(seg_lens)) weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) - lora_ranks = torch.empty( + lora_ranks = torch.zeros( (self.max_loras_per_batch,), dtype=torch.int64, device="cuda" ) - scalings = torch.empty( + scalings = torch.zeros( (self.max_loras_per_batch,), dtype=torch.float, device="cuda" ) for i, lora_path in enumerate(forward_batch.lora_paths): diff --git a/test/srt/models/lora/test_lora.py b/test/srt/models/lora/test_lora.py index 37571fd5d..3c27d3d57 100644 --- a/test/srt/models/lora/test_lora.py +++ b/test/srt/models/lora/test_lora.py @@ -13,66 +13,177 @@ # ============================================================================== import multiprocessing as mp +import os +import random import unittest +from typing import List -from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, run_lora_test_by_batch +from utils import ( + ALL_OTHER_MULTI_LORA_MODELS, + CI_MULTI_LORA_MODELS, + TORCH_DTYPES, + LoRAModelCase, +) -from sglang.test.test_utils import CustomTestCase +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci -PROMPTS = [ +TEST_MULTIPLE_BATCH_PROMPTS = [ """ -### Instruction: -Write a poem about the transformers Python library. -Mention the word "large language models" in that poem. -### Response: -The Transformers are large language models, -They're used to make predictions on text. -""", + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. + ### Question 2: + What do you know about llamas? + ### Answer: + """, + """ + ### Instruction: + Write a poem about the transformers Python library. + Mention the word "large language models" in that poem. + ### Response: + The Transformers are large language models, + They're used to make predictions on text. + """, "AI is a field of computer science focused on", -] - -LORA_MODELS_WITH_NONE = [ - LoRAModelCase( - base="meta-llama/Llama-3.1-8B-Instruct", - adaptors=[ - LoRAAdaptor( - name="algoprog/fact-generation-llama-3.1-8b-instruct-lora", - ), - LoRAAdaptor( - name=None, - ), - ], - max_loras_per_batch=2, - ), - LoRAModelCase( - base="meta-llama/Llama-3.1-8B-Instruct", - adaptors=[ - LoRAAdaptor( - name=None, - ), - LoRAAdaptor( - name="algoprog/fact-generation-llama-3.1-8b-instruct-lora", - ), - ], - max_loras_per_batch=2, - ), + "Computer science is the study of", + "Write a short story.", + "What are the main components of a computer?", ] class TestLoRA(CustomTestCase): - def test_lora_batch_with_none(self): - for model_case in LORA_MODELS_WITH_NONE: - prompts = PROMPTS + + def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): + for model_case in model_cases: for torch_dtype in TORCH_DTYPES: - run_lora_test_by_batch( - prompts, - model_case, - torch_dtype, - max_new_tokens=32, - backend="triton", - test_tag="test_lora_batch_with_none", + max_new_tokens = 32 + backend = "triton" + base_path = model_case.base + lora_adapter_paths = [a.name for a in model_case.adaptors] + assert len(lora_adapter_paths) >= 2 + + batches = [ + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [ + None, + lora_adapter_paths[0], + lora_adapter_paths[1], + ], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [ + lora_adapter_paths[0], + None, + lora_adapter_paths[1], + ], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [lora_adapter_paths[0], lora_adapter_paths[1], None], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [None, lora_adapter_paths[1], None], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [None, None, None], + ), + ] + + print( + f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" ) + # Initialize runners + srt_runner = SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], + max_loras_per_batch=len(lora_adapter_paths) + 1, + lora_backend=backend, + disable_radix_cache=True, + ) + hf_runner = HFRunner( + base_path, torch_dtype=torch_dtype, model_type="generation" + ) + + with srt_runner, hf_runner: + for i, (prompts, lora_paths) in enumerate(batches): + print( + f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}" + ) + + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + print("SRT outputs:", [s for s in srt_outputs.output_strs]) + print("HF outputs:", [s for s in hf_outputs.output_strs]) + + for srt_out, hf_out in zip( + srt_outputs.output_strs, hf_outputs.output_strs + ): + srt_str = srt_out.strip() + hf_str = hf_out.strip() + rouge_tol = model_case.rouge_l_tolerance + rouge_score = calculate_rouge_l([srt_str], [hf_str])[0] + if rouge_score < rouge_tol: + raise AssertionError( + f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " + f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" + ) + + print(f"--- Batch {i+1} Comparison Passed --- ") + + def test_ci_lora_models(self): + self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS) + + def test_all_lora_models(self): + if is_in_ci(): + return + + filtered_models = [] + for model_case in ALL_OTHER_MULTI_LORA_MODELS: + if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base: + continue + filtered_models.append(model_case) + + self._run_lora_multiple_batch_on_model_cases(filtered_models) + if __name__ == "__main__": try: diff --git a/test/srt/models/lora/test_multi_lora_backend.py b/test/srt/models/lora/test_multi_lora_backend.py index eb338dfaf..310ce72e4 100644 --- a/test/srt/models/lora/test_multi_lora_backend.py +++ b/test/srt/models/lora/test_multi_lora_backend.py @@ -18,50 +18,16 @@ import unittest from typing import List from utils import ( + ALL_OTHER_MULTI_LORA_MODELS, BACKENDS, + CI_MULTI_LORA_MODELS, TORCH_DTYPES, - LoRAAdaptor, LoRAModelCase, run_lora_test_one_by_one, ) from sglang.test.test_utils import CustomTestCase, is_in_ci -CI_MULTI_LORA_MODELS = [ - # multi-rank case - LoRAModelCase( - base="meta-llama/Llama-2-7b-hf", - adaptors=[ - LoRAAdaptor( - name="winddude/wizardLM-LlaMA-LoRA-7B", - prefill_tolerance=1e-1, - ), - LoRAAdaptor( - name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa", - prefill_tolerance=3e-1, - ), - ], - max_loras_per_batch=2, - ), -] - -ALL_OTHER_MULTI_LORA_MODELS = [ - LoRAModelCase( - base="meta-llama/Llama-3.1-8B-Instruct", - adaptors=[ - LoRAAdaptor( - name="algoprog/fact-generation-llama-3.1-8b-instruct-lora", - prefill_tolerance=1e-1, - ), - LoRAAdaptor( - name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", - prefill_tolerance=1e-1, - ), - ], - max_loras_per_batch=2, - ), -] - # All prompts are used at once in a batch. PROMPTS = [ "AI is a field of computer science focused on", diff --git a/test/srt/models/lora/utils.py b/test/srt/models/lora/utils.py index 0dd638100..642b8731e 100644 --- a/test/srt/models/lora/utils.py +++ b/test/srt/models/lora/utils.py @@ -93,6 +93,41 @@ ALL_OTHER_LORA_MODELS = [ ), ] +CI_MULTI_LORA_MODELS = [ + # multi-rank case + LoRAModelCase( + base="meta-llama/Llama-2-7b-hf", + adaptors=[ + LoRAAdaptor( + name="winddude/wizardLM-LlaMA-LoRA-7B", + prefill_tolerance=1e-1, + ), + LoRAAdaptor( + name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa", + prefill_tolerance=3e-1, + ), + ], + max_loras_per_batch=2, + ), +] + +ALL_OTHER_MULTI_LORA_MODELS = [ + LoRAModelCase( + base="meta-llama/Llama-3.1-8B-Instruct", + adaptors=[ + LoRAAdaptor( + name="algoprog/fact-generation-llama-3.1-8b-instruct-lora", + prefill_tolerance=1e-1, + ), + LoRAAdaptor( + name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + prefill_tolerance=1e-1, + ), + ], + max_loras_per_batch=2, + ), +] + def run_lora_test_one_by_one( prompts: List[str],