Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 189 additions & 183 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dill
import yaml
from ultralytics import YOLO
import functools

current_file_path = os.path.abspath(__file__)
current_directory = os.path.dirname(current_file_path)
Expand All @@ -27,6 +28,7 @@
from collections import OrderedDict

cur_device = None
@functools.lru_cache
def get_device():
global cur_device
if cur_device == None:
Expand Down Expand Up @@ -725,110 +727,112 @@ def parsing_command(self, command, motoin_link):

def run(self, retargeting_eyes, retargeting_mouth, turn_on, tracking_src_vid, animate_without_vid, command, crop_factor,
src_images=None, driving_images=None, motion_link=None):
if turn_on == False: return (None,None)
src_length = 1

if src_images == None:
if motion_link != None:
self.psi_list = [motion_link[0]]
else: return (None,None)

if src_images != None:
src_length = len(src_images)
if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor:
self.crop_factor = crop_factor
self.src_images = src_images
if 1 < src_length:
self.psi_list = g_engine.prepare_source(src_images, crop_factor, True, tracking_src_vid)
else:
self.psi_list = [g_engine.prepare_source(src_images, crop_factor)]


cmd_list, cmd_length = self.parsing_command(command, motion_link)
if cmd_list == None: return (None,None)
cmd_idx = 0

driving_length = 0
if driving_images is not None:
if id(driving_images) != id(self.driving_images):
self.driving_images = driving_images
self.driving_values = g_engine.prepare_driving_video(driving_images)
driving_length = len(self.driving_values)

total_length = max(driving_length, src_length)

if animate_without_vid:
total_length = max(total_length, cmd_length)

c_i_es = ExpressionSet()
c_o_es = ExpressionSet()
d_0_es = None
out_list = []

psi = None
pipeline = g_engine.get_pipeline()
for i in range(total_length):

if i < src_length:
psi = self.psi_list[i]
s_info = psi.x_s_info
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))

new_es = ExpressionSet(es = s_es)

if i < cmd_length:
cmd = cmd_list[cmd_idx]
if 0 < cmd.change:
cmd.change -= 1
c_i_es.add(cmd.es)
c_i_es.sub(c_o_es)
elif 0 < cmd.keep:
cmd.keep -= 1

new_es.add(c_i_es)

if cmd.change == 0 and cmd.keep == 0:
cmd_idx += 1
if cmd_idx < len(cmd_list):
c_o_es = ExpressionSet(es = c_i_es)
cmd = cmd_list[cmd_idx]
c_o_es.div(cmd.change)
elif 0 < cmd_length:
new_es.add(c_i_es)

if i < driving_length:
d_i_info = self.driving_values[i]
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")

if d_0_es is None:
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))

retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))

new_es.e += d_i_info['exp'] - d_0_es.e
new_es.r += d_i_r - d_0_es.r
new_es.t += d_i_info['t'] - d_0_es.t

r_new = get_rotation_matrix(
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
d_new = new_es.s * (new_es.e @ r_new) + new_es.t
d_new = pipeline.stitching(psi.x_s_user, d_new)
crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
crop_out = pipeline.parse_output(crop_out['out'])[0]

crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
cv2.INTER_LINEAR)
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
np.uint8)
out_list.append(out)

self.pbar.update_absolute(i+1, total_length, ("PNG", Image.fromarray(crop_out), None))

if len(out_list) == 0: return (None,)

out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
return (out_imgs,)
with torch.autocast(device_type=get_device().type, enabled=get_device().type == "cuda"):

if turn_on == False: return (None,None)
src_length = 1

if src_images == None:
if motion_link != None:
self.psi_list = [motion_link[0]]
else: return (None,None)

if src_images != None:
src_length = len(src_images)
if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor:
self.crop_factor = crop_factor
self.src_images = src_images
if 1 < src_length:
self.psi_list = g_engine.prepare_source(src_images, crop_factor, True, tracking_src_vid)
else:
self.psi_list = [g_engine.prepare_source(src_images, crop_factor)]


