Support compressed tensors fp8w8a8 (#4743)
This commit is contained in:
@@ -15,6 +15,11 @@ else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
def is_fp8_fnuz() -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
@@ -120,3 +125,29 @@ def requantize_with_max_scale(
|
||||
start = end
|
||||
|
||||
return max_w_scale, weight
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_parameter(
|
||||
mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
|
||||
) -> None:
|
||||
|
||||
old = getattr(mod, name)
|
||||
if (
|
||||
type(old) is type(new)
|
||||
and old.dtype == new.dtype
|
||||
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
|
||||
):
|
||||
# If we can just update in-place to avoid re-registering
|
||||
# can be faster if the underlying storage is the same
|
||||
update_tensor_inplace(old, new)
|
||||
else:
|
||||
# Fallback re-register parameter, convert to Parameter if necessary
|
||||
# this not only ensures we don't register a tensor as a parameter, but
|
||||
# also ensures that all parameter subclasses get re-registered as
|
||||
# parameters for `torch.compile` compatibility
|
||||
if not isinstance(new, torch.nn.Parameter):
|
||||
new = torch.nn.Parameter(new, requires_grad=False)
|
||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||
|
||||
Reference in New Issue
Block a user