Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 96 additions & 51 deletions web-demos/hugging_face/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
"""

import sys
sys.path.append('CodeFormer')
import os

# root path
script_dir = os.path.dirname(os.path.abspath(__file__)) # web-demos/hugging_face
project_root = os.path.dirname(os.path.dirname(script_dir))
sys.path.insert(0, project_root)

import cv2
import torch
import torch.nn.functional as F
Expand All @@ -23,41 +28,60 @@
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray

# working directory & output directory
ui_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(ui_dir, 'output')
example_dir = os.path.join(ui_dir, 'examples')
os.makedirs(output_dir, exist_ok=True)
os.makedirs(example_dir, exist_ok=True)

os.system("pip freeze")
weights_dir = os.path.join(project_root, 'weights')

pretrain_model_url = {
'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
}
# download weights
if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)

# download images
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
'01.png')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
'02.jpg')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
'03.jpg')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
'04.jpg')
torch.hub.download_url_to_file(
'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
'05.jpg')

# download weights (check in project weights directory)
print("Checking pre-trained models...")
weights_info = {
'codeformer': (os.path.join(weights_dir, 'CodeFormer/codeformer.pth'), 'CodeFormer'),
'detection': (os.path.join(weights_dir, 'facelib/detection_Resnet50_Final.pth'), 'detection'),
'parsing': (os.path.join(weights_dir, 'facelib/parsing_parsenet.pth'), 'parsing'),
'realesrgan': (os.path.join(weights_dir, 'realesrgan/RealESRGAN_x2plus.pth'), 'realesrgan'),
}

for model_name, (model_path, model_dir) in weights_info.items():
if not os.path.exists(model_path):
print(f"Downloading {model_name}...")
load_file_from_url(url=pretrain_model_url[model_name],
model_dir=os.path.dirname(model_path),
progress=True, file_name=None)
else:
print(f"✓ Found {model_name}")

# download example images (to examples directory)
print("Checking example images...")
example_files = [
('https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png', '01.png'),
('https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg', '02.jpg'),
('https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg', '03.jpg'),
('https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg', '04.jpg'),
('https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg', '05.jpg'),
]

for url, filename in example_files:
example_path = os.path.join(example_dir, filename)
if not os.path.exists(example_path):
print(f"Downloading example {filename}...")
try:
torch.hub.download_url_to_file(url, example_path)
except Exception as e:
print(f"Failed to download {filename}: {e}")
else:
print(f"✓ Found example {filename}")

def imread(img_path):
img = cv2.imread(img_path)
Expand All @@ -76,9 +100,10 @@ def set_realesrgan():
num_grow_ch=32,
scale=2,
)
realesrgan_path = os.path.join(weights_dir, "realesrgan/RealESRGAN_x2plus.pth")
upsampler = RealESRGANer(
scale=2,
model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
model_path=realesrgan_path,
model=model,
tile=400,
tile_pad=40,
Expand All @@ -90,19 +115,19 @@ def set_realesrgan():
upsampler = set_realesrgan()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = get_device()
print("Loading CodeFormer model...")
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
checkpoint = torch.load(ckpt_path)["params_ema"]
ckpt_path = os.path.join(weights_dir, "CodeFormer/codeformer.pth")
checkpoint = torch.load(ckpt_path, map_location=device)["params_ema"]
codeformer_net.load_state_dict(checkpoint)
codeformer_net.eval()

os.makedirs('output', exist_ok=True)
print("CodeFormer model loaded successfully")

def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
"""Run a single prediction on the model"""
Expand Down Expand Up @@ -204,7 +229,7 @@ def inference(image, background_enhance, face_upsample, upscale, codeformer_fide
)

# save restored img
save_path = f'output/out.png'
save_path = os.path.join(output_dir, 'out.png')
imwrite(restored_img, str(save_path))

restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
Expand Down Expand Up @@ -256,28 +281,48 @@ def inference(image, background_enhance, face_upsample, upscale, codeformer_fide
<center><img src='https://visitor-badge-sczhou.glitch.me/badge?page_id=sczhou/CodeFormer' alt='visitors'></center>
"""

# load examples
examples = []
for i in range(1, 6):
img_name = f'0{i}.png' if i <= 1 else f'0{i}.jpg' if i <= 4 else f'0{i}.jpg'
example_path = os.path.join(example_dir, img_name)
if os.path.exists(example_path):
fidelity = 0.7 if i <= 3 else 0.1
examples.append([example_path, True, True, 2, fidelity])

demo = gr.Interface(
inference, [
gr.inputs.Image(type="filepath", label="Input"),
gr.inputs.Checkbox(default=True, label="Background_Enhance"),
gr.inputs.Checkbox(default=True, label="Face_Upsample"),
gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
gr.Image(type="filepath", label="Input Image"),
gr.Checkbox(value=True, label="Background Enhancement"),
gr.Checkbox(value=True, label="Face Upsampling"),
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label='Fidelity (0=high quality, 1=preserve identity)')
], [
gr.outputs.Image(type="numpy", label="Output"),
gr.outputs.File(label="Download the output")
gr.Image(type="numpy", label="Output Image"),
gr.File(label="Download Result")
],
title=title,
description=description,
article=article,
examples=[
['01.png', True, True, 2, 0.7],
['02.jpg', True, True, 2, 0.7],
['03.jpg', True, True, 2, 0.7],
['04.jpg', True, True, 2, 0.1],
['05.jpg', True, True, 2, 0.1]
]
)
article=article,
examples=examples if examples else None
)

print("\n" + "="*60)
print("CodeFormer UI is starting...")
print("="*60)
print(f"Output directory: {output_dir}")
print(f"Example directory: {example_dir}")
print("="*60 + "\n")

# New Gradio API
try:
demo.queue(max_threads=10)
except TypeError:
# try old API
try:
demo.queue()
except DeprecationWarning:
# ignore
pass

demo.queue(concurrency_count=2)
demo.launch()