Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884)

Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree
This commit is contained in:
blepping
2025-09-15 18:05:03 -06:00
committed by GitHub
parent 47a9cde5d3
commit 1a85483da1
2 changed files with 19 additions and 19 deletions

View File

@@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mask = mask.unsqueeze(1)
try:
assert mask is None
if mask is not None:
raise RuntimeError("Mask must not be set for Flash attention")
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),