Files
su2-img/02_training.ipynb
Lukáš Trkan f4f2352ec3 update
2026-04-29 08:53:16 +02:00

283 lines
9.5 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 02 — Trénování YOLOv8 na VisDrone\n",
"\n",
"Trénujeme YOLOv8s (small) fine-tuned z ImageNet vah na VisDrone dataset.\n",
"Model detekuje 4 třídy vozidel z leteckých snímků.\n",
"\n",
"**GPU doporučeno** (trénink na CPU trvá ~10× déle)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"PyTorch: 2.4.1+cu124\n",
"CUDA dostupná: True\n",
"GPU: NVIDIA GeForce RTX 4090\n"
]
}
],
"source": [
"%pip install ultralytics --quiet\n",
"\n",
"import torch\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA dostupná: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
"elif torch.backends.mps.is_available():\n",
" print(\"MPS (Apple Silicon) dostupné\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating new Ultralytics Settings v0.0.6 file ✅ \n",
"View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'\n",
"Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.\n",
"Downloading https://github.com/ultralytics/assets/releases/download/v8.4.0/yolov8s.pt to 'yolov8s.pt': 100% ━━━━━━━━━━━━ 21.5MB 60.0MB/s 0.4s\n",
"YOLOv8s summary: 129 layers, 11,166,560 parameters, 0 gradients, 28.8 GFLOPs\n",
"Model načten: (129, 11166560, 0, 28.816844800000002)\n"
]
}
],
"source": [
"from ultralytics import YOLO\n",
"from pathlib import Path\n",
"\n",
"YAML = Path(\"data/yolo_visdrone/dataset.yaml\")\n",
"assert YAML.exists(), f\"Nejprve spusť 01_dataset_prep.ipynb! Chybí: {YAML}\"\n",
"\n",
"# Zvolíme YOLOv8s — dobrý poměr přesnosti a rychlosti\n",
"# Alternativy: yolov8n (nejrychlejší), yolov8m (přesnější)\n",
"model = YOLO(\"models/yolov8s.pt\") # stáhne předtrénované ImageNet váhy\n",
"print(\"Model načten:\", model.info())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pathlib import Path\n",
"\n",
"# Detekce dostupného zařízení\n",
"if torch.cuda.is_available():\n",
" DEVICE = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
" DEVICE = \"mps\"\n",
"else:\n",
" DEVICE = \"cpu\"\n",
"\n",
"print(f\"Trénink na: {DEVICE}\")\n",
"\n",
"results = model.train(\n",
" data=str(YAML),\n",
" epochs=50,\n",
" imgsz=640,\n",
" batch=16,\n",
" device=DEVICE,\n",
" name=\"visdrone_vehicles\",\n",
" project=str(Path(\"runs/train\").resolve()),\n",
" patience=10, # early stopping\n",
" save=True,\n",
" save_period=10,\n",
" val=True,\n",
" plots=True,\n",
" # Augmentace pro letecké snímky\n",
" degrees=15.0, # rotace\n",
" flipud=0.5, # vertikální flip\n",
" fliplr=0.5,\n",
" mosaic=1.0,\n",
" scale=0.5,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Zobrazení tréninkových grafů\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"from pathlib import Path\n",
"import glob\n",
"\n",
"run_dir = Path(\"runs/train/visdrone_vehicles\")\n",
"\n",
"for plot_name in [\"results.png\", \"confusion_matrix.png\", \"PR_curve.png\", \"val_batch0_pred.jpg\"]:\n",
" p = run_dir / plot_name\n",
" if p.exists():\n",
" fig, ax = plt.subplots(figsize=(12, 6))\n",
" ax.imshow(mpimg.imread(p))\n",
" ax.axis(\"off\")\n",
" ax.set_title(plot_name)\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine-tuning na Vrchlabí (volitelné)\n",
"\n",
"Pokud existuje `dataset_vrchlabi.zip` nebo rozbalená složka `vrchlabi/`, provede se fine-tuning předchozího modelu na vlastních datech z Vrchlabí."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Validace nejlepšího modelu\n",
"best_weights = Path(\"runs/train/visdrone_vehicles/weights/best.pt\")\n",
"assert best_weights.exists(), \"Trénink ještě neproběhl nebo selhal\"\n",
"\n",
"model_best = YOLO(str(best_weights))\n",
"val_results = model_best.val(data=str(YAML), imgsz=640, split=\"val\")\n",
"\n",
"print(f\"\\nmAP50: {val_results.box.map50:.4f}\")\n",
"print(f\"mAP50-95: {val_results.box.map:.4f}\")\n",
"for i, cls in enumerate([\"car\", \"van\", \"truck\", \"bus\"]):\n",
" ap = val_results.box.ap50[i] if i < len(val_results.box.ap50) else float('nan')\n",
" print(f\" AP50[{cls}]: {ap:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import zipfile\n",
"import shutil\n",
"from pathlib import Path\n",
"\n",
"VRCHLABI_ZIP = Path('dataset_vrchlabi.zip')\n",
"VRCHLABI_YAML = Path('vrchlabi/dataset.yaml')\n",
"\n",
"# Rozbal zip pokud yaml ještě není\n",
"if VRCHLABI_ZIP.exists() and not VRCHLABI_YAML.exists():\n",
" print('Rozbaluji dataset_vrchlabi.zip ...')\n",
" with zipfile.ZipFile(VRCHLABI_ZIP) as zf:\n",
" zf.extractall('.')\n",
" print('Hotovo.')\n",
"\n",
"if not VRCHLABI_YAML.exists():\n",
" print('Vrchlabí dataset nenalezen — fine-tuning přeskočen.')\n",
" print('Spusť scripts/prepare_dataset.py nebo nahraj dataset_vrchlabi.zip.')\n",
"else:\n",
" print(f'Fine-tuning na: {VRCHLABI_YAML}')\n",
"\n",
" # Záloha modelu před fine-tuningem\n",
" backup = best_weights.parent / 'best_before_finetune.pt'\n",
" shutil.copy(best_weights, backup)\n",
" print(f'Záloha uložena: {backup}')\n",
"\n",
" vrchlabi_model = YOLO(str(best_weights)) # navazuje na VisDrone trénink\n",
"\n",
" vrchlabi_results = vrchlabi_model.train(\n",
" data=str(VRCHLABI_YAML),\n",
" epochs=30,\n",
" imgsz=256,\n",
" batch=32,\n",
" device=DEVICE,\n",
" name='vrchlabi_finetune',\n",
" project=str(Path(\"runs/train\").resolve()),\n",
" patience=10,\n",
" save=True,\n",
" val=True,\n",
" plots=True,\n",
" degrees=10.0,\n",
" flipud=0.5,\n",
" fliplr=0.5,\n",
" mosaic=0.5,\n",
" lr0=0.001, # nižší LR pro fine-tuning\n",
" )\n",
"\n",
" best_weights = Path(vrchlabi_results.save_dir) / 'weights' / 'best.pt'\n",
" print(f'Vrchlabí fine-tuning hotov. Nejlepší váhy: {best_weights}')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nejlepší model uložen v: /workspace/runs/detect/runs/detect/runs/train/vrchlabi_finetune-2/weights/best.pt\n",
"Konfigurace zapsána do: model_config.json\n"
]
}
],
"source": [
"# Uložení cesty k nejlepšímu modelu pro další notebooky\n",
"import json\n",
"\n",
"config = {\"best_model\": str(best_weights.resolve())}\n",
"with open(\"models/model_config.json\", \"w\") as f:\n",
" json.dump(config, f)\n",
"\n",
"print(f\"Nejlepší model uložen v: {best_weights}\")\n",
"print(f\"Konfigurace zapsána do: models/model_config.json\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv (3.14.4)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.14.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}