Fix commandr import; format code
This commit is contained in:
@@ -510,9 +510,13 @@ class ModelRpcServer:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
logits, (_, _, decode_top_logprobs, _, last_logprobs) = (
|
logits, (
|
||||||
self.model_runner.forward(batch, ForwardMode.DECODE)
|
_,
|
||||||
)
|
_,
|
||||||
|
decode_top_logprobs,
|
||||||
|
_,
|
||||||
|
last_logprobs,
|
||||||
|
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
next_token_ids, _ = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|
||||||
|
|||||||
@@ -24,27 +24,30 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from transformers import CohereConfig
|
|
||||||
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size)
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
|
||||||
RowParallelLinear)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
||||||
VocabParallelEmbedding)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
|
||||||
hf_model_weights_iterator)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
LinearMethodBase,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
hf_model_weights_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
@@ -53,14 +56,12 @@ def layer_norm_func(hidden_states, weight, variance_epsilon):
|
|||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
mean = hidden_states.mean(-1, keepdim=True)
|
mean = hidden_states.mean(-1, keepdim=True)
|
||||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
|
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon)
|
||||||
variance_epsilon)
|
|
||||||
hidden_states = weight.to(torch.float32) * hidden_states
|
hidden_states = weight.to(torch.float32) * hidden_states
|
||||||
return hidden_states.to(input_dtype)
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, param_shape=None, eps=1e-5):
|
def __init__(self, param_shape=None, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(param_shape))
|
self.weight = nn.Parameter(torch.ones(param_shape))
|
||||||
@@ -68,8 +69,9 @@ class LayerNorm(nn.Module):
|
|||||||
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||||
|
|
||||||
def forward(self, hidden_states, residuals=None):
|
def forward(self, hidden_states, residuals=None):
|
||||||
hidden_states = layer_norm_func(hidden_states, self.weight,
|
hidden_states = layer_norm_func(
|
||||||
self.variance_epsilon)
|
hidden_states, self.weight, self.variance_epsilon
|
||||||
|
)
|
||||||
return hidden_states, residuals
|
return hidden_states, residuals
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
@@ -79,15 +81,13 @@ class LayerNorm(nn.Module):
|
|||||||
if shard_dim is not None:
|
if shard_dim is not None:
|
||||||
shard_size = param_data.shape[shard_dim]
|
shard_size = param_data.shape[shard_dim]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||||
shard_size)
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||||
class CohereMLP(nn.Module):
|
class CohereMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -119,10 +119,9 @@ class CohereMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CohereAttention(nn.Module):
|
class CohereAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
):
|
):
|
||||||
@@ -148,8 +147,8 @@ class CohereAttention(nn.Module):
|
|||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.max_position_embeddings = getattr(
|
self.max_position_embeddings = getattr(
|
||||||
config, "model_max_length", None) or getattr(
|
config, "model_max_length", None
|
||||||
config, "max_position_embeddings", 8192)
|
) or getattr(config, "max_position_embeddings", 8192)
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||||
@@ -183,12 +182,13 @@ class CohereAttention(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
)
|
)
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
self.q_norm = LayerNorm(
|
||||||
self.head_dim),
|
param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps
|
||||||
eps=config.layer_norm_eps)
|
)
|
||||||
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
|
self.k_norm = LayerNorm(
|
||||||
self.head_dim),
|
param_shape=(self.num_kv_heads, self.head_dim),
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_qk_norm(self, q, k):
|
def _apply_qk_norm(self, q, k):
|
||||||
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
||||||
@@ -216,19 +216,23 @@ class CohereAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CohereDecoderLayer(nn.Module):
|
class CohereDecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
config: CohereConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
linear_method: Optional[LinearMethodBase] = None):
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = CohereAttention(config, layer_id=layer_id, linear_method=linear_method)
|
self.self_attn = CohereAttention(
|
||||||
|
config, layer_id=layer_id, linear_method=linear_method
|
||||||
|
)
|
||||||
|
|
||||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
self.mlp = CohereMLP(config, linear_method=linear_method)
|
||||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
self.input_layernorm = LayerNorm(
|
||||||
eps=config.layer_norm_eps)
|
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -253,23 +257,26 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CohereModel(nn.Module):
|
class CohereModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.hidden_size)
|
config.vocab_size, config.hidden_size
|
||||||
self.layers = nn.ModuleList([
|
)
|
||||||
CohereDecoderLayer(config, i, linear_method=linear_method)
|
self.layers = nn.ModuleList(
|
||||||
for i in range(config.num_hidden_layers)
|
[
|
||||||
])
|
CohereDecoderLayer(config, i, linear_method=linear_method)
|
||||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
for i in range(config.num_hidden_layers)
|
||||||
eps=config.layer_norm_eps)
|
]
|
||||||
|
)
|
||||||
|
self.norm = LayerNorm(
|
||||||
|
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -292,10 +299,9 @@ class CohereModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CohereForCausalLM(nn.Module):
|
class CohereForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: PretrainedConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -311,7 +317,11 @@ class CohereForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata,)
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
input_metadata,
|
||||||
|
)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
@@ -334,7 +344,8 @@ class CohereForCausalLM(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params = set()
|
loaded_params = set()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format, revision):
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
):
|
||||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||||
if shard_name not in name:
|
if shard_name not in name:
|
||||||
continue
|
continue
|
||||||
@@ -355,8 +366,7 @@ class CohereForCausalLM(nn.Module):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
|||||||
@@ -171,7 +171,6 @@ class DbrxExperts(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxAttention(nn.Module):
|
class DbrxAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
@@ -251,7 +250,6 @@ class DbrxAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxFusedNormAttention(nn.Module):
|
class DbrxFusedNormAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
@@ -284,7 +282,6 @@ class DbrxFusedNormAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxBlock(nn.Module):
|
class DbrxBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
@@ -312,7 +309,6 @@ class DbrxBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxModel(nn.Module):
|
class DbrxModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
@@ -351,7 +347,6 @@ class DbrxModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxForCausalLM(nn.Module):
|
class DbrxForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
|
|||||||
@@ -66,9 +66,9 @@ class BenchBatch:
|
|||||||
p_idx = prefix_req_idx[i // fork_num].item()
|
p_idx = prefix_req_idx[i // fork_num].item()
|
||||||
n_idx = self.req_pool_indices[i].item()
|
n_idx = self.req_pool_indices[i].item()
|
||||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||||
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
|
req_to_token[
|
||||||
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
n_idx, prefix_len : prefix_len + extend_len
|
||||||
)
|
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||||
|
|
||||||
def update_decode(self, predict_ids, batch_size):
|
def update_decode(self, predict_ids, batch_size):
|
||||||
assert predict_ids.shape[0] == batch_size
|
assert predict_ids.shape[0] == batch_size
|
||||||
@@ -81,9 +81,9 @@ class BenchBatch:
|
|||||||
self.out_cache_cont_start,
|
self.out_cache_cont_start,
|
||||||
self.out_cache_cont_end,
|
self.out_cache_cont_end,
|
||||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
self.req_to_token_pool.req_to_token[
|
||||||
self.out_cache_loc
|
self.req_pool_indices, self.seq_lens
|
||||||
)
|
] = self.out_cache_loc
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user