fix black in pre-commit (#1940)

This commit is contained in:
Chayenne
2024-11-07 15:42:47 -08:00
committed by GitHub
parent dca87ec348
commit c77c1e05ba
29 changed files with 641 additions and 508 deletions

View File

@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
#from sglang.srt.layers.activation import get_act_fn
# from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module):
self,
layer_id: int,
config: GPT2Config,
cache_config = None,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.attn = RadixAttention(self.num_heads,
self.head_dim,
scaling=self.scale,
num_kv_heads=total_num_heads,
layer_id=layer_id)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
scaling=self.scale,
num_kv_heads=total_num_heads,
layer_id=layer_id,
)
def forward(
self,
@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
self.act = get_act_fn(
config.activation_function, quant_config, intermediate_size
)
def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
@@ -135,27 +140,20 @@ class GPT2Block(nn.Module):
self,
layer_id: int,
config: GPT2Config,
cache_config = None,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(layer_id,
config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.attn = GPT2Attention(
layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim,
config,
quant_config,
prefix=f"{prefix}.mlp")
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
@@ -179,13 +177,12 @@ class GPT2Block(nn.Module):
return hidden_states
class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config = None,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config = None,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config,
cache_config,
quant_config,
prefix="transformer")
self.transformer = GPT2Model(
config, cache_config, quant_config, prefix="transformer"
)
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config)
@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module):
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module):
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = GPT2LMHeadModel