nnInteractive安装和应用实战

今天将分享nnInteractive安装和应用实战完整实现版本,为了方便大家学习理解整个流程,将整个流程步骤进行了整理,并给出详细的步骤结果。感兴趣的朋友赶紧动手试一试吧。 一、nnInteractive原理 nnInteractive 是首个能够处理多种交互提示(比如正负点、框选、涂鸦和套索/lasso)并实现完整 3D 分割的工具,用户只需在 2D 层面进行操作,模型即可生成体积级分割。“Early prompting” 提示方式:与常见将提示与图像一起输入后在高层融合的方式不同,nnInteractive 将提示以额外输入通道的形式,在网络开始特征提取的最初阶段就融合进来。这保留了提示与图像间的最直接空间关系,使网络更多能力集中于实际的分割,而非提示融合。 nnInteractive 沿用了 nnU-Net 的最佳实践,通常采用变体如 Residual Encoder U-Net。网络在编码器的输入阶段就接收图像和提示的通道输入(early prompting)。支持正提示(positive)和负提示(negative),每种提示类型对应独立输入通道,支持点(points)、框选(boxes)、涂鸦(scribbles)、以及套索(lasso)等交互方式。 nnInteractive 在超过 120 个多样化体积数据集上训练,数据总量包括 64,518 个体积,涵盖 CT、MRI、PET、3D 显微图像等多种模态,确保其具有极强的泛化能力和开放集分割能力 二、nnInteractive安装 创建虚拟环境
conda create -n nnInteractive python=3.12conda activate nnInteractive
安装 PyTorch2.6.0
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
安装nnInteractive
pip install nninteractive
三、nnInteractive实例 官方提供了nnInteractive例子,详情可以访问: https://github.com/MIC-DKFZ/nnInteractive?tab=readme-ov-file nnInteractive是直接可以3D分割,代码如下所示,输入图像和肿瘤的points。模型权重可以在这里下载:

https://huggingface.co/nnInteractive/nnInteractive/tree/main

