Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
class BaseLayerWithLoRA(nn.Module):

View File

@@ -23,7 +23,7 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
@@ -207,9 +207,9 @@ class LoRAManager:
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, input_metadata: InputMetadata):
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(input_metadata.lora_paths)
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
@@ -229,14 +229,14 @@ class LoRAManager:
return
# setup lora in forward modules
bs = input_metadata.batch_size
bs = forward_batch.batch_size
seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(input_metadata.lora_paths):
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules: