暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

⌈ 传知代码 ⌋ KAN 卷积:医学图像分割新前沿

原创 Dream-Y.ocean 2024-10-13
378

💛前情提要💛

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


📌导航小助手📌


💡本章重点

  • KAN 卷积:医学图像分割新前沿

🍞一. 论文概述

在本文中深入探讨KAN卷积在医学图像分割领域的创新应用,特别是通过引入Tokenized KAN Block(Tok Kan)这一突破性设计,将深度学习中的图像分割技术推向了新的高度。KAN作为一种能够替代传统MLP(多层感知机)的网络结构,以其独特的优势在多个领域展现出强大的潜力。而在医学图像分割这一复杂且关键的领域,KAN卷积更是凭借其高效处理图像特征的能力,成为了研究的热点。

本文将U-Net结构中的卷积部分替换成了KAN卷积,将MLP部分用KANLinear取代,同时融入了类似Vision Transformer(VIT)的移位思想,使得模型在捕捉图像全局信息的同时,也能精准定位局部细节。


🍞二. 核心创新点

  1. 将KAN卷积引入分割网络中

Kolmogorov–Arnold Networks(KAN)通常不是直接指代一种具体的卷积神经网络架构,但在这里我们可以理解为一种特殊的卷积或特征提取机制,可能基于Kolmogorov-Arnold表示定理(也称为超位置定理),该定理提供了多变量函数可以通过一系列一元函数和固定二元函数的组合来表示的理论基础。因此,KAN卷积可能意味着一种高度非线性的、能够捕捉复杂依赖关系的卷积操作。将其引入U-Net中,可以显著提升模型对图像特征的提取能力,特别是那些需要高级抽象和复杂交互的特征。

  1. KANLinear替换MLP

将传统的MLP层替换为KANLinear层,可能意味着这一层结合了KAN的某些特性(如非线性处理能力或复杂的函数逼近能力)和线性变换的简洁性。这种替换可能使模型在保持高效计算的同时,能够更灵活地处理特征之间的复杂关系,进一步增强模型的特征表示能力,同时能够减少模型复杂度。KANLinear层可能通过其独特的机制,更好地整合和转换来自不同层级的特征,从而有助于模型在全局和局部信息之间做出更精准的权衡。

  1. 融入移位思想

虽然Vision Transformer(VIT)本身并不直接包含“移位”操作,但其自注意力机制能够捕捉图像中的全局依赖关系,这一点与移位操作在促进信息流动和增强全局感受野方面的作用相似。在U-Net中融入类似VIT的思想,可能意味着引入了一种能够跨越空间位置直接交互特征的机制(如自注意力模块),或者通过某种形式的特征重排(类似于移位但更灵活)来增强模型的全局理解能力。这种设计使得模型在保持对局部细节敏感的同时,也能够有效地整合全局信息,从而在处理复杂图像任务时展现出更高的性能。本文中设计了沿着两个方向的移位。


🍞三.模块介绍

KAN

  • KAN模型由数学定理Kolmogorov–Arnold启发得出,该定理由前苏联的两位数学家Vladimir Arnold和Andrey Kolmogorov提出。定理表明,任何多元连续函数都可以表示为单变量连续函数的两层嵌套叠加(一个单一变量的连续函数和一系列连续的双变量函数的组合)。这为多维函数的分解提供了理论基础,也是KAN模型设计的核心思想。

在这里插入图片描述
UNext模块

UNext模块一种基于卷积和多层感知器(MLP)的医学图像分割网络,旨在解决现有模型如UNet和Transformer版本在计算复杂度、参数量以及推理速度上的不足。本文是在此基础上将卷积用KAN卷积取代,将MLP用KAN模块代替。

在这里插入图片描述


🍞四.本文主要结构

本文受启发与上述的两种模块,将UNext模块中的上下采样中的卷积提取特征阶段用KAN卷积取代,将Tok MLP阶段用Tok Kan取代,从而增加模型提取特征的能力和渐少模型的参数量,同时提高模型非线性表达的能力。

整体架构如下:

在这里插入图片描述


🍞五.主要代码

KAN卷积:

class KANConvs(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, enable_standalone_scale_spline=True): """ 定义KAN卷积层,类似于nn.Conv2d,但包括KAN样条插值权重 :param in_channels: 输入通道数 :param out_channels: 输出通道数 :param kernel_size: 卷积核大小 :param stride: 步长 :param padding: 填充 :param dilation: 扩张 :param groups: 组卷积 :param bias: 偏置项 :param enable_standalone_scale_spline: 是否启用独立缩放样条插值 """ super(KANConvs, self).__init__() # 基本的卷积层初始化 self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.bias = bias # 标准卷积层参数 self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) # 样条插值权重 self.spline_weight = Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) # 是否启用独立缩放样条插值 self.enable_standalone_scale_spline = enable_standalone_scale_spline if enable_standalone_scale_spline: self.spline_scaler = Parameter(torch.ones(out_channels, 1)) # 初始化权重 self.reset_parameters() class ConvLayer(nn.Module): def __init__(self, in_ch, out_ch): super(ConvLayer, self).__init__() self.conv = nn.Sequential( KANConvs(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), KANConvs(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input)

KANLinear层:

class KANBlock(nn.Module): def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False): super().__init__() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim) self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, no_kan=no_kan) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_()

🫓总结

综上,我们基本了解了“一项全新的技术啦” 🍭 ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读😆

后续还会继续更新💓,欢迎持续关注📌哟~

💫如果有错误❌,欢迎指正呀💫

✨如果觉得收获满满,可以点点赞👍支持一下哟~✨

【传知科技 – 了解更多新知识】

「喜欢这篇文章,您的关注和赞赏是给作者最好的鼓励」
关注作者
【版权声明】本文为墨天轮用户原创内容,转载时必须标注文章的来源(墨天轮),文章链接,文章作者等基本信息,否则作者和墨天轮有权追究责任。如果您发现墨天轮中有涉嫌抄袭或者侵权的内容,欢迎发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

文章被以下合辑收录

评论