Files
su2-img/03_inference_tiles.ipynb
Lukáš Trkan 3f072b862e first commit
2026-04-20 23:46:50 +02:00

258 lines
8.6 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 03 — Inference na dlaždicích Hradce Králové\n",
"\n",
"Spustíme natrénovaný YOLOv8 model na všech dlaždicích leteckého snímku HK.\n",
"Každý typ vozidla dostane jinak barevný rámeček:\n",
"\n",
"| Třída | Barva |\n",
"|-------|-------|\n",
"| car | zelená |\n",
"| van | žlutá |\n",
"| truck | červená |\n",
"| bus | modrá |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from ultralytics import YOLO\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"import json\n",
"import numpy as np\n",
"from tqdm.notebook import tqdm\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Načtení konfigurace modelu\n",
"CONFIG_PATH = Path(\"model_config.json\")\n",
"assert CONFIG_PATH.exists(), \"Nejprve spusť 02_training.ipynb!\"\n",
"\n",
"with open(CONFIG_PATH) as f:\n",
" config = json.load(f)\n",
"\n",
"MODEL_PATH = Path(config[\"best_model\"])\n",
"assert MODEL_PATH.exists(), f\"Model nenalezen: {MODEL_PATH}\"\n",
"\n",
"model = YOLO(str(MODEL_PATH))\n",
"print(f\"Model: {MODEL_PATH}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"TILES_DIR = Path(\"tiles\")\n",
"ANNOTATED_DIR = Path(\"tiles_annotated\")\n",
"ANNOTATED_DIR.mkdir(exist_ok=True)\n",
"\n",
"# Barvy pro každou třídu (RGB)\n",
"CLASS_COLORS = {\n",
" 0: (0, 220, 0), # car — zelená\n",
" 1: (255, 220, 0), # van — žlutá\n",
" 2: (220, 0, 0), # truck — červená\n",
" 3: (0, 120, 255), # bus — modrá\n",
"}\n",
"CLASS_NAMES = [\"car\", \"van\", \"truck\", \"bus\"]\n",
"\n",
"# Detekce zařízení\n",
"if torch.cuda.is_available():\n",
" DEVICE = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
" DEVICE = \"mps\"\n",
"else:\n",
" DEVICE = \"cpu\"\n",
"\n",
"CONF = 0.25 # minimální confidence\n",
"IOU = 0.45 # NMS IoU práh\n",
"\n",
"tile_files = sorted(TILES_DIR.glob(\"18_*.jpg\"))\n",
"print(f\"Dlaždice celkem: {len(tile_files)}\")\n",
"print(f\"Zařízení: {DEVICE}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"def draw_detections(img: Image.Image, boxes, labels, scores) -> Image.Image:\n",
" \"\"\"Nakreslí boxy s popisky na obrázek.\"\"\"\n",
" draw = ImageDraw.Draw(img)\n",
" for box, label, score in zip(boxes, labels, scores):\n",
" x1, y1, x2, y2 = [int(v) for v in box]\n",
" color = CLASS_COLORS.get(label, (255, 255, 255))\n",
" draw.rectangle([x1, y1, x2, y2], outline=color, width=2)\n",
" text = f\"{CLASS_NAMES[label]} {score:.2f}\"\n",
" # Bílé pozadí textu pro čitelnost\n",
" tw, th = draw.textlength(text), 9\n",
" draw.rectangle([x1, y1-10, x1+tw+2, y1], fill=color)\n",
" draw.text((x1+1, y1-10), text, fill=(0, 0, 0))\n",
" return img\n",
"\n",
"\n",
"# Inference s dávkováním pro rychlost\n",
"BATCH_SIZE = 32\n",
"all_detections = {} # {tile_name: [{cls, x1,y1,x2,y2,score}]}\n",
"vehicle_counts = defaultdict(int)\n",
"\n",
"# Zpracuj dlaždice v dávkách\n",
"for i in tqdm(range(0, len(tile_files), BATCH_SIZE), desc=\"Inference\"):\n",
" batch_paths = tile_files[i:i+BATCH_SIZE]\n",
" batch_imgs = [str(p) for p in batch_paths]\n",
"\n",
" results = model.predict(\n",
" batch_imgs,\n",
" conf=CONF,\n",
" iou=IOU,\n",
" device=DEVICE,\n",
" verbose=False,\n",
" imgsz=256,\n",
" )\n",
"\n",
" for path, result in zip(batch_paths, results):\n",
" dets = []\n",
" if result.boxes is not None and len(result.boxes) > 0:\n",
" boxes = result.boxes.xyxy.cpu().numpy()\n",
" clss = result.boxes.cls.cpu().numpy().astype(int)\n",
" confs = result.boxes.conf.cpu().numpy()\n",
"\n",
" for box, cls, conf in zip(boxes, clss, confs):\n",
" dets.append({\"cls\": cls, \"box\": box.tolist(), \"score\": float(conf)})\n",
" vehicle_counts[CLASS_NAMES[cls]] += 1\n",
"\n",
" # Uložení anotované dlaždice\n",
" img = Image.open(path).convert(\"RGB\")\n",
" img = draw_detections(img, boxes, clss, confs)\n",
" img.save(ANNOTATED_DIR / path.name)\n",
" else:\n",
" # Dlaždice bez detekcí jen zkopíruj\n",
" import shutil\n",
" shutil.copy(path, ANNOTATED_DIR / path.name)\n",
"\n",
" all_detections[path.name] = dets\n",
"\n",
"print(\"\\nCelkový počet detekovaných vozidel:\")\n",
"for cls, cnt in sorted(vehicle_counts.items()):\n",
" print(f\" {cls:10s}: {cnt:5d}\")\n",
"print(f\" {'CELKEM':10s}: {sum(vehicle_counts.values()):5d}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Uložení detekcí do JSON pro další notebook\n",
"with open(\"detections.json\", \"w\") as f:\n",
" json.dump(all_detections, f)\n",
"\n",
"print(f\"Detekce uloženy: detections.json ({len(all_detections)} dlaždic)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Vizualizace: 6 náhodných anotovaných dlaždic s detekcemi\n",
"import random\n",
"\n",
"tiles_with_detections = [\n",
" name for name, dets in all_detections.items() if len(dets) > 0\n",
"]\n",
"print(f\"Dlaždice s detekcemi: {len(tiles_with_detections)} / {len(all_detections)}\")\n",
"\n",
"samples = random.sample(tiles_with_detections, min(6, len(tiles_with_detections)))\n",
"fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
"\n",
"for ax, name in zip(axes.flatten(), samples):\n",
" img = Image.open(ANNOTATED_DIR / name)\n",
" dets = all_detections[name]\n",
" ax.imshow(img)\n",
" ax.set_title(f\"{name}\\n{len(dets)} vozidel\", fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"plt.suptitle(\"Ukázkové detekce na dlaždicích HK\", fontsize=14)\n",
"plt.tight_layout()\n",
"plt.savefig(\"sample_detections.png\", dpi=150, bbox_inches=\"tight\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Tepelná mapa hustoty vozidel\n",
"import re\n",
"\n",
"# Přečtení rozsahu mřížky\n",
"xs = []; ys = []\n",
"for name in all_detections:\n",
" m = re.match(r\"18_(\\d+)_(\\d+)\\.jpg\", name)\n",
" if m:\n",
" xs.append(int(m.group(1)))\n",
" ys.append(int(m.group(2)))\n",
"\n",
"x_min, x_max = min(xs), max(xs)\n",
"y_min, y_max = min(ys), max(ys)\n",
"cols = x_max - x_min + 1\n",
"rows = y_max - y_min + 1\n",
"\n",
"heatmap = np.zeros((rows, cols), dtype=np.int32)\n",
"\n",
"for name, dets in all_detections.items():\n",
" m = re.match(r\"18_(\\d+)_(\\d+)\\.jpg\", name)\n",
" if m:\n",
" xi = int(m.group(1)) - x_min\n",
" yi = int(m.group(2)) - y_min\n",
" heatmap[yi, xi] = len(dets)\n",
"\n",
"fig, ax = plt.subplots(figsize=(14, 12))\n",
"im = ax.imshow(heatmap, cmap=\"hot\", interpolation=\"nearest\")\n",
"plt.colorbar(im, ax=ax, label=\"Počet vozidel\")\n",
"ax.set_title(\"Hustota vozidel — Hradec Králové (každý pixel = 1 dlaždice)\", fontsize=14)\n",
"ax.set_xlabel(\"X (západ → východ)\")\n",
"ax.set_ylabel(\"Y (sever → jih)\")\n",
"plt.tight_layout()\n",
"plt.savefig(\"heatmap_vehicles.png\", dpi=150, bbox_inches=\"tight\")\n",
"plt.show()\n",
"\n",
"print(f\"\\nNejhustší dlaždice: {heatmap.max()} vozidel\")\n",
"print(f\"Průměr na dlaždici: {heatmap[heatmap>0].mean():.1f} vozidel\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}