From 65501a9cf1dc9e73bba24f35b88988f5633866a9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 16 Apr 2024 18:10:12 +0000 Subject: [PATCH] Fix commandr import; format code --- .../sglang/srt/managers/router/model_rpc.py | 10 +- python/sglang/srt/models/commandr.py | 130 ++++++++++-------- python/sglang/srt/models/dbrx.py | 5 - test/srt/model/bench_llama_low_api.py | 12 +- 4 files changed, 83 insertions(+), 74 deletions(-) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 75c152610..cbf504b99 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -510,9 +510,13 @@ class ModelRpcServer: batch.prepare_for_decode() # Forward - logits, (_, _, decode_top_logprobs, _, last_logprobs) = ( - self.model_runner.forward(batch, ForwardMode.DECODE) - ) + logits, ( + _, + _, + decode_top_logprobs, + _, + last_logprobs, + ) = self.model_runner.forward(batch, ForwardMode.DECODE) next_token_ids, _ = batch.sample(logits) next_token_ids = next_token_ids.cpu().tolist() diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index c78fb222a..6f53bf2ec 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -24,27 +24,30 @@ from typing import List, Optional, Tuple import torch import torch.utils.checkpoint -from torch import nn -from torch.nn.parameter import Parameter -from transformers import CohereConfig - -from vllm.model_executor.parallel_utils.parallel_state import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) - from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from torch import nn +from torch.nn.parameter import Parameter +from transformers import PretrainedConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) @torch.compile @@ -53,14 +56,12 @@ def layer_norm_func(hidden_states, weight, variance_epsilon): hidden_states = hidden_states.to(torch.float32) mean = hidden_states.mean(-1, keepdim=True) variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - mean) * torch.rsqrt(variance + - variance_epsilon) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon) hidden_states = weight.to(torch.float32) * hidden_states return hidden_states.to(input_dtype) class LayerNorm(nn.Module): - def __init__(self, param_shape=None, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(param_shape)) @@ -68,8 +69,9 @@ class LayerNorm(nn.Module): set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): - hidden_states = layer_norm_func(hidden_states, self.weight, - self.variance_epsilon) + hidden_states = layer_norm_func( + hidden_states, self.weight, self.variance_epsilon + ) return hidden_states, residuals def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): @@ -79,15 +81,13 @@ class LayerNorm(nn.Module): if shard_dim is not None: shard_size = param_data.shape[shard_dim] start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(shard_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): - def __init__( self, config, @@ -119,10 +119,9 @@ class CohereMLP(nn.Module): class CohereAttention(nn.Module): - def __init__( self, - config: CohereConfig, + config: PretrainedConfig, layer_id: int = 0, linear_method: Optional[LinearMethodBase] = None, ): @@ -148,8 +147,8 @@ class CohereAttention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = getattr( - config, "model_max_length", None) or getattr( - config, "max_position_embeddings", 8192) + config, "model_max_length", None + ) or getattr(config, "max_position_embeddings", 8192) self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) self.use_qk_norm = getattr(config, "use_qk_norm", False) @@ -183,12 +182,13 @@ class CohereAttention(nn.Module): layer_id=layer_id, ) if self.use_qk_norm: - self.q_norm = LayerNorm(param_shape=(self.num_heads, - self.head_dim), - eps=config.layer_norm_eps) - self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, - self.head_dim), - eps=config.layer_norm_eps) + self.q_norm = LayerNorm( + param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps + ) + self.k_norm = LayerNorm( + param_shape=(self.num_kv_heads, self.head_dim), + eps=config.layer_norm_eps, + ) def _apply_qk_norm(self, q, k): q = q.view(*q.shape[:-1], -1, self.head_dim) @@ -216,19 +216,23 @@ class CohereAttention(nn.Module): class CohereDecoderLayer(nn.Module): - - def __init__(self, - config: CohereConfig, - layer_id: int = 0, - linear_method: Optional[LinearMethodBase] = None): + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, layer_id=layer_id, linear_method=linear_method) + self.self_attn = CohereAttention( + config, layer_id=layer_id, linear_method=linear_method + ) self.mlp = CohereMLP(config, linear_method=linear_method) - self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), - eps=config.layer_norm_eps) + self.input_layernorm = LayerNorm( + param_shape=(config.hidden_size), eps=config.layer_norm_eps + ) def forward( self, @@ -253,23 +257,26 @@ class CohereDecoderLayer(nn.Module): class CohereModel(nn.Module): - def __init__( self, - config: CohereConfig, + config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) - self.layers = nn.ModuleList([ - CohereDecoderLayer(config, i, linear_method=linear_method) - for i in range(config.num_hidden_layers) - ]) - self.norm = LayerNorm(param_shape=(config.hidden_size), - eps=config.layer_norm_eps) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.layers = nn.ModuleList( + [ + CohereDecoderLayer(config, i, linear_method=linear_method) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = LayerNorm( + param_shape=(config.hidden_size), eps=config.layer_norm_eps + ) def forward( self, @@ -292,10 +299,9 @@ class CohereModel(nn.Module): class CohereForCausalLM(nn.Module): - def __init__( self, - config: CohereConfig, + config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() @@ -311,7 +317,11 @@ class CohereForCausalLM(nn.Module): positions: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata,) + hidden_states = self.model( + input_ids, + positions, + input_metadata, + ) return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) @@ -334,7 +344,8 @@ class CohereForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) loaded_params = set() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue @@ -355,8 +366,7 @@ class CohereForCausalLM(nn.Module): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 7242a0d37..4742982cf 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -171,7 +171,6 @@ class DbrxExperts(nn.Module): class DbrxAttention(nn.Module): - def __init__( self, config: DbrxConfig, @@ -251,7 +250,6 @@ class DbrxAttention(nn.Module): class DbrxFusedNormAttention(nn.Module): - def __init__( self, config: DbrxConfig, @@ -284,7 +282,6 @@ class DbrxFusedNormAttention(nn.Module): class DbrxBlock(nn.Module): - def __init__( self, config: DbrxConfig, @@ -312,7 +309,6 @@ class DbrxBlock(nn.Module): class DbrxModel(nn.Module): - def __init__( self, config: DbrxConfig, @@ -351,7 +347,6 @@ class DbrxModel(nn.Module): class DbrxForCausalLM(nn.Module): - def __init__( self, config: DbrxConfig, diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py index 973907274..34c64cd6c 100644 --- a/test/srt/model/bench_llama_low_api.py +++ b/test/srt/model/bench_llama_low_api.py @@ -66,9 +66,9 @@ class BenchBatch: p_idx = prefix_req_idx[i // fork_num].item() n_idx = self.req_pool_indices[i].item() req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] - req_to_token[n_idx, prefix_len : prefix_len + extend_len] = ( - self.out_cache_loc[i * extend_len : (i + 1) * extend_len] - ) + req_to_token[ + n_idx, prefix_len : prefix_len + extend_len + ] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len] def update_decode(self, predict_ids, batch_size): assert predict_ids.shape[0] == batch_size @@ -81,9 +81,9 @@ class BenchBatch: self.out_cache_cont_start, self.out_cache_cont_end, ) = self.token_to_kv_pool.alloc_contiguous(batch_size) - self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( - self.out_cache_loc - ) + self.req_to_token_pool.req_to_token[ + self.req_pool_indices, self.seq_lens + ] = self.out_cache_loc self.seq_lens.add_(1)