Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)

This commit is contained in:
Lianmin Zheng
2024-09-29 20:28:45 -07:00
committed by GitHub
parent 55b974f96f
commit 3f0fe08d37
12 changed files with 142 additions and 157 deletions

View File

@@ -18,13 +18,12 @@ limitations under the License.
import re
from dataclasses import dataclass
import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
@@ -208,9 +207,9 @@ class LoRAManager:
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, batch, extend_seq_lens=None):
def prepare_lora_batch(self, input_metadata: InputMetadata):
# load active loras into lora memory pool
cur_uids = set([req.lora_path for req in batch.reqs])
cur_uids = set(input_metadata.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
@@ -230,11 +229,15 @@ class LoRAManager:
return
# setup lora in forward modules
bs = len(batch.reqs)
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
bs = input_metadata.batch_size
seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, req in enumerate(batch.reqs):
weight_indices[i] = self.buffer_id[req.lora_path]
for i, lora_path in enumerate(input_metadata.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)