torch.nn.Module
是 PyTorch 中一个重要的基类,用于构建神经网络模型。它提供了一种方便的方式来组织和管理模型参数、定义前向传播等功能。继承自 torch.nn.Module
的类可以被视为一个可训练的参数集合,可以包含其他模块,从而形成层次化的模型结构。
关键功能和属性
参数管理
torch.nn.Module
可以追踪并管理所有注册的参数。通过 parameters()
方法,可以方便地获取模型中的所有参数。
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
model = MyModel()
print(list(model.parameters()))
子模块管理
通过将其他 torch.nn.Module
的实例注册为当前模块的属性,可以形成层次化的模型结构。这使得模型可以以更模块化的方式进行定义。
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
model = MyModel()
前向传播定义
在继承 torch.nn.Module
的子类中,需要实现 forward
方法来定义模型的前向传播过程。
import torch.nn as nn
import torch
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
input_data = torch.randn(3, 10)
output = model(input_data)
模型保存和加载
模型可以方便地保存到文件并在需要时加载。这是通过 torch.save
和 torch.load
函数来实现的。
torch.save(model.state_dict(), 'my_model.pth')
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load('my_model.pth'))
模型训练
由于继承了 torch.nn.Module
,模型可以使用 PyTorch 的优化器进行训练。
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 在训练循环中使用 optimizer 和 criterion
使用案例
定义模型
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
定义不更新的参数
self.register_buffer定义一组参数,参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
import torch
import torch.nn as nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# (1)常见定义模型时的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(1, 1, 3, bias=False)),
('fc', nn.Linear(1, 2, bias=False))
]))
# (2)使用register_buffer()定义一组参数
self.register_buffer('param_buf', torch.randn(1, 2))
# (3)使用形式类似的register_parameter()定义一组参数
self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))
# (4)按照类的属性形式定义一组变量
self.param_attr = torch.randn(1, 2)
def forward(self, x):
return x
net = Model()
内置函数
add_module
将子模块添加到当前模块。
apply
对当前模块及其所有子模块递归地应用函数 fn
。
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
bfloat16
将所有浮点型参数和缓冲区(buffers)的数据类型转换为bfloat16。
**buffers
返回一个模块缓冲区的迭代器。
参数: recurse
(可选)表示是否递归地遍历所有子模块的缓冲区。用法: model.buffers()
返回模块及其所有子模块的缓冲区。
children
返回一个模块的直接子模块的迭代器。
compile
使用 torch.compile()
编译模块的前向传播。使用 TorchDynamo 和指定的后端来优化给定的模型或函数。
cpu
将模型的所有参数和缓冲区移动到 CPU 上。
cuda
将模型的所有参数和缓冲区移动到 GPU 上。
double
将所有浮点数参数和缓冲区转换为双精度数据类型(double
)。
eval
使用 eval()
将模块设置为评估(evaluation)模式。在评估模式下,模型中的某些层(例如,Dropout)的行为可能会有所不同,通常用于模型推断阶段。
extra_repr
通过调用 extra_repr()
方法,可以设置模块的额外表示(extra representation)。这通常用于提供模块的自定义描述,方便在打印模型时显示有关模块的更多信息。
float
通过调用 float()
方法,可以将模块中所有的浮点参数和缓冲区(buffers)转换为浮点数据类型(float datatype)。
forward
forward(*input)
定义了模块在每次调用时执行的计算。这是 torch.nn.Module
类中需要用户实现的方法,用于定义模型的前向传播过程。
get_buffer
get_buffer(target)
方法用于返回给定目标(target)的缓冲区(buffer),如果存在的话;否则,会抛出一个错误。这对于获取模块中特定缓冲区的引用很有用。
get_extra_state
get_extra_state()
方法用于返回任何额外的状态,以便包含在模块的 state_dict
中。如果需要存储额外的状态,需要实现此方法,并相应地实现 set_extra_state()
。通常在构建模块的状态字典时调用该方法。
get_parameter
使用 get_parameter(target)
函数可以获取给定目标的参数,如果存在则返回该参数,否则抛出错误。
get_submodule
使用 get_submodule(target)
函数可以获取给定目标的子模块,如果存在则返回该子模块,否则抛出错误。
half()
half()
函数将模块中所有的浮点参数和缓冲区转换为半精度数据类型。
ipu
ipu(device=None)
函数将模块中所有的模型参数和缓冲区移动到 IPU(Intel Processing Unit)。
load_state_dict
load_state_dict(state_dict, strict=True, assign=False)
函数将参数和缓冲区从给定的 state_dict
复制到当前模块及其子模块中。
modules
返回网络中所有模块的迭代器。这允许对模型进行深度遍历,获取所有子模块。
named_buffers
返回模块缓冲区的迭代器,包括缓冲区的名称和缓冲区本身。缓冲区通常用于存储非学习参数的状态信息。
named_children
返回模块的直接子模块的迭代器,包括子模块的名称和子模块本身。
named_modules
返回网络中所有模块的迭代器,包括模块的名称和模块本身。这与 modules
方法类似,但额外提供了模块的名称。
named_parameters
返回模块参数的迭代器,包括参数的名称和参数本身。参数是模型中的可学习参数,可以通过优化算法进行更新。
parameters
返回模块参数的迭代器,只包括模块中的可学习参数,不包括缓冲区等非学习参数。
register_backward_hook
在模块上注册一个反向传播钩子。反向传播钩子是在模型反向传播过程中执行的函数,用于在梯度计算过程中执行自定义操作。
register_buffer
向模块添加一个缓冲区。缓冲区通常用于存储不需要梯度更新的状态信息,如运行统计或固定权重。该函数通过名称将一个张量注册为模块的缓冲区,这意味着它不会参与到模型的训练过程中,并且在保存和加载模型时会一并处理。
register_forward_hook
在模块上注册一个前向钩子。前向钩子将在每次模块的前向传播完成后被调用,即在 forward()
计算完输出之后执行。这允许用户在模型的每一步前向传播过程中插入自定义操作。
register_forward_pre_hook
在模块上注册一个前向预处理钩子。前向预处理钩子将在每次模块的前向传播开始之前被调用,即在调用 forward()
之前执行。这提供了在前向传播开始前修改输入或执行其他自定义操作的机会。
register_full_backward_hook
在模块上注册一个完全的反向钩子。完全的反向钩子将在计算与模块相关的梯度时被调用,即只有当计算模块输出的梯度时才执行。该钩子的函数签名需要满足特定的要求,可以在梯度计算过程中执行自定义操作。
register_full_backward_pre_hook
在模块上注册一个完全的反向预处理钩子。完全的反向预处理钩子将在计算模块的梯度之前被调用。这提供了在梯度计算开始前执行自定义操作的机会。
set_extra_state(state)
在 load_state_dict()
中被调用,用于处理在 state_dict 中找到的任何额外状态。如果需要在模块内存储额外的状态,实现此函数并相应地实现 get_extra_state()
。
share_memory()
将底层存储移动到共享内存。如果底层存储已经在共享内存中或是 CUDA 张量,则此操作无效。共享内存中的张量无法调整大小。
state_dict
返回包含模块整个状态的字典引用。
to
移动和/或转换参数和缓冲区。
to_empty
将参数和缓冲区移动到指定设备,而不复制存储。
train
将模块设置为训练模式。
type
将所有参数和缓冲区转换为指定的类型。
xpu
将所有模型参数和缓冲区移动到 XPU。
zero_grad
重置所有模型参数的梯度。






