Home » Что должен знать каждый разработчик ML/AI о ONNX
Что должен знать каждый разработчик ML/AI о ONNX

Что должен знать каждый разработчик ML/AI о ONNX

Если вы деплоите модели машинного обучения на серверах, то наверняка сталкивались с классической проблемой: модель работает на одном стеке, а продакшн-сервер на другом. Или еще хуже — модель написана на Python, а заказчик просит интегрировать её в C++ приложение. Вот тут-то и приходит на помощь ONNX (Open Neural Network Exchange) — формат обмена нейросетевыми моделями, который превратился в настоящий швейцарский нож для ML-инженеров.

Эта статья поможет разобраться с ONNX на практике: от понимания принципов работы до развертывания на собственном сервере. Мы рассмотрим как настроить среду, конвертировать модели, и что самое главное — как интегрировать всё это в ваш production pipeline с минимальными потерями производительности.

Как работает ONNX: под капотом

ONNX — это, по сути, протокол описания вычислительных графов нейросетей. Представьте его как JSON для моделей машинного обучения. Формат использует Protocol Buffers для сериализации и описывает граф как набор операторов (nodes) и тензоров (tensors).

Ключевые компоненты ONNX:

  • Graph — основной граф вычислений
  • Node — операция в графе (Conv, MatMul, Relu и т.д.)
  • ValueInfo — описание входов/выходов
  • TensorProto — сериализованные веса модели
  • Opset — версия набора операторов

Магия в том, что ONNX Runtime может выполнять эти графы на разных устройствах и с разными провайдерами выполнения (CPU, CUDA, OpenVINO, TensorRT и др.).

Пошаговая настройка среды

Начнем с чистого Ubuntu сервера. Если у вас еще нет подходящего сервера, можете заказать VPS или выделенный сервер для более серьезных задач.

Установка базовых зависимостей

# Обновляем систему
sudo apt update && sudo apt upgrade -y

# Устанавливаем Python и pip
sudo apt install python3 python3-pip python3-venv git -y

# Создаем виртуальную среду
python3 -m venv onnx_env
source onnx_env/bin/activate

# Устанавливаем основные пакеты
pip install onnx onnxruntime numpy torch torchvision tensorflow

# Для GPU поддержки (если есть CUDA)
pip install onnxruntime-gpu

# Полезные утилиты
pip install onnx-tools netron matplotlib pillow

Проверка установки

# Проверяем версии
python -c "import onnx; print(f'ONNX version: {onnx.__version__}')"
python -c "import onnxruntime; print(f'ONNX Runtime version: {onnxruntime.__version__}')"

# Проверяем доступные провайдеры
python -c "import onnxruntime; print('Available providers:', onnxruntime.get_available_providers())"

Практические примеры конвертации

Конвертация из PyTorch

import torch
import torch.nn as nn
import torch.onnx

# Простая модель для демонстрации
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        return self.sigmoid(self.linear(x))

# Создаем модель
model = SimpleModel()
model.eval()

# Создаем тестовый вход
dummy_input = torch.randn(1, 10)

