This commit is contained in:
Lukáš Trkan
2026-05-01 21:42:32 +02:00
parent 8df624f568
commit 2bd731620c
5 changed files with 435 additions and 185 deletions

View File

@@ -12,6 +12,16 @@
"**GPU doporučeno** (trénink na CPU trvá ~10× déle)."
]
},
{
"cell_type": "markdown",
"id": "md-8d392d",
"metadata": {},
"source": [
"## Kontrola prostředí\n",
"\n",
"Zjistíme, zda je k dispozici GPU (CUDA) nebo Apple Silicon (MPS), a vypíšeme verzi PyTorche. Trénink na GPU je přibližně 10× rychlejší než na CPU."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -43,9 +53,19 @@
" print(\"MPS (Apple Silicon) dostupné\")"
]
},
{
"cell_type": "markdown",
"id": "md-af9159",
"metadata": {},
"source": [
"## Načtení předtrénovaného modelu\n",
"\n",
"Použijeme YOLOv8s (small) s váhami předtrénovanými na ImageNet. Tento model budeme fine-tunovat na VisDrone dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -68,12 +88,20 @@
"YAML = Path(\"data/yolo_visdrone/dataset.yaml\")\n",
"assert YAML.exists(), f\"Nejprve spusť 01_dataset_prep.ipynb! Chybí: {YAML}\"\n",
"\n",
"# Zvolíme YOLOv8s — dobrý poměr přesnosti a rychlosti\n",
"# Alternativy: yolov8n (nejrychlejší), yolov8m (přesnější)\n",
"model = YOLO(\"models/yolov8s.pt\") # stáhne předtrénované ImageNet váhy\n",
"model = YOLO(\"models/yolov8s.pt\") \n",
"print(\"Model načten:\", model.info())"
]
},
{
"cell_type": "markdown",
"id": "md-58715d",
"metadata": {},
"source": [
"## Trénování modelu\n",
"\n",
"Spustíme trénink na 50 epoch s early stoppingem (patience=10). Augmentace (rotace, překlápění, mosaic) pomáhají modelu generalizovat na různé letecké pohledy a světelné podmínky."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -83,7 +111,7 @@
"import torch\n",
"from pathlib import Path\n",
"\n",
"# Detekce dostupného zařízení\n",
"\n",
"if torch.cuda.is_available():\n",
" DEVICE = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
@@ -106,22 +134,30 @@
" save_period=10,\n",
" val=True,\n",
" plots=True,\n",
" # Augmentace pro letecké snímky\n",
" degrees=15.0, # rotace\n",
" flipud=0.5, # vertikální flip\n",
" degrees=15.0,\n",
" flipud=0.5,\n",
" fliplr=0.5,\n",
" mosaic=1.0,\n",
" scale=0.5,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "md-1fc179",
"metadata": {},
"source": [
"## Výsledky tréninku\n",
"\n",
"Zobrazíme grafy průběhu trénování (loss, mAP), matici záměn a příklady predikcí na validační sadě."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Zobrazení tréninkových grafů\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"from pathlib import Path\n",
@@ -142,11 +178,12 @@
},
{
"cell_type": "markdown",
"id": "md-021cf0",
"metadata": {},
"source": [
"## Fine-tuning na Vrchlabí (volitelné)\n",
"## Validace nejlepšího modelu\n",
"\n",
"Pokud existuje `dataset_vrchlabi.zip` nebo rozbalená složka `vrchlabi/`, provede se fine-tuning předchozího modelu na vlastních datech z Vrchlabí."
"Načteme nejlepší checkpoint z tréninku a spustíme plnou validaci na validační sadě. Sledujeme metriku mAP50, která je standardem pro detekci objektů."
]
},
{
@@ -169,6 +206,16 @@
" print(f\" AP50[{cls}]: {ap:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "md-f05d8c",
"metadata": {},
"source": [
"## Fine-tuning na vlastních datech (Vrchlabí)\n",
"\n",
"Pokud je k dispozici lokálně anotovaný dataset z Vrchlabí, provedeme fine-tuning modelu natrénovaného na VisDrone. Tímto krokem model přizpůsobíme konkrétním podmínkám českých měst."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -182,7 +229,6 @@
"VRCHLABI_ZIP = Path('dataset_vrchlabi.zip')\n",
"VRCHLABI_YAML = Path('vrchlabi/dataset.yaml')\n",
"\n",
"# Rozbal zip pokud yaml ještě není\n",
"if VRCHLABI_ZIP.exists() and not VRCHLABI_YAML.exists():\n",
" print('Rozbaluji dataset_vrchlabi.zip ...')\n",
" with zipfile.ZipFile(VRCHLABI_ZIP) as zf:\n",
@@ -195,12 +241,11 @@
"else:\n",
" print(f'Fine-tuning na: {VRCHLABI_YAML}')\n",
"\n",
" # Záloha modelu před fine-tuningem\n",
" backup = best_weights.parent / 'best_before_finetune.pt'\n",
" shutil.copy(best_weights, backup)\n",
" print(f'Záloha uložena: {backup}')\n",
"\n",
" vrchlabi_model = YOLO(str(best_weights)) # navazuje na VisDrone trénink\n",
" vrchlabi_model = YOLO(str(best_weights))\n",
"\n",
" vrchlabi_results = vrchlabi_model.train(\n",
" data=str(VRCHLABI_YAML),\n",
@@ -218,16 +263,26 @@
" flipud=0.5,\n",
" fliplr=0.5,\n",
" mosaic=0.5,\n",
" lr0=0.001, # nižší LR pro fine-tuning\n",
" lr0=0.001,\n",
" )\n",
"\n",
" best_weights = Path(vrchlabi_results.save_dir) / 'weights' / 'best.pt'\n",
" print(f'Vrchlabí fine-tuning hotov. Nejlepší váhy: {best_weights}')"
]
},
{
"cell_type": "markdown",
"id": "md-805d37",
"metadata": {},
"source": [
"## Uložení konfigurace modelu\n",
"\n",
"Cestu k nejlepšímu modelu zapíšeme do `models/model_config.json`, odkud ji načtou následující notebooky."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -240,7 +295,6 @@
}
],
"source": [
"# Uložení cesty k nejlepšímu modelu pro další notebooky\n",
"import json\n",
"\n",
"config = {\"best_model\": str(best_weights.resolve())}\n",