[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
@@ -41,6 +41,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
@@ -107,6 +108,8 @@ class ModelRunner:
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_running_requests,
|
||||
@@ -312,6 +315,17 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights"
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
lora_paths=self.server_args.lora_paths,
|
||||
base_hf_config=self.model_config.hf_config,
|
||||
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
||||
load_config=self.load_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
logger.info("LoRA manager ready.")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
@@ -450,6 +464,8 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch)
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
@@ -466,6 +482,9 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
||||
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
|
||||
Reference in New Issue
Block a user