[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)
Co-authored-by: ShenAo1111 <1377693092@qq.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
|
||||
super().__init__()
|
||||
self.config: LoRAConfig = config
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
|
||||
# lora weights in cpu. The weights are loaded from checkpoint.
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = None
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
|
||||
)
|
||||
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
for layer in self.layers:
|
||||
layer.load_to_gpu()
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = None
|
||||
for layer in self.layers:
|
||||
layer.offload_from_gpu()
|
||||
|
||||
# initialize the LoRA weights to cpu
|
||||
def initialize_weights(self):
|
||||
|
||||
Reference in New Issue
Block a user