support for interns1-mini (#9299)
This commit is contained in:
@@ -21,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.internvl import InternVisionModel
|
from sglang.srt.models.internvl import InternVisionModel
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
||||||
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||||
from sglang.utils import logger
|
from sglang.utils import logger
|
||||||
|
|
||||||
@@ -70,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|||||||
self.language_model = Qwen3MoeForCausalLM(
|
self.language_model = Qwen3MoeForCausalLM(
|
||||||
config=config.text_config, quant_config=quant_config
|
config=config.text_config, quant_config=quant_config
|
||||||
)
|
)
|
||||||
|
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||||
|
self.language_model = Qwen3ForCausalLM(
|
||||||
|
config=config.text_config, quant_config=quant_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{config.text_config.architectures[0]} is not implemented."
|
f"{config.text_config.architectures[0]} is not implemented."
|
||||||
|
|||||||
@@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
# For EAGLE3 support
|
# For EAGLE3 support
|
||||||
self.capture_aux_hidden_states = False
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user