[1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251)

This commit is contained in:
Lifu Huang
2025-09-10 09:58:37 -07:00
committed by GitHub
parent cda7e47ce7
commit 941002945b
6 changed files with 227 additions and 130 deletions

View File

@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
@dataclass
class LoRABatchInfo:
# The forward mode is using CUDA Graph.
use_cuda_graph: bool
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Number of segments. For triton backend, it is equal to batch size.
num_segments: int
# Indice pointers of each sequence in shape (bs + 1, )
# Indice pointers of each segment in shape (num_segments + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
# The index of lora adapter used by each segment, in shape (num_segments,)
weight_indices: torch.Tensor
# ranks of each lora adapter, in shape (lora_num,)
@@ -31,6 +31,15 @@ class LoRABatchInfo:
# scaling of each lora adapter, in shape (lora_num,)
scalings: torch.Tensor
# Lengths of each segments in shape (num_segments,)
seg_lens: Optional[torch.Tensor]
# Maximum segment length of current batch
max_len: Optional[int]
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
permutation: Optional[torch.Tensor]
class LoRAType(Enum):
LORA_A = 0