Home » AlexNet в PyTorch — создание сверточной нейронной сети
AlexNet в PyTorch — создание сверточной нейронной сети

AlexNet в PyTorch — создание сверточной нейронной сети

На прошлой неделе мне пришлось деплоить систему классификации изображений на production сервер, и я понял, что многие из нас до сих пор боятся нейросетей как огня. А зря! AlexNet — это идеальная стартовая точка для понимания сверточных нейронных сетей. Она простая, хорошо документированная, и что самое главное — работает быстро даже на средних серверах.

Если вы деплоите приложения, настраиваете инфраструктуру или просто хотите разобраться с машинным обучением без многомесячного погружения в теорию, эта статья поможет вам развернуть AlexNet на вашем сервере за пару часов. Я покажу конкретные команды, разберу типичные ошибки и дам готовые конфиги.

Как работает AlexNet под капотом

AlexNet — это классическая архитектура сверточной нейронной сети, которая произвела революцию в компьютерном зрении в 2012 году. Она состоит из 8 слоев: 5 сверточных и 3 полносвязных. Главная фишка — использование ReLU активации и dropout для борьбы с переобучением.

Архитектура выглядит примерно так:

  • Входной слой: 224×224×3 (RGB изображение)
  • Первый сверточный слой: 96 фильтров 11×11 с шагом 4
  • Max pooling 3×3 с шагом 2
  • Второй сверточный слой: 256 фильтров 5×5
  • Max pooling 3×3 с шагом 2
  • Третий сверточный слой: 384 фильтра 3×3
  • Четвертый сверточный слой: 384 фильтра 3×3
  • Пятый сверточный слой: 256 фильтров 3×3
  • Max pooling 3×3 с шагом 2
  • Три полносвязных слоя: 4096 → 4096 → 1000 нейронов

Быстрая настройка окружения

Для комфортной работы с AlexNet вам понадобится сервер с минимум 8GB RAM и желательно GPU. Если нет своего железа, можно взять VPS с GPU или выделенный сервер.

Устанавливаем необходимые пакеты:


# Обновляем систему
sudo apt update && sudo apt upgrade -y

# Устанавливаем Python и pip
sudo apt install python3 python3-pip python3-venv -y

# Создаем виртуальное окружение
python3 -m venv alexnet_env
source alexnet_env/bin/activate

# Устанавливаем PyTorch (для CPU)
pip install torch torchvision torchaudio

# Или для GPU (если есть CUDA)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Дополнительные пакеты
pip install matplotlib numpy pillow jupyter

Создаем AlexNet с нуля

Теперь самое интересное — код. Создаем файл `alexnet.py` и пишем нашу сеть:


import torch
import torch.nn as nn
import torch.nn.functional as F

class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()

# Сверточные слои
self.features = nn.Sequential(
# Первый блок
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),

# Второй блок
nn.Conv2d(96, 256, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),

# Третий блок
nn.Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),

# Четвертый блок
nn.Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),

# Пятый блок
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)

# Полносвязные слои
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)

def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

# Создаем модель
model = AlexNet(num_classes=10) # Для CIFAR-10
print(model)

Тренировка модели — практический пример

Создаем скрипт для тренировки на датасете CIFAR-10:


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Настройка устройства
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Используем устройство: {device}")

