Home » Написание ResNet с нуля на PyTorch — учебник
Написание ResNet с нуля на PyTorch — учебник

Написание ResNet с нуля на PyTorch — учебник

Если вы думаете, что написание ResNet с нуля — это только для академических кругов, то вы глубоко заблуждаетесь. Эта архитектура сетей — настоящий прорыв в глубоком обучении, который решил проблему исчезающего градиента и позволил создавать действительно глубокие сети. Сегодня мы детально разберём, как реализовать ResNet с нуля на PyTorch, причём сделаем это так, чтобы вы могли запустить всё на своём сервере и сразу начать эксперименты.

Зачем это нужно? Во-первых, понимание архитектуры изнутри даёт вам полный контроль над моделью. Во-вторых, это отличная база для кастомизации под ваши задачи. В-третьих, когда вы развернёте это на своём VPS или выделенном сервере, вы получите мощный инструмент для решения задач компьютерного зрения.

Как работает ResNet: архитектура без магии

ResNet (Residual Neural Network) — это не просто глубокая сеть, это революционный подход к обучению глубоких моделей. Основная идея: вместо того чтобы каждый слой изучал полное отображение H(x), мы заставляем его изучать остаточное отображение F(x) = H(x) – x. Звучит сложно? На самом деле всё гениально просто.

Ключевые компоненты ResNet:

  • Residual Block — основной строительный блок с skip connections
  • Skip connections — прямые соединения, которые “перепрыгивают” через слои
  • Batch Normalization — нормализация для стабилизации обучения
  • ReLU активация — стандартная функция активации

Проблема, которую решает ResNet: в очень глубоких сетях градиенты либо исчезают, либо взрываются. Skip connections позволяют градиентам “протекать” напрямую через сеть, обеспечивая стабильное обучение даже в сетях с сотнями слоёв.

Пошаговая реализация: от блока к сети

Начнём с создания базового окружения. Для экспериментов вам понадобится сервер с GPU — на CPU обучение будет мучительно медленным.

pip install torch torchvision torchaudio
pip install matplotlib numpy pillow

Теперь создадим основные компоненты. Начнём с импортов и базовых блоков:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Проверяем доступность GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Используем устройство: {device}")

Создаём базовый residual block — сердце архитектуры:

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)  # Вот она, магия skip connection!
        out = F.relu(out)
        return out

Для более глубоких сетей нужен Bottleneck блок:

class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

Теперь собираем полную сеть:

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# Создаём популярные конфигурации
def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])

def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])

def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

Практические примеры и кейсы

Давайте протестируем нашу реализацию на CIFAR-10 — классическом датасете для бенчмарков:

# Подготовка данных
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Создаём модель
net = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

Функция обучения:

def train_epoch(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {train_loss/(batch_idx+1):.3f}, Acc: {100.*correct/total:.3f}%')

def test():
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    print(f'Test Loss: {test_loss/len(testloader):.3f}, Test Acc: {100.*correct/total:.3f}%')
    return 100.*correct/total

Запуск обучения:

# Обучение
for epoch in range(200):
    train_epoch(epoch)
    if epoch % 10 == 0:
        test()
    scheduler.step()

Сравнение архитектур: что выбрать?

Архитектура Параметры (млн) CIFAR-10 точность Время обучения Память GPU Рекомендации
ResNet18 11.7 ~95.0% Быстро ~2GB Идеал для начала
ResNet34 21.8 ~95.5% Средне ~3GB Хороший компромисс
ResNet50 25.6 ~96.0% Медленно ~4GB Для серьёзных задач
ResNet101 44.5 ~96.3% Очень медленно ~6GB Только для больших датасетов

Оптимизация и продвинутые техники

Несколько полезных трюков для улучшения производительности:

# Инициализация весов по Kaiming He
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

net.apply(init_weights)

# Использование смешанной точности для экономии памяти
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

def train_with_mixed_precision(epoch):
    net.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        with autocast():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

Полезные модификации архитектуры:

# ResNet с Squeeze-and-Excitation блоками
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# Интеграция в BasicBlock
class SEBasicBlock(BasicBlock):
    def __init__(self, in_planes, planes, stride=1):
        super(SEBasicBlock, self).__init__(in_planes, planes, stride)
        self.se = SEBlock(planes)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)  # Добавляем SE блок
        out += self.shortcut(x)
        out = F.relu(out)
        return out

