Unverified Commit 486bda9b authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #6620 from guaneec/varsize_batch

Enable batch_size>1 for mixed-sized training
parents d7aec59c 6be644fa
Loading
Loading
Loading
Loading
+32 −4
Original line number Diff line number Diff line
@@ -3,8 +3,10 @@ import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices

import random
import tqdm
@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
        assert data_root, 'dataset directory not specified'
        assert os.path.isdir(data_root), "Dataset directory doesn't exist"
        assert os.listdir(data_root), "Dataset directory is empty"
        assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'

        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]

        self.shuffle_tags = shuffle_tags
        self.tag_drop_out = tag_drop_out
        groups = defaultdict(list)

        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
                with devices.autocast():
                    entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)

            groups[image.size].append(len(self.dataset))
            self.dataset.append(entry)
            del torchdata
            del latent_dist
            del latent_sample

        self.length = len(self.dataset)
        self.groups = list(groups.values())
        assert self.length > 0, "No images have been found in the dataset."
        self.batch_size = min(batch_size, self.length)
        self.gradient_step = min(gradient_step, self.length // self.batch_size)
@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
            entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
        return entry

class GroupedBatchSampler(Sampler):
    def __init__(self, data_source: PersonalizedBase, batch_size: int):
        n = len(data_source)
        self.groups = data_source.groups
        self.len = n_batch = n // batch_size
        expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
        self.base = [int(e) // batch_size for e in expected]
        self.n_rand_batches = nrb = n_batch - sum(self.base)
        self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
        self.batch_size = batch_size
    def __len__(self):
        return self.len
    def __iter__(self):
        b = self.batch_size
        for g in self.groups:
            shuffle(g)
        batches = []
        for g in self.groups:
            batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
        for _ in range(self.n_rand_batches):
            rand_group = choices(self.groups, self.probs)[0]
            batches.append(choices(rand_group, k=b))
        shuffle(batches)
        yield from batches

class PersonalizedDataLoader(DataLoader):
    def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
        super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
        super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
        if latent_sampling_method == "random":
            self.collate_fn = collate_wrapper_random
        else: