Files
enginex-vastai-va16-vllm/torch_vacc/vacc/custom_ops_cpu.py
2026-04-02 04:55:00 +00:00

307 lines
10 KiB
Python

from typing import Tuple, Union, Optional, List
import torch
from torch.nn import functional as F
def split_last_two_dims_into_blocks(x, h, w):
leading_dims = x.shape[:-2]
H, W = x.shape[-2:]
assert (
H % h == 0 and W % w == 0
), "The last two dimensions must be divisible by block size."
x_reshaped = x.view(-1, 1, H, W)
unfolded = F.unfold(x_reshaped, kernel_size=(h, w), stride=(h, w))
unfolded = unfolded.view(-1, 1, h, w, H // h, W // w)
unfolded = unfolded.permute(0, 1, 4, 5, 2, 3)
final_shape = leading_dims + (H // h, W // w, h, w)
result = unfolded.view(final_shape)
return result
def merge_blocks_to_original_layout(x, h, w):
leading_dims = x.shape[:-4]
H_div_h, W_div_w, h, w = x.shape[-4:]
H = H_div_h * h
W = W_div_w * w
x_reshaped = x.view(-1, 1, H_div_h, W_div_w, h, w)
x_reshaped = x_reshaped.permute(0, 1, 4, 5, 2, 3)
x_reshaped = x_reshaped.view(-1, h * w, H_div_h * W_div_w)
folded = F.fold(x_reshaped, output_size=(H, W), kernel_size=(h, w), stride=(h, w))
final_shape = leading_dims + (H, W)
result = folded.view(final_shape)
return result
def w8a8_block_fp8_matmul(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
is_linear_weight: bool = False,
output_opt: Optional[torch.Tensor] = None,
**kwargs
):
b0, b1 = block_size
dim0, dim1 = weight.shape
dim0pad, dim1pad = 0, 0
if dim0 % b0 != 0:
dim0pad = b0 - dim0 % b0
if dim1 % b1 != 0:
dim1pad = b1 - dim1 % b1
dim0_origin, dim1_origin = dim0, dim1
dim0 += dim0pad
dim1 += dim1pad
bs0, bs1 = dim0 // b0, dim1 // b1
weight_dequant = torch.nn.functional.pad(weight, (0, dim1pad, 0, dim0pad), value=0)
weight_dequant = weight_dequant.cpu().view(bs0, b0, bs1, b1).permute(
0, 2, 1, 3
).reshape(bs0, bs1, -1).float().to(input.device) * weight_scale.unsqueeze(-1)
weight_dequant = (
weight_dequant.reshape(bs0, bs1, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(input.dtype)
)
weight_dequant = weight_dequant[:dim0_origin, :dim1_origin]
output = torch.matmul(
input, weight_dequant.T if is_linear_weight else weight_dequant
)
if output_opt is not None:
output = output_opt.copy_(output)
return output
def w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
**kwargs
):
assert input_scale is None, "w8a8_block_fp8_matmul only support quant weight now"
return w8a8_block_fp8_matmul(
input, weight, None, weight_scale, block_size, is_linear_weight=True
)
def fused_experts(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
decode_with_batch: bool = False,
) -> torch.Tensor:
batch_seq_all, hidden_dims = hidden_states.shape
intermediate_size = w2_weight.shape[-1]
num_experts = w13_weight.shape[0]
w13_weight = w13_weight.contiguous()
w2_weight = w2_weight.contiguous()
w13_scale = w13_scale.contiguous()
w2_scale = w2_scale.contiguous()
final_hidden_states = torch.zeros_like(hidden_states)
import torch.nn.functional as F
w1_scale = w13_scale
w2_scale = w2_scale
_, bs0_w13, bs1_w13 = w1_scale.shape
_, bs0_w2, bs1_w2 = w2_scale.shape
sel_experts = topk_ids.shape[1]
if hidden_states.shape[0] == 1:
for id in range(sel_experts):
expert_idx = topk_ids[0][id]
expert_w1 = w13_weight[expert_idx].contiguous()
expert_w2 = w2_weight[expert_idx].contiguous()
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
dim0, dim1 = expert_w1.shape
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
expert_w1 = (
expert_w1
.view(bs0_w13, b0, bs1_w13, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w13, bs1_w13, -1)
.float()
.to(hidden_states.device)
* ws1
)
expert_w1 = (
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
dim0, dim1 = expert_w2.shape
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
# assert (bs0, bs1, 1)==ws2.shape
expert_w2 = (
expert_w2
.view(bs0_w2, b0, bs1_w2, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w2, bs1_w2, -1)
.float()
.to(hidden_states.device)
* ws2
)
expert_w2 = (
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
x = hidden_states
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
final_hidden_states += current_hidden_states
else:
for expert_idx in range(num_experts):
# topk_ids [tokens, experts] => sample:[10, 8]
# expert_mask [tokens, experts] => sample:[10, 8]
expert_mask = topk_ids == expert_idx
idx = torch.where(expert_mask)[0]
if idx.numel() == 0:
continue
expert_w1 = w13_weight[expert_idx].contiguous()
expert_w2 = w2_weight[expert_idx].contiguous()
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
dim0, dim1 = expert_w1.shape
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
expert_w1 = (
expert_w1
.view(bs0_w13, b0, bs1_w13, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w13, bs1_w13, -1)
.float()
.to(hidden_states.device)
* ws1
)
expert_w1 = (
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
dim0, dim1 = expert_w2.shape
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
# assert (bs0, bs1, 1)==ws2.shape
expert_w2 = (
expert_w2
.view(bs0_w2, b0, bs1_w2, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w2, bs1_w2, -1)
.float()
.to(hidden_states.device)
* ws2
)
expert_w2 = (
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
# [seq, experts]
expert_weights = (
topk_weights.masked_select(expert_mask)
.unsqueeze(1)
.to(hidden_states.dtype)
)
x = hidden_states[idx]
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
# final_hidden_states[idx] += current_hidden_states
final_hidden_states.index_add_(0, idx, current_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_seq_all, hidden_dims)
return final_hidden_states
def fused_mlp_mm_fp8(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_w13: Optional[List[int]] = None,
block_shape_w2: Optional[List[int]] = None,
):
def fp8_to_fp16(inp, scale, block_size, trans_type):
inp_t = inp.to(trans_type)
inp_t = split_last_two_dims_into_blocks(inp_t, block_size[0], block_size[1])
assert scale.size(0) == inp_t.size(-4)
assert scale.size(1) == inp_t.size(-3)
inp_t = inp_t * scale.unsqueeze(-1).unsqueeze(-1)
inp_t = merge_blocks_to_original_layout(inp_t, block_size[0], block_size[1])
return inp_t.to(trans_type)
w13_weight = w13_weight.contiguous()
w2_weight = w2_weight.contiguous()
w13_scale = w13_scale.contiguous()
w2_scale = w2_scale.contiguous()
w13_fp = fp8_to_fp16(w13_weight, w13_scale, block_shape_w13, hidden_states.dtype)
w2_fp = fp8_to_fp16(w2_weight, w2_scale, block_shape_w2, hidden_states.dtype)
out = hidden_states @ w13_fp
out = torch.chunk(out, 2, dim=-1)
out = F.silu(out[0]) * out[1]
out = out @ w2_fp
return out
def mla_matmul_scale(input: torch.Tensor, weight: torch.Tensor, scale: float):
output = torch.matmul(input, weight)
output = output * scale
output = output.to(input.dtype)
return output
def mla_matmul(input: torch.Tensor, weight: torch.Tensor):
output = torch.matmul(input, weight)
output = output.to(input.dtype)
return output