Commit 21642000 authored by Shondoit's avatar Shondoit
Browse files

Add PNG alpha channel as weight maps to data entries

parent c4bfd20f
Loading
Loading
Loading
Loading
+38 −13
Original line number Diff line number Diff line
@@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")


class DatasetEntry:
    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
        self.filename = filename
        self.filename_text = filename_text
        self.weight = weight
        self.latent_dist = latent_dist
        self.latent_sample = latent_sample
        self.cond = cond
@@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):

        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
            alpha_channel = None
            if shared.state.interrupted:
                raise Exception("interrupted")
            try:
                image = Image.open(path).convert('RGB')
                image = Image.open(path)
                #Currently does not work for single color transparency
                #We would need to read image.info['transparency'] for that
                if 'A' in image.getbands():
                    alpha_channel = image.getchannel('A')
                image = image.convert('RGB')
                if not varsize:
                    image = image.resize((width, height), PIL.Image.BICUBIC)
            except Exception:
@@ -87,17 +94,33 @@ class PersonalizedBase(Dataset):
            with devices.autocast():
                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))

            if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
                latent_sampling_method = "once"
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
            elif latent_sampling_method == "deterministic":
            #Perform latent sampling, even for random sampling.
            #We need the sample dimensions for the weights
            if latent_sampling_method == "deterministic":
                if isinstance(latent_dist, DiagonalGaussianDistribution):
                    # Works only for DiagonalGaussianDistribution
                    latent_dist.std = 0
                else:
                    latent_sampling_method = "once"
            latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
            elif latent_sampling_method == "random":
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)

            if alpha_channel is not None:
                channels, *latent_size = latent_sample.shape
                weight_img = alpha_channel.resize(latent_size)
                npweight = np.array(weight_img).astype(np.float32)
                #Repeat for every channel in the latent sample
                weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
                #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
                weight -= weight.min()
                weight /= weight.mean()
            else:
                #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
                weight = torch.ones([channels] + latent_size)
            
            if latent_sampling_method == "random":
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
            else:
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)

            if not (self.tag_drop_out != 0 or self.shuffle_tags):
                entry.cond_text = self.create_text(filename_text)
@@ -110,6 +133,7 @@ class PersonalizedBase(Dataset):
            del torchdata
            del latent_dist
            del latent_sample
            del weight

        self.length = len(self.dataset)
        self.groups = list(groups.values())
@@ -195,6 +219,7 @@ class BatchLoader:
        self.cond_text = [entry.cond_text for entry in data]
        self.cond = [entry.cond for entry in data]
        self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
        self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
        #self.emb_index = [entry.emb_index for entry in data]
        #print(self.latent_sample.device)