Fix oom issues with fp8 for llama (#1454)

This commit is contained in:
Lianmin Zheng
2024-09-18 03:45:19 -07:00
committed by GitHub
parent aa2750beb3
commit 1acccb364a
8 changed files with 33 additions and 21 deletions

View File

@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = self.param_dict
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:

View File

@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module):
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)
@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module):
)
self.eos_token_id = config.eos_token_id
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module):
return logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = self.param_dict
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "classification_head" in name:

View File

@@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = self.param_dict
params_dict = dict(self.named_parameters())
def load_weights_per_param(name, loaded_weight):
if "rotary_emb.inv_freq" in name or "projector" in name:

View File

@@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
self.param_dict = dict(self.named_parameters())
@torch.no_grad()
def forward(
self,
@@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = self.param_dict
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:

View File

@@ -22,6 +22,7 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import kill_child_process
from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600