Light Mode

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

statmlben/rankseg

Repository files navigation

RankSEG: A Consistent Ranking-based Framework for Segmentation (JMLR 2023)

About This Repository

Note: This is the experiment reproduction repository for our JMLR paper.
For the RankSEG software package (installation & usage), please visit:
https://github.com/rankseg/rankseg

RankSEG is a Python module designed for segmentation tasks, aiming to maximize Dice or IoU metrics based on estimated probabilities.

Key Features

Most segmentation methods traditionally rely on IoU and Dice as evaluation metrics. During inference and prediction, these methods typically use a threshold of 0.5 or apply argmax to the estimated probabilities to generate segmentation predictions. However, this approach does not directly optimize the IoU or Dice metrics.

Our method, RankDice, directly optimizes IoU and Dice metrics.

  • Nearly ensures improved Dice and IoU performance!
  • Seamlessly integrates with any pretrained segmentation neural network (no need to retrain the models).
  • A well-developed Python function rank_dice is available for use.

Installation

git clone https://github.com/statmlben/rankseg.git
pip install -r requirements.txt

How-to-Use (on a batch)

RankDice

## `out_prob` (batch_size, num_class, width, height) is the output probability for each pixel based on a trained neural network
from rankseg import rank_dice
predict_rd, tau_rd, cutpoint_rd = rank_dice(out_prob, app=2, device='cuda')

Other existing frameworks (Threshold and Argmax)

## `out_prob` (batch_size, num_class, width, height) is the output probability for each pixel based on a trained neural network

## Threshold
predict_T = torch.where(out_prob > .5, True, False)

## Argmax
idx = torch.argmax(out_prob.data, dim=1, keepdims=True)
predict_max = torch.zeros_like(out_prob.data, dtype=bool).scatter_(1, idx, True)

Usage in pytorch-segmentation-rankseg (in subfolder)

## rankdice
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/che ckpoint-epoch300.pth -p "rankdice"

TEST, Pred (rankdice) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.51, Mean Dice 0.59 |: 100%|######| 84/84 [01:03<00:00, 1.33it/s]

## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: rankdice ##
test_loss : 0.15925
Pixel_Accuracy : 0.9879999756813049
Mean_IoU : 0.5099999904632568
Mean_Dice : 0.5929999947547913
Class_IoU : {0: 0.771, 1: 0.508, 2: 0.767, 3: 0.164, 4: 0.117, 5: 0.317, 6: 0.283, 7: 0.401, 8: 0.841, 9: 0.231, 10: 0.778, 11: 0.4, 12: 0.292, 13: 0.766, 14: 0.233, 15: 0.465, 16: 0.315, 17: 0.177, 18: 0.326}
Class_Dice : {0: 0.856, 1: 0.608, 2: 0.851, 3: 0.21, 4: 0.158, 5: 0.46, 6: 0.374, 7: 0.514, 8: 0.903, 9: 0.294, 10: 0.845, 11: 0.495, 12: 0.372, 13: 0.84, 14: 0.265, 15: 0.513, 16: 0.358, 17: 0.222, 18: 0.419}

## max
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/che ckpoint-epoch300.pth -p "max"

TEST, Pred (max) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.49, Mean Dice 0.56 |: 100%|###########| 84/84 [00:12<00:00, 6.52it/s]

## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: max ##
test_loss : 0.15925
Pixel_Accuracy : 0.9879999756813049
Mean_IoU : 0.48500001430511475
Mean_Dice : 0.5649999976158142
Class_IoU : {0: 0.768, 1: 0.489, 2: 0.759, 3: 0.133, 4: 0.099, 5: 0.295, 6: 0.257, 7: 0.387, 8: 0.836, 9: 0.208, 10: 0.769, 11: 0.372, 12: 0.272, 13: 0.751, 14: 0.204, 15: 0.395, 16: 0.268, 17: 0.152, 18: 0.303}
Class_Dice : {0: 0.854, 1: 0.585, 2: 0.844, 3: 0.172, 4: 0.136, 5: 0.428, 6: 0.341, 7: 0.498, 8: 0.9, 9: 0.268, 10: 0.835, 11: 0.464, 12: 0.351, 13: 0.826, 14: 0.233, 15: 0.437, 16: 0.308, 17: 0.193, 18: 0.392}