import osimport numpy as npimport torchimport SimpleITK as sitkimport timefrom collections import defaultdict# --- Initialize Inference Session ---from nnInteractive.inference.inference_session import nnInteractiveInferenceSessionclass tumornnInteractivewithMutilBoxPoint_Inference:    """    project:https://github.com/MIC-DKFZ/nnInteractive?tab=readme-ov-file,    model:https://huggingface.co/nnInteractive/nnInteractive/tree/main    """    def __init__(self, propagate_with_type='box'):        # propagate_with_type support point,box,mask,lasso        self.session = nnInteractiveInferenceSession(            device=torch.device("cuda:0"),  # Set inference device            use_torch_compile=False,  # Experimental: Not tested yet            verbose=False,            torch_n_threads=os.cpu_count(),  # Use available CPU cores            do_autozoom=True,  # Enables AutoZoom for better patching            use_pinned_memory=True,  # Optimizes GPU memory transfers        )        # Load the trained model        model_path = r"weight\nnInteractive_v1.0"        self.session.initialize_from_trained_model_folder(model_path)        self.propagate_with_type = propagate_with_type    def _seg3d_infer_withpointboxmasklasso(self, input_image, points_pos_list=None, points_neg_list=None,                                           bboxs_list=None, mask_binary=None):        # --- Load Input Image (Example with SimpleITK) ---        # DO NOT preprocess the image in any way. Give it to nnInteractive as it is! DO NOT apply level window, DO NOT normalize        # intensities and never ever convert an image with higher precision (float32, uint16, etc) to uint8!        # The ONLY instance where some preprocesing makes sense is if your original image is too large to be reasonably used.        # This may be the case, for example, for some microCT images. In this case you can consider downsampling.        img = sitk.GetArrayFromImage(input_image)[None]  # Ensure shape (1, x, y, z)        # Validate input dimensions        if img.ndim != 4:            raise ValueError("Input image must be 4D with shape (1, x, y, z)")        self.session.set_image(img)        # --- Define Output Buffer ---        target_tensor = torch.zeros(img.shape[1:], dtype=torch.uint8)  # Must be 3D (x, y, z)        self.session.set_target_buffer(target_tensor)        # --- Interacting with the Model ---        # Interactions can be freely chained and mixed in any order. Each interaction refines the segmentation.        # The model updates the segmentation mask in the target buffer after every interaction.        # Example: Add a **positive** point interaction        # POINT_COORDINATES should be a tuple (x, y, z) specifying the point location.        if self.propagate_with_type == 'point':            if points_pos_list is not None:                for i in range(len(points_pos_list)):                    points = points_pos_list[i]                    POINT_COORDINATEone = (points[2], points[1], points[0])                    POINT_COORDINATEtwo = (points[5], points[4], points[3])                    print(POINT_COORDINATEone)                    print(POINT_COORDINATEtwo)                    self.session.add_point_interaction(POINT_COORDINATEone, include_interaction=True)                    self.session.add_point_interaction(POINT_COORDINATEtwo, include_interaction=True)            if points_neg_list is not None:                for i in range(len(points_neg_list)):                    # Example: Add a **negative** point interaction                    # To make any interaction negative set include_interaction=False                    one_point = points_neg_list[i]                    POINT_COORDINATES = (one_point[2], one_point[1], one_point[0])                    print(POINT_COORDINATES)                    self.session.add_point_interaction(POINT_COORDINATES, include_interaction=False)        if self.propagate_with_type == 'box':            # Example: Add a bounding box interaction            # BBOX_COORDINATES must be specified as [[x1, x2], [y1, y2], [z1, z2]] (half-open intervals).            # Note: nnInteractive pre-trained models currently only support **2D bounding boxes**.            # This means that **one dimension must be [d, d+1]** to indicate a single slice.            # Example of a 2D bounding box in the axial plane (XY slice at depth Z)            # BBOX_COORDINATES = [[30, 80], [40, 100], [10, 11]]  # X: 30-80, Y: 40-100, Z: slice 10            if bboxs_list is not None:                for i in range(len(bboxs_list)):                    one_box = bboxs_list[i]                    BBOX_COORDINATES = [[one_box[2], one_box[5] + 1], [one_box[1], one_box[4]],                                        [one_box[0], one_box[3]]]                    print(BBOX_COORDINATES)                    self.session.add_bbox_interaction(BBOX_COORDINATES, include_interaction=True)        if self.propagate_with_type == 'mask':            # Example: Add a scribble interaction            # - A 3D image of the same shape as img where one slice (any axis-aligned orientation) contains a hand-drawn scribble.            # - Background must be 0, and scribble must be 1.            # - Use session.preferred_scribble_thickness for optimal results.            mask_binary = mask_binary.astype('uint8')            mask_binary[mask_binary != 0] = 1            self.session.add_scribble_interaction(mask_binary, include_interaction=True)        if self.propagate_with_type == 'lasso':            # Example: Add a lasso interaction            # - Similarly to scribble a 3D image with a single slice containing a **closed contour** representing the selection.            mask_binary = mask_binary.astype('uint8')            mask_binary[mask_binary != 0] = 1            self.session.add_lasso_interaction(mask_binary, include_interaction=True)        # You can combine any number of interactions as needed.        # The model refines the segmentation result incrementally with each new interaction.        # --- Retrieve Results ---        # The target buffer holds the segmentation result.        # results = self.session.target_buffer.clone()        # OR (equivalent)        results = target_tensor.clone()        # Cloning is required because the buffer will be **reused** for the next object.        # Alternatively, set a new target buffer for each object:        self.session.set_target_buffer(torch.zeros(img.shape[1:], dtype=torch.uint8))        # --- Start a New Object Segmentation ---        self.session.reset_interactions()  # Clears the target buffer and resets interactions        return results    def network_prediction(self, inputfilepath, unique_labs_list=None, sitk_mask_binary=None):        """        :param inputfilepath: image path        :param unique_labs_list: [[x1,y1,z1,x2,y2,z2,label],[x1,y1,z1,x2,y2,z2,label],[x1,y1,z1,x2,y2,z2,label]]        :return:        """        if not (inputfilepath.endswith('.nii'or inputfilepath.endswith('.nii.gz'or inputfilepath.endswith('.mha')):            print("文件格式不支持,仅支持 .nii, .nii.gz 和 .mha 格式")            return FalseNone        try:            nii_image = sitk.ReadImage(inputfilepath)            array_mask = np.zeros_like(sitk.GetArrayFromImage(nii_image))            if unique_labs_list is not None:                # 先把 unique_labs_list 按照 label 分组,变成字典:                grouped_boxes_dict = defaultdict(list)                for box in unique_labs_list:                    x1, y1, z1, x2, y2, z2, label = box                    grouped_boxes_dict[label].append([x1, y1, z1, x2, y2, z2])                grouped_boxes_dict = dict(sorted(grouped_boxes_dict.items(), key=lambda x: x[0]))                print(grouped_boxes_dict)                if self.propagate_with_type == 'point':                    points_neg_list = None                    for label, bboxes in grouped_boxes_dict.items():                        print(f"类别: {label}")                        if label == 0:                            points_neg_list = bboxes                            continue                        points_pos_list = bboxes                        one_label_array_mask = self._seg3d_infer_withpointboxmasklasso(nii_image,                                                                                       points_pos_list=points_pos_list,                                                                                       points_neg_list=points_neg_list)                        array_mask[one_label_array_mask != 0] = label                if self.propagate_with_type == 'box':                    for label, bboxes in grouped_boxes_dict.items():                        print(f"类别: {label}")                        one_label_array_mask = self._seg3d_infer_withpointboxmasklasso(nii_image, bboxs_list=bboxes)                        array_mask[one_label_array_mask != 0] = label            elif sitk_mask_binary is not None:                mask_binary = sitk.GetArrayFromImage(sitk_mask_binary)                one_label_array_mask = self._seg3d_infer_withpointboxmasklasso(nii_image, mask_binary=mask_binary)                array_mask[one_label_array_mask != 0] = np.max(mask_binary)            else:                print('pleas check input')                return FalseNone            sitk_mask = sitk.GetImageFromArray(array_mask.astype('uint8'))            sitk_mask.CopyInformation(nii_image)            return True, sitk_mask        except Exception as e:            print(f"出现异常:{e}", inputfilepath)            return FalseNonedef box_point_test_demo():    input_image_path = r"D:\liver_image.nii.gz"    output_mask_path = "liver_tumor_nnInteractive_point.nii.gz"    box_list_2d_liver_tumor = [1413543831443603834]  # x1,y1,z1,x2,y2,z2,label    box_list_2d_kidney_tumor = [1583542751613572752]  # x1,y1,z1,x2,y2,z2,label    box_list_2d_vessel = [2632523603153053603]  # x1,y1,z1,x2,y2,z2,label    box_list_2d_liver = [781413802803803801]  # x1,y1,z1,x2,y2,z2,label    box_list_3d_list = [box_list_2d_liver_tumor, box_list_2d_kidney_tumor]    start = time.time()    tumorsam3d = tumornnInteractivewithMutilBoxPoint_Inference(propagate_with_type='point')    _, sitk_mask = tumorsam3d.network_prediction(input_image_path, unique_labs_list=box_list_3d_list)    end = time.time()    print(end - start)    sitk.WriteImage(sitk_mask, output_mask_path)if __name__ == '__main__':    box_point_test_demo()

