[bugfix] fix fp32 trans nz (#5068)
### What this PR does / why we need it? fix fp32 trans nz error, disable fp32 dtype trans nz. Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -85,7 +85,7 @@ def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8,
|
||||
and getattr(vllm_config.speculative_config, 'method',
|
||||
None) in ("eagle", "eagle3"))
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
return _ENABLE_NZ if _IS_EAGLE_MODE else False
|
||||
return _ENABLE_NZ
|
||||
|
||||
|
||||
Reference in New Issue
Block a user