Refactor previews into one command line argument.

Clean up a few things.
This commit is contained in:
comfyanonymous
2023-06-06 01:26:52 -04:00
parent 081134f5c8
commit a3a713b6c5
8 changed files with 107 additions and 100 deletions

View File

@@ -45,11 +45,12 @@ parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If th
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.")
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.")
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")

View File

@@ -50,9 +50,9 @@ class TAESD(nn.Module):
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"))
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
@staticmethod
def scale_latents(x):

View File

@@ -1,7 +1,6 @@
import torch
import math
import struct
import comfy.model_management
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
@@ -167,8 +166,6 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
comfy.model_management.throw_exception_if_processing_interrupted()
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu()