Rename lora_path to lora_id in batches (#8437)
This commit is contained in:
@@ -191,11 +191,7 @@ class LoRAManager:
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# Load active loras into lora memory pool
|
||||
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
|
||||
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
|
||||
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
||||
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
cur_uids = set(forward_batch.lora_ids)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||
|
||||
@@ -211,10 +207,10 @@ class LoRAManager:
|
||||
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
||||
to device (CUDA) asynchronously.
|
||||
"""
|
||||
weight_indices = [0] * len(forward_batch.lora_paths)
|
||||
weight_indices = [0] * len(forward_batch.lora_ids)
|
||||
lora_ranks = [0] * self.max_loras_per_batch
|
||||
scalings = [0] * self.max_loras_per_batch
|
||||
for i, uid in enumerate(forward_batch.lora_paths):
|
||||
for i, uid in enumerate(forward_batch.lora_ids):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||
if uid is not None:
|
||||
lora = self.loras[uid]
|
||||
|
||||
@@ -101,8 +101,10 @@ class GenerateReqInput:
|
||||
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
# The path to the LoRA
|
||||
# The path to the LoRA adaptors
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
||||
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||
@@ -500,7 +502,7 @@ class TokenizedGenerateReqInput:
|
||||
stream: bool
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
lora_id: Optional[str] = None # None means just use the base model
|
||||
# The input embeds
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
|
||||
|
||||
@@ -423,7 +423,7 @@ class Req:
|
||||
token_ids_logprob: List[int] = None,
|
||||
stream: bool = False,
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
lora_id: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
token_type_ids: List[int] = None,
|
||||
session_id: Optional[str] = None,
|
||||
@@ -467,7 +467,7 @@ class Req:
|
||||
self.sampling_params = sampling_params
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.lora_path = lora_path
|
||||
self.lora_id = lora_id
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
@@ -1750,7 +1750,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
encoder_lens=self.encoder_lens,
|
||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
lora_ids=[req.lora_id for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
input_embeds=self.input_embeds,
|
||||
token_type_ids=self.token_type_ids,
|
||||
@@ -1891,7 +1891,7 @@ class ModelWorkerBatch:
|
||||
encoder_out_cache_loc: Optional[torch.Tensor]
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]]
|
||||
lora_ids: Optional[List[str]]
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
@@ -1090,7 +1090,7 @@ class Scheduler(
|
||||
top_logprobs_num=recv_req.top_logprobs_num,
|
||||
token_ids_logprob=recv_req.token_ids_logprob,
|
||||
stream=recv_req.stream,
|
||||
lora_path=recv_req.lora_path,
|
||||
lora_id=recv_req.lora_id,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
custom_logit_processor=recv_req.custom_logit_processor,
|
||||
return_hidden_states=recv_req.return_hidden_states,
|
||||
@@ -1534,7 +1534,7 @@ class Scheduler(
|
||||
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
||||
|
||||
if self.enable_lora:
|
||||
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
||||
lora_set = set([req.lora_id for req in self.running_batch.reqs])
|
||||
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
@@ -1542,8 +1542,8 @@ class Scheduler(
|
||||
self.enable_lora
|
||||
and len(
|
||||
lora_set
|
||||
| set([req.lora_path for req in adder.can_run_list])
|
||||
| set([req.lora_path])
|
||||
| set([req.lora_id for req in adder.can_run_list])
|
||||
| set([req.lora_id])
|
||||
)
|
||||
> self.max_loras_per_batch
|
||||
):
|
||||
|
||||
@@ -556,7 +556,7 @@ class TokenizerManager:
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
||||
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
||||
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
||||
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
||||
|
||||
self._validate_one_request(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
@@ -665,7 +665,7 @@ class TokenizerManager:
|
||||
bootstrap_host=obj.bootstrap_host,
|
||||
bootstrap_port=obj.bootstrap_port,
|
||||
bootstrap_room=obj.bootstrap_room,
|
||||
lora_path=obj.lora_path,
|
||||
lora_id=obj.lora_id,
|
||||
input_embeds=input_embeds,
|
||||
session_params=session_params,
|
||||
custom_logit_processor=obj.custom_logit_processor,
|
||||
@@ -773,7 +773,7 @@ class TokenizerManager:
|
||||
|
||||
# Mark ongoing LoRA request as finished.
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
await self.lora_registry.release(obj.lora_path)
|
||||
await self.lora_registry.release(obj.lora_id)
|
||||
|
||||
# Check if this was an abort/error created by scheduler
|
||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||
|
||||
@@ -576,11 +576,11 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
if self.model_runner.server_args.enable_lora:
|
||||
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
|
||||
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
|
||||
lora_paths = [None] * bs
|
||||
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
||||
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
||||
lora_ids = [None] * bs
|
||||
else:
|
||||
lora_paths = None
|
||||
lora_ids = None
|
||||
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
@@ -607,11 +607,11 @@ class CudaGraphRunner:
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
num_token_non_padded=self.num_token_non_padded,
|
||||
global_forward_mode=self.capture_forward_mode,
|
||||
lora_paths=lora_paths,
|
||||
lora_ids=lora_ids,
|
||||
)
|
||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||
|
||||
if lora_paths is not None:
|
||||
if lora_ids is not None:
|
||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
# Attention backend
|
||||
|
||||
@@ -248,7 +248,7 @@ class ForwardBatch:
|
||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
lora_ids: Optional[List[str]] = None
|
||||
|
||||
# For input embeddings
|
||||
input_embeds: Optional[torch.Tensor] = None
|
||||
@@ -327,7 +327,7 @@ class ForwardBatch:
|
||||
is_extend_in_batch=batch.is_extend_in_batch,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
global_forward_mode=batch.global_forward_mode,
|
||||
lora_paths=batch.lora_paths,
|
||||
lora_ids=batch.lora_ids,
|
||||
sampling_info=batch.sampling_info,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
|
||||
@@ -468,7 +468,7 @@ class TboForwardBatchPreparer:
|
||||
"extend_prefix_lens_cpu",
|
||||
"extend_seq_lens_cpu",
|
||||
"extend_logprob_start_lens_cpu",
|
||||
"lora_paths",
|
||||
"lora_ids",
|
||||
]:
|
||||
old_value = getattr(batch, key)
|
||||
if old_value is None:
|
||||
|
||||
Reference in New Issue
Block a user