Support new flux model variants.
This commit is contained in:
@@ -16,9 +16,9 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
@@ -35,6 +35,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -51,7 +53,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
|
||||
Reference in New Issue
Block a user