Qwen3-Next support (#10233)

Co-authored-by: cao1zhg <114661107+cao1zhg@users.noreply.github.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Binyao Jiang <byjiang1996@gmail.com>
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: Lifu Huang <lifu.hlf@gmail.com>
Co-authored-by: qingquansong <ustcsqq@gmail.com>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Co-authored-by: Ke Bao <ISPObaoke@163.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
This commit is contained in:
Yi Zhang
2025-09-11 19:11:49 +08:00
committed by GitHub
parent bfe01a5eef
commit 30c6e1f569
19 changed files with 3224 additions and 8 deletions

View File

@@ -38,7 +38,7 @@ import threading
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
@@ -59,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
@@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Allocate req slots
bs = len(self.reqs)
req_pool_indices = self.alloc_req_slots(bs)
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
# Init tensors
reqs = self.reqs