### What this PR does / why we need it?
| File Path |
| :--- |
| `tests/e2e/singlecard/compile/backend.py` |
| `tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py` |
| `tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py` |
| `tests/e2e/singlecard/compile/test_norm_quant_fusion.py` |
| `tests/e2e/singlecard/model_runner_v2/test_basic.py` |
| `tests/e2e/singlecard/test_aclgraph_accuracy.py` |
| `tests/e2e/singlecard/test_aclgraph_batch_invariant.py` |
| `tests/e2e/singlecard/test_aclgraph_mem.py` |
| `tests/e2e/singlecard/test_async_scheduling.py` |
| `tests/e2e/singlecard/test_auto_fit_max_mode_len.py` |
| `tests/e2e/singlecard/test_batch_invariant.py` |
| `tests/e2e/singlecard/test_camem.py` |
| `tests/e2e/singlecard/test_completion_with_prompt_embeds.py` |
| `tests/e2e/singlecard/test_cpu_offloading.py` |
| `tests/e2e/singlecard/test_guided_decoding.py` |
| `tests/e2e/singlecard/test_ilama_lora.py` |
| `tests/e2e/singlecard/test_llama32_lora.py` |
| `tests/e2e/singlecard/test_models.py` |
| `tests/e2e/singlecard/test_multistream_overlap_shared_expert.py` |
| `tests/e2e/singlecard/test_quantization.py` |
| `tests/e2e/singlecard/test_qwen3_multi_loras.py` |
| `tests/e2e/singlecard/test_sampler.py` |
| `tests/e2e/singlecard/test_vlm.py` |
| `tests/e2e/singlecard/test_xlite.py` |
| `tests/e2e/singlecard/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -70,9 +70,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. " *
|
||||
(target_words // 50))
|
||||
padding_text = " This is an interesting topic that deserves more explanation. " * (target_words // 50)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
@@ -83,10 +81,7 @@ def _extract_step_logprobs(request_output):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
[inner.logprobs[i][tid].logprob for i, tid in enumerate(inner.token_ids)],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t, inner.token_ids
|
||||
@@ -95,8 +90,7 @@ def _extract_step_logprobs(request_output):
|
||||
|
||||
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
@@ -184,8 +178,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(
|
||||
_random_prompt(min_random_prompt, max_random_prompt))
|
||||
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
|
||||
|
||||
# Generate with the larger-batch engine
|
||||
outputs = llm.generate(prompts, sampling)
|
||||
@@ -196,27 +189,27 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
text = needle_output.outputs[0].text
|
||||
|
||||
if text != baseline_text:
|
||||
print(
|
||||
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
||||
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
||||
mismatches += 1
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, max_batch_size={max_batch_size}")
|
||||
print(
|
||||
f"[determinism] total={num_trials}, passed={passes}, failed={mismatches}, max_batch_size={max_batch_size}"
|
||||
)
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
|
||||
f"of {num_trials} trials (max_batch_size={max_batch_size})."
|
||||
)
|
||||
|
||||
finally:
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(monkeypatch: pytest.MonkeyPatch):
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = DEFAULT_MODEL
|
||||
@@ -230,9 +223,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(
|
||||
f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})"
|
||||
)
|
||||
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
@@ -266,15 +257,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
@@ -296,108 +284,92 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
failed_prompts = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)):
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)
|
||||
):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)")
|
||||
failed_prompts.append({
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
reason = f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
failed_prompts.append({
|
||||
"prompt_idx":
|
||||
i,
|
||||
"step":
|
||||
"sampling",
|
||||
"reason":
|
||||
"Different tokens sampled",
|
||||
"prompt_preview":
|
||||
prompts[i][:100],
|
||||
"bs1_tokens":
|
||||
tokens_bs1,
|
||||
"bsN_tokens":
|
||||
tokens_bsN,
|
||||
"bs1_all_logprobs":
|
||||
[logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs":
|
||||
[logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
})
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
failed_prompts.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
# Print which token failed
|
||||
print(
|
||||
f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}"
|
||||
)
|
||||
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
failed_prompts.append({
|
||||
"prompt_idx":
|
||||
i,
|
||||
"step":
|
||||
t,
|
||||
"reason":
|
||||
f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview":
|
||||
prompts[i][:100],
|
||||
"bs1_tokens":
|
||||
tokens_bs1,
|
||||
"bsN_tokens":
|
||||
tokens_bsN,
|
||||
"bs1_all_logprobs": [
|
||||
logprobs_bs1[s].tolist()
|
||||
for s in range(len(logprobs_bs1))
|
||||
],
|
||||
"bsN_all_logprobs": [
|
||||
logprobs_bsN[s].tolist()
|
||||
for s in range(len(logprobs_bsN))
|
||||
],
|
||||
})
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
}
|
||||
)
|
||||
break
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
# Print summary of all failures
|
||||
if failed_prompts:
|
||||
print(f"\n{'=' * 80}")
|
||||
fail_msg = (f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts failed")
|
||||
fail_msg = f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/{len(prompts)} prompts failed"
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for fail in failed_prompts:
|
||||
@@ -412,21 +384,18 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
print(f" BS=N tokens: {fail['bsN_tokens']}")
|
||||
|
||||
if "bs1_all_logprobs" in fail:
|
||||
print(
|
||||
f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:"
|
||||
)
|
||||
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
|
||||
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(
|
||||
f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:"
|
||||
)
|
||||
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
|
||||
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Fail the test with summary
|
||||
msg = (f"Batch invariance violated in {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts. See output above for details.")
|
||||
msg = (
|
||||
f"Batch invariance violated in {len(failed_prompts)}/{len(prompts)} prompts. See output above for details."
|
||||
)
|
||||
pytest.fail(msg)
|
||||
|
||||
|
||||
@@ -476,8 +445,7 @@ def test_simple_generation(monkeypatch: pytest.MonkeyPatch):
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_logprobs_without_batch_invariance_should_fail(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_logprobs_without_batch_invariance_should_fail(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||
@@ -540,15 +508,12 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
@@ -569,74 +534,80 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
differences_found = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)):
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)
|
||||
):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)")
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
reason = f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
||||
f"Token {t}: max_diff={max_diff:.6e}")
|
||||
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
@@ -646,7 +617,8 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
success_msg = (
|
||||
f"✓ SUCCESS: Batch invariance is doing something! "
|
||||
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
||||
f"with differences when batch invariance was DISABLED.")
|
||||
f"with differences when batch invariance was DISABLED."
|
||||
)
|
||||
print(success_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for diff in differences_found:
|
||||
@@ -666,7 +638,8 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
||||
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
||||
f"This suggests batch invariance might not be necessary, "
|
||||
f"or the test needs more sensitive prompts.")
|
||||
f"or the test needs more sensitive prompts."
|
||||
)
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}\n")
|
||||
pytest.fail(fail_msg)
|
||||
|
||||
Reference in New Issue
Block a user