分别在肾肿瘤,肝肿瘤和肺癌的数据进行测试,输入图像和肿瘤points,为了与MedSAM2、Vista3d分割效果进行对比,都是用相同的前景点作为输入。对比分割效果如下所示,第一个图是nnInteractive(左)与MedSAM2(右)结果,第二个图是nnInteractive(左)与Vista3d(右)结果。

点击阅读原文可以访问参考项目,如果大家觉得这个项目还不错,希望大家给个Star并Fork,可以让更多的人学习。如果有任何问题,随时给我留言我会及时回复的。

本文使用 文章同步助手 同步

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • """1.个性化消息: 将用户的姓名存到一个变量中,并向该用户显示一条消息。显示的消息应非常简单,如“Hello ...
    她即我命阅读 5,957评论 0 6
  • 为了让我有一个更快速、更精彩、更辉煌的成长,我将开始这段刻骨铭心的自我蜕变之旅!从今天开始,我将每天坚持阅...
    李薇帆阅读 2,286评论 1 4
  • 似乎最近一直都在路上,每次出来走的时候感受都会很不一样。 1、感恩一直遇到好心人,很幸运。在路上总是...
    时间里的花Lily阅读 1,792评论 1 3
  • 1、expected an indented block 冒号后面是要写上一定的内容的(新手容易遗忘这一点); 缩...
    庵下桃花仙阅读 1,171评论 1 2
  • 一、工具箱(多种工具共用一个快捷键的可同时按【Shift】加此快捷键选取)矩形、椭圆选框工具 【M】移动工具 【V...
    墨雅丫阅读 1,840评论 0 0

友情链接更多精彩内容