cmd_list, cmd_length = self.parsing_command(command, motion_link)
if cmd_list == None: return (None,None)
cmd_idx = 0

driving_length = 0
if driving_images is not None:
if id(driving_images) != id(self.driving_images):
self.driving_images = driving_images
self.driving_values = g_engine.prepare_driving_video(driving_images)
driving_length = len(self.driving_values)

total_length = max(driving_length, src_length)

if animate_without_vid:
total_length = max(total_length, cmd_length)

c_i_es = ExpressionSet()
c_o_es = ExpressionSet()
d_0_es = None
out_list = []

psi = None
pipeline = g_engine.get_pipeline()
for i in range(total_length):

if i < src_length:
psi = self.psi_list[i]
s_info = psi.x_s_info
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))

new_es = ExpressionSet(es = s_es)

if i < cmd_length:
cmd = cmd_list[cmd_idx]
if 0 < cmd.change:
cmd.change -= 1
c_i_es.add(cmd.es)
c_i_es.sub(c_o_es)
elif 0 < cmd.keep:
cmd.keep -= 1

new_es.add(c_i_es)

if cmd.change == 0 and cmd.keep == 0:
cmd_idx += 1
if cmd_idx < len(cmd_list):
c_o_es = ExpressionSet(es = c_i_es)
cmd = cmd_list[cmd_idx]
c_o_es.div(cmd.change)
elif 0 < cmd_length:
new_es.add(c_i_es)

if i < driving_length:
d_i_info = self.driving_values[i]
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])

if d_0_es is None:
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))

retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))

new_es.e += d_i_info['exp'] - d_0_es.e
new_es.r += d_i_r - d_0_es.r
new_es.t += d_i_info['t'] - d_0_es.t

r_new = get_rotation_matrix(
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
d_new = new_es.s * (new_es.e @ r_new) + new_es.t
d_new = pipeline.stitching(psi.x_s_user, d_new)
crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
crop_out = pipeline.parse_output(crop_out['out'])[0]

crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
cv2.INTER_LINEAR)
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
np.uint8)
out_list.append(out)

self.pbar.update_absolute(i+1, total_length, ("PNG", Image.fromarray(crop_out), None))

if len(out_list) == 0: return (None,)

out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
return (out_imgs,)

class ExpressionEditor:
def __init__(self):
Expand Down Expand Up @@ -883,85 +887,87 @@ def INPUT_TYPES(s):

def run(self, rotate_pitch, rotate_yaw, rotate_roll, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
src_ratio, sample_ratio, sample_parts, crop_factor, src_image=None, sample_image=None, motion_link=None, add_exp=None):
rotate_yaw = -rotate_yaw

new_editor_link = None
if motion_link != None:
self.psi = motion_link[0]
new_editor_link = motion_link.copy()
elif src_image != None:
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
self.crop_factor = crop_factor
self.psi = g_engine.prepare_source(src_image, crop_factor)
self.src_image = src_image
new_editor_link = []
new_editor_link.append(self.psi)
else:
return (None,None)

pipeline = g_engine.get_pipeline()

psi = self.psi
s_info = psi.x_s_info
#delta_new = copy.deepcopy()
s_exp = s_info['exp'] * src_ratio
s_exp[0, 5] = s_info['exp'][0, 5]
s_exp += s_info['kp']

es = ExpressionSet()

if sample_image != None:
if id(self.sample_image) != id(sample_image):
self.sample_image = sample_image
d_image_np = (sample_image * 255).byte().numpy()
d_face = g_engine.crop_face(d_image_np[0], 1.7)
i_d = g_engine.prepare_src_image(d_face)
self.d_info = pipeline.get_kp_info(i_d)
self.d_info['exp'][0, 5, 0] = 0
self.d_info['exp'][0, 5, 1] = 0

