分割一切,SAM模型详解和使用(论文复现)
V本文所涉及所有资源均在传知代码平台可获取
Table of Contents
概述
随着深度学习在计算机视觉领域的应用,图像分割技术得到了显著的发展。然而,大多数模型都是针对特定任务或数据集设计的,缺乏泛用性。Meta AI的研究团队开发了一个新的模型架构,该模型能够接收各种提示(如点、框、文本等)来执行图像分割任务。这跳脱了传统的分割任务的局限性,使得分割任务变得更加简单。 SAM模型使得高级图像分割技术变得更加易于使用和普及,让非专业人士也能轻松处理复杂的图像分割任务。
演示效果

任务介绍
传统的分割任务
传统的图像分割任务通常具有较为明确的边界和限制。在这个过程中,输入数据是一张图片,而模型的任务是根据在训练过程中学习到的特定分类标准来对这张图片进行分割。这种分割方式的核心在于模型的训练数据集中已经包含了所有需要识别和分割的对象类别。换句话说,模型的分割能力是建立在它所接触过的、有限的类别之上的。
当面对一个新的分割任务,尤其是涉及到模型未曾见过的分类时,这种传统的分割模型就会显得力不从心。由于模型在训练过程中并未接触过该类别的数据,因此它缺乏对该类别特征的理解和识别能力。在这种情况下,模型无法准确地对新类别进行分割,因为它既没有学习过这些类别的特征,也没有掌握如何区分这些类别与其他类别的界限。这种局限性导致了当需要分割出一个模型未曾见过的分类时,我们必须重新开始整个训练过程。
可提示分割

可提示分割任务是一种图像分割任务,它允许用户通过提供特定的提示(如点、框、文本等)来指导模型进行分割。
用户交互性:
用户可以通过点击图像上的点、绘制边界框或输入文本描述来提供分割提示。
SAM模型能够理解这些提示并据此生成分割掩码。
通用性:
SAM模型在训练时使用了大量多样化的数据集,这使得它能够泛化到多种不同的分割任务上。
模型不需要针对每个新类别进行专门的训练。
零样本学习能力:
即使是模型未曾见过的类别,SAM也能够根据提示进行有效的分割。
模型介绍
传统的分割任务(以U-net模型为例)

网络结构:
U-Net模型的结构特点是它的“U”形设计,这包括了收缩路径(编码器)和扩张路径(解码器)。
收缩路径: 这个部分类似于传统的卷积神经网络,包括多次卷积和池化操作。每次池化操作后,特征图的分辨率减半,而特征图的深度加倍,以捕获更多的上下文信息。
扩张路径: 在这个部分,通过上采样操作增加特征图的分辨率。同时,它通过跳跃连接将收缩路径中相同分辨率的特征图与上采样后的特征图拼接起来,这样可以保留位置信息。
Segment Anything模型

模块组成
图像编码器(Image Encoder): 这个模块用于从输入图像中提取特征。它通常是基于卷积神经网络(CNN)或Transformer架构,能够捕捉图像中的细节和上下文信息。
提示编码器(Prompt Encoder): 这个模块处理用户提供的提示。它将用户的点、框或文本提示转换为编码形式,使其能够与图像特征兼容。
任务解码器(Task Decoder): 结合图像特征和提示信息,任务解码器负责生成分割掩码。这个模块通常包含多个子网络,用于处理不同类型的提示和特征融合。
微调模型
from segment_anything import sam_model_registry
sam_checkpoint = "./sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda" # or "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
"""
for epoch in range(epoches):
# 加载符合模型数据集标准的数据:每一次输入是一个列表,每一项是一个字典
# 字典包括:"image",("point_coords","point_label"),("boxes"),("mask_inputs")
for data in train_dataloader:
output=sam(data)
loss=loss_function(output,mask)
optimiezer.zero_gard()
loss.backward()
optimiezer.step()
"""
使用方式
下载模型和python库
库文件地址:https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints
cd segment-anything-main
pip install .
使用单点分割
import numpy as np
import matplotlib.pyplot as plt
import cv2
from segment_anything import sam_model_registry, SamPredictor
# 在图像上显示分割的结果
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
image = cv2.imread('./image.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam_checkpoint = "./sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda" # or "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
input_point = np.array([[500, 375]]) # 标记点
input_label = np.array([1]) # 点所对应的标签
mask, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask[0], plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Score: {scores[0]:.3f}", fontsize=18)
plt.axis('off')
plt.show()

多点分割
input_point = np.array([[500, 375], [184, 359]])
input_label = np.array([1, 1]) # 1表示正向,0表示负向
mask, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)

画框分割
input_box = np.array([44, 160, 1011, 504]) mask, scores, logits = predictor.predict( box=input_box, multimask_output=False, )

点框分割
input_box = np.array([250, 370, 410, 510]) input_point = np.array([[334, 438]]) input_label = np.array([0]) mask, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box, multimask_output=False, )

多框分割
input_boxes = torch.tensor([ [24, 170, 998, 472], [830, 338, 966, 458], [720, 209, 772, 258], ], device=predictor.device) transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, )

文章代码资源点击附件获取




