Files
su2-img/scripts/prepare_dataset.py
Lukáš Trkan e8c1c3f2ee update
2026-05-01 22:06:01 +02:00

120 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Připraví dataset pro trénink na serveru.
Vždy zahrne: data/vrchlabi_custom/
Podmíněně: data/hk_custom/ (pokud existuje)
data/yolo_visdrone/ (pokud existuje, jen třídy 03)
Výstup: dataset_vrchlabi.zip
vrchlabi/train/images+labels, val/images+labels, dataset.yaml
best.pt
"""
import random, shutil, zipfile
from pathlib import Path
VAL_RATIO = 0.2
BG_PER_SPLIT = 50
SEED = 42
CLASS_NAMES = ["car", "van", "truck", "bus"]
OUT = Path("dataset_vrchlabi")
ZIP_OUT = Path("dataset_vrchlabi.zip")
random.seed(SEED)
# ---------------------------------------------------------------------------
# Zdroje dat (img_dir, lbl_dir, prefix pro unikátní názvy souborů)
# ---------------------------------------------------------------------------
SOURCES = [
("data/vrchlabi_custom/train/images", "data/vrchlabi_custom/train/labels", "vrchlabi"),
]
# Podmíněné zdroje — stejný YOLO formát, stejné třídy
for name, path in [
("visdrone", "data/yolo_visdrone/train"),
]:
img_dir = Path(path) / "images"
lbl_dir = Path(path) / "labels"
if img_dir.exists() and lbl_dir.exists():
n = sum(1 for l in lbl_dir.glob("*.txt") if l.stat().st_size > 0)
print(f"+ {name}: {n} anotovaných dlaždic")
SOURCES.append((str(img_dir), str(lbl_dir), name))
else:
print(f"- {name}: nenalezeno, přeskočeno")
# ---------------------------------------------------------------------------
# Načtení všech labelů ze všech zdrojů
# ---------------------------------------------------------------------------
all_annotated = [] # (lbl_path, img_path, prefix)
all_background = []
for img_dir, lbl_dir, prefix in SOURCES:
img_dir = Path(img_dir)
lbl_dir = Path(lbl_dir)
for lbl in sorted(lbl_dir.glob("*.txt")):
img = img_dir / (lbl.stem + ".jpg")
if not img.exists():
img = img_dir / (lbl.stem + ".png")
entry = (lbl, img if img.exists() else None, prefix)
if lbl.stat().st_size > 0:
all_annotated.append(entry)
else:
all_background.append(entry)
random.shuffle(all_annotated)
random.shuffle(all_background)
n_val = max(1, int(len(all_annotated) * VAL_RATIO))
val_ann = all_annotated[:n_val]
train_ann = all_annotated[n_val:]
train_bg = all_background[:BG_PER_SPLIT]
val_bg = all_background[BG_PER_SPLIT : BG_PER_SPLIT * 2]
print(f"\nAnotované: {len(all_annotated)} (train {len(train_ann)}, val {len(val_ann)})")
print(f"Background: {len(all_background)} (train {len(train_bg)}, val {len(val_bg)})")
# ---------------------------------------------------------------------------
# Kopírování do OUT/
# ---------------------------------------------------------------------------
if OUT.exists():
shutil.rmtree(OUT)
for split_name, split in [("train", train_ann + train_bg), ("val", val_ann + val_bg)]:
img_out = OUT / split_name / "images"
lbl_out = OUT / split_name / "labels"
img_out.mkdir(parents=True, exist_ok=True)
lbl_out.mkdir(parents=True, exist_ok=True)
for lbl_path, img_path, prefix in split:
stem = f"{prefix}_{lbl_path.stem}"
shutil.copy2(lbl_path, lbl_out / f"{stem}.txt")
if img_path:
shutil.copy2(img_path, img_out / f"{stem}{img_path.suffix}")
# ---------------------------------------------------------------------------
# dataset.yaml
# ---------------------------------------------------------------------------
(OUT / "dataset.yaml").write_text(
f"path: .\ntrain: train/images\nval: val/images\nnc: {len(CLASS_NAMES)}\nnames: {CLASS_NAMES}\n"
)
# ---------------------------------------------------------------------------
# ZIP
# ---------------------------------------------------------------------------
if ZIP_OUT.exists():
ZIP_OUT.unlink()
weights = Path("best.pt")
with zipfile.ZipFile(ZIP_OUT, "w", zipfile.ZIP_DEFLATED) as zf:
for f in sorted(OUT.rglob("*")):
if f.is_file():
zf.write(f, "vrchlabi/" + str(f.relative_to(OUT)))
if weights.exists():
zf.write(weights, "best.pt")
print("Přidáno: best.pt")
size_mb = ZIP_OUT.stat().st_size / 1e6
print(f"\nHotovo: {ZIP_OUT} ({size_mb:.1f} MB)")
print(f"\nNa serveru spusť:")
print(f" unzip dataset_vrchlabi.zip")
print(f" pip install ultralytics")
print(f" yolo train model=best.pt data=vrchlabi/dataset.yaml epochs=50 imgsz=256 batch=32")