Files
su2-img/02_training.ipynb
Lukáš Trkan 3f072b862e first commit
2026-04-20 23:46:50 +02:00

167 lines
4.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 02 — Trénování YOLOv8 na VisDrone\n",
"\n",
"Trénujeme YOLOv8s (small) fine-tuned z ImageNet vah na VisDrone dataset.\n",
"Model detekuje 4 třídy vozidel z leteckých snímků.\n",
"\n",
"**GPU doporučeno** (trénink na CPU trvá ~10× déle)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install ultralytics --quiet\n",
"\n",
"import torch\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA dostupná: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
"elif torch.backends.mps.is_available():\n",
" print(\"MPS (Apple Silicon) dostupné\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"from pathlib import Path\n",
"\n",
"YAML = Path(\"data/yolo_visdrone/dataset.yaml\")\n",
"assert YAML.exists(), f\"Nejprve spusť 01_dataset_prep.ipynb! Chybí: {YAML}\"\n",
"\n",
"# Zvolíme YOLOv8s — dobrý poměr přesnosti a rychlosti\n",
"# Alternativy: yolov8n (nejrychlejší), yolov8m (přesnější)\n",
"model = YOLO(\"yolov8s.pt\") # stáhne předtrénované ImageNet váhy\n",
"print(\"Model načten:\", model.info())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Detekce dostupného zařízení\n",
"if torch.cuda.is_available():\n",
" DEVICE = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
" DEVICE = \"mps\"\n",
"else:\n",
" DEVICE = \"cpu\"\n",
"\n",
"print(f\"Trénink na: {DEVICE}\")\n",
"\n",
"results = model.train(\n",
" data=str(YAML),\n",
" epochs=50,\n",
" imgsz=640,\n",
" batch=16,\n",
" device=DEVICE,\n",
" name=\"visdrone_vehicles\",\n",
" project=\"runs/train\",\n",
" patience=10, # early stopping\n",
" save=True,\n",
" save_period=10,\n",
" val=True,\n",
" plots=True,\n",
" # Augmentace pro letecké snímky\n",
" degrees=15.0, # rotace\n",
" flipud=0.5, # vertikální flip\n",
" fliplr=0.5,\n",
" mosaic=1.0,\n",
" scale=0.5,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Zobrazení tréninkových grafů\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"from pathlib import Path\n",
"import glob\n",
"\n",
"run_dir = Path(\"runs/train/visdrone_vehicles\")\n",
"\n",
"for plot_name in [\"results.png\", \"confusion_matrix.png\", \"PR_curve.png\", \"val_batch0_pred.jpg\"]:\n",
" p = run_dir / plot_name\n",
" if p.exists():\n",
" fig, ax = plt.subplots(figsize=(12, 6))\n",
" ax.imshow(mpimg.imread(p))\n",
" ax.axis(\"off\")\n",
" ax.set_title(plot_name)\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Validace nejlepšího modelu\n",
"best_weights = Path(\"runs/train/visdrone_vehicles/weights/best.pt\")\n",
"assert best_weights.exists(), \"Trénink ještě neproběhl nebo selhal\"\n",
"\n",
"model_best = YOLO(str(best_weights))\n",
"val_results = model_best.val(data=str(YAML), imgsz=640, split=\"val\")\n",
"\n",
"print(f\"\\nmAP50: {val_results.box.map50:.4f}\")\n",
"print(f\"mAP50-95: {val_results.box.map:.4f}\")\n",
"for i, cls in enumerate([\"car\", \"van\", \"truck\", \"bus\"]):\n",
" ap = val_results.box.ap50[i] if i < len(val_results.box.ap50) else float('nan')\n",
" print(f\" AP50[{cls}]: {ap:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Uložení cesty k nejlepšímu modelu pro další notebooky\n",
"import json\n",
"\n",
"config = {\"best_model\": str(best_weights.resolve())}\n",
"with open(\"model_config.json\", \"w\") as f:\n",
" json.dump(config, f)\n",
"\n",
"print(f\"Nejlepší model uložen v: {best_weights}\")\n",
"print(f\"Konfigurace zapsána do: model_config.json\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}