Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884)
Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree
This commit is contained in:
@@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
if mask is not None:
|
||||
raise RuntimeError("Mask must not be set for Flash attention")
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
|
||||
Reference in New Issue
Block a user