Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
This commit is contained in:
@@ -18,13 +18,12 @@ limitations under the License.
|
||||
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.utils import is_hip, replace_submodule
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
@@ -208,9 +207,9 @@ class LoRAManager:
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
|
||||
def prepare_lora_batch(self, batch, extend_seq_lens=None):
|
||||
def prepare_lora_batch(self, input_metadata: InputMetadata):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set([req.lora_path for req in batch.reqs])
|
||||
cur_uids = set(input_metadata.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
evictable_uids = list(self.active_uids)
|
||||
@@ -230,11 +229,15 @@ class LoRAManager:
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
bs = len(batch.reqs)
|
||||
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
|
||||
bs = input_metadata.batch_size
|
||||
seg_lens = (
|
||||
input_metadata.extend_seq_lens
|
||||
if input_metadata.forward_mode.is_extend()
|
||||
else torch.ones(bs)
|
||||
)
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, req in enumerate(batch.reqs):
|
||||
weight_indices[i] = self.buffer_id[req.lora_path]
|
||||
for i, lora_path in enumerate(input_metadata.lora_paths):
|
||||
weight_indices[i] = self.buffer_id[lora_path]
|
||||
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
Reference in New Issue
Block a user