Fix Flux2 reference image mem estimation. (#10905)
This commit is contained in:
@@ -926,7 +926,7 @@ class Flux(BaseModel):
|
|||||||
out = {}
|
out = {}
|
||||||
ref_latents = kwargs.get("reference_latents", None)
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Flux2(Flux):
|
class Flux2(Flux):
|
||||||
|
|||||||
Reference in New Issue
Block a user