From 441cca773d28b2147d9fd14c6e699f29fe9754e7 Mon Sep 17 00:00:00 2001 From: Chen Xuechen Li Date: Wed, 3 Jul 2024 12:23:30 -0700 Subject: [PATCH] support gptj style rope in llama --- python/sglang/srt/models/llama2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index e60b036bd..7bdab0f5d 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 """Inference-only LLaMA model compatible with HuggingFace weights.""" + from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -76,6 +77,7 @@ class LlamaAttention(nn.Module): layer_id: int = 0, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -123,6 +125,7 @@ class LlamaAttention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, ) self.attn = RadixAttention( self.num_heads, @@ -160,9 +163,10 @@ class LlamaDecoderLayer(nn.Module): if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): - rope_scaling[ - "original_max_position_embeddings" - ] = config.original_max_position_embeddings + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, @@ -171,6 +175,7 @@ class LlamaDecoderLayer(nn.Module): layer_id=layer_id, rope_theta=rope_theta, rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, )