Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip, replace_submodule
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
@@ -207,9 +207,9 @@ class LoRAManager:
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
|
||||
def prepare_lora_batch(self, input_metadata: InputMetadata):
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(input_metadata.lora_paths)
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
evictable_uids = list(self.active_uids)
|
||||
@@ -229,14 +229,14 @@ class LoRAManager:
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
bs = input_metadata.batch_size
|
||||
bs = forward_batch.batch_size
|
||||
seg_lens = (
|
||||
input_metadata.extend_seq_lens
|
||||
if input_metadata.forward_mode.is_extend()
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs)
|
||||
)
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, lora_path in enumerate(input_metadata.lora_paths):
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.buffer_id[lora_path]
|
||||
|
||||
for module_name, module in self.lora_modules:
|
||||
|
||||
Reference in New Issue
Block a user