This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
import torch
from typing import List, Optional
def _apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
assert len(block_size) == 2, "only support dim2 block now"
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
try:
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
mla_oproj_output = None
if memory_recycler is not None:
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
if os1 == input_2d.size(0) and os2 == weight.size(0):
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
except Exception as e:
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)