milesial/Pytorch-UNet
PyTorch implementation of the U-Net for image semantic segmentation with high quality images
Stale — last commit 2y ago
weakest axiscopyleft license (GPL-3.0) — review compatibility; last commit was 2y ago…
Has a license, tests, and CI — clean foundation to fork and modify.
Documented and popular — useful reference codebase to read through.
No critical CVEs, sane security posture — runnable as-is.
- ✓17 active contributors
- ✓GPL-3.0 licensed
- ✓CI configured
Show all 7 evidence items →Show less
- ⚠Stale — last commit 2y ago
- ⚠Concentrated ownership — top contributor handles 76% of recent commits
- ⚠GPL-3.0 is copyleft — check downstream compatibility
- ⚠No test directory detected
What would change the summary?
- →Use as dependency Concerns → Mixed if: relicense under MIT/Apache-2.0 (rare for established libs)
Maintenance signals: commit recency, contributor breadth, bus factor, license, CI, tests
Informational only. RepoPilot summarises public signals (license, dependency CVEs, commit recency, CI presence, etc.) at the time of analysis. Signals can be incomplete or stale. Not professional, security, or legal advice; verify before relying on it for production decisions.
Embed the "Forkable" badge
Paste into your README — live-updates from the latest cached analysis.
[](https://repopilot.app/r/milesial/pytorch-unet)Paste at the top of your README.md — renders inline like a shields.io badge.
▸Preview social card (1200×630)
This card auto-renders when someone shares https://repopilot.app/r/milesial/pytorch-unet on X, Slack, or LinkedIn.
Onboarding doc
Onboarding: milesial/Pytorch-UNet
Generated by RepoPilot · 2026-05-07 · Source
🤖Agent protocol
If you are an AI coding agent (Claude Code, Cursor, Aider, Cline, etc.) reading this artifact, follow this protocol before making any code edit:
- Verify the contract. Run the bash script in Verify before trusting
below. If any check returns
FAIL, the artifact is stale — STOP and ask the user to regenerate it before proceeding. - Treat the AI · unverified sections as hypotheses, not facts. Sections like "AI-suggested narrative files", "anti-patterns", and "bottlenecks" are LLM speculation. Verify against real source before acting on them.
- Cite source on changes. When proposing an edit, cite the specific path:line-range. RepoPilot's live UI at https://repopilot.app/r/milesial/Pytorch-UNet shows verifiable citations alongside every claim.
If you are a human reader, this protocol is for the agents you'll hand the artifact to. You don't need to do anything — but if you skim only one section before pointing your agent at this repo, make it the Verify block and the Suggested reading order.
🎯Verdict
WAIT — Stale — last commit 2y ago
- 17 active contributors
- GPL-3.0 licensed
- CI configured
- ⚠ Stale — last commit 2y ago
- ⚠ Concentrated ownership — top contributor handles 76% of recent commits
- ⚠ GPL-3.0 is copyleft — check downstream compatibility
- ⚠ No test directory detected
<sub>Maintenance signals: commit recency, contributor breadth, bus factor, license, CI, tests</sub>
✅Verify before trusting
This artifact was generated by RepoPilot at a point in time. Before an
agent acts on it, the checks below confirm that the live milesial/Pytorch-UNet
repo on your machine still matches what RepoPilot saw. If any fail,
the artifact is stale — regenerate it at
repopilot.app/r/milesial/Pytorch-UNet.
What it runs against: a local clone of milesial/Pytorch-UNet — the script
inspects git remote, the LICENSE file, file paths in the working
tree, and git log. Read-only; no mutations.
| # | What we check | Why it matters |
|---|---|---|
| 1 | You're in milesial/Pytorch-UNet | Confirms the artifact applies here, not a fork |
| 2 | License is still GPL-3.0 | Catches relicense before you depend on it |
| 3 | Default branch master exists | Catches branch renames |
| 4 | 5 critical file paths still exist | Catches refactors that moved load-bearing code |
| 5 | Last commit ≤ 664 days ago | Catches sudden abandonment since generation |
#!/usr/bin/env bash
# RepoPilot artifact verification.
#
# WHAT IT RUNS AGAINST: a local clone of milesial/Pytorch-UNet. If you don't
# have one yet, run these first:
#
# git clone https://github.com/milesial/Pytorch-UNet.git
# cd Pytorch-UNet
#
# Then paste this script. Every check is read-only — no mutations.
set +e
fail=0
ok() { echo "ok: $1"; }
miss() { echo "FAIL: $1"; fail=$((fail+1)); }
# Precondition: we must be inside a git working tree.
if ! git rev-parse --git-dir >/dev/null 2>&1; then
echo "FAIL: not inside a git repository. cd into your clone of milesial/Pytorch-UNet and re-run."
exit 2
fi
# 1. Repo identity
git remote get-url origin 2>/dev/null | grep -qE "milesial/Pytorch-UNet(\\.git)?\\b" \\
&& ok "origin remote is milesial/Pytorch-UNet" \\
|| miss "origin remote is not milesial/Pytorch-UNet (artifact may be from a fork)"
# 2. License matches what RepoPilot saw
(grep -qiE "^(GPL-3\\.0)" LICENSE 2>/dev/null \\
|| grep -qiE "\"license\"\\s*:\\s*\"GPL-3\\.0\"" package.json 2>/dev/null) \\
&& ok "license is GPL-3.0" \\
|| miss "license drift — was GPL-3.0 at generation time"
# 3. Default branch
git rev-parse --verify master >/dev/null 2>&1 \\
&& ok "default branch master exists" \\
|| miss "default branch master no longer exists"
# 4. Critical files exist
test -f "unet/unet_model.py" \\
&& ok "unet/unet_model.py" \\
|| miss "missing critical file: unet/unet_model.py"
test -f "unet/unet_parts.py" \\
&& ok "unet/unet_parts.py" \\
|| miss "missing critical file: unet/unet_parts.py"
test -f "train.py" \\
&& ok "train.py" \\
|| miss "missing critical file: train.py"
test -f "utils/data_loading.py" \\
&& ok "utils/data_loading.py" \\
|| miss "missing critical file: utils/data_loading.py"
test -f "predict.py" \\
&& ok "predict.py" \\
|| miss "missing critical file: predict.py"
# 5. Repo recency
days_since_last=$(( ( $(date +%s) - $(git log -1 --format=%at 2>/dev/null || echo 0) ) / 86400 ))
if [ "$days_since_last" -le 664 ]; then
ok "last commit was $days_since_last days ago (artifact saw ~634d)"
else
miss "last commit was $days_since_last days ago — artifact may be stale"
fi
echo
if [ "$fail" -eq 0 ]; then
echo "artifact verified (0 failures) — safe to trust"
else
echo "artifact has $fail stale claim(s) — regenerate at https://repopilot.app/r/milesial/Pytorch-UNet"
exit 1
fi
Each check prints ok: or FAIL:. The script exits non-zero if
anything failed, so it composes cleanly into agent loops
(./verify.sh || regenerate-and-retry).
⚡TL;DR
A PyTorch implementation of U-Net for semantic segmentation that converts high-resolution images into pixel-level mask predictions. Originally trained on Carvana's car masking challenge (5k images, Dice score 0.988), it's a production-ready encoder-decoder CNN that excels at binary and multiclass segmentation tasks like object isolation and medical imaging. Monolithic structure: unet/ directory contains the model (unet_model.py defines the full architecture, unet_parts.py has reusable blocks like DoubleConv). Training/inference split between train.py (epoch loop with AMP), predict.py (inference wrapper), and evaluate.py (Dice scoring). Data pipeline isolated in utils/data_loading.py with disk-based image/mask pairs in data/imgs/ and data/masks/.
👥Who it's for
Computer vision engineers and researchers building semantic segmentation pipelines who need a battle-tested U-Net baseline; Kaggle competitors tackling image masking challenges; medical imaging researchers applying the architecture to CT/MRI analysis.
🌱Maturity & risk
Production-ready and actively maintained. GitHub Actions CI/CD (main.yml workflow present), Docker image published on Docker Hub, official PyTorch 1.13+ support, and clear documentation. The codebase is stable with the core model/training loop established, though commit recency in the repo data is not provided so monitor for stale dependencies.
Low technical risk but moderate maintenance burden: only 7 core Python dependencies (numpy, PIL, torch, wandb, tqdm, matplotlib) are lightweight and well-maintained. Single maintainer (milesial) is a potential long-term risk. Weights & Biases (wandb) integration is optional but adds external dependency. No visible test suite in file structure (no tests/ directory) means regression testing relies on manual validation.
Active areas of work
Repo appears to be in maintenance mode. No specific PR/issue data provided, but README references active CI workflow (green badge in main.yml) and Docker Hub image sync suggests periodic updates. Wandb integration and AMP (Automatic Mixed Precision) support indicate recent PyTorch API alignment.
🚀Get running
Clone and install:
git clone https://github.com/milesial/Pytorch-UNet.git
cd Pytorch-UNet
pip install -r requirements.txt
bash scripts/download_data.sh
python train.py --amp
Or with Docker:
sudo docker run --rm --gpus all -it milesial/unet
Daily commands: Training:
python train.py --amp --epochs 5 --batch-size 1 --learning-rate 0.0001 --scale 0.5
Prediction on new image:
python predict.py -i test_image.jpg -o output_mask.png -m checkpoints/checkpoint_epoch5.pth
Evaluation:
python evaluate.py --model checkpoints/model.pth --input data/imgs --output masks_predicted
🗺️Map of the codebase
unet/unet_model.py— Core U-Net architecture definition; all semantic segmentation models instantiate from this file, making it essential for understanding the neural network structure.unet/unet_parts.py— Defines reusable building blocks (DoubleConv, Down, Up) for U-Net; modifications here cascade across all model variants.train.py— Main training loop entry point; orchestrates data loading, model training, validation, and checkpoint management—required reading for training workflows.utils/data_loading.py— Custom dataset and dataloader implementation for image-mask pairs; critical for understanding data pipeline and augmentation strategy.predict.py— Inference entry point that loads trained models and generates segmentation predictions; key interface for end-to-end workflows.requirements.txt— Pins exact dependency versions including PyTorch and data handling libraries; mismatches here cause silent failures in image loading and tensor operations.evaluate.py— Model evaluation metrics computation; essential for understanding how Dice score and other performance metrics are calculated.
🧩Components & responsibilities
- UNet (unet/unet_model.py) (PyTorch nn.Module, Conv2d, MaxPool2d, ReLU) — Encodes input image to low-res feature map, then decodes with skip connections to generate per-pixel segmentation logits
🛠️How to make changes
Add a Custom U-Net Variant (Different Depth/Channels)
- Create a new class inheriting from UNet in unet/unet_model.py, override init with custom n_channels, n_classes, and bilinear interpolation settings (
unet/unet_model.py) - Instantiate the custom model in train.py (line where model is created) or add a --model_variant argument to argparse (
train.py) - Test with a forward pass and verify output shape matches (batch_size, n_classes, H, W) (
train.py)
Add Image Preprocessing or Augmentation
- Open utils/data_loading.py and extend the BasicDataset class's getitem method to apply additional transforms (rotation, flip, brightness) (
utils/data_loading.py) - Import torch.transforms or albumentations at the top and compose your augmentation pipeline before returning img and mask tensors (
utils/data_loading.py) - Run train.py with a small batch to verify augmented images load without errors and shapes remain consistent (
train.py)
Add a New Evaluation Metric
- Create a new metric function in utils/dice_score.py (e.g., iou_score, precision_recall) following the same signature as multiclass_dice_coeff (
utils/dice_score.py) - Import and call the new metric in evaluate.py alongside the existing Dice evaluation in the validation loop (
evaluate.py) - Log the metric to wandb by adding it to the wandb.log() call in train.py if using Weights & Biases integration (
train.py)
Deploy Model to Docker for Production Inference
- Update Dockerfile COPY and CMD directives to include your trained checkpoint and modify predict.py to load the specific model path (
Dockerfile) - Build the Docker image: docker build -t unet:latest . and test locally with docker run -v /path/to/images:/data unet:latest (
Dockerfile) - Verify predict.py outputs segmentation masks correctly when invoked inside the container (
predict.py)
🔧Why these technologies
- PyTorch 1.13+ — Modern deep learning framework with autograd, GPU acceleration, and strong community support for semantic segmentation tasks
- Pillow for image I/O — Efficient, non-blocking image loading and resizing for high-resolution image datasets typical of Carvana challenge
- NumPy for tensor operations — Vectorized operations for mask preprocessing and metric calculation before PyTorch conversion
- Weights & Biases (wandb) for experiment tracking — Optional integration for logging hyperparameters, metrics, and model checkpoints across training runs for reproducibility
- Docker for containerization — Ensures reproducible inference environment across different hardware/OS platforms for production deployment
⚖️Trade-offs already made
-
Bilinear upsampling vs. transposed convolution in decoder
- Why: Bilinear is configurable (set via bilinear=True/False in UNet) to avoid checkerboard artifacts in some use cases
- Consequence: Bilinear adds slight interpolation error but provides training stability; transposed conv is more learnable but can introduce artifacts
-
Single-file data loading (BasicDataset) without prefetching/caching
- Why: Simplicity and memory efficiency for datasets that fit on disk; no overhead of maintaining in-memory cache
- Consequence: I/O can become bottleneck for very large datasets or slow storage; users must implement custom caching if needed
-
BCEWithLogitsLoss for binary/multiclass segmentation
- Why: Numerically stable and handles class imbalance well for real-world masking tasks
- Consequence: Assumes binary or independent class predictions; not suitable for multi-class exclusive segmentation without modification
-
Checkpoint saving only on validation improvement (if Dice increases)
- Why: Prevents storage bloat and automatic recovery of best model
- Consequence: Early stopping requires external logic; only best checkpoint persists, losing training history
🚫Non-goals (don't propose these)
- Real-time inference (no optimization for latency)
- Multi-GPU distributed training (single-GPU only)
- Data augmentation library (relies on external transforms; minimal built-in augmentation)
- Handles 3D volumetric data (2D image segmentation only)
- Instance segmentation (semantic segmentation only, no per-object ID tracking)
- Post-processing or CRF refinement (raw model predictions only)
🪤Traps & gotchas
AMP stability: train.py uses torch.cuda.amp which requires careful loss scaling; runs on CUDA devices only (CPU fallback not tested). Data paths: scripts/download_data.sh populates data/imgs/ and data/masks/ with exact directory structure assumed by data_loading.py; manual data placement must match this layout. Checkpoint format: train.py saves state_dict() but predict.py expects specific checkpoint keys; model versioning across commits can break inference. wandb integration: Optional but enabled by default in train.py; requires API key or --no-log flag to disable. Input normalization: data_loading.py applies fixed normalization (assumes ImageNet stats); custom datasets may need retraining or normalization adjustment.
🏗️Architecture
💡Concepts to learn
- U-Net Architecture (Encoder-Decoder with Skip Connections) — The entire repo is built around this design; understanding how DoubleConv blocks, MaxPool2d down-sampling, ConvTranspose2d up-sampling, and skip concatenation work is essential to modifying the model.
- Dice Coefficient (Sørensen–Dice Index) — The primary evaluation metric in utils/dice_score.py; the model was tuned to maximize this (0.988 on test set), so understanding how it penalizes false positives/negatives is critical for interpreting results.
- Automatic Mixed Precision (AMP) — train.py uses torch.cuda.amp context managers and GradScaler for 3x training speedup; misunderstanding loss scaling can cause NaN divergence.
- Skip Connections / Residual Learning — unet_model.py concatenates encoder features to decoder layers to preserve spatial detail; this is why U-Net excels at dense predictions vs. a naive encoder-decoder.
- Transpose Convolution (Deconvolution) — unet_parts.py uses ConvTranspose2d for up-sampling; understanding stride, padding, and output_padding is essential for controlling spatial resolution in the decoder.
- Gradient Accumulation — train.py supports --batch-size 1 on memory-constrained hardware; batch normalization and loss computation depend on effective batch size, which can be tuned via accumulation steps.
- Cross-Entropy Loss vs. Dice Loss — train.py combines both losses for balanced optimization; understanding their different penalties (pixel-wise class imbalance vs. region overlap) helps debug convergence issues.
🔗Related repos
facebookresearch/detectron2— Production-grade PyTorch vision library with U-Net and other segmentation backbones; if you need multi-task detection + segmentation, use this.qubvel/segmentation_models.pytorch— Comprehensive PyTorch segmentation model zoo (U-Net, DeepLab, etc.) with pretrained encoders; this repo is lighter-weight for custom training on Carvana-style tasks.miccaia/Pytorch-UNet— Alternative U-Net fork with 3D support and multiclass extensions; if you need volumetric medical segmentation, compare this one.pytorch/pytorch— Core PyTorch framework; required dependency and source for torch.cuda.amp, F.interpolate, and nn modules used throughout.wandb/wandb— Experiment tracking platform integrated into train.py for logging metrics and model checkpoints; essential for reproducing reported Dice 0.988 score.
🪄PR ideas
To work on one of these in Claude Code or Cursor, paste:
Implement the "<title>" PR idea from CLAUDE.md, working through the checklist as the task list.
Add unit tests for utils/ modules with pytest
The repo has no test directory despite containing critical utility functions (data_loading.py, dice_score.py, utils.py). These functions handle data loading, metric computation, and preprocessing - core functionality that should be tested. Adding a tests/ directory with pytest would catch regressions and help new contributors understand expected behavior.
- [ ] Create tests/ directory with init.py
- [ ] Add tests/test_dice_score.py with tests for dice computation edge cases (empty masks, perfect overlap, no overlap)
- [ ] Add tests/test_data_loading.py with tests for BasicDataset class initialization and augmentation logic in utils/data_loading.py
- [ ] Add tests/test_utils.py for any utility functions in utils/utils.py
- [ ] Update requirements.txt to include pytest and pytest-cov
- [ ] Add pytest configuration to .github/workflows/main.yml if not already present
Add type hints to unet/ and utils/ modules
The codebase has no type hints in core modules (unet_model.py, unet_parts.py, data_loading.py, utils.py). This makes it harder for contributors to understand function signatures and enables better IDE support. Adding comprehensive type hints would improve code quality and maintainability with minimal refactoring.
- [ ] Add type hints to unet/unet_parts.py for all Conv classes and DoubleConv forward() methods
- [ ] Add type hints to unet/unet_model.py for UNet class init() and forward() methods
- [ ] Add type hints to utils/data_loading.py for BasicDataset.init(), getitem(), and len()
- [ ] Add type hints to utils/dice_score.py for dice_coeff() and multiclass_dice_coeff() functions
- [ ] Add typing imports and ensure Python 3.6+ compatibility (Union, List, Tuple, Optional)
- [ ] Update README to mention type hint coverage
Add integration tests for train.py and predict.py workflows
The repo has training (train.py) and prediction (predict.py) scripts but no integration tests to verify end-to-end workflows work correctly. This is critical since these are the user-facing entry points. Adding minimal integration tests would catch breaking changes and help new contributors understand the expected data flow.
- [ ] Create tests/integration/ directory
- [ ] Add tests/integration/test_training_workflow.py that: trains a small model on dummy data (from data/imgs, data/masks), saves checkpoint, and verifies loss decreases
- [ ] Add tests/integration/test_prediction_workflow.py that: loads a trained model from hubconf.py, runs prediction on a test image, and verifies output shape matches input
- [ ] Create a small dummy dataset fixture in tests/conftest.py (4-5 small sample images/masks) for reproducible testing
- [ ] Add integration test step to .github/workflows/main.yml
- [ ] Document in README how to run integration tests for contributors
🌿Good first issues
- Add unit tests for unet_parts.py (DoubleConv, Down, Up blocks) with mock tensors to verify forward pass shapes and parameter counts; currently no tests/ directory exists.: Ensures model components don't regress and helps new contributors understand tensor dimensions.
- Document hyperparameter sensitivity in README with specific examples (e.g., 'learning_rate=0.0001 achieves 0.988 Dice, learning_rate=0.001 diverges'); train.py has many flags but no guidance.: Reduces onboarding friction and prevents common training mistakes.
- Add CPU fallback in predict.py and train.py (e.g., device='cuda' if torch.cuda.is_available() else 'cpu') and test on CPU; currently assumes GPU-only.: Enables local testing and inference on machines without CUDA, improving usability.
⭐Top contributors
Click to expand
Top contributors
- @milesial — 76 commits
- @Gouvernathor — 6 commits
- @qslia — 2 commits
- @Arka161 — 2 commits
- @laclouis5 — 2 commits
📝Recent commits
Click to expand
Recent commits
21d7850— Merge pull request #475 from IshanG97/download-data-bat-script (milesial)de41eaa— Merge pull request #470 from yelboudouri/master (milesial)c478dc9— Fix issue #474: Windows .bat setup script (IshanG97)52b4f14— Unscale gradients before clipping (yelboudouri)2f62e6b— Update docs to PyT 1.13 (milesial)c10f0d1— Fix typo when checking nans (milesial)c04a07c— Check nans before logging to wandb (milesial)40d5ba7— Merge pull request #421 from vitorrussi/patch-1 (milesial)d4c389f— Add n_classes argument for predict.py (vitorrussi)6b7f354— Switch to base dataset on IndexError (milesial)
🔒Security observations
The PyTorch-UNet repository has a moderate security posture. The primary concerns are outdated dependencies with potential known vulnerabilities, lack of Docker security hardening (non-root user execution), and the use of pre-release pip flags. The project is a research/ML implementation without apparent injection risks or hardcoded credentials, but dependency management should be improved. No critical SQL injection, XSS, or authentication bypass issues were identified in the visible code structure. Immediate actions should focus on updating dependencies and applying Docker security best practices.
- High · Outdated Dependencies with Known Vulnerabilities —
requirements.txt. The project uses outdated package versions that may contain known security vulnerabilities. Specifically: Pillow 9.3.0 (current: 10.x+), numpy 1.23.5 (current: 1.26+), and matplotlib 3.6.2 (current: 3.8+). These older versions may have unpatched CVEs. Fix: Update all dependencies to their latest stable versions. Run 'pip install --upgrade' for each package and test compatibility. Consider using 'pip-audit' to identify known vulnerabilities in the dependency tree. - Medium · Missing Security Headers in Docker Configuration —
Dockerfile. The Dockerfile uses a base image from nvcr.io but does not implement security best practices such as running as a non-root user, using multi-stage builds, or applying minimal layer principles. Fix: Add 'RUN useradd -m -u 1000 unet' and 'USER unet' to run the container as a non-root user. Minimize the number of layers and consider using a minimal base image variant. - Medium · Insecure Pip Installation Without Verification —
Dockerfile. The Dockerfile installs packages with '--no-cache-dir --upgrade --pre' flags. The '--pre' flag installs pre-release versions which may be unstable and untested, increasing risk of bugs and security issues. Fix: Remove the '--pre' flag from pip install commands. Use stable release versions only. Pin specific versions in requirements.txt to ensure reproducible builds. - Low · Unrestricted Workspace Cleanup —
Dockerfile. The Dockerfile executes 'rm -rf /workspace/*' without validation, which could be problematic in certain deployment scenarios or if the image is extended. Fix: Ensure this command only removes intended temporary files. Consider using a more explicit cleanup approach that targets only specific directories. - Low · Missing .gitignore Validation —
.gitignore. While a .gitignore file exists, its content is not visible. It may not properly exclude sensitive files like configuration files, API keys, or model weights if they are checked in. Fix: Ensure .gitignore includes common sensitive patterns: *.env, *.key, *.pem, *.pkl, model_weights/, .aws/, etc. Audit git history for accidentally committed sensitive data using 'git log' or tools like 'truffleHog'.
LLM-derived; treat as a starting point, not a security audit.
👉Where to read next
- Open issues — current backlog
- Recent PRs — what's actively shipping
- Source on GitHub
Generated by RepoPilot. Verdict based on maintenance signals — see the live page for receipts. Re-run on a new commit to refresh.