# "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
if sample_parts == "OnlyExpression" or sample_parts == "All":
es.e += self.d_info['exp'] * sample_ratio
if sample_parts == "OnlyRotation" or sample_parts == "All":
rotate_pitch += self.d_info['pitch'] * sample_ratio
rotate_yaw += self.d_info['yaw'] * sample_ratio
rotate_roll += self.d_info['roll'] * sample_ratio
elif sample_parts == "OnlyMouth":
retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
elif sample_parts == "OnlyEyes":
retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))

es.r = g_engine.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
rotate_pitch, rotate_yaw, rotate_roll)

if add_exp != None:
es.add(add_exp)

new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
s_info['roll'] + es.r[2])
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']

x_d_new = pipeline.stitching(psi.x_s_user, x_d_new)

crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
crop_out = pipeline.parse_output(crop_out['out'])[0]

crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)

out_img = pil2tensor(out)

filename = g_engine.get_temp_img_name() #"fe_edit_preview.png"
folder_paths.get_save_image_path(filename, folder_paths.get_temp_directory())
img = Image.fromarray(crop_out)
img.save(os.path.join(folder_paths.get_temp_directory(), filename), compress_level=1)
results = list()
results.append({"filename": filename, "type": "temp"})

new_editor_link.append(es)

return {"ui": {"images": results}, "result": (out_img, new_editor_link, es)}
with torch.autocast(device_type=get_device().type, enabled=get_device().type == "cuda"):

rotate_yaw = -rotate_yaw

new_editor_link = None
if motion_link != None:
self.psi = motion_link[0]
new_editor_link = motion_link.copy()
elif src_image != None:
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
self.crop_factor = crop_factor
self.psi = g_engine.prepare_source(src_image, crop_factor)
self.src_image = src_image
new_editor_link = []
new_editor_link.append(self.psi)
else:
return (None,None)

pipeline = g_engine.get_pipeline()

psi = self.psi
s_info = psi.x_s_info
#delta_new = copy.deepcopy()
s_exp = s_info['exp'] * src_ratio
s_exp[0, 5] = s_info['exp'][0, 5]
s_exp += s_info['kp']

es = ExpressionSet()

if sample_image != None:
if id(self.sample_image) != id(sample_image):
self.sample_image = sample_image
d_image_np = (sample_image * 255).byte().numpy()
d_face = g_engine.crop_face(d_image_np[0], 1.7)
i_d = g_engine.prepare_src_image(d_face)
self.d_info = pipeline.get_kp_info(i_d)
self.d_info['exp'][0, 5, 0] = 0
self.d_info['exp'][0, 5, 1] = 0

# "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
if sample_parts == "OnlyExpression" or sample_parts == "All":
es.e += self.d_info['exp'] * sample_ratio
if sample_parts == "OnlyRotation" or sample_parts == "All":
rotate_pitch += self.d_info['pitch'] * sample_ratio
rotate_yaw += self.d_info['yaw'] * sample_ratio
rotate_roll += self.d_info['roll'] * sample_ratio
elif sample_parts == "OnlyMouth":
retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
elif sample_parts == "OnlyEyes":
retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))

es.r = g_engine.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
rotate_pitch, rotate_yaw, rotate_roll)

if add_exp != None:
es.add(add_exp)

new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
s_info['roll'] + es.r[2])
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']

x_d_new = pipeline.stitching(psi.x_s_user, x_d_new)

crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
crop_out = pipeline.parse_output(crop_out['out'])[0]

crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)

out_img = pil2tensor(out)

filename = g_engine.get_temp_img_name() #"fe_edit_preview.png"
folder_paths.get_save_image_path(filename, folder_paths.get_temp_directory())
img = Image.fromarray(crop_out)
img.save(os.path.join(folder_paths.get_temp_directory(), filename), compress_level=1)
results = list()
results.append({"filename": filename, "type": "temp"})

new_editor_link.append(es)

return {"ui": {"images": results}, "result": (out_img, new_editor_link, es)}

NODE_CLASS_MAPPINGS = {
"AdvancedLivePortrait": AdvancedLivePortrait,
Expand Down