support gptj style rope in llama
This commit is contained in:
committed by
Ying Sheng
parent
c7709d3abe
commit
441cca773d
@@ -1,6 +1,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -76,6 +77,7 @@ class LlamaAttention(nn.Module):
|
|||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
rope_is_neox_style: bool = True,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -123,6 +125,7 @@ class LlamaAttention(nn.Module):
|
|||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=rope_is_neox_style,
|
||||||
)
|
)
|
||||||
self.attn = RadixAttention(
|
self.attn = RadixAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@@ -160,9 +163,10 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
if rope_scaling is not None and getattr(
|
if rope_scaling is not None and getattr(
|
||||||
config, "original_max_position_embeddings", None
|
config, "original_max_position_embeddings", None
|
||||||
):
|
):
|
||||||
rope_scaling[
|
rope_scaling["original_max_position_embeddings"] = (
|
||||||
"original_max_position_embeddings"
|
config.original_max_position_embeddings
|
||||||
] = config.original_max_position_embeddings
|
)
|
||||||
|
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -171,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
|
rope_is_neox_style=rope_is_neox_style,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user