[2/N][Refactor][Qwen3-Next] remove redundant methods and patch methods in Qwen3NextGatedDeltaNet (#3082)
### What this PR does / why we need it?
remove redundant methods and patch methods in Qwen3NextGatedDeltaNet
involved causal_conv1d_fn, causal_conv1d_update_npu, fused_gdn_gating,
fused_reccrrent_gated_delta_rule, torch_chunk_gated_delta_rule,
RMSNormGated
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
```
def main():
prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
trust_remote_code=True,
max_model_len=256,
gpu_memory_utilization=0.7,
block_size=64,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
CI passed with new added/existing test.
- vLLM version: v0.10.2
- vLLM main:
5aeb925452
---------
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -9,109 +9,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True,
|
||||
):
|
||||
dtype = x.dtype
|
||||
#N = x.shape[-1]
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
|
||||
layer_norm_fwd_kernel
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
@@ -158,7 +57,7 @@ def _layer_norm_fwd(
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.npu.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
@@ -222,160 +121,98 @@ class LayerNormFn(torch.autograd.Function):
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = torch.nn.Parameter(
|
||||
torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGated(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_gdn_gating_kernel(
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
NUM_HEADS: tl.constexpr,
|
||||
beta: tl.constexpr,
|
||||
threshold: tl.constexpr,
|
||||
BLK_HEADS: tl.constexpr,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
|
||||
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
|
||||
mask = head_off < NUM_HEADS
|
||||
blk_A_log = tl.load(A_log + head_off, mask=mask)
|
||||
blk_a = tl.load(a + off, mask=mask)
|
||||
blk_bias = tl.load(dt_bias + head_off, mask=mask)
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
|
||||
softplus_x = tl.where(beta * x <= threshold,
|
||||
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = F.normalize(query, p=2, dim=-1)
|
||||
key = F.normalize(key, p=2, dim=-1)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, sequence_length, num_heads, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
tot_heads = num_heads + pad_size
|
||||
scale = 1 / (query.shape[-1]**0.5)
|
||||
query = query * scale
|
||||
|
||||
def fused_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
) -> torch.Tensor:
|
||||
batch, num_heads = a.shape
|
||||
seq_len = 1
|
||||
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
|
||||
g = torch.empty_like(a, dtype=torch.float32)
|
||||
fused_gdn_gating_kernel[grid](g,
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
seq_len,
|
||||
num_heads,
|
||||
beta,
|
||||
threshold,
|
||||
8,
|
||||
num_warps=1)
|
||||
return g
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) -
|
||||
g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -(
|
||||
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
|
||||
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
|
||||
k_head_dim, v_head_dim).to(value) if
|
||||
initial_state is None else initial_state.to(value))
|
||||
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, tot_heads // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) *
|
||||
decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
|
||||
(k_i *
|
||||
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
|
||||
-1, -2) @ v_new)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
|
||||
core_attn_out.shape[1], -1,
|
||||
core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :num_heads]
|
||||
core_attn_out = core_attn_out.transpose(1,
|
||||
2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
|
||||
@@ -97,16 +97,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
# p_q = q + (bos * H + i_h) * K + o_k
|
||||
# p_k = k + (bos * H + i_h) * K + o_k
|
||||
# p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
# if IS_BETA_HEADWISE:
|
||||
# p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||
# else:
|
||||
# p_beta = beta + bos * HV + i_hv
|
||||
# p_g = g + bos * HV + i_hv
|
||||
# p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
@@ -170,13 +160,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
# p_q += H * K
|
||||
# p_k += H * K
|
||||
# p_o += HV * V
|
||||
# p_v += HV * V
|
||||
# p_g += HV
|
||||
# p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
@@ -342,13 +325,11 @@ def fused_recurrent_gated_delta_rule(
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
@@ -400,4 +381,4 @@ def fused_recurrent_gated_delta_rule(
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
return o, final_state
|
||||
Reference in New Issue
Block a user