RepoPilotOpen in app →

milesial/Pytorch-UNet

PyTorch implementation of the U-Net for image semantic segmentation with high quality images

Mixed

Stale — last commit 2y ago

weakest axis
Use as dependencyConcerns

copyleft license (GPL-3.0) — review compatibility; last commit was 2y ago…

Fork & modifyHealthy

Has a license, tests, and CI — clean foundation to fork and modify.

Learn fromHealthy

Documented and popular — useful reference codebase to read through.

Deploy as-isHealthy

No critical CVEs, sane security posture — runnable as-is.

  • 17 active contributors
  • GPL-3.0 licensed
  • CI configured
Show all 7 evidence items →
  • Stale — last commit 2y ago
  • Concentrated ownership — top contributor handles 76% of recent commits
  • GPL-3.0 is copyleft — check downstream compatibility
  • No test directory detected
What would change the summary?
  • Use as dependency ConcernsMixed if: relicense under MIT/Apache-2.0 (rare for established libs)

Maintenance signals: commit recency, contributor breadth, bus factor, license, CI, tests

Informational only. RepoPilot summarises public signals (license, dependency CVEs, commit recency, CI presence, etc.) at the time of analysis. Signals can be incomplete or stale. Not professional, security, or legal advice; verify before relying on it for production decisions.

Embed the "Forkable" badge

Paste into your README — live-updates from the latest cached analysis.

