Clean up some Qwen3-Next and deterministic code (#11585)
This commit is contained in:
@@ -70,7 +70,7 @@ class Mamba2StateShape:
|
|||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
# - they are typically small
|
# - they are typically small
|
||||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
# e.g., QWen3-Next: (32, 128, 128)
|
||||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||||
return Mamba2StateShape(
|
return Mamba2StateShape(
|
||||||
conv=conv_state_shape,
|
conv=conv_state_shape,
|
||||||
|
|||||||
@@ -27,12 +27,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: HybridLayerType
|
|
||||||
class HybridLayerType(enum.Enum):
|
class HybridLayerType(enum.Enum):
|
||||||
full_attention = "attention"
|
full_attention = "attention"
|
||||||
swa_attention = "swa_attention"
|
|
||||||
linear_attention = "linear_attention"
|
linear_attention = "linear_attention"
|
||||||
mamba2 = "mamba"
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3NextConfig(PretrainedConfig):
|
class Qwen3NextConfig(PretrainedConfig):
|
||||||
|
|||||||
@@ -450,13 +450,6 @@ class FalconH1Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class HybridLayerType(enum.Enum):
|
|
||||||
full_attention = "attention"
|
|
||||||
swa_attention = "swa_attention"
|
|
||||||
linear_attention = "linear_attention"
|
|
||||||
mamba2 = "mamba"
|
|
||||||
|
|
||||||
|
|
||||||
class FalconH1ForCausalLM(nn.Module):
|
class FalconH1ForCausalLM(nn.Module):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
|
|
||||||
|
|||||||
@@ -226,10 +226,6 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
def test_deterministic(args):
|
def test_deterministic(args):
|
||||||
# First do some warmups
|
|
||||||
for i in range(3):
|
|
||||||
send_single(args, 16, args.profile)
|
|
||||||
|
|
||||||
if args.test_mode == "single":
|
if args.test_mode == "single":
|
||||||
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
|
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
|
||||||
texts = []
|
texts = []
|
||||||
|
|||||||
Reference in New Issue
Block a user