Интеграция с другими инструментами

Современный ML-пайплайн требует интеграции с множеством инструментов. Вот несколько полезных примеров:

# Интеграция с TensorBoard
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/resnet_experiment')

def log_training(epoch, loss, accuracy):
    writer.add_scalar('Loss/Train', loss, epoch)
    writer.add_scalar('Accuracy/Train', accuracy, epoch)
    
    # Визуализация весов
    for name, param in net.named_parameters():
        writer.add_histogram(name, param, epoch)

# Сохранение и загрузка модели
def save_checkpoint(epoch, model, optimizer, loss):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, f'checkpoint_epoch_{epoch}.pth')

def load_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return epoch, loss

Интеграция с ONNX для деплоя:

# Экспорт в ONNX
def export_to_onnx():
    net.eval()
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    torch.onnx.export(net, dummy_input, "resnet.onnx", 
                      export_params=True, 
                      opset_version=11,
                      do_constant_folding=True,
                      input_names=['input'],
                      output_names=['output'])

Альтернативы и похожие решения

В экосистеме глубокого обучения есть множество альтернатив ResNet:

  • DenseNet — каждый слой соединяется со всеми предыдущими слоями
  • EfficientNet — оптимизированная архитектура с compound scaling
  • RegNet — архитектура от Facebook AI с фокусом на эффективность
  • Vision Transformer (ViT) — трансформеры для компьютерного зрения

Сравнение с готовыми решениями:

Решение Плюсы Минусы Когда использовать
Своя реализация Полный контроль, кастомизация Время разработки, возможные баги Исследования, специфичные задачи
torchvision.models Готовые предобученные модели Ограниченная кастомизация Быстрый прототип, transfer learning
timm библиотека Множество архитектур Дополнительная зависимость Эксперименты с архитектурами

Развёртывание на сервере

Для эффективного обучения вам понадобится мощный сервер. Рекомендую взять VPS с GPU или выделенный сервер для серьёзных экспериментов.

Пример докерфайла для развёртывания:

# Dockerfile
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY . .

CMD ["python", "train.py"]

Скрипт для мониторинга обучения:

#!/bin/bash
# monitor.sh
while true; do
    nvidia-smi
    echo "=== GPU Memory Usage ==="
    nvidia-smi --query-gpu=memory.used,memory.total --format=csv
    echo "=== Training Progress ==="
    tail -n 5 training.log
    sleep 30
done

Интересные факты и нестандартные применения

Несколько любопытных фактов о ResNet:

  • ResNet-152 имеет меньше параметров, чем VGG-16, но работает намного лучше
  • Skip connections можно интерпретировать как ensemble множества более мелких сетей
  • ResNet можно использовать не только для изображений, но и для временных рядов
  • Градиенты в ResNet-1000 всё ещё остаются стабильными благодаря skip connections

Нестандартные применения:

# ResNet для временных рядов
class ResNet1D(nn.Module):
    def __init__(self, input_size, num_classes):
        super(ResNet1D, self).__init__()
        self.conv1 = nn.Conv1d(input_size, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        # Дальше аналогично 2D версии
        
    def forward(self, x):
        # x shape: (batch, channels, sequence_length)
        out = F.relu(self.bn1(self.conv1(x)))
        # ...
        return out

# ResNet для обработки графов с PyTorch Geometric
import torch_geometric.nn as gnn

class GraphResNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(GraphResNet, self).__init__()
        self.conv1 = gnn.GCNConv(input_dim, hidden_dim)
        self.conv2 = gnn.GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index):
        identity = x
        out = F.relu(self.conv1(x, edge_index))
        out = self.conv2(out, edge_index)
        out = out + identity  # Skip connection
        out = F.relu(out)
        return self.classifier(out)

Автоматизация и скрипты

Создадим полезные скрипты для автоматизации экспериментов:

# experiment_runner.py
import argparse
import json
import os
from datetime import datetime

def run_experiment(config):
    """Запуск эксперимента с заданной конфигурацией"""
    
    # Создаём уникальную папку для эксперимента
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_dir = f"experiments/resnet_{config['architecture']}_{timestamp}"
    os.makedirs(exp_dir, exist_ok=True)
    
    # Сохраняём конфигурацию
    with open(f"{exp_dir}/config.json", "w") as f:
        json.dump(config, f, indent=2)
    
    # Инициализируем модель согласно конфигурации
    if config['architecture'] == 'resnet18':
        net = ResNet18()
    elif config['architecture'] == 'resnet50':
        net = ResNet50()
    
    # Логирование
    log_file = f"{exp_dir}/training.log"
    
    # Запуск обучения
    train_model(net, config, exp_dir)

