AI驱动的图像处理管道

你是否看到过那些令人惊叹的 AI 驱动的图像编辑工具随处可见?比如,想象一下从照片中删除前任或将枯燥的背景换成热带海滩——只需点击几下即可!我对 AI 在多媒体领域的潜力非常着迷,而这么多令人惊叹的工具都是开源的这一事实更是酷毙了。这意味着我们都可以进行实验和创造,而无需花费太多钱!”

因此,我一直在尝试一些很棒的开源模型,例如 SAM、LaMa 和 Stable Diffusion,我想与你分享其中的乐趣。在这篇博客中,我们将一起构建一个小型图像编辑管道。可以把它想象成一个由 AI 驱动的 DIY Photoshop。我们将学习如何遮罩物体,让它们消失(噗!),用 AI 生成的精彩内容填补这些空白(也许是独角兽而不是前任?),然后将所有内容无缝融合以获得自然的外观。不需要花哨的软件或昂贵的订阅——只需要开源 AI 的魔力!让我们发挥创意吧!

1、使用 SAM 进行对象选择和蒙版

SAM 是一种强大的图像分割模型,用于为对象创建蒙版。它以图像和提示(点、框)作为输入,并输出详细的分割蒙版。其主要功能包括可提示性、支持交互式对象选择和零样本泛化,允许分割看不见的对象。这是通过对超过 10 亿个蒙版的庞大数据集进行训练实现的。从架构上讲,SAM 通常使用基于 ViT 的图像编码器来提取特征,然后使用轻量级解码器来生成蒙版。在我们的工作流程中,SAM 对于隔离要移除或替换的对象至关重要,为修复和混合步骤奠定了基础。它生成准确蒙版的能力对于以后实现逼真的图像处理至关重要。

我们将从 Unsplash 中的这张图片开始。我们首先使用 initialize_sam_model 加载预先训练的 SAM 模型。然后,调用 load_image 来加载输入图像。遮罩过程的核心发生在 generate_mask中,其中SAM模型根据提供的输入点预测遮罩。为了改进这个初始遮罩, expand_and_feather_mask应用了扩张和模糊以实现更平滑的羽化边缘。最后,保存生成的遮罩以供以后使用。这一系列函数利用SAM的功能为图像处理任务创建准确而精致的对象遮罩。

!pip install -q opencv-python matplotlib segment-anything jupyter_bbox_widget

import os
import urllib.request
import cv2
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import SamPredictor, sam_model_registry
import requests
from io import BytesIO
from PIL import Image
import base64
import torch
from jupyter_bbox_widget import BBoxWidget

# Encode the image to be used by BBOX - visual identification of coordinates
def encode_image(filepath):
    with open(filepath, 'rb') as f:
        image_bytes = f.read()
    encoded = str(base64.b64encode(image_bytes), 'utf-8')
    return "data:image/jpg;base64,"+encoded

# Crop the image so that the images aligns with the standard 512 size
def crop_image(image_path, x, y, crop_size=512):
    image = cv2.imread(image_path)
    height, width, _ = image.shape
    start_x = max(0, x - crop_size // 2)
    start_y = max(0, y - crop_size // 2)
    end_x = min(width, start_x + crop_size)
    end_y = min(height, start_y + crop_size)
    start_x = max(0, end_x - crop_size)
    start_y = max(0, end_y - crop_size)
    cropped_image = image[start_y:end_y, start_x:end_x]
    return cropped_image

# Initialize the SAM model (you can change the model type if needed)
def initialize_sam_model(model_type="vit_h"):
    sam_checkpoint = "/content/models/sam_vit_h_4b8939.pth"  
    sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint).to("cuda")
    return SamPredictor(sam_model)

# Load an image
def load_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

# Generate masks based on input points
def generate_mask(image, predictor, points):
    predictor.set_image(image)
    input_points = np.array(points)  
    input_labels = np.ones(len(points)) 
    masks, scores, _ = predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=True)
    best_mask = masks[np.argmax(scores)]
    return best_mask

# Expand and feather the mask
def expand_and_feather_mask(mask, dilation_iterations=10, blur_kernel_size=21):
    kernel = np.ones((3, 3), np.uint8)  
    expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=dilation_iterations)
    feathered_mask = cv2.GaussianBlur(expanded_mask.astype(np.float32), (blur_kernel_size, blur_kernel_size), 0)
    feathered_mask = np.clip(feathered_mask, 0, 1)
    return feathered_mask

