diff --git a/web-demos/hugging_face/app.py b/web-demos/hugging_face/app.py index c614e7c8..5037288d 100644 --- a/web-demos/hugging_face/app.py +++ b/web-demos/hugging_face/app.py @@ -4,8 +4,13 @@ """ import sys -sys.path.append('CodeFormer') import os + +# root path +script_dir = os.path.dirname(os.path.abspath(__file__)) # web-demos/hugging_face +project_root = os.path.dirname(os.path.dirname(script_dir)) +sys.path.insert(0, project_root) + import cv2 import torch import torch.nn.functional as F @@ -23,8 +28,14 @@ from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.misc import is_gray +# working directory & output directory +ui_dir = os.path.dirname(os.path.abspath(__file__)) +output_dir = os.path.join(ui_dir, 'output') +example_dir = os.path.join(ui_dir, 'examples') +os.makedirs(output_dir, exist_ok=True) +os.makedirs(example_dir, exist_ok=True) -os.system("pip freeze") +weights_dir = os.path.join(project_root, 'weights') pretrain_model_url = { 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', @@ -32,32 +43,45 @@ 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth', 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth' } -# download weights -if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'): - load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None) -if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'): - load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None) -if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'): - load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None) -if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'): - load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None) - -# download images -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png', - '01.png') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg', - '02.jpg') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg', - '03.jpg') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg', - '04.jpg') -torch.hub.download_url_to_file( - 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg', - '05.jpg') + +# download weights (check in project weights directory) +print("Checking pre-trained models...") +weights_info = { + 'codeformer': (os.path.join(weights_dir, 'CodeFormer/codeformer.pth'), 'CodeFormer'), + 'detection': (os.path.join(weights_dir, 'facelib/detection_Resnet50_Final.pth'), 'detection'), + 'parsing': (os.path.join(weights_dir, 'facelib/parsing_parsenet.pth'), 'parsing'), + 'realesrgan': (os.path.join(weights_dir, 'realesrgan/RealESRGAN_x2plus.pth'), 'realesrgan'), +} + +for model_name, (model_path, model_dir) in weights_info.items(): + if not os.path.exists(model_path): + print(f"Downloading {model_name}...") + load_file_from_url(url=pretrain_model_url[model_name], + model_dir=os.path.dirname(model_path), + progress=True, file_name=None) + else: + print(f"✓ Found {model_name}") + +# download example images (to examples directory) +print("Checking example images...") +example_files = [ + ('https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png', '01.png'), + ('https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg', '02.jpg'), + ('https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg', '03.jpg'), + ('https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg', '04.jpg'), + ('https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg', '05.jpg'), +] + +for url, filename in example_files: + example_path = os.path.join(example_dir, filename) + if not os.path.exists(example_path): + print(f"Downloading example {filename}...") + try: + torch.hub.download_url_to_file(url, example_path) + except Exception as e: + print(f"Failed to download {filename}: {e}") + else: + print(f"✓ Found example {filename}") def imread(img_path): img = cv2.imread(img_path) @@ -76,9 +100,10 @@ def set_realesrgan(): num_grow_ch=32, scale=2, ) + realesrgan_path = os.path.join(weights_dir, "realesrgan/RealESRGAN_x2plus.pth") upsampler = RealESRGANer( scale=2, - model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth", + model_path=realesrgan_path, model=model, tile=400, tile_pad=40, @@ -90,6 +115,7 @@ def set_realesrgan(): upsampler = set_realesrgan() # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = get_device() +print("Loading CodeFormer model...") codeformer_net = ARCH_REGISTRY.get("CodeFormer")( dim_embd=512, codebook_size=1024, @@ -97,12 +123,11 @@ def set_realesrgan(): n_layers=9, connect_list=["32", "64", "128", "256"], ).to(device) -ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth" -checkpoint = torch.load(ckpt_path)["params_ema"] +ckpt_path = os.path.join(weights_dir, "CodeFormer/codeformer.pth") +checkpoint = torch.load(ckpt_path, map_location=device)["params_ema"] codeformer_net.load_state_dict(checkpoint) codeformer_net.eval() - -os.makedirs('output', exist_ok=True) +print("CodeFormer model loaded successfully") def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity): """Run a single prediction on the model""" @@ -204,7 +229,7 @@ def inference(image, background_enhance, face_upsample, upscale, codeformer_fide ) # save restored img - save_path = f'output/out.png' + save_path = os.path.join(output_dir, 'out.png') imwrite(restored_img, str(save_path)) restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) @@ -256,28 +281,48 @@ def inference(image, background_enhance, face_upsample, upscale, codeformer_fide
visitors
""" +# load examples +examples = [] +for i in range(1, 6): + img_name = f'0{i}.png' if i <= 1 else f'0{i}.jpg' if i <= 4 else f'0{i}.jpg' + example_path = os.path.join(example_dir, img_name) + if os.path.exists(example_path): + fidelity = 0.7 if i <= 3 else 0.1 + examples.append([example_path, True, True, 2, fidelity]) + demo = gr.Interface( inference, [ - gr.inputs.Image(type="filepath", label="Input"), - gr.inputs.Checkbox(default=True, label="Background_Enhance"), - gr.inputs.Checkbox(default=True, label="Face_Upsample"), - gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"), - gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)') + gr.Image(type="filepath", label="Input Image"), + gr.Checkbox(value=True, label="Background Enhancement"), + gr.Checkbox(value=True, label="Face Upsampling"), + gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"), + gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label='Fidelity (0=high quality, 1=preserve identity)') ], [ - gr.outputs.Image(type="numpy", label="Output"), - gr.outputs.File(label="Download the output") + gr.Image(type="numpy", label="Output Image"), + gr.File(label="Download Result") ], title=title, description=description, - article=article, - examples=[ - ['01.png', True, True, 2, 0.7], - ['02.jpg', True, True, 2, 0.7], - ['03.jpg', True, True, 2, 0.7], - ['04.jpg', True, True, 2, 0.1], - ['05.jpg', True, True, 2, 0.1] - ] - ) + article=article, + examples=examples if examples else None +) + +print("\n" + "="*60) +print("CodeFormer UI is starting...") +print("="*60) +print(f"Output directory: {output_dir}") +print(f"Example directory: {example_dir}") +print("="*60 + "\n") + +# New Gradio API +try: + demo.queue(max_threads=10) +except TypeError: + # try old API + try: + demo.queue() + except DeprecationWarning: + # ignore + pass -demo.queue(concurrency_count=2) demo.launch() \ No newline at end of file