# Подготовка данных
transform = transforms.Compose([
transforms.Resize((224, 224)), # AlexNet ожидает 224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Загрузка CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# Создаем модель и переносим на устройство
model = AlexNet(num_classes=10).to(device)

# Функция потерь и оптимизатор
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Тренировка
def train_model(epochs=10):
model.train()
for epoch in range(epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()

if i % 100 == 99: # Вывод каждые 100 батчей
print(f'Эпоха {epoch + 1}, батч {i + 1}: потеря {running_loss / 100:.3f}')
running_loss = 0.0

# Запускаем тренировку
train_model(epochs=5)

Сравнение производительности и требований

Параметр AlexNet VGG16 ResNet50
Параметры 61M 138M 25M
Время обучения (эпоха) ~10 мин ~25 мин ~15 мин
Память GPU 2-4 GB 6-8 GB 4-6 GB
Точность на ImageNet ~57% ~71% ~76%

Готовые решения и альтернативы

PyTorch предоставляет готовую реализацию AlexNet:


import torchvision.models as models

# Загрузка предобученной модели
model = models.alexnet(pretrained=True)

# Или создание с нуля
model = models.alexnet(pretrained=False)

# Для fine-tuning замораживаем веса
for param in model.features.parameters():
param.requires_grad = False

# Заменяем последний слой под свою задачу
model.classifier[6] = nn.Linear(4096, 10) # Для 10 классов

Альтернативные фреймворки:

  • TensorFlow/Keras — более простой синтаксис, но менее гибкий
  • ONNX — для кроссплатформенного деплоя
  • TorchScript — для production-серверов

Деплой на production сервер

Для деплоя создаем простой Flask API:


from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io

app = Flask(__name__)

# Загружаем модель
model = torch.load('alexnet_model.pth', map_location='cpu')
model.eval()

# Трансформации для входных данных
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

@app.route('/predict', methods=['POST'])
def predict():
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400

file = request.files['image']
image = Image.open(io.BytesIO(file.read()))

# Предобработка
input_tensor = transform(image).unsqueeze(0)

# Предсказание
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)

return jsonify({'class': predicted.item()})

if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)

Типичные ошибки и решения

CUDA out of memory — уменьшите batch_size или используйте gradient accumulation:


# Вместо batch_size=128 используйте 32 или 64
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

# Или используйте gradient accumulation
accumulation_steps = 4
for i, (inputs, labels) in enumerate(trainloader):
outputs = model(inputs)
loss = criterion(outputs, labels) / accumulation_steps
loss.backward()

if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

Slow training — используйте DataLoader с несколькими workers:


trainloader = DataLoader(trainset, batch_size=128, shuffle=True,
num_workers=4, pin_memory=True)

Интересные факты и нестандартные применения

AlexNet можно использовать не только для классификации изображений:

  • Transfer learning — используйте предобученные веса как feature extractor
  • Стиль-трансфер — извлекайте features для переноса стиля
  • Детекция аномалий — обучите автоэнкодер на базе AlexNet
  • Embeddings для поиска — используйте предпоследний слой для получения векторных представлений

Забавный факт: AlexNet изначально тренировался на двух GTX 580, а сейчас можно обучить на смартфоне!

Автоматизация и скрипты

Создаем скрипт для автоматического мониторинга обучения:


import wandb # pip install wandb
import matplotlib.pyplot as plt

# Инициализация W&B
wandb.init(project="alexnet-training")

def train_with_logging():
model.train()
for epoch in range(epochs):
epoch_loss = 0
correct = 0
total = 0

for batch_idx, (data, target) in enumerate(trainloader):
data, target = data.to(device), target.to(device)

optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

epoch_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

# Логирование каждые 100 батчей
if batch_idx % 100 == 0:
wandb.log({
'batch_loss': loss.item(),
'batch_accuracy': 100. * correct / total
})

# Логирование эпохи
wandb.log({
'epoch': epoch,
'epoch_loss': epoch_loss / len(trainloader),
'epoch_accuracy': 100. * correct / total
})

Заключение и рекомендации

AlexNet — это отличная точка входа в мир сверточных нейронных сетей. Она достаточно простая для понимания, но при этом демонстрирует все ключевые концепции. Для production-окружений рекомендую:

  • Используйте официальную реализацию PyTorch для стабильности
  • Для серьезных задач рассмотрите более современные архитектуры (ResNet, EfficientNet)
  • Обязательно используйте предобученные веса для transfer learning
  • Мониторьте потребление ресурсов и оптимизируйте batch_size
  • Настройте автоматическое логирование и алерты

Если планируете серьезно заниматься ML, возьмите VPS с GPU или выделенный сервер — на CPU современные модели обучаются мучительно долго.

Главное — не бойтесь экспериментировать. AlexNet может показаться устаревшей, но она отлично подходит для изучения основ и быстрого прототипирования. А понимание её архитектуры поможет разобраться с более сложными моделями.


В этой статье собрана информация и материалы из различных интернет-источников. Мы признаем и ценим работу всех оригинальных авторов, издателей и веб-сайтов. Несмотря на то, что были приложены все усилия для надлежащего указания исходного материала, любая непреднамеренная оплошность или упущение не являются нарушением авторских прав. Все упомянутые товарные знаки, логотипы и изображения являются собственностью соответствующих владельцев. Если вы считаете, что какой-либо контент, использованный в этой статье, нарушает ваши авторские права, немедленно свяжитесь с нами для рассмотрения и принятия оперативных мер.

Данная статья предназначена исключительно для ознакомительных и образовательных целей и не ущемляет права правообладателей. Если какой-либо материал, защищенный авторским правом, был использован без должного упоминания или с нарушением законов об авторском праве, это непреднамеренно, и мы исправим это незамедлительно после уведомления. Обратите внимание, что переиздание, распространение или воспроизведение части или всего содержимого в любой форме запрещено без письменного разрешения автора и владельца веб-сайта. Для получения разрешений или дополнительных запросов, пожалуйста, свяжитесь с нами.

Leave a reply

Your email address will not be published. Required fields are marked