# Save the mask as an image
def save_mask(mask, output_path):
    cv2.imwrite(output_path, (mask * 255).astype(np.uint8))


os.makedirs("models", exist_ok = True)
!wget -c  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -O models/sam_vit_h_4b8939.pth

os.makedirs("input", exist_ok = True)
downloaded_image = "input/raw_image1.png"
url = "https://unsplash.com/photos/N5H3CL-AZjs/download?ixid=M3wxMjA3fDB8MXxzZWFyY2h8MTZ8fGh1bWFucyUyMGluJTIwbmF0dXJlfGVufDB8MHx8fDE3MzMxNDExOTJ8Mg&force=true&w=640"
urllib.request.urlretrieve(url, downloaded_image)

bbox_widget = BBoxWidget(image= encode_image(downloaded_image))
bbox_widget # Reveals the image with a clickable interface to identify the coordinates
bbox_widget.bboxes # -> [{'x': 275, 'y': 248, 'width': 0, 'height': 0, 'label': ''}]

x, y = 275, 248  
cropped_image = crop_image(downloaded_image, x, y)
cv2.imwrite('input/image1.png', cropped_image)
input_image = "input/image1.png"

predictor = initialize_sam_model()
image_path = input_image
image = load_image(image_path)
points = [(275, 248)]

#Create a mask that sticks to the target image boundaries
mask = generate_mask(image, predictor, points)
save_mask(mask, "/content/input/image1_mask001.png")

# Create two more masks for choice that has expanded area around the object boundary
feathered_mask1 = expand_and_feather_mask(mask)
feathered_mask2 = expand_and_feather_mask(mask, dilation_iterations=15, blur_kernel_size=31)
save_mask(feathered_mask1, "/content/input/image1_mask001_fm1.png")
save_mask(feathered_mask2, "/content/input/image1_mask001_fm2.png")

plt.figure(figsize=(24, 8))
plt.subplot(1, 3, 1)
plt.imshow(Image.open("/content/input/image1.png"))
plt.imshow(Image.open("/content/input/image1_mask001.png"), alpha=0.5, cmap="jet")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(Image.open("/content/input/image1.png"))
plt.imshow(Image.open("/content/input/image1_mask001_fm1.png"), alpha=0.5, cmap="jet")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(Image.open("/content/input/image1.png"))
plt.imshow(Image.open("/content/input/image1_mask001_fm2.png"), alpha=0.5, cmap="jet")
plt.axis("off")
plt.show()

我们现在有了蒙版!

2、使用 LaMa 进行图像修复

LaMa 是一种用于填充蒙版区域(已移除对象)的图像修复模型。它擅长处理大型蒙版并使用快速傅里叶卷积 (FFC) 保留图像分辨率,从而允许在修复过程中提供更广泛的图像上下文。LaMa 经过各种蒙版图像的训练,可以很好地推广到各种修复任务。其架构通常涉及编码器-解码器网络,解码器中带有 FFC。在我们的案例中,LaMa 无缝移除蒙版对象并填充区域,为混合做准备。此功能是在整个图像处理工作流程中实现自然合成的关键。

现在我们将研究如何使用 LaMa 修复模型从图像中移除蒙版对象并用合理的内容填充区域。我们使用 SimpleLama 类初始化修复过程。然后使用 PIL 库中的 Image.open 函数加载输入图像和蒙版。然后调用 simple_lama 执行修复,将原始图像和蒙版作为输入,重建缺失或被遮罩的区域。使用 result.save 保存最终修复的图像,生成输出图像,其中目标对象被移除并替换为背景内容。本质上,本节使用 LaMa 智能填充被遮罩的区域,从而无缝地从图像中移除所选对象。

!pip install simple-lama-inpainting

from simple_lama_inpainting import SimpleLama
from PIL import Image

simple_lama = SimpleLama()

img_path = "/content/input/image1.png"
mask_path = "/content/input/image1_mask001_fm1.png"

image = Image.open(img_path)
mask = Image.open(mask_path).convert('L')

os.makedirs("output", exist_ok="True")

result = simple_lama(image, mask)
result.save("output/inpainted.png")

plt.figure(figsize=(16, 8))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(result)
plt.axis("off")
plt.show()

我们在这里得到了几乎令人满意的输出!

3、使用稳定扩散填充蒙版区域

