diff --git a/python/sglang/srt/models/mimo_mtp.py b/python/sglang/srt/models/mimo_mtp.py index 6c81d8d85..89e8c02cd 100644 --- a/python/sglang/srt/models/mimo_mtp.py +++ b/python/sglang/srt/models/mimo_mtp.py @@ -7,33 +7,17 @@ import torch from torch import nn from transformers import PretrainedConfig -from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, -) +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.mimo import MiMoForCausalLM -from sglang.srt.models.qwen2 import ( - Qwen2Attention, - Qwen2DecoderLayer, - Qwen2MLP, - Qwen2Model, -) -from sglang.srt.utils import add_prefix +from sglang.srt.models.qwen2 import Qwen2DecoderLayer class MiMoMultiTokenPredictorLayer(nn.Module):