【小贪】项目实战——Zero-shot根据文字提示分割出图片目标掩码

目标描述

给定RGB视频或图片,目标是分割出图像中的指定目标掩码。我们需要复现两个Zero-shot的开源项目,分别为IDEA研究院的GroundingDINO和Facebook的SAM。首先使用目标检测方法GroundingDINO,输入想检测目标的文字提示,可以获得目标的anchor box。将上一步获得的box信息作为SAM的提示,分割出目标mask。具体效果如下(测试数据来自VolumeDeform数据集):

在这里插入图片描述

其中GroundingDINO根据white shirt的文字输入计算的box信息为:"shirt_000500": "[194.23726, 2.378189, 524.09503, 441.5135]"。项目实测下来单张图片的预测速度GroundingDINO要慢于SAM。GroundingDINO和SAM均会给出多个预测结果,当选择置信度最高的结果时两个模型也会存在预测不准确的情况。

论文简介

GroundingDINO

GroundingDINO extends a closedset detector DINO by performing vision-language modality fusion at multiple phases, including a feature enhancer, a language-guided query selection module, and a cross-modality decoder. Such a deep fusion strategy effectively improves open-set object detection.

在这里插入图片描述

SAM

  • 简介:使用三个组件建立图像分割的foundation model,解决一系列下游分割问题,可zero-shot生成
  • 关键技术:
    1. promptable分割任务:使用prompt engineering,prompt不确定时输出多目标mask
    2. 分割模型:image encoder + prompt encoder -> mask decoder
    3. 数据驱动:SA-1B(1B masks from 11M imgs)手工标注->半自动->全自动
  • Limitation:存在不连贯不精细的mask结果;交互式实时mask生成但是img encoder耗时;text-to-mask任务效果不鲁棒

在这里插入图片描述
在这里插入图片描述

项目实战

两个项目的复现很简单,按照github的readme配置相关环境并运行程序。当然也可以直接使用一站式项目Grounded Segment Anything等。当需要分割的图片较多时,可以修改GroundingDINO的demo.shdemo/inference_on_a_image.py文件将检测结果保存至json文件。

demo/inference_on_a_image.py文件

# 修改plot_boxes_to_image函数输出box信息
image_with_box, mask, box_coor = plot_boxes_to_image(image_pil, pred_dict)
# obj为目标名称,i为当前图片的索引
obj = 'shirt'
data = {f'{obj}_{str(i).zfill(6)}': str(list(box_coor.cpu().detach().numpy()))}
with open("box.json", "r", encoding="utf-8") as f:
    old_data = json.load(f)
    old_data.update(data)
with open("box.json", "w", encoding="utf-8") as f:
    json.dump(old_data, f, indent=4)
    # f.write(json.dumps(old_data, indent=4, ensure_ascii=False))
f.close()

然后SAM再读取json文件获取box信息,将SAM的输入提示改为box。

测试代码

import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import json

coords = []

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))


def on_click(event):
    global coords
    if event.button == 1:
        x, y = event.xdata, event.ydata
        print(f"鼠标左键点击:x={x:.2f}, y={y:.2f}")
        coords.append([x, y])
        # if len(coords) == 2:
        #     fig.canvas.mpl_disconnect(cid)
    elif event.button == 3:
        print("鼠标右键点击")


def get_mask(image, mask_id=1, click_coords=False, choose_mask=False, box=None):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(image)
    # plt.axis('on')

    if click_coords:
        global coords
        fig, ax = plt.subplots()  # 创建画布和子图对象
        fig.set_size_inches(30, 20)  # 设置宽度和高度,单位为英寸(inch)
        ax.imshow(image)
        cid = fig.canvas.mpl_connect('button_press_event', on_click)
        plt.show()
    else:  # 如果使用 必须全局
        coords = []

    from segment_anything import SamPredictor, sam_model_registry
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    predictor.set_image(image)

    input_point = np.array(coords)
    input_label = np.array([1] * len(coords))

    # plt.figure(figsize=(10, 10))
    # plt.imshow(image)
    # show_points(input_point, input_label, plt.gca())
    # plt.axis('on')
    # plt.show()

    input_box = box
    if len(coords) == 0:
        input_point = None
        input_label = None
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box[None, :],
        multimask_output=True)

    if choose_mask:
        plt.figure(figsize=(60, 20))
        plt.subplot(1, 3, 1)
        plt.imshow(image)
        show_mask(masks[0], plt.gca())
        # show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask 0, Score: {scores[0]:.3f}", fontsize=18)
        plt.subplot(1, 3, 2)
        plt.imshow(image)
        show_mask(masks[1], plt.gca())
        # show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask 1, Score: {scores[1]:.3f}", fontsize=18)
        plt.subplot(1, 3, 3)
        plt.imshow(image)
        show_mask(masks[2], plt.gca())
        # show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask 2, Score: {scores[1]:.3f}", fontsize=18)
        plt.show()
        mask_id = int(input())  # 通过输入idx或者设置特定的idx输出

    mask = masks[mask_id]
    mask = np.tile(np.expand_dims(mask, axis=-1), 3)
    mask_data = np.where(mask, 255, 0)
    # mask_image = np.where(mask, image/255, 0.)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(mask_image)
    # plt.show()
    if click_coords: coords.clear()
    return mask_data


