Fix commandr import; format code
This commit is contained in:
@@ -510,9 +510,13 @@ class ModelRpcServer:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward
|
||||
logits, (_, _, decode_top_logprobs, _, last_logprobs) = (
|
||||
self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
)
|
||||
logits, (
|
||||
_,
|
||||
_,
|
||||
decode_top_logprobs,
|
||||
_,
|
||||
last_logprobs,
|
||||
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
|
||||
@@ -24,27 +24,30 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import CohereConfig
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile
|
||||
@@ -53,14 +56,12 @@ def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
|
||||
variance_epsilon)
|
||||
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon)
|
||||
hidden_states = weight.to(torch.float32) * hidden_states
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, param_shape=None, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(param_shape))
|
||||
@@ -68,8 +69,9 @@ class LayerNorm(nn.Module):
|
||||
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, hidden_states, residuals=None):
|
||||
hidden_states = layer_norm_func(hidden_states, self.weight,
|
||||
self.variance_epsilon)
|
||||
hidden_states = layer_norm_func(
|
||||
hidden_states, self.weight, self.variance_epsilon
|
||||
)
|
||||
return hidden_states, residuals
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
@@ -79,15 +81,13 @@ class LayerNorm(nn.Module):
|
||||
if shard_dim is not None:
|
||||
shard_size = param_data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||
class CohereMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@@ -119,10 +119,9 @@ class CohereMLP(nn.Module):
|
||||
|
||||
|
||||
class CohereAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
@@ -148,8 +147,8 @@ class CohereAttention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = getattr(
|
||||
config, "model_max_length", None) or getattr(
|
||||
config, "max_position_embeddings", 8192)
|
||||
config, "model_max_length", None
|
||||
) or getattr(config, "max_position_embeddings", 8192)
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||
@@ -183,12 +182,13 @@ class CohereAttention(nn.Module):
|
||||
layer_id=layer_id,
|
||||
)
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||
self.head_dim),
|
||||
eps=config.layer_norm_eps)
|
||||
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
|
||||
self.head_dim),
|
||||
eps=config.layer_norm_eps)
|
||||
self.q_norm = LayerNorm(
|
||||
param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps
|
||||
)
|
||||
self.k_norm = LayerNorm(
|
||||
param_shape=(self.num_kv_heads, self.head_dim),
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q, k):
|
||||
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
||||
@@ -216,19 +216,23 @@ class CohereAttention(nn.Module):
|
||||
|
||||
|
||||
class CohereDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CohereConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CohereAttention(config, layer_id=layer_id, linear_method=linear_method)
|
||||
self.self_attn = CohereAttention(
|
||||
config, layer_id=layer_id, linear_method=linear_method
|
||||
)
|
||||
|
||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
self.input_layernorm = LayerNorm(
|
||||
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -253,23 +257,26 @@ class CohereDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
CohereDecoderLayer(config, i, linear_method=linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CohereDecoderLayer(config, i, linear_method=linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LayerNorm(
|
||||
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -292,10 +299,9 @@ class CohereModel(nn.Module):
|
||||
|
||||
|
||||
class CohereForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -311,7 +317,11 @@ class CohereForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata,)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
input_metadata,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
@@ -334,7 +344,8 @@ class CohereForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
@@ -355,8 +366,7 @@ class CohereForCausalLM(nn.Module):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
|
||||
@@ -171,7 +171,6 @@ class DbrxExperts(nn.Module):
|
||||
|
||||
|
||||
class DbrxAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
@@ -251,7 +250,6 @@ class DbrxAttention(nn.Module):
|
||||
|
||||
|
||||
class DbrxFusedNormAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
@@ -284,7 +282,6 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
|
||||
|
||||
class DbrxBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
@@ -312,7 +309,6 @@ class DbrxBlock(nn.Module):
|
||||
|
||||
|
||||
class DbrxModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
@@ -351,7 +347,6 @@ class DbrxModel(nn.Module):
|
||||
|
||||
|
||||
class DbrxForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
|
||||
Reference in New Issue
Block a user