## threshold at 0.5
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/che ckpoint-epoch300.pth -p "T"

TEST, Pred (T) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.50, Mean Dice 0.57 |: 100%|#############| 84/84 [00:13<00:00, 6.45it/s]

## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: T ##
test_loss : 0.15925
Pixel_Accuracy : 0.9890000224113464
Mean_IoU : 0.4959999918937683
Mean_Dice : 0.574999988079071
Class_IoU : {0: 0.772, 1: 0.478, 2: 0.762, 3: 0.136, 4: 0.109, 5: 0.29, 6: 0.265, 7: 0.39, 8: 0.841, 9: 0.201, 10: 0.77, 11: 0.363, 12: 0.273, 13: 0.769, 14: 0.219, 15: 0.422, 16: 0.307, 17: 0.158, 18: 0.325}
Class_Dice : {0: 0.857, 1: 0.573, 2: 0.846, 3: 0.174, 4: 0.147, 5: 0.419, 6: 0.349, 7: 0.499, 8: 0.902, 9: 0.257, 10: 0.836, 11: 0.451, 12: 0.351, 13: 0.841, 14: 0.247, 15: 0.468, 16: 0.349, 17: 0.197, 18: 0.414}

Jupyter Notebook

Illustrative results

Results in Fine-annotated Cityscapes dataset

  • Threshold, Argmax and rankDice are performed based on the same network (in Model column) trained by the same loss (in Loss column).
  • Averaged mDice and mIoU metrics based on state-of-the-art models/losses on Fine-annotated CityScapes val set. '/' indicates not applicable since the proposed RankDice/mRankDice requires a strictly proper loss. The best performance in each model/loss is bold-faced.
  • All trained neural networks and their config.json with different network and loss are saved in this link (12G folder: network/loss/.../*.pth + config.json)
Model Loss Threshold (at 0.5) Argmax mRankDice (our)
(mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$)
DeepLab-V3+ CE (56.00, 48.40) (54.20, 46.60) (57.80, 49.80)
(resnet101) Focal (54.10, 46.60) (53.30, 45.60) (56.50, 48.70)
BCE (49.80, 24.90) (44.20, 22.10) (54.00, 27.00)
Soft-Dice (39.50, 35.90) (39.50, 35.90) /
B-Soft-Dice (41.00, 20.50) (27.60, 13.80) /
LovaszSoftmax (55.20, 47.60) (52.30, 45.10) /
PSPNet CE (57.50, 49.60) (56.50, 48.50) (59.30, 51.00)
(resnet50) Focal (56.00, 48.20) (55.80, 47.70) (58.20, 50.00)
BCE (51.40, 25.70) (47.60, 23.80) (55.10, 27.60)
Soft-Dice (49.10, 43.50) (48.70, 43.20) /
B-Soft-Dice (46.30, 23.10) (32.70, 16.40) /
LovaszSoftmax (56.80, 48.90) (55.40, 47.70) /
FCN8 CE (51.40, 43.70) (50.50, 42.60) (53.50, 45.30)
(resnet101) Focal (48.50, 41.20) (49.60, 41.60) (51.50, 43.70)
BCE (39.40, 19.70) (39.40, 19.70) (41.30, 20.60)
Soft-Dice (28.30, 24.30) (28.30, 24.30) /
B-Soft-Dice (29.10, 14.60) (29.10, 14.60) /
LovaszSoftmax (48.10, 40.40) (42.90, 35.80) /

Results in PASCAL VOC 2012 dataset

  • Threshold, Argmax and rankDice are performed based on the same network (in Model column) trained by the same loss (in Loss column).
  • Averaged mDice and mIoU based on state-of-the-art models/losses on PASCAL VOC 2012 val set. '---' indicates that either the performance is significantly worse or the training is unstable, and '/' indicates not applicable since the proposed RankDice/mRankDice requires a strictly proper loss. The best performance in each model-loss pair is bold-faced.
  • All trained neural networks with different network and loss are saved in this link (22G folder: network/loss/.../*.pth)
Model Loss Threshold (at 0.5) Argmax mRankDice (our)
(mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$)
DeepLab-V3+ CE (63.60, 56.70) (61.90, 55.30) (64.01, 57.01)
(resnet101) Focal (62.70, 55.01) (60.50, 53.20) (62.90, 55.10)
BCE (63.30, 31.70) (59.90, 29.90) (64.60, 32.30)
Soft-Dice --- --- /
B-Soft-Dice --- --- /
LovaszSoftmax (57.70, 51.60) (56.20, 50.30) /
PSPNet CE (64.60, 57.10) (63.20, 55.90) (65.40, 57.80)
(resnet50) Focal (64.00, 56.10) (63.90, 56.10) (66.60, 58.50)
BCE (64.20, 32.10) (65.20, 32.60) (67.10, 33.50)
Soft-Dice (59.60, 54.00) (58.80, 53.20) /
B-Soft-Dice (63.30, 31.60) (54.00. 27.00) /
LovaszSoftmax (62.00, 55.20) (60.80, 54.10) /
FCN8 CE (49.50, 41.90) (45.30, 38.40) (50.40, 42.70)
(resnet101) Focal (50.40, 41.80) (47.20, 39.30) (51.50, 42.50)
BCE (46.20, 23.10) (44.20, 22.10) (47.70, 23.80)
Soft-Dice --- --- /
B-Soft-Dice --- --- /
LovaszSoftmax (39.80, 34.30) (37.30, 32.20) /

Results in Kvasir-SEG dataset

  • Threshold, Argmax and rankDice are performed based on the same network (in Model column) trained by the same loss (in Loss column).
  • Threshold and Argmax are exactly the same in binary segmentation.
  • Averaged mDice and mIoU based on state-of-the-art models/losses on Kvasir-SEG dataset set. '---' indicates that either the performance is significantly worse or the training is unstable, and '/' indicates not applicable since the proposed RankDice/mRankDice requires a strictly proper loss. The best performance in each model-loss pair is bold-faced.
Model Loss Threshold/Argmax mRankDice (our)
(Dice, IoU) ($\times .01$) (Dice, IoU) ($\times .01$)
DeepLab-V3+ CE (87.9, 80.7) (88.3, 80.9)
(resnet101) Focal (86.5, 87.3) /
Soft-Dice (85.7, 77.8) /
LovaszSoftmax (84.3, 77.3) /
PSPNet CE (86.3, 79.2) (87.1, 79.8)
(resnet50) Focal (83.8, 75.4) /
Soft-Dice (83.5, 75.9) /
LovaszSoftmax (86.0, 79.2) /
FCN8 CE (81.9, 73.5) (82.1, 73.6)
(resnet101) Focal (78.5, 69.0) /
Soft-Dice --- ---
LovaszSoftmax (82.0, 73.4) /

More results

  • All empirical results on different losses and models can be found here

Replication

If you want to replicate the experiments in our papers, please check the folder ./pytorch-segmentation-rankseg and its README file Pytorch-segmentation-rankseg

Citation

If you like RankSEG please star the repository and cite the following paper:

@article{dai2023rankseg,
title={RankSEG: A Consistent Ranking-based Framework for Segmentation},
author={Dai, Ben and Li, Chunlin},
journal={Journal of Machine Learning Research},
volume={24},
number={224},
pages={1--50},
year={2023}
}

Thank you

If you find this repository helpful, please star our repo .

Thank you so much for the support from our stargazers.

Releases

No releases published

Packages

Contributors