- Home »

Написание ResNet с нуля на PyTorch — учебник
Если вы думаете, что написание ResNet с нуля — это только для академических кругов, то вы глубоко заблуждаетесь. Эта архитектура сетей — настоящий прорыв в глубоком обучении, который решил проблему исчезающего градиента и позволил создавать действительно глубокие сети. Сегодня мы детально разберём, как реализовать ResNet с нуля на PyTorch, причём сделаем это так, чтобы вы могли запустить всё на своём сервере и сразу начать эксперименты.
Зачем это нужно? Во-первых, понимание архитектуры изнутри даёт вам полный контроль над моделью. Во-вторых, это отличная база для кастомизации под ваши задачи. В-третьих, когда вы развернёте это на своём VPS или выделенном сервере, вы получите мощный инструмент для решения задач компьютерного зрения.
Как работает ResNet: архитектура без магии
ResNet (Residual Neural Network) — это не просто глубокая сеть, это революционный подход к обучению глубоких моделей. Основная идея: вместо того чтобы каждый слой изучал полное отображение H(x), мы заставляем его изучать остаточное отображение F(x) = H(x) – x. Звучит сложно? На самом деле всё гениально просто.
Ключевые компоненты ResNet:
- Residual Block — основной строительный блок с skip connections
- Skip connections — прямые соединения, которые “перепрыгивают” через слои
- Batch Normalization — нормализация для стабилизации обучения
- ReLU активация — стандартная функция активации
Проблема, которую решает ResNet: в очень глубоких сетях градиенты либо исчезают, либо взрываются. Skip connections позволяют градиентам “протекать” напрямую через сеть, обеспечивая стабильное обучение даже в сетях с сотнями слоёв.
Пошаговая реализация: от блока к сети
Начнём с создания базового окружения. Для экспериментов вам понадобится сервер с GPU — на CPU обучение будет мучительно медленным.
pip install torch torchvision torchaudio
pip install matplotlib numpy pillow
Теперь создадим основные компоненты. Начнём с импортов и базовых блоков:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# Проверяем доступность GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Используем устройство: {device}")
Создаём базовый residual block — сердце архитектуры:
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x) # Вот она, магия skip connection!
out = F.relu(out)
return out
Для более глубоких сетей нужен Bottleneck блок:
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
Теперь собираем полную сеть:
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
# Создаём популярные конфигурации
def ResNet18():
return ResNet(BasicBlock, [2, 2, 2, 2])
def ResNet34():
return ResNet(BasicBlock, [3, 4, 6, 3])
def ResNet50():
return ResNet(Bottleneck, [3, 4, 6, 3])
def ResNet101():
return ResNet(Bottleneck, [3, 4, 23, 3])
def ResNet152():
return ResNet(Bottleneck, [3, 8, 36, 3])
Практические примеры и кейсы
Давайте протестируем нашу реализацию на CIFAR-10 — классическом датасете для бенчмарков:
# Подготовка данных
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# Создаём модель
net = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
Функция обучения:
def train_epoch(epoch):
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {train_loss/(batch_idx+1):.3f}, Acc: {100.*correct/total:.3f}%')
def test():
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print(f'Test Loss: {test_loss/len(testloader):.3f}, Test Acc: {100.*correct/total:.3f}%')
return 100.*correct/total
Запуск обучения:
# Обучение
for epoch in range(200):
train_epoch(epoch)
if epoch % 10 == 0:
test()
scheduler.step()
Сравнение архитектур: что выбрать?
Архитектура | Параметры (млн) | CIFAR-10 точность | Время обучения | Память GPU | Рекомендации |
---|---|---|---|---|---|
ResNet18 | 11.7 | ~95.0% | Быстро | ~2GB | Идеал для начала |
ResNet34 | 21.8 | ~95.5% | Средне | ~3GB | Хороший компромисс |
ResNet50 | 25.6 | ~96.0% | Медленно | ~4GB | Для серьёзных задач |
ResNet101 | 44.5 | ~96.3% | Очень медленно | ~6GB | Только для больших датасетов |
Оптимизация и продвинутые техники
Несколько полезных трюков для улучшения производительности:
# Инициализация весов по Kaiming He
def init_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
net.apply(init_weights)
# Использование смешанной точности для экономии памяти
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
def train_with_mixed_precision(epoch):
net.train()
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
with autocast():
outputs = net(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Полезные модификации архитектуры:
# ResNet с Squeeze-and-Excitation блоками
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# Интеграция в BasicBlock
class SEBasicBlock(BasicBlock):
def __init__(self, in_planes, planes, stride=1):
super(SEBasicBlock, self).__init__(in_planes, planes, stride)
self.se = SEBlock(planes)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.se(out) # Добавляем SE блок
out += self.shortcut(x)
out = F.relu(out)
return out
Интеграция с другими инструментами
Современный ML-пайплайн требует интеграции с множеством инструментов. Вот несколько полезных примеров:
# Интеграция с TensorBoard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/resnet_experiment')
def log_training(epoch, loss, accuracy):
writer.add_scalar('Loss/Train', loss, epoch)
writer.add_scalar('Accuracy/Train', accuracy, epoch)
# Визуализация весов
for name, param in net.named_parameters():
writer.add_histogram(name, param, epoch)
# Сохранение и загрузка модели
def save_checkpoint(epoch, model, optimizer, loss):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, f'checkpoint_epoch_{epoch}.pth')
def load_checkpoint(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
return epoch, loss
Интеграция с ONNX для деплоя:
# Экспорт в ONNX
def export_to_onnx():
net.eval()
dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(net, dummy_input, "resnet.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
Альтернативы и похожие решения
В экосистеме глубокого обучения есть множество альтернатив ResNet:
- DenseNet — каждый слой соединяется со всеми предыдущими слоями
- EfficientNet — оптимизированная архитектура с compound scaling
- RegNet — архитектура от Facebook AI с фокусом на эффективность
- Vision Transformer (ViT) — трансформеры для компьютерного зрения
Сравнение с готовыми решениями:
Решение | Плюсы | Минусы | Когда использовать |
---|---|---|---|
Своя реализация | Полный контроль, кастомизация | Время разработки, возможные баги | Исследования, специфичные задачи |
torchvision.models | Готовые предобученные модели | Ограниченная кастомизация | Быстрый прототип, transfer learning |
timm библиотека | Множество архитектур | Дополнительная зависимость | Эксперименты с архитектурами |
Развёртывание на сервере
Для эффективного обучения вам понадобится мощный сервер. Рекомендую взять VPS с GPU или выделенный сервер для серьёзных экспериментов.
Пример докерфайла для развёртывания:
# Dockerfile
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
CMD ["python", "train.py"]
Скрипт для мониторинга обучения:
#!/bin/bash
# monitor.sh
while true; do
nvidia-smi
echo "=== GPU Memory Usage ==="
nvidia-smi --query-gpu=memory.used,memory.total --format=csv
echo "=== Training Progress ==="
tail -n 5 training.log
sleep 30
done
Интересные факты и нестандартные применения
Несколько любопытных фактов о ResNet:
- ResNet-152 имеет меньше параметров, чем VGG-16, но работает намного лучше
- Skip connections можно интерпретировать как ensemble множества более мелких сетей
- ResNet можно использовать не только для изображений, но и для временных рядов
- Градиенты в ResNet-1000 всё ещё остаются стабильными благодаря skip connections
Нестандартные применения:
# ResNet для временных рядов
class ResNet1D(nn.Module):
def __init__(self, input_size, num_classes):
super(ResNet1D, self).__init__()
self.conv1 = nn.Conv1d(input_size, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm1d(64)
# Дальше аналогично 2D версии
def forward(self, x):
# x shape: (batch, channels, sequence_length)
out = F.relu(self.bn1(self.conv1(x)))
# ...
return out
# ResNet для обработки графов с PyTorch Geometric
import torch_geometric.nn as gnn
class GraphResNet(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes):
super(GraphResNet, self).__init__()
self.conv1 = gnn.GCNConv(input_dim, hidden_dim)
self.conv2 = gnn.GCNConv(hidden_dim, hidden_dim)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x, edge_index):
identity = x
out = F.relu(self.conv1(x, edge_index))
out = self.conv2(out, edge_index)
out = out + identity # Skip connection
out = F.relu(out)
return self.classifier(out)
Автоматизация и скрипты
Создадим полезные скрипты для автоматизации экспериментов:
# experiment_runner.py
import argparse
import json
import os
from datetime import datetime
def run_experiment(config):
"""Запуск эксперимента с заданной конфигурацией"""
# Создаём уникальную папку для эксперимента
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
exp_dir = f"experiments/resnet_{config['architecture']}_{timestamp}"
os.makedirs(exp_dir, exist_ok=True)
# Сохраняём конфигурацию
with open(f"{exp_dir}/config.json", "w") as f:
json.dump(config, f, indent=2)
# Инициализируем модель согласно конфигурации
if config['architecture'] == 'resnet18':
net = ResNet18()
elif config['architecture'] == 'resnet50':
net = ResNet50()
# Логирование
log_file = f"{exp_dir}/training.log"
# Запуск обучения
train_model(net, config, exp_dir)
# Пример конфигурации
config = {
"architecture": "resnet18",
"batch_size": 128,
"learning_rate": 0.1,
"epochs": 200,
"weight_decay": 5e-4,
"momentum": 0.9,
"dataset": "cifar10"
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
run_experiment(config)
Скрипт для гиперпараметрического поиска:
# hyperparameter_search.py
import itertools
import subprocess
import json
def grid_search():
"""Простой grid search по гиперпараметрам"""
param_grid = {
'learning_rate': [0.01, 0.1, 0.2],
'weight_decay': [1e-4, 5e-4, 1e-3],
'batch_size': [64, 128, 256],
'architecture': ['resnet18', 'resnet34']
}
keys = param_grid.keys()
values = param_grid.values()
for combo in itertools.product(*values):
config = dict(zip(keys, combo))
config['epochs'] = 100 # Меньше эпох для быстрого поиска
config_file = f"temp_config_{hash(str(config))}.json"
with open(config_file, "w") as f:
json.dump(config, f)
# Запуск эксперимента
subprocess.run(["python", "experiment_runner.py", "--config", config_file])
# Очистка
os.remove(config_file)
if __name__ == "__main__":
grid_search()
Заключение и рекомендации
ResNet — это не просто ещё одна архитектура нейронных сетей, это фундаментальный прорыв, который изменил подход к проектированию глубоких сетей. Написание ResNet с нуля даёт вам несколько ключевых преимуществ:
- Глубокое понимание архитектуры — вы знаете каждый компонент и можете его модифицировать
- Гибкость кастомизации — легко адаптировать под специфичные задачи
- Оптимизация под железо — можете оптимизировать модель под ваш конкретный сервер
- Исследовательские возможности — база для экспериментов с новыми идеями
Где и как использовать:
- Для обучения и понимания — начните с ResNet18 на CIFAR-10
- Для production — используйте ResNet50 или готовые предобученные модели
- Для исследований — модифицируйте архитектуру под ваши гипотезы
- Для специфичных доменов — адаптируйте под 1D данные, графы и т.д.
Рекомендации по развёртыванию:
- Для экспериментов используйте VPS с GPU
- Для серьёзных проектов арендуйте выделенный сервер
- Используйте докер для изоляции окружения
- Настройте мониторинг GPU и логирование
- Делайте чекпоинты каждые несколько эпох
Помните: ResNet — это не магия, а инженерное решение конкретной проблемы. Понимание принципов работы поможет вам создавать более эффективные архитектуры и решать сложные задачи компьютерного зрения. Удачных экспериментов!
Полезные ссылки для дальнейшего изучения:
В этой статье собрана информация и материалы из различных интернет-источников. Мы признаем и ценим работу всех оригинальных авторов, издателей и веб-сайтов. Несмотря на то, что были приложены все усилия для надлежащего указания исходного материала, любая непреднамеренная оплошность или упущение не являются нарушением авторских прав. Все упомянутые товарные знаки, логотипы и изображения являются собственностью соответствующих владельцев. Если вы считаете, что какой-либо контент, использованный в этой статье, нарушает ваши авторские права, немедленно свяжитесь с нами для рассмотрения и принятия оперативных мер.
Данная статья предназначена исключительно для ознакомительных и образовательных целей и не ущемляет права правообладателей. Если какой-либо материал, защищенный авторским правом, был использован без должного упоминания или с нарушением законов об авторском праве, это непреднамеренно, и мы исправим это незамедлительно после уведомления. Обратите внимание, что переиздание, распространение или воспроизведение части или всего содержимого в любой форме запрещено без письменного разрешения автора и владельца веб-сайта. Для получения разрешений или дополнительных запросов, пожалуйста, свяжитесь с нами.