[v0.18.0][Bugfix][EAGLE] Fix FIA pad bug under max concurrency (#7754)
cherry picked from https://github.com/vllm-project/vllm-ascend/pull/7740
Fixes padding problems of FIA op under max concurrency.
- vLLM version: v0.18.0
- vLLM main:
35141a7eed
Signed-off-by: Wangbingjie <wangbj1207@126.com>
This commit is contained in:
@@ -534,3 +534,32 @@ def test_parallel_drafting_acceptance(
|
|||||||
print(f"golden: {golden}")
|
print(f"golden: {golden}")
|
||||||
|
|
||||||
assert match
|
assert match
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("method", MODELS.keys())
|
||||||
|
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
||||||
|
def test_eagle3_fia_pad_under_max_concurrency(
|
||||||
|
method: str,
|
||||||
|
num_speculative_tokens: int,
|
||||||
|
):
|
||||||
|
main_model_name = MODELS[method]["main"]
|
||||||
|
spec_model_name = MODELS[method]["spec"]
|
||||||
|
prompts = [
|
||||||
|
"Hello, I am",
|
||||||
|
]
|
||||||
|
speculative_config = {
|
||||||
|
"method": method,
|
||||||
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
"model": spec_model_name,
|
||||||
|
}
|
||||||
|
max_num_tokens = 1 + num_speculative_tokens
|
||||||
|
compilation_config = CompilationConfig(cudagraph_mode="FULL_DECODE_ONLY",cudagraph_capture_sizes=[max_num_tokens])
|
||||||
|
with VllmRunner(
|
||||||
|
main_model_name,
|
||||||
|
max_model_len=2048,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
max_num_batched_tokens=max_num_tokens,
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
) as llm:
|
||||||
|
_ = llm.generate_greedy(prompts, max_tokens=10)
|
||||||
|
|||||||
@@ -168,6 +168,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
# RoPE need (max_num_tokens,)
|
# RoPE need (max_num_tokens,)
|
||||||
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
|
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
self.token_arange_np = np.arange(self.max_num_tokens + 1)
|
||||||
|
|
||||||
def _get_model(self) -> nn.Module:
|
def _get_model(self) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Default method to call get_model(). Can be overridden by subclasses which
|
Default method to call get_model(). Can be overridden by subclasses which
|
||||||
|
|||||||
Reference in New Issue
Block a user