# Экспортируем в ONNX
torch.onnx.export(
    model,                     # модель
    dummy_input,               # примерный вход
    "model.onnx",              # имя файла
    export_params=True,        # экспортировать параметры
    opset_version=11,          # версия операторов
    do_constant_folding=True,  # оптимизация
    input_names=['input'],     # имена входов
    output_names=['output'],   # имена выходов
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

Конвертация из TensorFlow

import tensorflow as tf
from tf2onnx import convert

# Загружаем модель TensorFlow
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Сохраняем в формате SavedModel
model.save('tf_model')

# Конвертируем в ONNX
onnx_model, _ = convert.from_keras(model, opset=13)

# Сохраняем
with open("tf_model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

Работа с ONNX Runtime

Базовый инференс

import onnxruntime as ort
import numpy as np

# Создаем сессию
session = ort.InferenceSession("model.onnx")

# Получаем информацию о входах и выходах
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

print(f"Input name: {input_name}")
print(f"Output name: {output_name}")

# Подготавливаем данные
input_data = np.random.randn(1, 10).astype(np.float32)

# Запускаем инференс
result = session.run([output_name], {input_name: input_data})
print(f"Result: {result[0]}")

Продвинутая конфигурация

import onnxruntime as ort

# Настройки сессии
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL

# Настройки провайдера
providers = [
    ('CUDAExecutionProvider', {
        'device_id': 0,
        'arena_extend_strategy': 'kNextPowerOfTwo',
        'gpu_mem_limit': 2 * 1024 * 1024 * 1024,  # 2GB
        'cudnn_conv_algo_search': 'EXHAUSTIVE',
        'do_copy_in_default_stream': True,
    }),
    'CPUExecutionProvider',
]

# Создаем оптимизированную сессию
session = ort.InferenceSession(
    "model.onnx",
    sess_options=session_options,
    providers=providers
)

Оптимизация и производительность

Квантизация модели

from onnxruntime.quantization import quantize_dynamic, QuantType

# Динамическая квантизация (уменьшает размер модели)
quantize_dynamic(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    weight_type=QuantType.QUInt8
)

# Статическая квантизация (требует калибровочных данных)
from onnxruntime.quantization import quantize_static, CalibrationDataReader

class DataReader(CalibrationDataReader):
    def __init__(self, calibration_data):
        self.data = calibration_data
        self.counter = 0
    
    def get_next(self):
        if self.counter < len(self.data):
            result = {"input": self.data[self.counter]}
            self.counter += 1
            return result
        return None

# Подготавливаем калибровочные данные
calibration_data = [np.random.randn(1, 10).astype(np.float32) for _ in range(100)]
data_reader = DataReader(calibration_data)

quantize_static(
    model_input="model.onnx",
    model_output="model_static_quantized.onnx",
    calibration_data_reader=data_reader
)

Сравнение производительности

Метод Размер модели Скорость инференса Точность Потребление памяти
Оригинальная модель 100% 1x 100% 100%
Динамическая квантизация ~25% 2-3x ~99% ~50%
Статическая квантизация ~25% 3-4x ~98% ~30%
TensorRT (GPU) ~30% 5-10x ~99% ~40%

Создание производственного сервиса

FastAPI сервер с ONNX

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import onnxruntime as ort
import numpy as np
from typing import List
import uvicorn

app = FastAPI(title="ONNX Model Server")

# Глобальная сессия (загружается один раз)
session = None

class PredictionRequest(BaseModel):
    data: List[List[float]]

class PredictionResponse(BaseModel):
    predictions: List[List[float]]
    processing_time: float

@app.on_event("startup")
async def load_model():
    global session
    try:
        session = ort.InferenceSession("model.onnx")
        print("Model loaded successfully")
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    import time
    start_time = time.time()
    
    try:
        # Подготавливаем данные
        input_data = np.array(request.data, dtype=np.float32)
        
        # Получаем имена входов/выходов
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name
        
        # Запускаем инференс
        result = session.run([output_name], {input_name: input_data})
        
        processing_time = time.time() - start_time
        
        return PredictionResponse(
            predictions=result[0].tolist(),
            processing_time=processing_time
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": session is not None}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Docker контейнер

# Dockerfile
FROM python:3.9-slim

WORKDIR /app

# Устанавливаем зависимости
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Копируем код и модель
COPY . .

# Экспонируем порт
EXPOSE 8000

# Запускаем приложение
CMD ["python", "server.py"]
# requirements.txt
fastapi==0.68.0
uvicorn==0.15.0
onnxruntime==1.15.1
numpy==1.21.0
pydantic==1.8.2
# Команды для сборки и запуска
docker build -t onnx-server .
docker run -p 8000:8000 onnx-server

# Тестируем
curl -X POST "http://localhost:8000/predict" \
     -H "Content-Type: application/json" \
     -d '{"data": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]]}'

Альтернативные решения и сравнение

Решение Языки Производительность Поддерживаемые фреймворки Сложность
ONNX Runtime Python, C++, C#, Java Высокая PyTorch, TensorFlow, Scikit-learn Низкая
TensorFlow Serving Python, C++ Высокая Только TensorFlow Средняя
TorchServe Python, Java Средняя Только PyTorch Средняя
TensorRT C++, Python Очень высокая Через ONNX Высокая
OpenVINO C++, Python Высокая Через ONNX Высокая

Полезные утилиты и инструменты

Анализ модели с помощью Netron

# Устанавливаем Netron
pip install netron

# Запускаем веб-интерфейс
netron model.onnx

# Или из командной строки
python -c "import netron; netron.start('model.onnx')"

Валидация и оптимизация

import onnx
from onnx import checker, optimizer

# Загружаем модель
model = onnx.load("model.onnx")

# Проверяем корректность
checker.check_model(model)

# Оптимизируем
optimized_model = optimizer.optimize(model)

# Сохраняем оптимизированную модель
onnx.save(optimized_model, "model_optimized.onnx")

Бенчмаркинг

import time
import numpy as np
import onnxruntime as ort

def benchmark_model(model_path, input_shape, num_runs=1000):
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    # Прогреваем модель
    dummy_input = np.random.randn(*input_shape).astype(np.float32)
    for _ in range(10):
        session.run([output_name], {input_name: dummy_input})
    
    # Измеряем время
    times = []
    for _ in range(num_runs):
        start = time.time()
        session.run([output_name], {input_name: dummy_input})
        times.append(time.time() - start)
    
    return {
        'mean_time': np.mean(times),
        'std_time': np.std(times),
        'min_time': np.min(times),
        'max_time': np.max(times),
        'throughput': 1.0 / np.mean(times)
    }

# Тестируем
results = benchmark_model("model.onnx", (1, 10))
print(f"Average inference time: {results['mean_time']:.4f}s")
print(f"Throughput: {results['throughput']:.2f} requests/sec")

Интеграция с CI/CD

GitHub Actions workflow

# .github/workflows/model-deploy.yml
name: Deploy ONNX Model

on:
  push:
    branches: [ main ]
    paths: [ 'models/**' ]

jobs:
  validate-and-deploy:
    runs-on: ubuntu-latest
    
    steps:
    - uses: actions/checkout@v3
    
    - name: Setup Python
      uses: actions/setup-python@v4
      with:
        python-version: '3.9'
    
    - name: Install dependencies
      run: |
        pip install onnx onnxruntime numpy
    
    - name: Validate ONNX model
      run: |
        python -c "
        import onnx
        model = onnx.load('models/model.onnx')
        onnx.checker.check_model(model)
        print('Model validation passed')
        "
    
    - name: Run performance benchmark
      run: |
        python benchmark.py
    
    - name: Deploy to server
      run: |
        scp models/model.onnx user@server:/opt/models/
        ssh user@server "sudo systemctl restart onnx-service"

Мониторинг и логирование

import logging
import time
from functools import wraps

# Настройка логирования
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def log_inference_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        
        logger.info(f"Inference completed in {end_time - start_time:.4f}s")
        return result
    return wrapper

class ModelMetrics:
    def __init__(self):
        self.request_count = 0
        self.total_time = 0
        self.error_count = 0
    
    def record_request(self, processing_time, error=False):
        self.request_count += 1
        self.total_time += processing_time
        if error:
            self.error_count += 1
    
    def get_stats(self):
        if self.request_count == 0:
            return {"requests": 0, "avg_time": 0, "error_rate": 0}
        
        return {
            "requests": self.request_count,
            "avg_time": self.total_time / self.request_count,
            "error_rate": self.error_count / self.request_count
        }

# Использование в FastAPI
metrics = ModelMetrics()

@app.post("/predict")
async def predict(request: PredictionRequest):
    start_time = time.time()
    error = False
    
    try:
        # ... код инференса ...
        pass
    except Exception as e:
        error = True
        logger.error(f"Prediction error: {e}")
        raise
    finally:
        processing_time = time.time() - start_time
        metrics.record_request(processing_time, error)

@app.get("/metrics")
async def get_metrics():
    return metrics.get_stats()

Нестандартные способы использования

Batch-обработка с очередями

import asyncio
import numpy as np
from collections import deque
from datetime import datetime, timedelta

class BatchProcessor:
    def __init__(self, model_path, batch_size=32, timeout=0.1):
        self.session = ort.InferenceSession(model_path)
        self.batch_size = batch_size
        self.timeout = timeout
        self.queue = deque()
        self.futures = deque()
        
    async def process_batch(self):
        while True:
            if len(self.queue) >= self.batch_size:
                await self._run_batch()
            else:
                await asyncio.sleep(0.01)
    
    async def _run_batch(self):
        if not self.queue:
            return
        
        # Собираем batch
        batch_data = []
        batch_futures = []
        
        for _ in range(min(self.batch_size, len(self.queue))):
            data, future = self.queue.popleft()
            batch_data.append(data)
            batch_futures.append(future)
        
        # Запускаем инференс
        input_array = np.array(batch_data)
        input_name = self.session.get_inputs()[0].name
        output_name = self.session.get_outputs()[0].name
        
        results = self.session.run([output_name], {input_name: input_array})
        
        # Возвращаем результаты
        for i, future in enumerate(batch_futures):
            future.set_result(results[0][i])
    
    async def predict(self, data):
        future = asyncio.Future()
        self.queue.append((data, future))
        return await future

# Использование
processor = BatchProcessor("model.onnx")
asyncio.create_task(processor.process_batch())

Кеширование результатов

import hashlib
import pickle
from functools import wraps

class ResultCache:
    def __init__(self, max_size=1000):
        self.cache = {}
        self.max_size = max_size
        self.access_order = deque()
    
    def _hash_input(self, data):
        return hashlib.md5(pickle.dumps(data)).hexdigest()
    
    def get(self, key):
        if key in self.cache:
            # Обновляем порядок доступа
            self.access_order.remove(key)
            self.access_order.append(key)
            return self.cache[key]
        return None
    
    def set(self, key, value):
        if len(self.cache) >= self.max_size:
            # Удаляем самый старый элемент
            oldest = self.access_order.popleft()
            del self.cache[oldest]
        
        self.cache[key] = value
        self.access_order.append(key)

cache = ResultCache()

def cached_inference(func):
    @wraps(func)
    def wrapper(input_data):
        cache_key = cache._hash_input(input_data)
        
        # Проверяем кеш
        cached_result = cache.get(cache_key)
        if cached_result is not None:
            return cached_result
        
        # Выполняем инференс
        result = func(input_data)
        
        # Сохраняем в кеш
        cache.set(cache_key, result)
        return result
    
    return wrapper

Автоматизация и скрипты

Автоматическое тестирование модели

#!/bin/bash
# test_model.sh

MODEL_PATH="model.onnx"
TEST_DATA_PATH="test_data.npy"

echo "Testing ONNX model: $MODEL_PATH"

# Проверяем существование файлов
if [ ! -f "$MODEL_PATH" ]; then
    echo "Error: Model file not found"
    exit 1
fi

# Запускаем Python тест
python3 << EOF import onnx import onnxruntime as ort import numpy as np try: # Валидация модели model = onnx.load("$MODEL_PATH") onnx.checker.check_model(model) print("✓ Model validation passed") # Тест инференса session = ort.InferenceSession("$MODEL_PATH") input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name # Тестовые данные test_input = np.random.randn(1, 10).astype(np.float32) result = session.run([output_name], {input_name: test_input}) print(f"✓ Inference test passed, output shape: {result[0].shape}") # Бенчмарк import time times = [] for _ in range(100): start = time.time() session.run([output_name], {input_name: test_input}) times.append(time.time() - start) avg_time = np.mean(times) print(f"✓ Average inference time: {avg_time:.4f}s") if avg_time > 0.1:
        print("⚠ Warning: Inference time is high")
        exit(1)
    
    print("✓ All tests passed")
    
except Exception as e:
    print(f"✗ Test failed: {e}")
    exit(1)
EOF

Скрипт для автоматического развертывания

#!/bin/bash
# deploy_model.sh

set -e

MODEL_NAME="$1"
MODEL_PATH="$2"
SERVER_HOST="$3"

if [ $# -ne 3 ]; then
    echo "Usage: $0   "
    exit 1
fi

echo "Deploying model: $MODEL_NAME"

# Создаем временную директорию
TEMP_DIR=$(mktemp -d)
cd "$TEMP_DIR"

# Копируем модель
cp "$MODEL_PATH" .

# Создаем Docker образ
cat > Dockerfile << EOF FROM python:3.9-slim WORKDIR /app COPY requirements.txt . RUN pip install -r requirements.txt COPY . . EXPOSE 8000 CMD ["python", "server.py"] EOF cat > requirements.txt << EOF fastapi==0.68.0 uvicorn==0.15.0 onnxruntime==1.15.1 numpy==1.21.0 EOF # Создаем базовый сервер cat > server.py << EOF
from fastapi import FastAPI
import onnxruntime as ort
import numpy as np

app = FastAPI()
session = ort.InferenceSession("$(basename $MODEL_PATH)")

@app.post("/predict")
async def predict(data: dict):
    input_data = np.array(data["input"], dtype=np.float32)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: input_data})
    return {"output": result[0].tolist()}

@app.get("/health")
async def health():
    return {"status": "healthy"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
EOF

# Собираем и развертываем
docker build -t "$MODEL_NAME" .
docker save "$MODEL_NAME" | ssh "$SERVER_HOST" "docker load"

# Запускаем на сервере
ssh "$SERVER_HOST" << EOF docker stop "$MODEL_NAME" 2>/dev/null || true
    docker rm "$MODEL_NAME" 2>/dev/null || true
    docker run -d --name "$MODEL_NAME" -p 8000:8000 "$MODEL_NAME"
EOF

echo "Deployment completed successfully"
echo "Model available at: http://$SERVER_HOST:8000"

# Очищаем временные файлы
rm -rf "$TEMP_DIR"

Статистика и интересные факты

  • Размер экосистемы: ONNX поддерживает более 100 операторов и постоянно расширяется
  • Производительность: В среднем ONNX Runtime на 1.5-3x быстрее нативных фреймворков благодаря оптимизациям
  • Поддержка платформ: Работает на x86, ARM, и специализированных чипах (NPU, VPU)
  • Индустрия: Используется в Microsoft, Meta, Amazon, NVIDIA и многих других компаниях
  • Размер сообщества: Более 15,000 звезд на GitHub, активное сообщество разработчиков

Интересно, что ONNX Runtime может автоматически выбирать оптимальный провайдер выполнения в зависимости от архитектуры процессора. Например, на серверах с Intel CPU будет использоваться DNNL, на ARM — ARM Compute Library, а на GPU — CUDA или ROCm.

Заключение и рекомендации

ONNX — это не просто еще один формат моделей, а полноценная экосистема для развертывания машинного обучения в продакшене. Основные преимущества:

  • Универсальность: Одна модель работает везде
  • Производительность: Оптимизации на уровне графа и аппаратные ускорения
  • Простота интеграции: Минимальные зависимости, понятный API
  • Активное развитие: Регулярные обновления и новые возможности

Когда использовать ONNX:

  • Нужна кроссплатформенность
  • Требуется высокая производительность
  • Модели разрабатываются в разных фреймворках
  • Планируется развертывание на edge-устройствах

Когда стоит поискать альтернативы:

  • Используете экзотические операторы
  • Нужна нативная интеграция с фреймворком
  • Требуется онлайн-обучение

Для начала рекомендую попробовать ONNX на простых моделях, постепенно переходя к более сложным задачам. Обязательно настройте мониторинг производительности и тестируйте на реальных данных. И помните — правильная настройка провайдеров выполнения может дать прирост производительности в разы.

Полезные ссылки:


В этой статье собрана информация и материалы из различных интернет-источников. Мы признаем и ценим работу всех оригинальных авторов, издателей и веб-сайтов. Несмотря на то, что были приложены все усилия для надлежащего указания исходного материала, любая непреднамеренная оплошность или упущение не являются нарушением авторских прав. Все упомянутые товарные знаки, логотипы и изображения являются собственностью соответствующих владельцев. Если вы считаете, что какой-либо контент, использованный в этой статье, нарушает ваши авторские права, немедленно свяжитесь с нами для рассмотрения и принятия оперативных мер.

Данная статья предназначена исключительно для ознакомительных и образовательных целей и не ущемляет права правообладателей. Если какой-либо материал, защищенный авторским правом, был использован без должного упоминания или с нарушением законов об авторском праве, это непреднамеренно, и мы исправим это незамедлительно после уведомления. Обратите внимание, что переиздание, распространение или воспроизведение части или всего содержимого в любой форме запрещено без письменного разрешения автора и владельца веб-сайта. Для получения разрешений или дополнительных запросов, пожалуйста, свяжитесь с нами.

Leave a reply

Your email address will not be published. Required fields are marked