CheXpert MAE-DenseNet-FPN
A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining Masked Autoencoders (MAE), DenseNet with CBAM attention, and Feature Pyramid Networks (FPN) with bidirectional cross-attention fusion.
๐๏ธ Architecture Overview
This project implements a novel multi-modal fusion architecture for medical image classification:
- MAE Encoder: Vision Transformer-based masked autoencoder for self-supervised feature extraction
- DenseNet-169: Dense convolutional network with Channel and Spatial Attention (CBAM)
- Feature Pyramid Network: Multi-scale feature extraction at 4 different resolutions
- Bidirectional Cross-Attention: Fusion mechanism allowing MAE and DenseNet features to attend to each other
- Learned Logit Ensemble: Intelligent combination of 7 prediction heads with learnable temperature scaling
Key Components
Input Image (384ร384)
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ โ
โผ โผ
MAE Encoder DenseNet-169
(ViT-based) (with CBAM)
โ โ
โ โโโโโโโโโโโโโโโโโโโโโค
โ โ โ
โ FPN Pyramid Dense Features
โ (P1-P4) (Multi-scale)
โ โ โ
โโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโ
โ
Bidirectional Cross-Attention
โ
โโโโโโโโโโโดโโโโโโโโโโโ
โ โ
MAE Head Dense Head + 4 FPN Heads
โ โ
โโโโโโโโโโฌโโโโโโโโโโโโ
โ
Learned Ensemble (7 heads)
โ
โผ
14-class Predictions
โจ Features
- Hybrid Architecture: Combines transformer-based and convolutional approaches
- Multi-scale Learning: FPN extracts features at 4 different resolutions
- Advanced Fusion: Bidirectional cross-attention between MAE and DenseNet features
- Optimized Training:
- Mixed precision training (FP16)
- Gradient accumulation
- Weighted sampling for class imbalance
- Cosine annealing with linear warmup
- Gradient checkpointing for memory efficiency
- Smart Data Loading:
- ZIP file reader with LRU caching
- On-the-fly augmentation using Albumentations
- Multi-worker data loading with persistent workers
- Comprehensive Evaluation:
- Per-class AUC metrics
- Optimal threshold computation per class
- Macro and Micro AUC tracking
๐ Requirements
- Python 3.8+
- CUDA-capable GPU (recommended: 16GB+ VRAM)
- CheXpert dataset
๐ Installation
- Clone the repository
git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git
cd chexpert-mae-densenet-fpn
- Create a virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
- Install dependencies
pip install -r requirements.txt
๐ Dataset Setup
Download CheXpert Dataset
- Visit: https://stanfordmlgroup.github.io/competitions/chexpert/
- Download CheXpert-v1.0-small
Prepare the dataset
# Extract the dataset
unzip CheXpert-v1.0-small.zip
# Optionally, create a ZIP archive for faster loading
cd CheXpert-v1.0-small
zip -r chexpert.zip train/ valid/
- Update configuration
- Edit
configs/configs.py - Update
rootvariable to point to your dataset location - Update all paths accordingly
- Edit
๐ง Configuration
Edit configs/configs.py to customize:
# Example: Update paths
root = "/path/to/your/data"
mae_config = {
"lr": 1e-4,
"num_epochs": 200,
"batch_size": 96,
# ... other parameters
}
config = {
"lr": 1e-4,
"num_epochs": 200,
"batch_size": 36,
# ... other parameters
}
๐ฏ Training
Phase 1: Pre-train MAE
python trainer/trainer.py
# When prompted, type: mae
The MAE pre-training learns robust feature representations through masked image reconstruction.
Phase 2: Train Classifier
python trainer/trainer.py
# When prompted, type: classifier
This loads the pre-trained MAE encoder and trains the full classification pipeline.
Training Configuration
MAE Training:
- Batch size: 96
- Mask ratio: 0.75 (masks 75% of patches)
- Reconstruction loss on masked patches
Classifier Training:
- Batch size: 36 with gradient accumulation (8 steps)
- Effective batch size: 288
- Asymmetric loss with class weights
- Per-class threshold optimization
๐งช Testing
from trainer.utils import Trainer
from configs.configs import config
# Initialize trainer
trainer = Trainer(config)
# Run evaluation on test set
macro_auc, micro_auc, per_class = trainer.test(
model_path="path/to/checkpoint.pth"
)
print(f"Macro AUC: {macro_auc:.4f}")
print(f"Micro AUC: {micro_auc:.4f}")
๐ Project Structure
chexpert-mae-densenet-fpn/
โโโ configs/
โ โโโ __init__.py
โ โโโ configs.py # Configuration parameters
โโโ data/
โ โโโ __init__.py
โ โโโ dataset.py # CheXpert dataset with ZIP caching
โ โโโ splitter.py # Data splitting utilities
โโโ loss/
โ โโโ __init__.py
โ โโโ assymetric.py # Asymmetric loss for imbalanced data
โโโ models/
โ โโโ __init__.py
โ โโโ mae.py # Masked Autoencoder implementation
โ โโโ densenet.py # DenseNet-169 with CBAM
โ โโโ classifier.py # Full classification architecture
โโโ trainer/
โ โโโ __init__.py
โ โโโ trainer.py # Main training script
โ โโโ utils.py # Training utilities and loops
โ โโโ test.py # Testing utilities
โโโ notebooks/
โ โโโ chexpert_mae.ipynb # MAE experiments
โ โโโ chexpert_mae_mask_classifier.ipynb # Full pipeline experiments
โโโ requirements.txt
โโโ README.md
๐ Model Architecture Details
MAE Encoder
- Patch size: 16ร16
- Embedding dim: 768
- Depth: 12 transformer blocks
- Heads: 8 attention heads
- MLP ratio: 4ร
DenseNet-169
- Growth rate (k): 64
- Layers: [6, 12, 24, 16]
- CBAM: Channel + Spatial attention at each stage
- Dropout: Progressive (0.05 โ 0.1 โ 0.1 โ 0.1)
Cross-Attention Fusion
- 12 bidirectional cross-attention layers
- Projection dim: 512
- Attention heads: 8
FPN
- Feature levels: P1 (192ร192), P2 (96ร96), P3 (48ร48), P4 (24ร24)
- Channel unification: 256 channels per level
๐ CheXpert Labels
The model predicts 14 pathologies:
- No Finding
- Enlarged Cardiomediastinum
- Cardiomegaly
- Lung Opacity
- Lung Lesion
- Edema
- Consolidation
- Pneumonia
- Atelectasis
- Pneumothorax
- Pleural Effusion
- Pleural Other
- Fracture
- Support Devices
๐ฌ Data Augmentation
Training augmentations (conservative for medical images):
- Horizontal flip (p=0.5)
- Random affine (translation, scale, rotation ยฑ10ยฐ)
- Random brightness/contrast
- CLAHE histogram equalization
- Gaussian blur and noise
๐พ Checkpoints
The training automatically saves:
- Best MAE checkpoint: Based on validation reconstruction loss
- Best classifier checkpoint: Based on validation AUC (macro/micro)
- Training history: JSON file with all metrics
- Per-epoch metrics plots: Loss and AUC curves
๐ Monitoring
Training logs are saved to:
training_log.txt: Training progress with live metricsval_log.txt: Validation resultstest_log.txt: Test evaluation resultshistory.json: All metrics across epochsmetrics.png: Visualization plots
โก Performance Tips
Memory Optimization:
- Use gradient checkpointing (already enabled)
- Reduce batch size if OOM occurs
- Increase gradient accumulation steps
Speed Optimization:
- Use persistent workers (already enabled)
- Enable cuDNN benchmark (already enabled)
- Use ZIP caching for faster data loading
Training Stability:
- Gradient clipping at norm 1.0
- Mixed precision with dynamic loss scaling
- Warmup learning rate schedule
๐ Troubleshooting
Q: Out of memory errors?
- Reduce batch size in configs.py
- Increase gradient accumulation steps
- Enable gradient checkpointing
Q: Slow training?
- Check if ZIP caching is enabled
- Verify persistent workers are active
- Monitor GPU utilization
Q: Poor convergence?
- Ensure MAE is properly pre-trained first
- Check learning rate and warmup settings
- Verify class weights are computed correctly
๐ Citation
If you use this code in your research, please cite:
@misc{chexpert-mae-densenet-fpn,
author = {adel elsayed},
title = {CheXpert Classification with MAE-DenseNet-FPN},
year = {2025},
publisher = {GitHub},
url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn}
}
๐ Acknowledgments
- CheXpert Dataset: Stanford ML Group
- Masked Autoencoders: Meta AI Research (He et al., 2021)
- DenseNet: Huang et al., 2017
- CBAM: Woo et al., 2018
- Feature Pyramid Networks: Lin et al., 2017
๐ License
License
This project is licensed under the MIT License.
๐ง Contact
https://www.linkedin.com/in/adel-elsayed-a5260246/
Note: This is a research project. For clinical use, please ensure proper validation and regulatory approval.