# Пример конфигурации
config = {
    "architecture": "resnet18",
    "batch_size": 128,
    "learning_rate": 0.1,
    "epochs": 200,
    "weight_decay": 5e-4,
    "momentum": 0.9,
    "dataset": "cifar10"
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    
    with open(args.config, "r") as f:
        config = json.load(f)
    
    run_experiment(config)

Скрипт для гиперпараметрического поиска:

# hyperparameter_search.py
import itertools
import subprocess
import json

def grid_search():
    """Простой grid search по гиперпараметрам"""
    
    param_grid = {
        'learning_rate': [0.01, 0.1, 0.2],
        'weight_decay': [1e-4, 5e-4, 1e-3],
        'batch_size': [64, 128, 256],
        'architecture': ['resnet18', 'resnet34']
    }
    
    keys = param_grid.keys()
    values = param_grid.values()
    
    for combo in itertools.product(*values):
        config = dict(zip(keys, combo))
        config['epochs'] = 100  # Меньше эпох для быстрого поиска
        
        config_file = f"temp_config_{hash(str(config))}.json"
        with open(config_file, "w") as f:
            json.dump(config, f)
        
        # Запуск эксперимента
        subprocess.run(["python", "experiment_runner.py", "--config", config_file])
        
        # Очистка
        os.remove(config_file)

if __name__ == "__main__":
    grid_search()

Заключение и рекомендации

ResNet — это не просто ещё одна архитектура нейронных сетей, это фундаментальный прорыв, который изменил подход к проектированию глубоких сетей. Написание ResNet с нуля даёт вам несколько ключевых преимуществ:

  • Глубокое понимание архитектуры — вы знаете каждый компонент и можете его модифицировать
  • Гибкость кастомизации — легко адаптировать под специфичные задачи
  • Оптимизация под железо — можете оптимизировать модель под ваш конкретный сервер
  • Исследовательские возможности — база для экспериментов с новыми идеями

Где и как использовать:

  • Для обучения и понимания — начните с ResNet18 на CIFAR-10
  • Для production — используйте ResNet50 или готовые предобученные модели
  • Для исследований — модифицируйте архитектуру под ваши гипотезы
  • Для специфичных доменов — адаптируйте под 1D данные, графы и т.д.

Рекомендации по развёртыванию:

  • Для экспериментов используйте VPS с GPU
  • Для серьёзных проектов арендуйте выделенный сервер
  • Используйте докер для изоляции окружения
  • Настройте мониторинг GPU и логирование
  • Делайте чекпоинты каждые несколько эпох

Помните: ResNet — это не магия, а инженерное решение конкретной проблемы. Понимание принципов работы поможет вам создавать более эффективные архитектуры и решать сложные задачи компьютерного зрения. Удачных экспериментов!

Полезные ссылки для дальнейшего изучения:


В этой статье собрана информация и материалы из различных интернет-источников. Мы признаем и ценим работу всех оригинальных авторов, издателей и веб-сайтов. Несмотря на то, что были приложены все усилия для надлежащего указания исходного материала, любая непреднамеренная оплошность или упущение не являются нарушением авторских прав. Все упомянутые товарные знаки, логотипы и изображения являются собственностью соответствующих владельцев. Если вы считаете, что какой-либо контент, использованный в этой статье, нарушает ваши авторские права, немедленно свяжитесь с нами для рассмотрения и принятия оперативных мер.

Данная статья предназначена исключительно для ознакомительных и образовательных целей и не ущемляет права правообладателей. Если какой-либо материал, защищенный авторским правом, был использован без должного упоминания или с нарушением законов об авторском праве, это непреднамеренно, и мы исправим это незамедлительно после уведомления. Обратите внимание, что переиздание, распространение или воспроизведение части или всего содержимого в любой форме запрещено без письменного разрешения автора и владельца веб-сайта. Для получения разрешений или дополнительных запросов, пожалуйста, свяжитесь с нами.

Leave a reply

Your email address will not be published. Required fields are marked