稳定扩散是一种文本到图像模型,用于根据文本提示用新内容填充蒙版区域。它采用文本描述和蒙版图像,并在蒙版区域内生成新的图像内容。使用扩散过程,它在输入提示的引导下逐渐将图像转换为噪声并转回。其架构通常包括 VAE、文本编码器和基于 U-Net 的扩散模型。稳定扩散经过大量图像文本数据训练,创造性地填充蒙版区域,为图像处理增加了新的维度。

我们将使用稳定扩散以文本提示为指导填充图像的蒙版区域。我们通过使用 StableDiffusionInpaintPipeline.from_pretrained 加载预训练模型来初始化该过程,具体模型路径为 runwayml/stable-diffusion-inpainting。原始图像和掩码

使用 Image.open 加载已编辑的图像。然后使用 pipe 调用管道,使用文本提示来指导在掩码区域内生成新的图像内容,由参数(如 guide_scalegenerator)控制。然后存储生成的图像以供可视化。本节旨在根据用户指定的文本提示创造性地用新内容填充掩码区域,利用 Stable Diffusion 的强大功能。

我在这里使用的 runwayml/stable-diffusion-inpainting模型是 Stable Diffusion 的一个相对轻量级的版本,之所以选择它,是因为它的内存要求更低、推理时间更快,适合 Colab 环境中 GPU 资源有限的用户。但是,可以使用更强大的 GPU 的用户可以利用更大、更复杂的修复模型(如 stableai/stable-diffusion-2-inpainting),从而有可能提高图像质量和更精细的细节。这允许根据单个硬件功能和期望结果实现灵活性。

!pip install -q diffusers accelerate

from diffusers import StableDiffusionInpaintPipeline

device = "cuda"
model_path = "runwayml/stable-diffusion-inpainting"

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
).to(device)


org_image = Image.open("/content/input/image1.png")
mask_image = Image.open("/content/input/image1_mask001_fm2.png")

prompt = "A big tiger walking through the trees looking at the camera"

guidance_scale=7.5
num_samples = 3
generator = torch.Generator(device="cuda").manual_seed(0) # change the seed to get different results

images = pipe(
    prompt=prompt,
    image=image,
    mask_image=mask_image,
    guidance_scale=guidance_scale,
    generator=generator,
    num_images_per_prompt=num_samples,
).images

plt.figure(figsize=(36, 12))
plt.subplot(1, 4, 1)
plt.imshow(org_image)
plt.axis("off")
plt.subplot(1, 4, 2)
plt.imshow(images[0])
plt.axis("off")
plt.subplot(1, 4, 3)
plt.imshow(images[1])
plt.axis("off")
plt.subplot(1, 4, 4)
plt.imshow(images[2])
plt.axis("off")
plt.show()

同样,对于小模型,我们得到了令人满意的结果。我非常确定 FLUX 模型将产生出色的结果!

4、使用 OpenCV 进行无缝混合

OpenCV (cv2) 是一个用于计算机视觉任务的综合库,主要用于图像混合和处理。它提供了广泛的图像处理功能,包括读取、写入、调整大小、颜色转换、过滤、特征检测和对象跟踪。在我们的案例中,OpenCV 对于将前面步骤的输出(使用 LaMa 进行修复、使用 Stable Diffusion 进行内容生成)无缝集成到目标图像中至关重要。它利用泊松混合 ( seamlessClone) 等技术,通过基于蒙版平滑合并源图像和目标图像来创建逼真的合成图像。OpenCV 的多功能性和广泛的功能使其成为图像混合的宝贵工具,可确保在此图像处理工作流程中获得流畅自然的最终输出。

现在让我们尝试使用 OpenCV 的功能将源图像无缝混合到目标图像中。我显然不会使用第 2 步中的绘制图像,因为那不适合新的构图。我们将改用 Unsplash 中的新图像作为目标,并使用另一幅图像作为源。我们将首先使用 SAM 为源图像生成羽化蒙版,使用诸如 initialize_sam_modelgenerate_maskexpand_and_feather_mask 之类的函数。核心混合由 poisson_blend 函数执行,该函数利用 OpenCV 的 seamlessClone 根据蒙版和指定位置组合源图像和目标图像。在混合之前, exposure.match_histograms 尝试对齐源图像和目标图像的颜色分布。最后,保存混合图像。本节致力于通过将源对象无缝集成到目标场景中来创建逼真的合成图像。

