Skip to content

Commit

Permalink
Fixed mistakenly commited inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ohayonguy committed Oct 14, 2024
1 parent f8cbdf2 commit 3caa3b5
Showing 1 changed file with 1 addition and 55 deletions.
56 changes: 1 addition & 55 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,21 @@
from torchvision.transforms.functional import to_tensor
from torchvision.utils import save_image
import os
import cv2
from torch.utils.data import DataLoader
from torch_datasets.image_folder_dataset import ImageFolderDataset
from tqdm import tqdm
import numpy as np
import argparse
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from realesrgan.utils import RealESRGANer


torch.set_float32_matmul_precision('high')
torch.set_grad_enabled(False)



realesrgan_folder = "checkpoints"
os.makedirs(realesrgan_folder, exist_ok=True)
realesr_model_path = f"{realesrgan_folder}/RealESRGAN_x4plus.pth"
if not os.path.exists(realesr_model_path):
os.system(
f"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O {realesrgan_folder}/RealESRGAN_x4plus.pth"
)


def set_realesrgan():
use_half = False
if torch.cuda.is_available(): # set False in CPU/MPS mode
no_half_gpu_list = ["1650", "1660"] # set False for GPUs that don't support f16
if not True in [
gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list
]:
use_half = True

model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
model=model,
tile=400,
tile_pad=40,
pre_pad=0,
half=use_half,
)
return upsampler


upsampler = set_realesrgan()
def resize(img, size):
# From https://github.com/sczhou/CodeFormer/blob/master/facelib/utils/face_restoration_helper.py
h, w = img.shape[0:2]
scale = size / min(h, w)
h, w = int(h * scale), int(w * scale)
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
return cv2.resize(img, (w, h), interpolation=interp)

def main(args):
torch.manual_seed(args.seed)
np.random.seed(args.seed)

def identity(x):
return x, None
ds = ImageFolderDataset(args.lq_data_path, degradation=identity, transform=to_tensor)
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0)

output_path = os.path.join(args.output_dir, 'restored_images')
os.makedirs(output_path, exist_ok=True)

Expand Down Expand Up @@ -128,7 +76,5 @@ def identity(x):
help='Number of flow steps to use for inference.')
parser.add_argument('--batch_size', type=int, required=False, default=64,
help='Batch size for inference.')
parser.add_argument('--seed', type=int, required=False, default=0,
help='The input random seed.')

main(parser.parse_args())

0 comments on commit 3caa3b5

Please sign in to comment.