Deep Learning for Tuberculosis Detection in Chest X-rays
Leveraging EfficientNet and advanced data augmentation techniques to achieve 95.8% accuracy in TB detection from chest radiographs.
Deep Learning for Tuberculosis Detection in Chest X-rays
Tuberculosis (TB) remains one of the leading infectious disease killers worldwide, with an estimated 10 million people falling ill with TB in 2022. Early and accurate detection is crucial for effective treatment and preventing transmission. In this post, I'll share our recent work on developing a deep learning system for automated TB detection in chest X-rays using EfficientNet architecture.
The Challenge
Chest X-ray interpretation for TB detection requires significant expertise and can be subjective, leading to inter-reader variability. In resource-limited settings where TB burden is highest, there's often a shortage of trained radiologists. Computer-aided detection (CAD) systems can help bridge this gap by providing consistent, objective analysis of chest radiographs.
Our Approach
Dataset and Preprocessing
We worked with a large dataset of chest X-rays from multiple international sources:
- Training set: 15,000 images (7,500 TB positive, 7,500 normal)
- Validation set: 3,000 images
- Test set: 2,000 images from different geographic regions
Key preprocessing steps included:
- Image normalization and resizing to 512×512 pixels
- CLAHE (Contrast Limited Adaptive Histogram Equalization) for contrast enhancement
- Lung field segmentation using U-Net to focus on relevant anatomical regions
Data Augmentation Strategy
To improve model robustness and generalization, we implemented advanced augmentation techniques:
pythonimport albumentations as A augmentation_pipeline = A.Compose([ A.RandomRotate90(p=0.3), A.HorizontalFlip(p=0.5), A.ShiftScaleRotate( shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.7 ), A.ElasticTransform( alpha=1, sigma=50, alpha_affine=50, p=0.3 ), A.GridDistortion(p=0.3), A.OpticalDistortion( distort_limit=0.05, shift_limit=0.05, p=0.3 ) ])
Model Architecture
We chose EfficientNet-B4 as our backbone architecture due to its excellent balance between accuracy and computational efficiency:
pythonimport torch import torch.nn as nn from efficientnet_pytorch import EfficientNet class TBDetectionModel(nn.Module): def __init__(self, num_classes=2): super(TBDetectionModel, self).__init__() # Load pre-trained EfficientNet-B4 self.backbone = EfficientNet.from_pretrained('efficientnet-b4') # Replace classifier num_features = self.backbone._fc.in_features self.backbone._fc = nn.Identity() # Custom classifier with dropout self.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(num_features, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, num_classes) ) def forward(self, x): features = self.backbone(x) return self.classifier(features)
Training Strategy
We employed a multi-stage training approach:
- Stage 1: Freeze backbone, train classifier only (5 epochs)
- Stage 2: Unfreeze and fine-tune entire network with lower learning rate (15 epochs)
- Stage 3: Final fine-tuning with cosine annealing (10 epochs)
Training configuration:
- Optimizer: AdamW with weight decay 1e-4
- Learning rate: 3e-4 with cosine annealing
- Batch size: 16 (limited by GPU memory)
- Loss function: Focal loss to handle class imbalance
Results
Our model achieved impressive performance on the test set:
Quantitative Results
- Accuracy: 95.8%
- Sensitivity (Recall): 94.2%
- Specificity: 97.4%
- Precision: 97.1%
- F1-Score: 95.6%
- AUC-ROC: 0.978
Comparison with Radiologists
We compared our model's performance with three experienced radiologists:
- Radiologist 1: 92.3% accuracy
- Radiologist 2: 94.1% accuracy
- Radiologist 3: 93.7% accuracy
- Our Model: 95.8% accuracy
Model Interpretability
Understanding model predictions is crucial for clinical adoption. We implemented Grad-CAM visualization to highlight regions contributing to TB detection:
pythonimport cv2 import numpy as np from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image def generate_gradcam(model, input_tensor, target_layer): cam = GradCAM(model=model, target_layers=[target_layer]) grayscale_cam = cam(input_tensor=input_tensor) # Convert to heatmap heatmap = cv2.applyColorMap( np.uint8(255 * grayscale_cam[0]), cv2.COLORMAP_JET ) return heatmap
The Grad-CAM visualizations showed that our model correctly focused on:
- Upper lobe infiltrates
- Cavitary lesions
- Pleural effusions
- Hilar lymphadenopathy
Clinical Validation
We conducted a prospective validation study at two hospitals:
- Site 1 (Urban hospital): 500 cases, 96.2% accuracy
- Site 2 (Rural clinic): 300 cases, 94.8% accuracy
The slight performance drop in rural settings was attributed to image quality variations and different X-ray equipment.
Deployment Considerations
For real-world deployment, we optimized the model:
Model Optimization
- Quantization to INT8 reduced model size by 75%
- TensorRT optimization improved inference speed by 3x
- Final model: 45MB, 120ms inference time on NVIDIA T4
Integration Workflow
pythonclass TBScreeningPipeline: def __init__(self, model_path): self.model = self.load_model(model_path) self.preprocessor = ImagePreprocessor() def screen_xray(self, image_path): # Preprocess image image = self.preprocessor.process(image_path) # Model inference prediction = self.model(image) confidence = torch.softmax(prediction, dim=1) # Generate report report = self.generate_report(prediction, confidence) return report
Future Directions
Multi-Task Learning
We're exploring extensions to detect multiple pathologies:
- Pneumonia
- COVID-19
- Lung cancer
- Pneumothorax
Federated Learning
To address data privacy concerns, we're implementing federated learning approaches that allow model training across institutions without data sharing.
Mobile Deployment
Converting the model to TensorFlow Lite for smartphone deployment in resource-limited settings.
Conclusion
Our EfficientNet-based approach demonstrates that deep learning can achieve superhuman performance in TB detection from chest X-rays. The combination of robust data augmentation, careful architecture selection, and clinical validation shows promise for real-world deployment.
Key takeaways:
- Data quality matters more than quantity for medical imaging tasks
- Domain-specific augmentation significantly improves generalization
- Clinical validation is essential before deployment
- Model interpretability builds trust with healthcare providers
The code and pre-trained models are available on GitHub, and we're actively collaborating with health organizations for wider deployment.
References
- WHO Global Tuberculosis Report 2023
- Tan, M., & Le, Q. (2019). EfficientNet: Rethinking model scaling for convolutional neural networks
- Rajpurkar, P., et al. (2017). CheXNet: Radiologist-level pneumonia detection on chest X-rays
- Selvaraju, R. R., et al. (2017). Grad-cam: Visual explanations from deep networks
This work was conducted in collaboration with the Department of Medical Imaging at Karolinska Institute and supported by the Swedish Research Council.
Comments
Related Articles
View all →Deep Learning for Medical Image Analysis: A Comprehensive Review
Exploring the latest advances in deep learning techniques for medical image analysis, including CNN architectures, attention mechanisms, and transformer models for radiology, pathology, and clinical diagnosis.
Transformer Architecture for Genomic Sequence Analysis
Deep dive into how transformer models are revolutionizing genomic analysis, from DNA sequence classification to protein structure prediction, with mathematical foundations and practical implementations.
Understanding Transformer Models in Natural Language Processing
An in-depth look at how transformer models have revolutionized NLP tasks and how they work under the hood.
Last updated: 2025-05-17 17:35:55 by linhduongtuan