target_image = "/content/input/bg.png"
source_image = "/content/input/source.png"
url = "https://unsplash.com/photos/0R9kCkqILQE/download?ixid=M3wxMjA3fDB8MXxzZWFyY2h8ODh8fGVtcHR5JTIwcm9hZHxlbnwwfDB8fHwxNzMzMzE2ODQ5fDI&force=true&w=640"
urllib.request.urlretrieve(url, target_image)
url = "https://unsplash.com/photos/nyvR6wbU1ho/download?ixid=M3wxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNzMzMzE0ODcwfA&force=true&w=640"
urllib.request.urlretrieve(url, source_image)
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.imshow(Image.open(target_image))
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(Image.open(source_image))
plt.axis("off")
plt.show()
predictor = initialize_sam_model()
image_path = source_image
source_image_v2 = load_image(image_path)
points = [(460, 151)] 
mask = generate_mask(source_image_v2, predictor, points)
feathered_mask = expand_and_feather_mask(mask)
mask_image = "/content/input/blend_image1_mask001.png"
save_mask(feathered_mask, mask_image)
no_feather_mask_image = "/content/input/blend_image1_mask002.png"
no_feather_mask = generate_mask(source_image_v2, predictor, points)
save_mask(no_feather_mask, no_feather_mask_image)

plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.imshow(Image.open(source_image))
plt.imshow(Image.open(no_feather_mask_image), alpha=0.5, cmap="jet")
plt.axis("off")
plt.subplot(1, 2,2)
plt.imshow(Image.open(source_image))
plt.imshow(Image.open(mask_image), alpha=0.5, cmap="jet")
plt.axis("off")
plt.show()
bbox_widget = BBoxWidget(image= encode_image(target_image))
bbox_widget # Identify where the new image needs to be placed in the target image
bbox_widget.bboxes #--> [{'x': 314, 'y': 169, 'width': 0, 'height': 1, 'label': ''}]
import cv2
import numpy as np
from skimage import exposure

# Function to apply Poisson Blending
def poisson_blend(background, foreground, mask, position):
    x, y = position
    bg_h, bg_w = background.shape[:2]
    fg_h, fg_w = foreground.shape[:2]
    x = min(x, bg_w - fg_w)  
    y = min(y, bg_h - fg_h) 
    x = max(0, x)  
    y = max(0, y) 


    center = (x + foreground.shape[1] // 2, y + foreground.shape[0] // 2)
    output = cv2.seamlessClone(foreground, background, mask, center, cv2.NORMAL_CLONE)
    return output

background = cv2.imread(target_image)
foreground = cv2.imread(source_image)
mask = cv2.imread(mask_image, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (foreground.shape[1], foreground.shape[0]))
matched_foreground = exposure.match_histograms(foreground, background, channel_axis=-1)
position = (314, 169) 
blended_image = poisson_blend(background, matched_foreground, mask, position)
cv2.imwrite('/content/output/blended_image.png', blended_image)

plt.figure(figsize=(24, 8))
plt.subplot(1,3, 1)
plt.imshow(Image.open(target_image))
plt.axis("off")
plt.subplot(1,3, 2)
plt.imshow(Image.open(source_image))
plt.axis("off")
plt.subplot(1,3, 3)
plt.imshow(Image.open("/content/output/blended_image.png"))
plt.axis("off")
plt.show()

如上所示,输出质量并不理想。可能是由于颜色差异、光照不一致和混合边界处的伪影等因素。以下是一些可能改善结果的策略:

  • 精炼蒙版:尝试不同的蒙版羽化参数,甚至使用替代蒙版技术,可以实现更平滑的过渡和更少的明显接缝。
  • 高级色彩校正:除了简单的直方图匹配之外,实施更复杂的色彩校正方法可能有助于更好地对齐源图像和目标图像的颜色配置文件。
  • 混合技术:探索 OpenCV 提供的其他混合技术,如多波段混合或梯度域混合,可能会为复杂场景产生更好的结果。
  • 内容感知填充:利用内容感知填充算法去除混合区域周围的任何剩余伪影或不一致之处,可以进一步增强合成的真实感。
  • 高分辨率图像:对源图像和目标图像使用更高分辨率的图像也可以提高混合输出的整体质量和细节。
  • 微调混合参数:仔细调整 seamlessClone 函数中的参数,例如混合模式和标志,可能会产生更好的混合结果。

原文链接:Open-source AI advanced image processing workflow using SAM, LaMa, Stable Diffusion and OpenCV2

汇智网翻译整理,转载请标明出处