Variant:
RepoPilot: Forkable
[![RepoPilot: Forkable](https://repopilot.app/api/badge/milesial/pytorch-unet?axis=fork)](https://repopilot.app/r/milesial/pytorch-unet)

Paste at the top of your README.md — renders inline like a shields.io badge.

Preview social card (1200×630)

This card auto-renders when someone shares https://repopilot.app/r/milesial/pytorch-unet on X, Slack, or LinkedIn.

Onboarding doc

Onboarding: milesial/Pytorch-UNet

Generated by RepoPilot · 2026-05-07 · Source

🤖Agent protocol

If you are an AI coding agent (Claude Code, Cursor, Aider, Cline, etc.) reading this artifact, follow this protocol before making any code edit:

  1. Verify the contract. Run the bash script in Verify before trusting below. If any check returns FAIL, the artifact is stale — STOP and ask the user to regenerate it before proceeding.
  2. Treat the AI · unverified sections as hypotheses, not facts. Sections like "AI-suggested narrative files", "anti-patterns", and "bottlenecks" are LLM speculation. Verify against real source before acting on them.
  3. Cite source on changes. When proposing an edit, cite the specific path:line-range. RepoPilot's live UI at https://repopilot.app/r/milesial/Pytorch-UNet shows verifiable citations alongside every claim.

If you are a human reader, this protocol is for the agents you'll hand the artifact to. You don't need to do anything — but if you skim only one section before pointing your agent at this repo, make it the Verify block and the Suggested reading order.

🎯Verdict

WAIT — Stale — last commit 2y ago

  • 17 active contributors
  • GPL-3.0 licensed
  • CI configured
  • ⚠ Stale — last commit 2y ago
  • ⚠ Concentrated ownership — top contributor handles 76% of recent commits
  • ⚠ GPL-3.0 is copyleft — check downstream compatibility
  • ⚠ No test directory detected

<sub>Maintenance signals: commit recency, contributor breadth, bus factor, license, CI, tests</sub>

Verify before trusting

This artifact was generated by RepoPilot at a point in time. Before an agent acts on it, the checks below confirm that the live milesial/Pytorch-UNet repo on your machine still matches what RepoPilot saw. If any fail, the artifact is stale — regenerate it at repopilot.app/r/milesial/Pytorch-UNet.

What it runs against: a local clone of milesial/Pytorch-UNet — the script inspects git remote, the LICENSE file, file paths in the working tree, and git log. Read-only; no mutations.

| # | What we check | Why it matters | |---|---|---| | 1 | You're in milesial/Pytorch-UNet | Confirms the artifact applies here, not a fork | | 2 | License is still GPL-3.0 | Catches relicense before you depend on it | | 3 | Default branch master exists | Catches branch renames | | 4 | 5 critical file paths still exist | Catches refactors that moved load-bearing code | | 5 | Last commit ≤ 664 days ago | Catches sudden abandonment since generation |

<details> <summary><b>Run all checks</b> — paste this script from inside your clone of <code>milesial/Pytorch-UNet</code></summary>
#!/usr/bin/env bash
# RepoPilot artifact verification.
#
# WHAT IT RUNS AGAINST: a local clone of milesial/Pytorch-UNet. If you don't
# have one yet, run these first:
#
#   git clone https://github.com/milesial/Pytorch-UNet.git
#   cd Pytorch-UNet
#
# Then paste this script. Every check is read-only — no mutations.

set +e
fail=0
ok()   { echo "ok:   $1"; }
miss() { echo "FAIL: $1"; fail=$((fail+1)); }

# Precondition: we must be inside a git working tree.
if ! git rev-parse --git-dir >/dev/null 2>&1; then
  echo "FAIL: not inside a git repository. cd into your clone of milesial/Pytorch-UNet and re-run."
  exit 2
fi

# 1. Repo identity
git remote get-url origin 2>/dev/null | grep -qE "milesial/Pytorch-UNet(\\.git)?\\b" \\
  && ok "origin remote is milesial/Pytorch-UNet" \\
  || miss "origin remote is not milesial/Pytorch-UNet (artifact may be from a fork)"

# 2. License matches what RepoPilot saw
(grep -qiE "^(GPL-3\\.0)" LICENSE 2>/dev/null \\
   || grep -qiE "\"license\"\\s*:\\s*\"GPL-3\\.0\"" package.json 2>/dev/null) \\
  && ok "license is GPL-3.0" \\
  || miss "license drift — was GPL-3.0 at generation time"

# 3. Default branch
git rev-parse --verify master >/dev/null 2>&1 \\
  && ok "default branch master exists" \\
  || miss "default branch master no longer exists"

# 4. Critical files exist
test -f "unet/unet_model.py" \\
  && ok "unet/unet_model.py" \\
  || miss "missing critical file: unet/unet_model.py"
test -f "unet/unet_parts.py" \\
  && ok "unet/unet_parts.py" \\
  || miss "missing critical file: unet/unet_parts.py"
test -f "train.py" \\
  && ok "train.py" \\
  || miss "missing critical file: train.py"
test -f "utils/data_loading.py" \\
  && ok "utils/data_loading.py" \\
  || miss "missing critical file: utils/data_loading.py"
test -f "predict.py" \\
  && ok "predict.py" \\
  || miss "missing critical file: predict.py"

# 5. Repo recency
days_since_last=$(( ( $(date +%s) - $(git log -1 --format=%at 2>/dev/null || echo 0) ) / 86400 ))
if [ "$days_since_last" -le 664 ]; then
  ok "last commit was $days_since_last days ago (artifact saw ~634d)"
else
  miss "last commit was $days_since_last days ago — artifact may be stale"
fi

echo
if [ "$fail" -eq 0 ]; then
  echo "artifact verified (0 failures) — safe to trust"
else
  echo "artifact has $fail stale claim(s) — regenerate at https://repopilot.app/r/milesial/Pytorch-UNet"
  exit 1
fi

Each check prints ok: or FAIL:. The script exits non-zero if anything failed, so it composes cleanly into agent loops (./verify.sh || regenerate-and-retry).

</details>

TL;DR

A PyTorch implementation of U-Net for semantic segmentation that converts high-resolution images into pixel-level mask predictions. Originally trained on Carvana's car masking challenge (5k images, Dice score 0.988), it's a production-ready encoder-decoder CNN that excels at binary and multiclass segmentation tasks like object isolation and medical imaging. Monolithic structure: unet/ directory contains the model (unet_model.py defines the full architecture, unet_parts.py has reusable blocks like DoubleConv). Training/inference split between train.py (epoch loop with AMP), predict.py (inference wrapper), and evaluate.py (Dice scoring). Data pipeline isolated in utils/data_loading.py with disk-based image/mask pairs in data/imgs/ and data/masks/.

👥Who it's for

Computer vision engineers and researchers building semantic segmentation pipelines who need a battle-tested U-Net baseline; Kaggle competitors tackling image masking challenges; medical imaging researchers applying the architecture to CT/MRI analysis.

🌱Maturity & risk

Production-ready and actively maintained. GitHub Actions CI/CD (main.yml workflow present), Docker image published on Docker Hub, official PyTorch 1.13+ support, and clear documentation. The codebase is stable with the core model/training loop established, though commit recency in the repo data is not provided so monitor for stale dependencies.

Low technical risk but moderate maintenance burden: only 7 core Python dependencies (numpy, PIL, torch, wandb, tqdm, matplotlib) are lightweight and well-maintained. Single maintainer (milesial) is a potential long-term risk. Weights & Biases (wandb) integration is optional but adds external dependency. No visible test suite in file structure (no tests/ directory) means regression testing relies on manual validation.

Active areas of work

Repo appears to be in maintenance mode. No specific PR/issue data provided, but README references active CI workflow (green badge in main.yml) and Docker Hub image sync suggests periodic updates. Wandb integration and AMP (Automatic Mixed Precision) support indicate recent PyTorch API alignment.

🚀Get running

Clone and install:

git clone https://github.com/milesial/Pytorch-UNet.git
cd Pytorch-UNet
pip install -r requirements.txt
bash scripts/download_data.sh
python train.py --amp

Or with Docker:

sudo docker run --rm --gpus all -it milesial/unet

Daily commands: Training:

python train.py --amp --epochs 5 --batch-size 1 --learning-rate 0.0001 --scale 0.5

Prediction on new image:

python predict.py -i test_image.jpg -o output_mask.png -m checkpoints/checkpoint_epoch5.pth

Evaluation:

python evaluate.py --model checkpoints/model.pth --input data/imgs --output masks_predicted

🗺️Map of the codebase

  • unet/unet_model.py — Core U-Net architecture definition; all semantic segmentation models instantiate from this file, making it essential for understanding the neural network structure.
  • unet/unet_parts.py — Defines reusable building blocks (DoubleConv, Down, Up) for U-Net; modifications here cascade across all model variants.
  • train.py — Main training loop entry point; orchestrates data loading, model training, validation, and checkpoint management—required reading for training workflows.
  • utils/data_loading.py — Custom dataset and dataloader implementation for image-mask pairs; critical for understanding data pipeline and augmentation strategy.
  • predict.py — Inference entry point that loads trained models and generates segmentation predictions; key interface for end-to-end workflows.
  • requirements.txt — Pins exact dependency versions including PyTorch and data handling libraries; mismatches here cause silent failures in image loading and tensor operations.
  • evaluate.py — Model evaluation metrics computation; essential for understanding how Dice score and other performance metrics are calculated.

🧩Components & responsibilities

  • UNet (unet/unet_model.py) (PyTorch nn.Module, Conv2d, MaxPool2d, ReLU) — Encodes input image to low-res feature map, then decodes with skip connections to generate per-pixel segmentation logits

🛠️How to make changes

Add a Custom U-Net Variant (Different Depth/Channels)

  1. Create a new class inheriting from UNet in unet/unet_model.py, override init with custom n_channels, n_classes, and bilinear interpolation settings (unet/unet_model.py)
  2. Instantiate the custom model in train.py (line where model is created) or add a --model_variant argument to argparse (train.py)
  3. Test with a forward pass and verify output shape matches (batch_size, n_classes, H, W) (train.py)

Add Image Preprocessing or Augmentation

  1. Open utils/data_loading.py and extend the BasicDataset class's getitem method to apply additional transforms (rotation, flip, brightness) (utils/data_loading.py)
  2. Import torch.transforms or albumentations at the top and compose your augmentation pipeline before returning img and mask tensors (utils/data_loading.py)
  3. Run train.py with a small batch to verify augmented images load without errors and shapes remain consistent (train.py)

Add a New Evaluation Metric

  1. Create a new metric function in utils/dice_score.py (e.g., iou_score, precision_recall) following the same signature as multiclass_dice_coeff (utils/dice_score.py)
  2. Import and call the new metric in evaluate.py alongside the existing Dice evaluation in the validation loop (evaluate.py)
  3. Log the metric to wandb by adding it to the wandb.log() call in train.py if using Weights & Biases integration (train.py)

Deploy Model to Docker for Production Inference

  1. Update Dockerfile COPY and CMD directives to include your trained checkpoint and modify predict.py to load the specific model path (Dockerfile)
  2. Build the Docker image: docker build -t unet:latest . and test locally with docker run -v /path/to/images:/data unet:latest (Dockerfile)
  3. Verify predict.py outputs segmentation masks correctly when invoked inside the container (predict.py)

🔧Why these technologies

  • PyTorch 1.13+ — Modern deep learning framework with autograd, GPU acceleration, and strong community support for semantic segmentation tasks
  • Pillow for image I/O — Efficient, non-blocking image loading and resizing for high-resolution image datasets typical of Carvana challenge
  • NumPy for tensor operations — Vectorized operations for mask preprocessing and metric calculation before PyTorch conversion
  • Weights & Biases (wandb) for experiment tracking — Optional integration for logging hyperparameters, metrics, and model checkpoints across training runs for reproducibility
  • Docker for containerization — Ensures reproducible inference environment across different hardware/OS platforms for production deployment

⚖️Trade-offs already made

  • Bilinear upsampling vs. transposed convolution in decoder

    • Why: Bilinear is configurable (set via bilinear=True/False in UNet) to avoid checkerboard artifacts in some use cases
    • Consequence: Bilinear adds slight interpolation error but provides training stability; transposed conv is more learnable but can introduce artifacts
  • Single-file data loading (BasicDataset) without prefetching/caching

    • Why: Simplicity and memory efficiency for datasets that fit on disk; no overhead of maintaining in-memory cache
    • Consequence: I/O can become bottleneck for very large datasets or slow storage; users must implement custom caching if needed
  • BCEWithLogitsLoss for binary/multiclass segmentation

    • Why: Numerically stable and handles class imbalance well for real-world masking tasks
    • Consequence: Assumes binary or independent class predictions; not suitable for multi-class exclusive segmentation without modification
  • Checkpoint saving only on validation improvement (if Dice increases)

    • Why: Prevents storage bloat and automatic recovery of best model
    • Consequence: Early stopping requires external logic; only best checkpoint persists, losing training history

🚫Non-goals (don't propose these)

  • Real-time inference (no optimization for latency)
  • Multi-GPU distributed training (single-GPU only)
  • Data augmentation library (relies on external transforms; minimal built-in augmentation)
  • Handles 3D volumetric data (2D image segmentation only)
  • Instance segmentation (semantic segmentation only, no per-object ID tracking)
  • Post-processing or CRF refinement (raw model predictions only)

🪤Traps & gotchas

AMP stability: train.py uses torch.cuda.amp which requires careful loss scaling; runs on CUDA devices only (CPU fallback not tested). Data paths: scripts/download_data.sh populates data/imgs/ and data/masks/ with exact directory structure assumed by data_loading.py; manual data placement must match this layout. Checkpoint format: train.py saves state_dict() but predict.py expects specific checkpoint keys; model versioning across commits can break inference. wandb integration: Optional but enabled by default in train.py; requires API key or --no-log flag to disable. Input normalization: data_loading.py applies fixed normalization (assumes ImageNet stats); custom datasets may need retraining or normalization adjustment.

🏗️Architecture

💡Concepts to learn

  • U-Net Architecture (Encoder-Decoder with Skip Connections) — The entire repo is built around this design; understanding how DoubleConv blocks, MaxPool2d down-sampling, ConvTranspose2d up-sampling, and skip concatenation work is essential to modifying the model.
  • Dice Coefficient (Sørensen–Dice Index) — The primary evaluation metric in utils/dice_score.py; the model was tuned to maximize this (0.988 on test set), so understanding how it penalizes false positives/negatives is critical for interpreting results.
  • Automatic Mixed Precision (AMP) — train.py uses torch.cuda.amp context managers and GradScaler for 3x training speedup; misunderstanding loss scaling can cause NaN divergence.
  • Skip Connections / Residual Learning — unet_model.py concatenates encoder features to decoder layers to preserve spatial detail; this is why U-Net excels at dense predictions vs. a naive encoder-decoder.
  • Transpose Convolution (Deconvolution) — unet_parts.py uses ConvTranspose2d for up-sampling; understanding stride, padding, and output_padding is essential for controlling spatial resolution in the decoder.
  • Gradient Accumulation — train.py supports --batch-size 1 on memory-constrained hardware; batch normalization and loss computation depend on effective batch size, which can be tuned via accumulation steps.
  • Cross-Entropy Loss vs. Dice Loss — train.py combines both losses for balanced optimization; understanding their different penalties (pixel-wise class imbalance vs. region overlap) helps debug convergence issues.
  • facebookresearch/detectron2 — Production-grade PyTorch vision library with U-Net and other segmentation backbones; if you need multi-task detection + segmentation, use this.
  • qubvel/segmentation_models.pytorch — Comprehensive PyTorch segmentation model zoo (U-Net, DeepLab, etc.) with pretrained encoders; this repo is lighter-weight for custom training on Carvana-style tasks.
  • miccaia/Pytorch-UNet — Alternative U-Net fork with 3D support and multiclass extensions; if you need volumetric medical segmentation, compare this one.
  • pytorch/pytorch — Core PyTorch framework; required dependency and source for torch.cuda.amp, F.interpolate, and nn modules used throughout.
  • wandb/wandb — Experiment tracking platform integrated into train.py for logging metrics and model checkpoints; essential for reproducing reported Dice 0.988 score.

🪄PR ideas

To work on one of these in Claude Code or Cursor, paste: Implement the "<title>" PR idea from CLAUDE.md, working through the checklist as the task list.

Add unit tests for utils/ modules with pytest

The repo has no test directory despite containing critical utility functions (data_loading.py, dice_score.py, utils.py). These functions handle data loading, metric computation, and preprocessing - core functionality that should be tested. Adding a tests/ directory with pytest would catch regressions and help new contributors understand expected behavior.

  • [ ] Create tests/ directory with init.py
  • [ ] Add tests/test_dice_score.py with tests for dice computation edge cases (empty masks, perfect overlap, no overlap)
  • [ ] Add tests/test_data_loading.py with tests for BasicDataset class initialization and augmentation logic in utils/data_loading.py
  • [ ] Add tests/test_utils.py for any utility functions in utils/utils.py
  • [ ] Update requirements.txt to include pytest and pytest-cov
  • [ ] Add pytest configuration to .github/workflows/main.yml if not already present

Add type hints to unet/ and utils/ modules

The codebase has no type hints in core modules (unet_model.py, unet_parts.py, data_loading.py, utils.py). This makes it harder for contributors to understand function signatures and enables better IDE support. Adding comprehensive type hints would improve code quality and maintainability with minimal refactoring.

  • [ ] Add type hints to unet/unet_parts.py for all Conv classes and DoubleConv forward() methods
  • [ ] Add type hints to unet/unet_model.py for UNet class init() and forward() methods
  • [ ] Add type hints to utils/data_loading.py for BasicDataset.init(), getitem(), and len()
  • [ ] Add type hints to utils/dice_score.py for dice_coeff() and multiclass_dice_coeff() functions
  • [ ] Add typing imports and ensure Python 3.6+ compatibility (Union, List, Tuple, Optional)
  • [ ] Update README to mention type hint coverage

Add integration tests for train.py and predict.py workflows

The repo has training (train.py) and prediction (predict.py) scripts but no integration tests to verify end-to-end workflows work correctly. This is critical since these are the user-facing entry points. Adding minimal integration tests would catch breaking changes and help new contributors understand the expected data flow.

  • [ ] Create tests/integration/ directory
  • [ ] Add tests/integration/test_training_workflow.py that: trains a small model on dummy data (from data/imgs, data/masks), saves checkpoint, and verifies loss decreases
  • [ ] Add tests/integration/test_prediction_workflow.py that: loads a trained model from hubconf.py, runs prediction on a test image, and verifies output shape matches input
  • [ ] Create a small dummy dataset fixture in tests/conftest.py (4-5 small sample images/masks) for reproducible testing
  • [ ] Add integration test step to .github/workflows/main.yml
  • [ ] Document in README how to run integration tests for contributors

🌿Good first issues

  • Add unit tests for unet_parts.py (DoubleConv, Down, Up blocks) with mock tensors to verify forward pass shapes and parameter counts; currently no tests/ directory exists.: Ensures model components don't regress and helps new contributors understand tensor dimensions.
  • Document hyperparameter sensitivity in README with specific examples (e.g., 'learning_rate=0.0001 achieves 0.988 Dice, learning_rate=0.001 diverges'); train.py has many flags but no guidance.: Reduces onboarding friction and prevents common training mistakes.
  • Add CPU fallback in predict.py and train.py (e.g., device='cuda' if torch.cuda.is_available() else 'cpu') and test on CPU; currently assumes GPU-only.: Enables local testing and inference on machines without CUDA, improving usability.

Top contributors

Click to expand

📝Recent commits

Click to expand
  • 21d7850 — Merge pull request #475 from IshanG97/download-data-bat-script (milesial)
  • de41eaa — Merge pull request #470 from yelboudouri/master (milesial)
  • c478dc9 — Fix issue #474: Windows .bat setup script (IshanG97)
  • 52b4f14 — Unscale gradients before clipping (yelboudouri)
  • 2f62e6b — Update docs to PyT 1.13 (milesial)
  • c10f0d1 — Fix typo when checking nans (milesial)
  • c04a07c — Check nans before logging to wandb (milesial)
  • 40d5ba7 — Merge pull request #421 from vitorrussi/patch-1 (milesial)
  • d4c389f — Add n_classes argument for predict.py (vitorrussi)
  • 6b7f354 — Switch to base dataset on IndexError (milesial)

🔒Security observations

The PyTorch-UNet repository has a moderate security posture. The primary concerns are outdated dependencies with potential known vulnerabilities, lack of Docker security hardening (non-root user execution), and the use of pre-release pip flags. The project is a research/ML implementation without apparent injection risks or hardcoded credentials, but dependency management should be improved. No critical SQL injection, XSS, or authentication bypass issues were identified in the visible code structure. Immediate actions should focus on updating dependencies and applying Docker security best practices.

  • High · Outdated Dependencies with Known Vulnerabilities — requirements.txt. The project uses outdated package versions that may contain known security vulnerabilities. Specifically: Pillow 9.3.0 (current: 10.x+), numpy 1.23.5 (current: 1.26+), and matplotlib 3.6.2 (current: 3.8+). These older versions may have unpatched CVEs. Fix: Update all dependencies to their latest stable versions. Run 'pip install --upgrade' for each package and test compatibility. Consider using 'pip-audit' to identify known vulnerabilities in the dependency tree.
  • Medium · Missing Security Headers in Docker Configuration — Dockerfile. The Dockerfile uses a base image from nvcr.io but does not implement security best practices such as running as a non-root user, using multi-stage builds, or applying minimal layer principles. Fix: Add 'RUN useradd -m -u 1000 unet' and 'USER unet' to run the container as a non-root user. Minimize the number of layers and consider using a minimal base image variant.
  • Medium · Insecure Pip Installation Without Verification — Dockerfile. The Dockerfile installs packages with '--no-cache-dir --upgrade --pre' flags. The '--pre' flag installs pre-release versions which may be unstable and untested, increasing risk of bugs and security issues. Fix: Remove the '--pre' flag from pip install commands. Use stable release versions only. Pin specific versions in requirements.txt to ensure reproducible builds.
  • Low · Unrestricted Workspace Cleanup — Dockerfile. The Dockerfile executes 'rm -rf /workspace/*' without validation, which could be problematic in certain deployment scenarios or if the image is extended. Fix: Ensure this command only removes intended temporary files. Consider using a more explicit cleanup approach that targets only specific directories.
  • Low · Missing .gitignore Validation — .gitignore. While a .gitignore file exists, its content is not visible. It may not properly exclude sensitive files like configuration files, API keys, or model weights if they are checked in. Fix: Ensure .gitignore includes common sensitive patterns: *.env, *.key, *.pem, *.pkl, model_weights/, .aws/, etc. Audit git history for accidentally committed sensitive data using 'git log' or tools like 'truffleHog'.

LLM-derived; treat as a starting point, not a security audit.


Generated by RepoPilot. Verdict based on maintenance signals — see the live page for receipts. Re-run on a new commit to refresh.

Mixed signals · milesial/Pytorch-UNet — RepoPilot