Pull in latest upscale model code from chainner.
This commit is contained in:
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OmniSR.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .OSAG import OSAG
|
||||
from .pixelshuffle import pixelshuffle_block
|
||||
|
||||
|
||||
class OmniSR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
**kwargs,
|
||||
):
|
||||
super(OmniSR, self).__init__()
|
||||
self.state = state_dict
|
||||
|
||||
bias = True # Fine to assume this for now
|
||||
block_num = 1 # Fine to assume this for now
|
||||
ffn_bias = True
|
||||
pe = True
|
||||
|
||||
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||
|
||||
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||
if up_scale - int(up_scale) > 0:
|
||||
print(
|
||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||
)
|
||||
up_scale = int(up_scale)
|
||||
res_num = 0
|
||||
for key in state_dict.keys():
|
||||
if "residual_layer" in key:
|
||||
temp_res_num = int(key.split(".")[1])
|
||||
if temp_res_num > res_num:
|
||||
res_num = temp_res_num
|
||||
res_num = res_num + 1 # zero-indexed
|
||||
|
||||
residual_layer = []
|
||||
self.res_num = res_num
|
||||
|
||||
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
|
||||
self.up_scale = up_scale
|
||||
|
||||
for _ in range(res_num):
|
||||
temp_res = OSAG(
|
||||
channel_num=num_feat,
|
||||
bias=bias,
|
||||
block_num=block_num,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=self.window_size,
|
||||
pe=pe,
|
||||
)
|
||||
residual_layer.append(temp_res)
|
||||
self.residual_layer = nn.Sequential(*residual_layer)
|
||||
self.input = nn.Conv2d(
|
||||
in_channels=num_in_ch,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.output = nn.Conv2d(
|
||||
in_channels=num_feat,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||
|
||||
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||
|
||||
# chaiNNer specific stuff
|
||||
self.model_arch = "OmniSR"
|
||||
self.sub_type = "SR"
|
||||
self.in_nc = num_in_ch
|
||||
self.out_nc = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.scale = up_scale
|
||||
|
||||
self.supports_fp16 = True # TODO: Test this
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = 16
|
||||
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
# import pdb; pdb.set_trace()
|
||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2:]
|
||||
x = self.check_image_size(x)
|
||||
|
||||
residual = self.input(x)
|
||||
out = self.residual_layer(residual)
|
||||
|
||||
# origin
|
||||
out = torch.add(self.output(out), residual)
|
||||
out = self.up(out)
|
||||
|
||||
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||
return out
|
||||
Reference in New Issue
Block a user