- Home »

AlexNet в PyTorch — создание сверточной нейронной сети
На прошлой неделе мне пришлось деплоить систему классификации изображений на production сервер, и я понял, что многие из нас до сих пор боятся нейросетей как огня. А зря! AlexNet — это идеальная стартовая точка для понимания сверточных нейронных сетей. Она простая, хорошо документированная, и что самое главное — работает быстро даже на средних серверах.
Если вы деплоите приложения, настраиваете инфраструктуру или просто хотите разобраться с машинным обучением без многомесячного погружения в теорию, эта статья поможет вам развернуть AlexNet на вашем сервере за пару часов. Я покажу конкретные команды, разберу типичные ошибки и дам готовые конфиги.
Как работает AlexNet под капотом
AlexNet — это классическая архитектура сверточной нейронной сети, которая произвела революцию в компьютерном зрении в 2012 году. Она состоит из 8 слоев: 5 сверточных и 3 полносвязных. Главная фишка — использование ReLU активации и dropout для борьбы с переобучением.
Архитектура выглядит примерно так:
- Входной слой: 224×224×3 (RGB изображение)
- Первый сверточный слой: 96 фильтров 11×11 с шагом 4
- Max pooling 3×3 с шагом 2
- Второй сверточный слой: 256 фильтров 5×5
- Max pooling 3×3 с шагом 2
- Третий сверточный слой: 384 фильтра 3×3
- Четвертый сверточный слой: 384 фильтра 3×3
- Пятый сверточный слой: 256 фильтров 3×3
- Max pooling 3×3 с шагом 2
- Три полносвязных слоя: 4096 → 4096 → 1000 нейронов
Быстрая настройка окружения
Для комфортной работы с AlexNet вам понадобится сервер с минимум 8GB RAM и желательно GPU. Если нет своего железа, можно взять VPS с GPU или выделенный сервер.
Устанавливаем необходимые пакеты:
# Обновляем систему
sudo apt update && sudo apt upgrade -y
# Устанавливаем Python и pip
sudo apt install python3 python3-pip python3-venv -y
# Создаем виртуальное окружение
python3 -m venv alexnet_env
source alexnet_env/bin/activate
# Устанавливаем PyTorch (для CPU)
pip install torch torchvision torchaudio
# Или для GPU (если есть CUDA)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# Дополнительные пакеты
pip install matplotlib numpy pillow jupyter
Создаем AlexNet с нуля
Теперь самое интересное — код. Создаем файл `alexnet.py` и пишем нашу сеть:
import torch
import torch.nn as nn
import torch.nn.functional as F
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
# Сверточные слои
self.features = nn.Sequential(
# Первый блок
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
# Второй блок
nn.Conv2d(96, 256, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
# Третий блок
nn.Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
# Четвертый блок
nn.Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
# Пятый блок
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
# Полносвязные слои
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Создаем модель
model = AlexNet(num_classes=10) # Для CIFAR-10
print(model)
Тренировка модели — практический пример
Создаем скрипт для тренировки на датасете CIFAR-10:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Настройка устройства
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Используем устройство: {device}")
# Подготовка данных
transform = transforms.Compose([
transforms.Resize((224, 224)), # AlexNet ожидает 224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Загрузка CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# Создаем модель и переносим на устройство
model = AlexNet(num_classes=10).to(device)
# Функция потерь и оптимизатор
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Тренировка
def train_model(epochs=10):
model.train()
for epoch in range(epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # Вывод каждые 100 батчей
print(f'Эпоха {epoch + 1}, батч {i + 1}: потеря {running_loss / 100:.3f}')
running_loss = 0.0
# Запускаем тренировку
train_model(epochs=5)
Сравнение производительности и требований
Параметр | AlexNet | VGG16 | ResNet50 |
---|---|---|---|
Параметры | 61M | 138M | 25M |
Время обучения (эпоха) | ~10 мин | ~25 мин | ~15 мин |
Память GPU | 2-4 GB | 6-8 GB | 4-6 GB |
Точность на ImageNet | ~57% | ~71% | ~76% |
Готовые решения и альтернативы
PyTorch предоставляет готовую реализацию AlexNet:
import torchvision.models as models
# Загрузка предобученной модели
model = models.alexnet(pretrained=True)
# Или создание с нуля
model = models.alexnet(pretrained=False)
# Для fine-tuning замораживаем веса
for param in model.features.parameters():
param.requires_grad = False
# Заменяем последний слой под свою задачу
model.classifier[6] = nn.Linear(4096, 10) # Для 10 классов
Альтернативные фреймворки:
- TensorFlow/Keras — более простой синтаксис, но менее гибкий
- ONNX — для кроссплатформенного деплоя
- TorchScript — для production-серверов
Деплой на production сервер
Для деплоя создаем простой Flask API:
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
app = Flask(__name__)
# Загружаем модель
model = torch.load('alexnet_model.pth', map_location='cpu')
model.eval()
# Трансформации для входных данных
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
@app.route('/predict', methods=['POST'])
def predict():
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
file = request.files['image']
image = Image.open(io.BytesIO(file.read()))
# Предобработка
input_tensor = transform(image).unsqueeze(0)
# Предсказание
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
return jsonify({'class': predicted.item()})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
Типичные ошибки и решения
CUDA out of memory — уменьшите batch_size или используйте gradient accumulation:
# Вместо batch_size=128 используйте 32 или 64
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
# Или используйте gradient accumulation
accumulation_steps = 4
for i, (inputs, labels) in enumerate(trainloader):
outputs = model(inputs)
loss = criterion(outputs, labels) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Slow training — используйте DataLoader с несколькими workers:
trainloader = DataLoader(trainset, batch_size=128, shuffle=True,
num_workers=4, pin_memory=True)
Интересные факты и нестандартные применения
AlexNet можно использовать не только для классификации изображений:
- Transfer learning — используйте предобученные веса как feature extractor
- Стиль-трансфер — извлекайте features для переноса стиля
- Детекция аномалий — обучите автоэнкодер на базе AlexNet
- Embeddings для поиска — используйте предпоследний слой для получения векторных представлений
Забавный факт: AlexNet изначально тренировался на двух GTX 580, а сейчас можно обучить на смартфоне!
Автоматизация и скрипты
Создаем скрипт для автоматического мониторинга обучения:
import wandb # pip install wandb
import matplotlib.pyplot as plt
# Инициализация W&B
wandb.init(project="alexnet-training")
def train_with_logging():
model.train()
for epoch in range(epochs):
epoch_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(trainloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# Логирование каждые 100 батчей
if batch_idx % 100 == 0:
wandb.log({
'batch_loss': loss.item(),
'batch_accuracy': 100. * correct / total
})
# Логирование эпохи
wandb.log({
'epoch': epoch,
'epoch_loss': epoch_loss / len(trainloader),
'epoch_accuracy': 100. * correct / total
})
Заключение и рекомендации
AlexNet — это отличная точка входа в мир сверточных нейронных сетей. Она достаточно простая для понимания, но при этом демонстрирует все ключевые концепции. Для production-окружений рекомендую:
- Используйте официальную реализацию PyTorch для стабильности
- Для серьезных задач рассмотрите более современные архитектуры (ResNet, EfficientNet)
- Обязательно используйте предобученные веса для transfer learning
- Мониторьте потребление ресурсов и оптимизируйте batch_size
- Настройте автоматическое логирование и алерты
Если планируете серьезно заниматься ML, возьмите VPS с GPU или выделенный сервер — на CPU современные модели обучаются мучительно долго.
Главное — не бойтесь экспериментировать. AlexNet может показаться устаревшей, но она отлично подходит для изучения основ и быстрого прототипирования. А понимание её архитектуры поможет разобраться с более сложными моделями.
В этой статье собрана информация и материалы из различных интернет-источников. Мы признаем и ценим работу всех оригинальных авторов, издателей и веб-сайтов. Несмотря на то, что были приложены все усилия для надлежащего указания исходного материала, любая непреднамеренная оплошность или упущение не являются нарушением авторских прав. Все упомянутые товарные знаки, логотипы и изображения являются собственностью соответствующих владельцев. Если вы считаете, что какой-либо контент, использованный в этой статье, нарушает ваши авторские права, немедленно свяжитесь с нами для рассмотрения и принятия оперативных мер.
Данная статья предназначена исключительно для ознакомительных и образовательных целей и не ущемляет права правообладателей. Если какой-либо материал, защищенный авторским правом, был использован без должного упоминания или с нарушением законов об авторском праве, это непреднамеренно, и мы исправим это незамедлительно после уведомления. Обратите внимание, что переиздание, распространение или воспроизведение части или всего содержимого в любой форме запрещено без письменного разрешения автора и владельца веб-сайта. Для получения разрешений или дополнительных запросов, пожалуйста, свяжитесь с нами.