Files
su2-img/02_training.ipynb
Lukáš Trkan e8c1c3f2ee update
2026-05-01 22:06:01 +02:00

331 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 02 — Trénování YOLOv8 na VisDrone\n",
"\n",
"Trénujeme YOLOv8s (small) fine-tuned z ImageNet vah na VisDrone dataset.\n",
"Model detekuje 4 třídy vozidel z leteckých snímků.\n",
"\n",
"**GPU doporučeno** (trénink na CPU trvá ~10× déle)."
]
},
{
"cell_type": "markdown",
"id": "md-8d392d",
"metadata": {},
"source": [
"## Kontrola prostředí\n",
"\n",
"Zjistíme, zda je k dispozici GPU (CUDA) nebo Apple Silicon (MPS), a vypíšeme verzi PyTorche. Trénink na GPU je přibližně 10× rychlejší než na CPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"PyTorch: 2.4.1+cu124\n",
"CUDA dostupná: True\n",
"GPU: NVIDIA GeForce RTX 4090\n"
]
}
],
"source": [
"%pip install ultralytics --quiet\n",
"\n",
"import torch\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA dostupná: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
"elif torch.backends.mps.is_available():\n",
" print(\"MPS (Apple Silicon) dostupné\")"
]
},
{
"cell_type": "markdown",
"id": "md-af9159",
"metadata": {},
"source": [
"## Načtení předtrénovaného modelu\n",
"\n",
"Použijeme YOLOv8s (small) s váhami předtrénovanými na ImageNet. Tento model budeme fine-tunovat na VisDrone dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating new Ultralytics Settings v0.0.6 file ✅ \n",
"View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'\n",
"Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.\n",
"Downloading https://github.com/ultralytics/assets/releases/download/v8.4.0/yolov8s.pt to 'yolov8s.pt': 100% ━━━━━━━━━━━━ 21.5MB 60.0MB/s 0.4s\n",
"YOLOv8s summary: 129 layers, 11,166,560 parameters, 0 gradients, 28.8 GFLOPs\n",
"Model načten: (129, 11166560, 0, 28.816844800000002)\n"
]
}
],
"source": [
"from ultralytics import YOLO\n",
"from pathlib import Path\n",
"\n",
"YAML = Path(\"data/yolo_visdrone/dataset.yaml\")\n",
"assert YAML.exists(), f\"Nejprve spusť 01_dataset_prep.ipynb! Chybí: {YAML}\"\n",
"\n",
"model = YOLO(\"models/yolov8s.pt\") \n",
"print(\"Model načten:\", model.info())"
]
},
{
"cell_type": "markdown",
"id": "md-58715d",
"metadata": {},
"source": [
"## Trénování modelu\n",
"\n",
"Spustíme trénink na 50 epoch s early stoppingem (patience=10). Augmentace (rotace, překlápění, mosaic) pomáhají modelu generalizovat na různé letecké pohledy a světelné podmínky."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pathlib import Path\n",
"\n",
"\n",
"if torch.cuda.is_available():\n",
" DEVICE = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
" DEVICE = \"mps\"\n",
"else:\n",
" DEVICE = \"cpu\"\n",
"\n",
"print(f\"Trénink na: {DEVICE}\")\n",
"\n",
"results = model.train(\n",
" data=str(YAML),\n",
" epochs=50,\n",
" imgsz=640,\n",
" batch=16,\n",
" device=DEVICE,\n",
" name=\"visdrone_vehicles\",\n",
" project=str(Path(\"runs/train\").resolve()),\n",
" patience=10, # early stopping\n",
" save=True,\n",
" save_period=10,\n",
" val=True,\n",
" plots=True,\n",
" degrees=15.0,\n",
" flipud=0.5,\n",
" fliplr=0.5,\n",
" mosaic=1.0,\n",
" scale=0.5,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "md-1fc179",
"metadata": {},
"source": [
"## Výsledky tréninku\n",
"\n",
"Zobrazíme grafy průběhu trénování (loss, mAP), matici záměn a příklady predikcí na validační sadě."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"from pathlib import Path\n",
"import glob\n",
"\n",
"run_dir = Path(\"runs/train/visdrone_vehicles\")\n",
"\n",
"for plot_name in [\"results.png\", \"confusion_matrix.png\", \"PR_curve.png\", \"val_batch0_pred.jpg\"]:\n",
" p = run_dir / plot_name\n",
" if p.exists():\n",
" fig, ax = plt.subplots(figsize=(12, 6))\n",
" ax.imshow(mpimg.imread(p))\n",
" ax.axis(\"off\")\n",
" ax.set_title(plot_name)\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "md-021cf0",
"metadata": {},
"source": [
"## Validace nejlepšího modelu\n",
"\n",
"Načteme nejlepší checkpoint z tréninku a spustíme plnou validaci na validační sadě. Sledujeme metriku mAP50, která je standardem pro detekci objektů."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Validace nejlepšího modelu\n",
"best_weights = Path(\"runs/train/visdrone_vehicles/weights/best.pt\")\n",
"assert best_weights.exists(), \"Trénink ještě neproběhl nebo selhal\"\n",
"\n",
"model_best = YOLO(str(best_weights))\n",
"val_results = model_best.val(data=str(YAML), imgsz=640, split=\"val\")\n",
"\n",
"print(f\"\\nmAP50: {val_results.box.map50:.4f}\")\n",
"print(f\"mAP50-95: {val_results.box.map:.4f}\")\n",
"for i, cls in enumerate([\"car\", \"van\", \"truck\", \"bus\"]):\n",
" ap = val_results.box.ap50[i] if i < len(val_results.box.ap50) else float('nan')\n",
" print(f\" AP50[{cls}]: {ap:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "md-f05d8c",
"metadata": {},
"source": [
"## Fine-tuning na vlastních datech (Vrchlabí)\n",
"\n",
"Pokud je k dispozici lokálně anotovaný dataset z Vrchlabí, provedeme fine-tuning modelu natrénovaného na VisDrone. Tímto krokem model přizpůsobíme konkrétním podmínkám českých měst."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import zipfile\n",
"import shutil\n",
"from pathlib import Path\n",
"\n",
"VRCHLABI_ZIP = Path('data/dataset_vrchlabi.zip')\n",
"VRCHLABI_YAML = Path('data/vrchlabi/dataset.yaml')\n",
"\n",
"if VRCHLABI_ZIP.exists() and not VRCHLABI_YAML.exists():\n",
" print('Rozbaluji dataset_vrchlabi.zip ...')\n",
" with zipfile.ZipFile(VRCHLABI_ZIP) as zf:\n",
" zf.extractall('.data/')\n",
" print('Hotovo.')\n",
"\n",
"if not VRCHLABI_YAML.exists():\n",
" print('Vrchlabí dataset nenalezen — fine-tuning přeskočen.')\n",
" print('Spusť scripts/prepare_dataset.py nebo nahraj dataset_vrchlabi.zip.')\n",
"else:\n",
" print(f'Fine-tuning na: {VRCHLABI_YAML}')\n",
"\n",
" backup = best_weights.parent / 'best_before_finetune.pt'\n",
" shutil.copy(best_weights, backup)\n",
" print(f'Záloha uložena: {backup}')\n",
"\n",
" vrchlabi_model = YOLO(str(best_weights))\n",
"\n",
" vrchlabi_results = vrchlabi_model.train(\n",
" data=str(VRCHLABI_YAML),\n",
" epochs=30,\n",
" imgsz=256,\n",
" batch=32,\n",
" device=DEVICE,\n",
" name='vrchlabi_finetune',\n",
" project=str(Path(\"runs/train\").resolve()),\n",
" patience=10,\n",
" save=True,\n",
" val=True,\n",
" plots=True,\n",
" degrees=10.0,\n",
" flipud=0.5,\n",
" fliplr=0.5,\n",
" mosaic=0.5,\n",
" lr0=0.001,\n",
" )\n",
"\n",
" best_weights = Path(vrchlabi_results.save_dir) / 'weights' / 'best.pt'\n",
" print(f'Vrchlabí fine-tuning hotov. Nejlepší váhy: {best_weights}')"
]
},
{
"cell_type": "markdown",
"id": "md-805d37",
"metadata": {},
"source": [
"## Uložení konfigurace modelu\n",
"\n",
"Cestu k nejlepšímu modelu zapíšeme do `models/model_config.json`, odkud ji načtou následující notebooky."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nejlepší model uložen v: /workspace/runs/detect/runs/detect/runs/train/vrchlabi_finetune-2/weights/best.pt\n",
"Konfigurace zapsána do: model_config.json\n"
]
}
],
"source": [
"import json\n",
"\n",
"config = {\"best_model\": str(best_weights.resolve())}\n",
"with open(\"models/model_config.json\", \"w\") as f:\n",
" json.dump(config, f)\n",
"\n",
"print(f\"Nejlepší model uložen v: {best_weights}\")\n",
"print(f\"Konfigurace zapsána do: models/model_config.json\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv (3.14.4)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.14.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}