Rename lora_path to lora_id in batches (#8437)
This commit is contained in:
@@ -191,11 +191,7 @@ class LoRAManager:
|
|||||||
|
|
||||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||||
# Load active loras into lora memory pool
|
# Load active loras into lora memory pool
|
||||||
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
|
cur_uids = set(forward_batch.lora_ids)
|
||||||
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
|
|
||||||
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
|
||||||
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
|
||||||
cur_uids = set(forward_batch.lora_paths)
|
|
||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||||
|
|
||||||
@@ -211,10 +207,10 @@ class LoRAManager:
|
|||||||
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
||||||
to device (CUDA) asynchronously.
|
to device (CUDA) asynchronously.
|
||||||
"""
|
"""
|
||||||
weight_indices = [0] * len(forward_batch.lora_paths)
|
weight_indices = [0] * len(forward_batch.lora_ids)
|
||||||
lora_ranks = [0] * self.max_loras_per_batch
|
lora_ranks = [0] * self.max_loras_per_batch
|
||||||
scalings = [0] * self.max_loras_per_batch
|
scalings = [0] * self.max_loras_per_batch
|
||||||
for i, uid in enumerate(forward_batch.lora_paths):
|
for i, uid in enumerate(forward_batch.lora_ids):
|
||||||
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||||
if uid is not None:
|
if uid is not None:
|
||||||
lora = self.loras[uid]
|
lora = self.loras[uid]
|
||||||
|
|||||||
@@ -101,8 +101,10 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
# The modalities of the image data [image, multi-images, video]
|
# The modalities of the image data [image, multi-images, video]
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
# The path to the LoRA
|
# The path to the LoRA adaptors
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
||||||
|
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
# Session info for continual prompting
|
# Session info for continual prompting
|
||||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
@@ -500,7 +502,7 @@ class TokenizedGenerateReqInput:
|
|||||||
stream: bool
|
stream: bool
|
||||||
|
|
||||||
# LoRA related
|
# LoRA related
|
||||||
lora_path: Optional[str] = None # None means just use the base model
|
lora_id: Optional[str] = None # None means just use the base model
|
||||||
# The input embeds
|
# The input embeds
|
||||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -423,7 +423,7 @@ class Req:
|
|||||||
token_ids_logprob: List[int] = None,
|
token_ids_logprob: List[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||||
lora_path: Optional[str] = None,
|
lora_id: Optional[str] = None,
|
||||||
input_embeds: Optional[List[List[float]]] = None,
|
input_embeds: Optional[List[List[float]]] = None,
|
||||||
token_type_ids: List[int] = None,
|
token_type_ids: List[int] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
@@ -467,7 +467,7 @@ class Req:
|
|||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.custom_logit_processor = custom_logit_processor
|
self.custom_logit_processor = custom_logit_processor
|
||||||
self.return_hidden_states = return_hidden_states
|
self.return_hidden_states = return_hidden_states
|
||||||
self.lora_path = lora_path
|
self.lora_id = lora_id
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
self.req_pool_idx: Optional[int] = None
|
self.req_pool_idx: Optional[int] = None
|
||||||
@@ -1750,7 +1750,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
encoder_lens=self.encoder_lens,
|
encoder_lens=self.encoder_lens,
|
||||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||||
lora_paths=[req.lora_path for req in self.reqs],
|
lora_ids=[req.lora_id for req in self.reqs],
|
||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
input_embeds=self.input_embeds,
|
input_embeds=self.input_embeds,
|
||||||
token_type_ids=self.token_type_ids,
|
token_type_ids=self.token_type_ids,
|
||||||
@@ -1891,7 +1891,7 @@ class ModelWorkerBatch:
|
|||||||
encoder_out_cache_loc: Optional[torch.Tensor]
|
encoder_out_cache_loc: Optional[torch.Tensor]
|
||||||
|
|
||||||
# For LoRA
|
# For LoRA
|
||||||
lora_paths: Optional[List[str]]
|
lora_ids: Optional[List[str]]
|
||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo
|
sampling_info: SamplingBatchInfo
|
||||||
|
|||||||
@@ -1090,7 +1090,7 @@ class Scheduler(
|
|||||||
top_logprobs_num=recv_req.top_logprobs_num,
|
top_logprobs_num=recv_req.top_logprobs_num,
|
||||||
token_ids_logprob=recv_req.token_ids_logprob,
|
token_ids_logprob=recv_req.token_ids_logprob,
|
||||||
stream=recv_req.stream,
|
stream=recv_req.stream,
|
||||||
lora_path=recv_req.lora_path,
|
lora_id=recv_req.lora_id,
|
||||||
input_embeds=recv_req.input_embeds,
|
input_embeds=recv_req.input_embeds,
|
||||||
custom_logit_processor=recv_req.custom_logit_processor,
|
custom_logit_processor=recv_req.custom_logit_processor,
|
||||||
return_hidden_states=recv_req.return_hidden_states,
|
return_hidden_states=recv_req.return_hidden_states,
|
||||||
@@ -1534,7 +1534,7 @@ class Scheduler(
|
|||||||
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
||||||
|
|
||||||
if self.enable_lora:
|
if self.enable_lora:
|
||||||
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
lora_set = set([req.lora_id for req in self.running_batch.reqs])
|
||||||
|
|
||||||
# Get requests from the waiting queue to a new prefill batch
|
# Get requests from the waiting queue to a new prefill batch
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
@@ -1542,8 +1542,8 @@ class Scheduler(
|
|||||||
self.enable_lora
|
self.enable_lora
|
||||||
and len(
|
and len(
|
||||||
lora_set
|
lora_set
|
||||||
| set([req.lora_path for req in adder.can_run_list])
|
| set([req.lora_id for req in adder.can_run_list])
|
||||||
| set([req.lora_path])
|
| set([req.lora_id])
|
||||||
)
|
)
|
||||||
> self.max_loras_per_batch
|
> self.max_loras_per_batch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -556,7 +556,7 @@ class TokenizerManager:
|
|||||||
if self.server_args.enable_lora and obj.lora_path:
|
if self.server_args.enable_lora and obj.lora_path:
|
||||||
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
||||||
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
||||||
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
||||||
|
|
||||||
self._validate_one_request(obj, input_ids)
|
self._validate_one_request(obj, input_ids)
|
||||||
return self._create_tokenized_object(
|
return self._create_tokenized_object(
|
||||||
@@ -665,7 +665,7 @@ class TokenizerManager:
|
|||||||
bootstrap_host=obj.bootstrap_host,
|
bootstrap_host=obj.bootstrap_host,
|
||||||
bootstrap_port=obj.bootstrap_port,
|
bootstrap_port=obj.bootstrap_port,
|
||||||
bootstrap_room=obj.bootstrap_room,
|
bootstrap_room=obj.bootstrap_room,
|
||||||
lora_path=obj.lora_path,
|
lora_id=obj.lora_id,
|
||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
session_params=session_params,
|
session_params=session_params,
|
||||||
custom_logit_processor=obj.custom_logit_processor,
|
custom_logit_processor=obj.custom_logit_processor,
|
||||||
@@ -773,7 +773,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# Mark ongoing LoRA request as finished.
|
# Mark ongoing LoRA request as finished.
|
||||||
if self.server_args.enable_lora and obj.lora_path:
|
if self.server_args.enable_lora and obj.lora_path:
|
||||||
await self.lora_registry.release(obj.lora_path)
|
await self.lora_registry.release(obj.lora_id)
|
||||||
|
|
||||||
# Check if this was an abort/error created by scheduler
|
# Check if this was an abort/error created by scheduler
|
||||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||||
|
|||||||
@@ -576,11 +576,11 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.model_runner.server_args.enable_lora:
|
if self.model_runner.server_args.enable_lora:
|
||||||
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
|
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
||||||
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
|
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
||||||
lora_paths = [None] * bs
|
lora_ids = [None] * bs
|
||||||
else:
|
else:
|
||||||
lora_paths = None
|
lora_ids = None
|
||||||
|
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
forward_mode=self.capture_forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
@@ -607,11 +607,11 @@ class CudaGraphRunner:
|
|||||||
capture_hidden_mode=self.capture_hidden_mode,
|
capture_hidden_mode=self.capture_hidden_mode,
|
||||||
num_token_non_padded=self.num_token_non_padded,
|
num_token_non_padded=self.num_token_non_padded,
|
||||||
global_forward_mode=self.capture_forward_mode,
|
global_forward_mode=self.capture_forward_mode,
|
||||||
lora_paths=lora_paths,
|
lora_ids=lora_ids,
|
||||||
)
|
)
|
||||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||||
|
|
||||||
if lora_paths is not None:
|
if lora_ids is not None:
|
||||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ class ForwardBatch:
|
|||||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For LoRA
|
# For LoRA
|
||||||
lora_paths: Optional[List[str]] = None
|
lora_ids: Optional[List[str]] = None
|
||||||
|
|
||||||
# For input embeddings
|
# For input embeddings
|
||||||
input_embeds: Optional[torch.Tensor] = None
|
input_embeds: Optional[torch.Tensor] = None
|
||||||
@@ -327,7 +327,7 @@ class ForwardBatch:
|
|||||||
is_extend_in_batch=batch.is_extend_in_batch,
|
is_extend_in_batch=batch.is_extend_in_batch,
|
||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
global_forward_mode=batch.global_forward_mode,
|
global_forward_mode=batch.global_forward_mode,
|
||||||
lora_paths=batch.lora_paths,
|
lora_ids=batch.lora_ids,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
|
|||||||
@@ -468,7 +468,7 @@ class TboForwardBatchPreparer:
|
|||||||
"extend_prefix_lens_cpu",
|
"extend_prefix_lens_cpu",
|
||||||
"extend_seq_lens_cpu",
|
"extend_seq_lens_cpu",
|
||||||
"extend_logprob_start_lens_cpu",
|
"extend_logprob_start_lens_cpu",
|
||||||
"lora_paths",
|
"lora_ids",
|
||||||
]:
|
]:
|
||||||
old_value = getattr(batch, key)
|
old_value = getattr(batch, key)
|
||||||
if old_value is None:
|
if old_value is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user