Improve gemma and documentations (#278)
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from torch import nn
|
||||
from transformers import GemmaConfig
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
@@ -136,7 +136,7 @@ class GemmaAttention(nn.Module):
|
||||
class GemmaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
@@ -190,7 +190,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -213,12 +213,12 @@ class GemmaModel(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if not skip_embed:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
hidden_states = input_embeds
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
hidden_states *= self.config.hidden_size**0.5
|
||||
@@ -262,7 +262,7 @@ class GemmaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
@@ -279,9 +279,9 @@ class GemmaForCausalLM(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
|
||||
@@ -233,9 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata
|
||||
)
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user