### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -30,15 +30,12 @@ def fused_gdn_gating_kernel(
|
||||
):
|
||||
i_b, i_s = tl.program_id(0), tl.program_id(1)
|
||||
for row_idx in range(0, ROW_ITER):
|
||||
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(
|
||||
0, BLK_BATCHES)
|
||||
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
|
||||
|
||||
for col_idx in range(0, COL_ITER):
|
||||
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
|
||||
|
||||
off = batch_off[:,
|
||||
None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[
|
||||
None, :]
|
||||
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
|
||||
head_mask = head_off < NUM_HEADS
|
||||
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
|
||||
|
||||
@@ -48,17 +45,14 @@ def fused_gdn_gating_kernel(
|
||||
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
|
||||
|
||||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
|
||||
softplus_x = tl.where(beta * x <= threshold,
|
||||
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
||||
softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
|
||||
|
||||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
|
||||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
|
||||
|
||||
# compute beta_output = sigmoid(b)
|
||||
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
|
||||
tl.store(beta_output + off,
|
||||
blk_beta_output.to(beta_output.dtype.element_ty),
|
||||
mask=mask)
|
||||
tl.store(beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def fused_gdn_gating_patch(
|
||||
@@ -85,17 +79,13 @@ def fused_gdn_gating_patch(
|
||||
progs = num_cores
|
||||
FACTOR = 8 * num_heads
|
||||
row_per_core = triton.cdiv(batch, num_cores)
|
||||
BLK_BATCHES = triton.next_power_of_2(
|
||||
triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) //
|
||||
a.element_size()) // 2
|
||||
BLK_BATCHES = (
|
||||
triton.next_power_of_2(triton.cdiv(UNIFIED_BUFFER_SIZE, FACTOR * BLK_HEADS) // a.element_size()) // 2
|
||||
)
|
||||
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
|
||||
|
||||
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
|
||||
beta_output = torch.empty(1,
|
||||
batch,
|
||||
num_heads,
|
||||
dtype=b.dtype,
|
||||
device=b.device)
|
||||
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
|
||||
|
||||
grid = (progs, seq_len)
|
||||
fused_gdn_gating_kernel[grid](
|
||||
|
||||
Reference in New Issue
Block a user