if __name__ == '__main__':
    obj = 'shirt'
    color_path = f'/Data/VolumeDeformData/{obj}/data/'
    mask_path = f'/Data/VolumeDeformData/{obj}/mask/'
    if not os.path.exists(mask_path):
        os.makedirs(mask_path)

    img_paths = []
    for extension in ["jpg", "png", "jpeg"]:
        img_paths += glob.glob(os.path.join(color_path, "*.{}".format(extension)))

    json_path = 'GroundingDINO-main/box.json'
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
        for i in range(len(img_paths) // 2):
            img_name = f'frame-{str(i).zfill(6)}.color.png'
            img = cv2.imread(color_path + img_name)
            id = f'{obj}_{str(i).zfill(6)}'
            box = np.array(list(map(float, data[id][1:-1].split(','))))
            mask = get_mask(img, mask_id=2, click_coords=False, choose_mask=False, box=box)
            cv2.imwrite(mask_path + str(i).zfill(6) + '.png', mask)
            print(img_name)
    f.close()

相关链接

  • GroundingDINO github arXiv
  • SAM Demo github arXiv
  • Grounded Segment Anything github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/762465.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

互联网框架五层模型详解

注:机翻,未校对。 What is the Five Layers Model? The Framework of the Internet Explained 五层模型互联网框架解释 Computer Networks are a beautiful, amazing topic. Networks involve so much knowledge from different fields, from physics…

[OHOS_ERROR]: Please call hb utilities inside ohos source directory

当执行hb set报如下错误时:原因时重新拉取了源码,且源码路径被改了 [OHOS_ERROR]: Please call hb utilities inside ohos source directory 【解决办法】 卸载hb并在源码路径下重新安装 python3 -m pip uninstall ohos-build 安装hb python3 -m pi…

python-逻辑语句

if else语句 不同于C:else if range语句: continue continue的作用是: 中断所在循环的当次执行,直接进入下一次 continue在嵌套循环中的应用 break 直接结束所在的循环 break在嵌套循环中的应用 continue和break,在…

力扣:LCR 024. 反转链表(Java)

目录 题目描述:示例 1:示例 2:代码实现: 题目描述: 给定单链表的头节点 head ,请反转链表,并返回反转后的链表的头节点。 示例 1: 输入:head [1,2,3,4,5] 输出&#x…

【嵌入式DIY实例】- LCD ST7735显示DHT11传感器数据

LCD ST7735显示DHT11传感器数据 文章目录 LCD ST7735显示DHT11传感器数据1、硬件准备与接线2、代码实现本文介绍如何将 ESP8266 NodeMCU 板 (ESP-12E) 与 DHT11 (RHT01) 数字湿度和温度传感器连接。 NodeMCU 从 DHT11 传感器读取温度(以 C 为单位)和湿度(以 rH% 为单位)值,…

1.5 Canal 数据同步工具详细教程

欢迎来到我的博客,很高兴能够在这里和您见面!欢迎订阅相关专栏: ⭐️ 全网最全IT互联网公司面试宝典:收集整理全网各大IT互联网公司技术、项目、HR面试真题. ⭐️ AIGC时代的创新与未来:详细讲解AIGC的概念、核心技术、…

【你也能从零基础学会网站开发】关系型数据库中的表(Table)设计结构以及核心组成部分

🚀 个人主页 极客小俊 ✍🏻 作者简介:程序猿、设计师、技术分享 🐋 希望大家多多支持, 我们一起学习和进步! 🏅 欢迎评论 ❤️点赞💬评论 📂收藏 📂加关注 关系型数据库中…

FTP 文件传输协议:概念、工作原理;上传下载操作步骤

目录 FTP 概念 工作原理 匿名用户 授权用户 FTP软件包 匿名用户上传下载实验步骤 环境配置 下载 上传 wget 授权用户上传下载步骤 root用户登录FTP步骤 监听 设置端口号范围 修改用户家目录 匿名用户 授权用户 FTP 概念 FTP(File Transfer Prot…

C语言之线程的学习

线程属于某一个进程 共同点:都能并发 线程共享变量,进程不共享。 多线程任务中,其中某一个线程调用了exit了,其他线程会跟着一起退出 如果是特定的线程就调用pthread_exit 失败返回的是错误号 下面也是

VSCode无法识别 node、npm

一、前提 电脑新安装了node.js,在cmd查看node和npm版本没有问题,但是在VSCode无法识别 1.cmd查看版本: 2.VSCode报错信息: 无法将“npm”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写,如果…

C#——Property属性详情

属性 属性(Property)是类(class)、结构体(structure)和接口(interface)的成员,类或结构体中的成员变量称为字段,属性是字段的扩展,使用访问器&am…

视频转音频:怎样提取视频中的音频?6个提取音频的小技巧(建议收藏)

怎样提取视频中的音频?当我们想从视频中提取出声音时,通常会遇到很多问题。无论是想单独提取出视频里的音频,还是把它转成方便储存或者分享的音频格式,这都会涉及到视频转音频的一个需求。因此,在这篇指南里&#xff0…

如何编写高质量更优雅的代码(Java)

1、函数式接口—FunctionalInterface 好处:高逼格、代码收拢、解藕、统一处理 适用范围:具有共性的接口调用代码 举个栗子: 在我们平时的微服务开发中,调用其他服务的接口,通常要把接口调用部分做异常处理(try catch…

为何交易价格可能超出预期?

当你尝试执行订单时,如果收到“报价超出”的提示,这通常意味着交易无法按你的预期价格成交。对于某些交易者来说,这可能会带来一些困扰,但在外汇等流动性极高的市场中,这种情况是相当常见的。 外汇市场之所以吸引众多…

Python系统教程01

Python 是一门解释性语言,相对更简单、易学,它可以用于解决数学问题、获取与分 析数据、爬虫爬取网络数据、实现复制数学算法等等。 1、print()函数: print()书写时注意所有的符号都是英文符号。print()输出内容时,若要输出字符…

A股低开高走,近3000点,行情要启动了吗?

A股低开高走,近3000点,行情要启动了吗? 今天的A股,让人瞪目结舌了,你们知道是为什么吗?盘面上出现2个重要信号,一起来看看: 1、今天两市低开高走,银行板块护盘指数&…

如何使用AI学习一门编程语言?

无论你是软件开发新手还是拥有几十年的丰富经验,总是需要学习新知识。TIOBE Index追踪50种最受欢迎的编程语言,许多生态系统为职业发展和横向转型提供了机会。鉴于现有技术具有的广度,抽空学习一项新技能并有效运用技能可能困难重重。 最近我…

Linux启动elasticsearch,提示权限不够

Linux启动elasticsearch,提示权限不够,如下图所示: 解决办法: 设置文件所有者,即使用户由权限访问文件 sudo chown -R 用户名[:新组] ./elasticsearch-8.10.4 //切换到elasticsearch-8.10.4目录同级 chown详细格式…

关于vue创建项目失败报错(镜像过期)的解决方案

在新建vue项目时出现以下错误: 原因: npm.taobao.org和registry.npm.taobao.org旧域名于2021年官方公告域名更换事件,已于2022年05月31日零时起停止服务,域名HTTPS证书于2024年1月22日正式到期,不可再用。 解决方案:…

昇思MindSpore学习总结七——模型训练

1、模型训练 模型训练一般分为四个步骤: 构建数据集。定义神经网络模型。定义超参、损失函数及优化器。输入数据集进行训练与评估。 现在我们有了数据集和模型后,可以进行模型的训练与评估。 2、构建数据集 首先从数据集 Dataset加载代码&#xff0…