Commit aa004018 authored by Bucky Lee's avatar Bucky Lee
Browse files

feat: add generalization model

parent e7b4c911
Loading
Loading
Loading
Loading
+147 −0
Original line number Diff line number Diff line
# Restoration-of-Cataract-Images-via-Domain-Generalization
Code for Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components [1]. 

This code is inherited from [our previous work [7]](https://github.com/liamheng/Restoration-of-Cataract-Images-via-Domain-Adaptation)

Unlike the previous work, this model is based on domain generalization, free from the target domain data in training.

## Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components

<div align="left">
    <img src="./images/introduction.png" alt="RCDG" style="zoom: 33%;" />
</div>

Fig. 1. Overview of the proposed model. The bottom of (a) and (b) exhibit that $k$ cataract-like fundus images $s'_i$ are randomly synthesized from an identical clear image $s$ to cover the potential target domain $T$.On the top part, HFCs $H(\cdot)$ are extracted from the images to reduce the domain shift and then achieve domain alignment. Finally, the clear image is reconstructed from the aligned HFCs.

<div align="left">
    <img src="./images/structure.png" alt="RCDG" style="width: 80%;" />
</div>

Fig. 2. Overview of the proposed model. Cataract-like images $s’$ are synthesized from clear image $s$ using  DR to construct source domains. $H(\cdot)$ and  $L(\cdot)$ are the extraction of HFCs and LFCs. Then, DIFs are acquired by domain alignment using HFCs and generator $G_H$. Finally, generator $G_R$ reconstructs the clear fundus image from the aligned HFCs.

<div align="left">
    <img src="./images/comparison.png" alt="RCDG" style="zoom: 100%;" />
</div>

Fig. 3. Comparison between the cataract restoration algorithms. (a) cataract fundus image. (b) SGRIF [2]. (c) pix2pix [3]. (d) Luo et al. [4]. (e) CofeNet [5]. (f) Li et al. [6]. (g) The proposed method [1]. (h) clear image after surgery.

# Prerequisites

\- Win10

\- Python 3

\- CPU or NVIDIA GPU + CUDA CuDNN

## Environment (Using conda)

```
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing opencv-python

conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9

conda install visdom dominate -c conda-forge # install visdom and dominate
```

## Data preparation

Go to the root directory of this project, and run the following command:

### Preparing the simulation image

```shell
python util/cataract_simulation.py
```

### Get the mask of source image and target image

Get the mask of source image

```shell
python util/get_mask.py --image_dir ./images/drive_cataract/source --output_dir ./images/drive_cataract/source_mask --mode pair
```

Copy the target image into './images/drive_cataract/target', and run the following command.

```shell
python ./util/get_mask.py --image_dir ./images/drive_cataract/target --output_dir ./images/drive_cataract/target_mask --mode single
```

### Dataset and dataloader

You can also design your own dataset in data/xx_dataset.py for your own dataset format by imitating the script data/cataract_guide_padding_dataset.py.

Note that mask is needed in the model.


## Visualization when training

python -m visdom.server

Then, open this link in the browser

http://localhost:8097/

## Trained model's weight

For the model of "Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components", please download the pretrained model from this link:

https://drive.google.com/file/d/1ejnisgBh8aolGd5qcglWW-RBfc1QqLdj/view?usp=sharing

Then, place the directory in project_root/checkpoints/RCDG_drive, so that we can get the file like project_root/checkpoints/RCDG_drive/latest_net_GH.pth

With this trained weight, we can use the following command to inference.

```
python test.py --dataroot ./images/drive_cataract --name RCDG_drive_trained --model RCDG --dataset_mode cataract_guide_padding --eval
```

# Model Training, testing and inference

## Train

```
python train.py --dataroot ./images/drive_cataract --name RCDG_drive --model RCDG --dataset_mode cataract_guide_padding --batch_size 8 --n_epochs 150 --n_epochs_decay 50
```

## Test & inference

```
python test.py --dataroot ./images/drive_cataract --name RCDG_drive --model RCDG --dataset_mode cataract_guide_padding --eval
```

# Reference

[1] Liu H ,  Li H ,  Ou M , et al. Domain Generalization in Restoration of Cataract Fundus Images via High-frequency Components[C]// 2022 IEEE 19th International Symposium on Biomedical Imaging (ISBI). IEEE, 2022.

[2] Cheng J ,  Li Z ,  Gu Z , et al. Structure-Preserving Guided Retinal Image Filtering and Its Application for Optic Disk Analysis[J]. IEEE TRANSACTIONS ON MEDICAL IMAGING MI, 2018.

[3] Isola P ,  Zhu J Y ,  Zhou T , et al. Image-to-Image Translation with Conditional Adversarial Networks[C]// IEEE Conference on Computer Vision & Pattern Recognition. IEEE, 2016.

[4] Luo Y ,  K  Chen,  Liu L , et al. Dehaze of Cataractous Retinal Images Using an Unpaired Generative Adversarial Network[J]. IEEE Journal of Biomedical and Health Informatics, 2020, PP(99):1-1.

[5] Z. Shen, H. Fu, J. Shen, and L. Shao, “Modeling and enhancing lowquality retinal fundus images,” IEEE transactions on medical imaging, vol. 40, no. 3, pp. 996–1006, 2020.

[6] Li H, Liu H, Hu Y, et al. Restoration Of Cataract Fundus Images Via Unsupervised Domain Adaptation[C]//2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI). IEEE, 2021: 516-520.

[7] Li H, Liu H, Hu Y, et al. An Annotation-free Restoration Network for Cataractous Fundus Images[J]. IEEE Transactions on Medical Imaging, 2022.

# Citation

```
@article{li2022annotation,
  title={An Annotation-free Restoration Network for Cataractous Fundus Images},
  author={Li, Heng and Liu, Haofeng and Hu, Yan and Fu, Huazhu and Zhao, Yitian and Miao, Hanpei and Liu, Jiang},
  journal={IEEE Transactions on Medical Imaging},
  year={2022},
  publisher={IEEE}
}
@inproceedings{li2021restoration,
  title={Restoration Of Cataract Fundus Images Via Unsupervised Domain Adaptation},
  author={Li, Heng and Liu, Haofeng and Hu, Yan and Higashita, Risa and Zhao, Yitian and Qi, Hong and Liu, Jiang},
  booktitle={2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI)},
  pages={516--520},
  year={2021},
  organization={IEEE}
}
```
 No newline at end of file
+94 −0
Original line number Diff line number Diff line
"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt)
    dataset = data_loader.load_data()
    return dataset



class CustomDatasetDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, opt):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        """
        self.opt = opt
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        self.dataset = dataset_class(opt)
        print("dataset [%s] was created" % type(self.dataset).__name__)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads))

    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data
+60 −0
Original line number Diff line number Diff line
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform
from data.image_folder import make_dataset
from PIL import Image


class AlignedDataset(BaseDataset):
    """A dataset class for paired image dataset.

    It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.
    """

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        self.dir_AB = os.path.join(opt.dataroot, opt.phase)  # get the image directory
        self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size))  # get image paths
        assert(self.opt.load_size >= self.opt.crop_size)   # crop_size should be smaller than the size of loaded image
        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
        self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        AB_path = self.AB_paths[index]
        AB = Image.open(AB_path).convert('RGB')
        # split AB image into A and B
        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h))
        B = AB.crop((w2, 0, w, h))

        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)

        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.AB_paths)
+309 −0
Original line number Diff line number Diff line
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.

It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from torch import nn
from torch.nn import functional as F
# from torch.autograd import Variable
# # import kornia
import torch
from abc import ABC, abstractmethod


class BaseDataset(data.Dataset, ABC):
    """This class is an abstract base class (ABC) for datasets.

    To create a subclass, you need to implement the following four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the class; save the options in the class

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.opt = opt
        self.root = opt.dataroot

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def __len__(self):
        """Return the total number of images in the dataset."""
        return 0

    @abstractmethod
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns:
            a dictionary of data with their names. It ususally contains the data itself and its metadata information.
        """
        pass


def get_params(opt, size, is_source=True):
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        if opt.source_size_count == 1:
            new_h = new_w = opt.load_size
        else:
            if not opt.isTrain:
                new_h = new_w = opt.load_size
            else:
                new_h = new_w = random.choice([286, 306, 326, 346])
            # new_h = new_w = random.choice([opt.load_source_size, opt.load_target_size])
    elif opt.preprocess == 'scale_width_and_crop':
        new_w = opt.load_size
        new_h = opt.load_size * h // w

    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

    flip = random.random() > 0.5
    flip_vertical = random.random() > 0.5

    return {'load_size': new_h, 'crop_pos': (x, y), 'flip': flip, 'flip_vertical': flip_vertical}


def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        if params is None:
            osize = [opt.load_size, opt.load_size]
        else:
            load_size = params['load_size']
            osize = [load_size, load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    # 加入上下翻转
    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomVerticalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip_vertical(img, params['flip_vertical'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


def get_transform_six_channel(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    mask_transform_list = []
    if 'resize' in opt.preprocess:
        if params is None:
            osize = [opt.load_size, opt.load_size]
        else:
            load_size = params['load_size']
            osize = [load_size, load_size]
        transform_list.append(transforms.Resize(osize, method))
        mask_transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
        mask_transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
            mask_transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
            mask_transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
        mask_transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
            mask_transform_list.append(transforms.RandomHorizontalFlip())

        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
            mask_transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    # 加入上下翻转
    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomVerticalFlip())
            mask_transform_list.append(transforms.RandomVerticalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip_vertical(img, params['flip_vertical'])))
            mask_transform_list.append(transforms.Lambda(lambda img: __flip_vertical(img, params['flip_vertical'])))
    if convert:
        transform_list += [transforms.ToTensor()]
        mask_transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

    return transforms.Compose(transform_list), transforms.Compose(mask_transform_list)


def get_gray_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    gray_transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))

    if 'resize' in opt.preprocess:
        if params is None:
            osize = [opt.load_size, opt.load_size]
        else:
            load_size = params['load_size']
            osize = [load_size, load_size]
        transform_list.append(transforms.Resize(osize, method))
        gray_transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    # 加入上下翻转
    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomVerticalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip_vertical(img, params['flip_vertical'])))
    gray_transform_list.append(transforms.Grayscale(1))
    gray_transform_list += transform_list
    if convert:
        transform_list += [transforms.ToTensor()]
        gray_transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
            # gray_transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
            # gray_transform_list += [transforms.Normalize((0.5,), (0.5,))]
    return transforms.Compose(transform_list), transforms.Compose(gray_transform_list)


class TensorToGrayTensor(nn.Module):
    def __init__(self, device, R_rate=0.299, G_rate=0.587, B_rate=0.114):
        super(TensorToGrayTensor, self).__init__()
        self.kernel = torch.tensor([])
        self.kernel = torch.empty(size=(1, 3, 1, 1), dtype=torch.float32, device=device)
        self.kernel.requires_grad = False
        self.kernel[0, 0, 0, 0] = R_rate
        self.kernel[0, 1, 0, 0] = G_rate
        self.kernel[0, 2, 0, 0] = B_rate

    def forward(self, x):
        output = F.conv2d(x, self.kernel)
        return output


def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if h == oh and w == ow:
        return img

    __print_size_warning(ow, oh, w, h)
    return img.resize((w, h), method)


def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
    ow, oh = img.size
    if ow == target_size and oh >= crop_size:
        return img
    w = target_size
    h = int(max(target_size * oh / ow, crop_size))
    return img.resize((w, h), method)


def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img


def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


def __flip_vertical(img, flip):
    if flip:
        return img.transpose(Image.FLIP_TOP_BOTTOM)
    return img


def __print_size_warning(ow, oh, w, h):
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True
+82 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading