Commit 0387e092 authored by GaoYunshu1's avatar GaoYunshu1
Browse files

update

parent 281bf9bb
Loading
Loading
Loading
Loading
+52 −58
Original line number Diff line number Diff line
# Restoration-of-Cataract-Images-via-Domain-Generalization
Code for Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components [1]. 
# Restoration-of-Cataract-Images-via-Domain-Adaptation
There is little access to large datasets of cataract images paired with their corresponding clear ones. Therefore, it is unlikely to build a restoration model for cataract images through supervised learning.

This code is inherited from [our previous work [7]](https://github.com/liamheng/Restoration-of-Cataract-Images-via-Domain-Adaptation)
Here, we propose an unsupervised restoration method via cataract-like image simulation and domain adaptation, and an annotation-free restoration network for cataractous fundus images. The source code for both has been released.

Unlike the previous work, this model is based on domain generalization, free from the target domain data in training.
## Results

## Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components
Li H ,  Liu H ,  Hu Y , et al. Restoration Of Cataract Fundus Images Via Unsupervised Domain Adaptation[C]// 2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI). IEEE, 2021.

<div align="left">
    <img src="./images/introduction.png" alt="RCDG" style="zoom: 33%;" />
</div>
**Result:**
![Output](images/Output.png)
A comparison of the restored fundus images. (a) cataract image. (b) clear fundus image after surgery. (c) dark channel prior. (d) SGRIF [2]. (e) pix2pix [4]. (f) CycleGAN [5]. (g) the proposed method [8].

Fig. 1. Overview of the proposed model. The bottom of (a) and (b) exhibit that $k$ cataract-like fundus images $s'_i$ are randomly synthesized from an identical clear image $s$ to cover the potential target domain $T$.On the top part, HFCs $H(\cdot)$ are extracted from the images to reduce the domain shift and then achieve domain alignment. Finally, the clear image is reconstructed from the aligned HFCs.

<div align="left">
    <img src="./images/structure.png" alt="RCDG" style="width: 80%;" />
</div>

Fig. 2. Overview of the proposed model. Cataract-like images $s’$ are synthesized from clear image $s$ using  DR to construct source domains. $H(\cdot)$ and  $L(\cdot)$ are the extraction of HFCs and LFCs. Then, DIFs are acquired by domain alignment using HFCs and generator $G_H$. Finally, generator $G_R$ reconstructs the clear fundus image from the aligned HFCs.
Li H, Liu H, Hu Y, et al. An Annotation-free Restoration Network for Cataractous Fundus Images[J]. IEEE Transactions on Medical Imaging, 2022.
![arcnet](./images/arcnet.png)

<div align="left">
    <img src="./images/comparison.png" alt="RCDG" style="zoom: 100%;" />
</div>

Fig. 3. Comparison between the cataract restoration algorithms. (a) cataract fundus image. (b) SGRIF [2]. (c) pix2pix [3]. (d) Luo et al. [4]. (e) CofeNet [5]. (f) Li et al. [6]. (g) The proposed method [1]. (h) clear image after surgery.
Visual comparison of images restored from cataract ones. (a) cataract image. (b) clear fundus image after surgery. (c) Mitra et al. [1]. (d)SGRIF [2]. (e) Cao et al. [3]. (f) pix2pix [4]. (g) CycleGAN [5]. (h) Luo et al. [6]. (i) CofeNet [7]. (j) ArcNet [10].

# Prerequisites

@@ -33,7 +26,7 @@ Fig. 3. Comparison between the cataract restoration algorithms. (a) cataract fun

\- CPU or NVIDIA GPU + CUDA CuDNN

## Environment (Using conda)
# Environment (Using conda)

```
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing opencv-python
@@ -43,89 +36,90 @@ conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9
conda install visdom dominate -c conda-forge # install visdom and dominate
```

## Data preparation
# Simulate cataract-like images

Go to the root directory of this project, and run the following command:
Use the script in ./utils/catacact_simulation.py

### Preparing the simulation image

```shell
python util/cataract_simulation.py
```
# Visualization when training

### Get the mask of source image and target image
python -m visdom.server

Get the mask of source image
# To open this link in the browser

```shell
python util/get_mask.py --image_dir ./images/drive_cataract/source --output_dir ./images/drive_cataract/source_mask --mode pair
```
http://localhost:8097/

Copy the target image into './images/drive_cataract/target', and run the following command.
# Dataset preparation

```shell
python ./util/get_mask.py --image_dir ./images/drive_cataract/target --output_dir ./images/drive_cataract/target_mask --mode single
```
To set up your own dataset constructed like images/cataract_dataset. Note that the number of source images should be bigger than the number of target images, or you can design you own data loader.

### Dataset and dataloader
## Trained model's weight

You can also design your own dataset in data/xx_dataset.py for your own dataset format by imitating the script data/cataract_guide_padding_dataset.py.
For the model of "Restoration Of Cataract Fundus Images Via Unsupervised Domain Adaptation", please download the pretrained model from this link:

Note that mask is needed in the model.
https://drive.google.com/file/d/1Ystqt3RQVfIPPukE7ZdzzFM_hBqB0lr0/view?usp=sharing

or use link: https://pan.baidu.com/s/1Ax18-10dpJDToieqvcXGxQ , code: ak7c

## Visualization when training
Then, place the document in project_root/checkpoints/pixDA_sobel, so that we can get the file like project_root/checkpoints/cataract_model/latest_net_G.pth

python -m visdom.server

Then, open this link in the browser

http://localhost:8097/
For the model of "An Annotation-free Restoration Network for Cataractous Fundus Images", please download the pretrained model from this link:

## Trained model's weight
https://drive.google.com/file/d/1VJ-_W7rRmy90AcgeAJtt_z7fgeBpC4Id/view?usp=share_link

or use link: https://pan.baidu.com/s/1hFt0bMpBb5V0Gj0ogYHGbA , code: 3xg0

For the model of "Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components", please download the pretrained model from this link:
Then, place the document in project_root/checkpoints/arcnet, so that we can get the file like project_root/checkpoints/arcnet/latest_net_G.pth

https://drive.google.com/file/d/1ejnisgBh8aolGd5qcglWW-RBfc1QqLdj/view?usp=sharing
# Command to run

Then, place the directory in project_root/checkpoints/RCDG_drive, so that we can get the file like project_root/checkpoints/RCDG_drive/latest_net_GH.pth
Please note that root directory is the project root directory.

With this trained weight, we can use the following command to inference.
## Train

```
python test.py --dataroot ./images/drive_cataract --name RCDG_drive_trained --model RCDG --dataset_mode cataract_guide_padding --eval
python train.py --dataroot ./datasets/cataract_dataset --name train_pixDA_sobel --model pixDA_sobel --netG unet_256 --direction AtoB --dataset_mode cataract --norm batch --batch_size 8 --n_epochs 150 --n_epochs_decay 50 --input_nc 6 --output_nc 3
```

# Model Training, testing and inference
or

## Train
```
python train.py --dataroot ./images/cataract_dataset --name train_arcnet --model arcnet --netG unet_256 --input_nc 6 --direction AtoB --dataset_mode cataract_guide_padding --norm batch --batch_size 8 --lr_policy step --n_epochs 100 --n_epochs_decay 0 --lr_decay_iters 80
```

## Test & Visualization

```
python train.py --dataroot ./images/drive_cataract --name RCDG_drive --model RCDG --dataset_mode cataract_guide_padding --batch_size 8 --n_epochs 150 --n_epochs_decay 50
python test.py --dataroot ./datasets/cataract_dataset --name pixDA_sobel --model pixDA_sobel --netG unet_256 --direction AtoB --dataset_mode cataract --norm batch --input_nc 6 --output_nc 3
```

## Test & inference
or

```
python test.py --dataroot ./images/drive_cataract --name RCDG_drive --model RCDG --dataset_mode cataract_guide_padding --eval
python test.py --dataroot ./images/cataract_dataset --name arcnet --model arcnet --netG unet_256 --input_nc 6 --direction AtoB --dataset_mode cataract_guide_padding --norm batch
```

# Reference

[1] Liu H ,  Li H ,  Ou M , et al. Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components[C]// 2022 IEEE 19th International Symposium on Biomedical Imaging (ISBI). IEEE, 2022.
[1] A. Mitra, S. Roy, S. Roy, and S. K. Setua, “Enhancement and restoration of non-uniform illuminated fundus image of retina obtained through thin layer of cataract,” Computer methods and programs in biomedicine, vol. 156, pp. 169–178, 2018.

[2] Cheng J ,  Li Z ,  Gu Z , et al. Structure-Preserving Guided Retinal Image Filtering and Its Application for Optic Disk Analysis[J]. IEEE TRANSACTIONS ON MEDICAL IMAGING MI, 2018.

[3] Isola P ,  Zhu J Y ,  Zhou T , et al. Image-to-Image Translation with Conditional Adversarial Networks[C]// IEEE Conference on Computer Vision & Pattern Recognition. IEEE, 2016.
[3] L. Cao, H. Li, and Y. Zhang, “Retinal image enhancement using lowpass filtering and α-rooting,” Signal Processing, vol. 170, p. 107445, 2020.

[4] Isola P ,  Zhu J Y ,  Zhou T , et al. Image-to-Image Translation with Conditional Adversarial Networks[C]// IEEE Conference on Computer Vision & Pattern Recognition. IEEE, 2016.

[4] Luo Y ,  K  Chen,  Liu L , et al. Dehaze of Cataractous Retinal Images Using an Unpaired Generative Adversarial Network[J]. IEEE Journal of Biomedical and Health Informatics, 2020, PP(99):1-1.
[5] Zhu J Y ,  Park T ,  Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. IEEE, 2017.

[5] Z. Shen, H. Fu, J. Shen, and L. Shao, “Modeling and enhancing lowquality retinal fundus images,” IEEE transactions on medical imaging, vol. 40, no. 3, pp. 996–1006, 2020.
[6] Luo Y ,  K  Chen,  Liu L , et al. Dehaze of Cataractous Retinal Images Using an Unpaired Generative Adversarial Network[J]. IEEE Journal of Biomedical and Health Informatics, 2020, PP(99):1-1.

[6] Li H, Liu H, Hu Y, et al. Restoration Of Cataract Fundus Images Via Unsupervised Domain Adaptation[C]//2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI). IEEE, 2021: 516-520.
[7] Z. Shen, H. Fu, J. Shen, and L. Shao, “Modeling and enhancing lowquality retinal fundus images,” IEEE transactions on medical imaging, vol. 40, no. 3, pp. 996–1006, 2020.

[7] Li H, Liu H, Hu Y, et al. An Annotation-free Restoration Network for Cataractous Fundus Images[J]. IEEE Transactions on Medical Imaging, 2022.
[8] Li H ,  Liu H ,  Hu Y , et al. Restoration Of Cataract Fundus Images Via Unsupervised Domain Adaptation[C]// 2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI). IEEE, 2021.

[9] Li H, Liu H, Hu Y, et al. An Annotation-free Restoration Network for Cataractous Fundus Images[J]. IEEE Transactions on Medical Imaging, 2022.
# Citation

```
+8 −9
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from .base_dataset import BaseDataset
from predict_models.res.data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
@@ -22,14 +22,14 @@ def find_dataset_using_name(dataset_name):
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    dataset_filename = "predict_models.res.data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower():
           # and issubclass(cls, BaseDataset):
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
@@ -51,8 +51,8 @@ def create_dataset(opt):
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        # >>> from data import create_dataset
        # >>> dataset = create_dataset(opt)
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt)
    dataset = data_loader.load_data()
@@ -70,13 +70,12 @@ class CustomDatasetDataLoader():
        Step 2: create a multi-threaded data loader.
        """
        self.opt = opt
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        dataset_class = find_dataset_using_name('cataract_guide_padding')
        self.dataset = dataset_class(opt)
        print("dataset [%s] was created" % type(self.dataset).__name__)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            # batch_size=opt.batch_size,
            batch_size=1,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads))

+3 −0
Original line number Diff line number Diff line
@@ -70,6 +70,7 @@ def get_params(opt, size, is_source=True):
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        # TODO:添加数组的尺寸来随机挑选,target只随机裁286
        if opt.source_size_count == 1:
            new_h = new_w = opt.load_size
        else:
@@ -96,6 +97,7 @@ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, conve
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        # TODO:resize需要优化,此处只考虑params没有时直接取target的
        if params is None:
            osize = [opt.load_size, opt.load_size]
        else:
@@ -249,6 +251,7 @@ class TensorToGrayTensor(nn.Module):
        self.kernel = torch.tensor([])
        self.kernel = torch.empty(size=(1, 3, 1, 1), dtype=torch.float32, device=device)
        self.kernel.requires_grad = False
        # TODO:确定输入是RGB
        self.kernel[0, 0, 0, 0] = R_rate
        self.kernel[0, 1, 0, 0] = G_rate
        self.kernel[0, 2, 0, 0] = B_rate
+2 −2
Original line number Diff line number Diff line
import os.path
import random
import torch
from data.base_dataset import BaseDataset, get_params, get_transform_six_channel
from data.image_folder import make_dataset
from predict_models.res.data.base_dataset import BaseDataset, get_params, get_transform_six_channel
from predict_models.res.data.image_folder import make_dataset
from PIL import Image


Loading