diff --git a/python/sglang/srt/models/interns1.py b/python/sglang/srt/models/interns1.py index 267170301..c7383ed25 100644 --- a/python/sglang/srt/models/interns1.py +++ b/python/sglang/srt/models/interns1.py @@ -21,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.internvl import InternVisionModel from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.qwen3 import Qwen3ForCausalLM from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM from sglang.utils import logger @@ -70,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module): self.language_model = Qwen3MoeForCausalLM( config=config.text_config, quant_config=quant_config ) + elif config.text_config.architectures[0] == "Qwen3ForCausalLM": + self.language_model = Qwen3ForCausalLM( + config=config.text_config, quant_config=quant_config + ) else: raise NotImplementedError( f"{config.text_config.architectures[0]} is not implemented." diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 04120e77b..a73d8764a 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module): # For EAGLE3 support self.capture_aux_hidden_states = False - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def get_input_embeddings(self) -> nn.Embedding: + return self.model.get_input_embeddings() @torch.no_grad() def forward(