update
This commit is contained in:
@@ -12,6 +12,16 @@
|
||||
"**GPU doporučeno** (trénink na CPU trvá ~10× déle)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-8d392d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Kontrola prostředí\n",
|
||||
"\n",
|
||||
"Zjistíme, zda je k dispozici GPU (CUDA) nebo Apple Silicon (MPS), a vypíšeme verzi PyTorche. Trénink na GPU je přibližně 10× rychlejší než na CPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -43,9 +53,19 @@
|
||||
" print(\"MPS (Apple Silicon) dostupné\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-af9159",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Načtení předtrénovaného modelu\n",
|
||||
"\n",
|
||||
"Použijeme YOLOv8s (small) s váhami předtrénovanými na ImageNet. Tento model budeme fine-tunovat na VisDrone dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -68,12 +88,20 @@
|
||||
"YAML = Path(\"data/yolo_visdrone/dataset.yaml\")\n",
|
||||
"assert YAML.exists(), f\"Nejprve spusť 01_dataset_prep.ipynb! Chybí: {YAML}\"\n",
|
||||
"\n",
|
||||
"# Zvolíme YOLOv8s — dobrý poměr přesnosti a rychlosti\n",
|
||||
"# Alternativy: yolov8n (nejrychlejší), yolov8m (přesnější)\n",
|
||||
"model = YOLO(\"models/yolov8s.pt\") # stáhne předtrénované ImageNet váhy\n",
|
||||
"model = YOLO(\"models/yolov8s.pt\") \n",
|
||||
"print(\"Model načten:\", model.info())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-58715d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Trénování modelu\n",
|
||||
"\n",
|
||||
"Spustíme trénink na 50 epoch s early stoppingem (patience=10). Augmentace (rotace, překlápění, mosaic) pomáhají modelu generalizovat na různé letecké pohledy a světelné podmínky."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -83,7 +111,7 @@
|
||||
"import torch\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"# Detekce dostupného zařízení\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" DEVICE = \"cuda\"\n",
|
||||
"elif torch.backends.mps.is_available():\n",
|
||||
@@ -106,22 +134,30 @@
|
||||
" save_period=10,\n",
|
||||
" val=True,\n",
|
||||
" plots=True,\n",
|
||||
" # Augmentace pro letecké snímky\n",
|
||||
" degrees=15.0, # rotace\n",
|
||||
" flipud=0.5, # vertikální flip\n",
|
||||
" degrees=15.0,\n",
|
||||
" flipud=0.5,\n",
|
||||
" fliplr=0.5,\n",
|
||||
" mosaic=1.0,\n",
|
||||
" scale=0.5,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-1fc179",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Výsledky tréninku\n",
|
||||
"\n",
|
||||
"Zobrazíme grafy průběhu trénování (loss, mAP), matici záměn a příklady predikcí na validační sadě."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Zobrazení tréninkových grafů\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"from pathlib import Path\n",
|
||||
@@ -142,11 +178,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-021cf0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fine-tuning na Vrchlabí (volitelné)\n",
|
||||
"## Validace nejlepšího modelu\n",
|
||||
"\n",
|
||||
"Pokud existuje `dataset_vrchlabi.zip` nebo rozbalená složka `vrchlabi/`, provede se fine-tuning předchozího modelu na vlastních datech z Vrchlabí."
|
||||
"Načteme nejlepší checkpoint z tréninku a spustíme plnou validaci na validační sadě. Sledujeme metriku mAP50, která je standardem pro detekci objektů."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -169,6 +206,16 @@
|
||||
" print(f\" AP50[{cls}]: {ap:.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-f05d8c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fine-tuning na vlastních datech (Vrchlabí)\n",
|
||||
"\n",
|
||||
"Pokud je k dispozici lokálně anotovaný dataset z Vrchlabí, provedeme fine-tuning modelu natrénovaného na VisDrone. Tímto krokem model přizpůsobíme konkrétním podmínkám českých měst."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -182,7 +229,6 @@
|
||||
"VRCHLABI_ZIP = Path('dataset_vrchlabi.zip')\n",
|
||||
"VRCHLABI_YAML = Path('vrchlabi/dataset.yaml')\n",
|
||||
"\n",
|
||||
"# Rozbal zip pokud yaml ještě není\n",
|
||||
"if VRCHLABI_ZIP.exists() and not VRCHLABI_YAML.exists():\n",
|
||||
" print('Rozbaluji dataset_vrchlabi.zip ...')\n",
|
||||
" with zipfile.ZipFile(VRCHLABI_ZIP) as zf:\n",
|
||||
@@ -195,12 +241,11 @@
|
||||
"else:\n",
|
||||
" print(f'Fine-tuning na: {VRCHLABI_YAML}')\n",
|
||||
"\n",
|
||||
" # Záloha modelu před fine-tuningem\n",
|
||||
" backup = best_weights.parent / 'best_before_finetune.pt'\n",
|
||||
" shutil.copy(best_weights, backup)\n",
|
||||
" print(f'Záloha uložena: {backup}')\n",
|
||||
"\n",
|
||||
" vrchlabi_model = YOLO(str(best_weights)) # navazuje na VisDrone trénink\n",
|
||||
" vrchlabi_model = YOLO(str(best_weights))\n",
|
||||
"\n",
|
||||
" vrchlabi_results = vrchlabi_model.train(\n",
|
||||
" data=str(VRCHLABI_YAML),\n",
|
||||
@@ -218,16 +263,26 @@
|
||||
" flipud=0.5,\n",
|
||||
" fliplr=0.5,\n",
|
||||
" mosaic=0.5,\n",
|
||||
" lr0=0.001, # nižší LR pro fine-tuning\n",
|
||||
" lr0=0.001,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" best_weights = Path(vrchlabi_results.save_dir) / 'weights' / 'best.pt'\n",
|
||||
" print(f'Vrchlabí fine-tuning hotov. Nejlepší váhy: {best_weights}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "md-805d37",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Uložení konfigurace modelu\n",
|
||||
"\n",
|
||||
"Cestu k nejlepšímu modelu zapíšeme do `models/model_config.json`, odkud ji načtou následující notebooky."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -240,7 +295,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Uložení cesty k nejlepšímu modelu pro další notebooky\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"config = {\"best_model\": str(best_weights.resolve())}\n",
|
||||
|
||||
Reference in New Issue
Block a user