Eliminate stream sync to speed up LoRA batch init (#6960)
This commit is contained in:
@@ -137,7 +137,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.A_buffer_gate_up = A_buffer
|
self.A_buffer_gate_up = A_buffer
|
||||||
if self.lora_backend.fuse_stacked_lora_b:
|
if self.lora_backend.fuse_stacked_lora_b:
|
||||||
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
||||||
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
|
if getattr(self, "B_buffer_gate_up", None) is None:
|
||||||
self.B_buffer_gate_up = torch.empty(
|
self.B_buffer_gate_up = torch.empty(
|
||||||
(
|
(
|
||||||
B_buffer[0].shape[0],
|
B_buffer[0].shape[0],
|
||||||
@@ -202,7 +202,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||||
|
|
||||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||||
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
|
if getattr(self, "B_buffer_qkv", None) is None:
|
||||||
self.B_buffer_qkv = torch.empty(
|
self.B_buffer_qkv = torch.empty(
|
||||||
(
|
(
|
||||||
B_buffer_q[0].shape[0],
|
B_buffer_q[0].shape[0],
|
||||||
@@ -221,20 +221,17 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Offsets of q/k/v in output dimension
|
# Offsets of q/k/v in output dimension
|
||||||
if not hasattr(self, "output_offset") or self.output_offset is None:
|
if getattr(self, "output_offset", None) is None:
|
||||||
self.output_offset = torch.empty(
|
self.output_offset = torch.tensor(
|
||||||
4, dtype=torch.int32, device=B_buffer_q.device
|
[
|
||||||
|
0,
|
||||||
|
output_dim_q,
|
||||||
|
output_dim_q + output_dim_kv,
|
||||||
|
output_dim_q + 2 * output_dim_kv,
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=B_buffer_q.device,
|
||||||
)
|
)
|
||||||
self.output_offset[:4] = torch.tensor(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
output_dim_q,
|
|
||||||
output_dim_q + output_dim_kv,
|
|
||||||
output_dim_q + 2 * output_dim_kv,
|
|
||||||
],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=B_buffer_q.device,
|
|
||||||
)
|
|
||||||
# For computing number of launched blocks
|
# For computing number of launched blocks
|
||||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user