fix black in pre-commit (#1940)
This commit is contained in:
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
#from sglang.srt.layers.activation import get_act_fn
|
||||
# from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.attn = RadixAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling=self.scale,
|
||||
num_kv_heads=total_num_heads,
|
||||
layer_id=layer_id)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
scaling=self.scale,
|
||||
num_kv_heads=total_num_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
self.act = get_act_fn(
|
||||
config.activation_function, quant_config, intermediate_size
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
@@ -135,27 +140,20 @@ class GPT2Block(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(layer_id,
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = GPT2Attention(
|
||||
layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim,
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -179,13 +177,12 @@ class GPT2Block(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="transformer")
|
||||
self.transformer = GPT2Model(
|
||||
config, cache_config, quant_config, prefix="transformer"
|
||||
)
|
||||
self.lm_head = self.transformer.wte
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module):
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = GPT2LMHeadModel
|
||||
|
||||
Reference in New Issue
Block a user