[1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251)
This commit is contained in:
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
|
||||
@dataclass
|
||||
class LoRABatchInfo:
|
||||
# The forward mode is using CUDA Graph.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Batch size
|
||||
bs: int
|
||||
|
||||
# Lengths of each sequence in shape (bs,)
|
||||
seg_lens: torch.Tensor
|
||||
# Number of segments. For triton backend, it is equal to batch size.
|
||||
num_segments: int
|
||||
|
||||
# Indice pointers of each sequence in shape (bs + 1, )
|
||||
# Indice pointers of each segment in shape (num_segments + 1, )
|
||||
seg_indptr: torch.Tensor
|
||||
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
# The index of lora adapter used by each segment, in shape (num_segments,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
# ranks of each lora adapter, in shape (lora_num,)
|
||||
@@ -31,6 +31,15 @@ class LoRABatchInfo:
|
||||
# scaling of each lora adapter, in shape (lora_num,)
|
||||
scalings: torch.Tensor
|
||||
|
||||
# Lengths of each segments in shape (num_segments,)
|
||||
seg_lens: Optional[torch.Tensor]
|
||||
|
||||
# Maximum segment length of current batch
|
||||
max_len: Optional[int]
|
||||
|
||||
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
|
||||
permutation: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class LoRAType(Enum):
|
||||
LORA_A = 0
|
||||
|
||||
Reference in New Issue
Block a user