Commit a176d894 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

print bucket sizes for training without resizing images #6620

fix an error when generating a picture with embedding in it
parent 486bda9b
Loading
Loading
Loading
Loading
+16 −0
Original line number Diff line number Diff line
@@ -118,6 +118,12 @@ class PersonalizedBase(Dataset):
        self.gradient_step = min(gradient_step, self.length // self.batch_size)
        self.latent_sampling_method = latent_sampling_method

        if len(groups) > 1:
            print("Buckets:")
            for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
                print(f"  {w}x{h}: {len(ids)}")
            print()

    def create_text(self, filename_text):
        text = random.choice(self.lines)
        tags = filename_text.split(',')
@@ -140,8 +146,11 @@ class PersonalizedBase(Dataset):
            entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
        return entry


class GroupedBatchSampler(Sampler):
    def __init__(self, data_source: PersonalizedBase, batch_size: int):
        super().__init__(data_source)

        n = len(data_source)
        self.groups = data_source.groups
        self.len = n_batch = n // batch_size
@@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler):
        self.n_rand_batches = nrb = n_batch - sum(self.base)
        self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
        self.batch_size = batch_size

    def __len__(self):
        return self.len

    def __iter__(self):
        b = self.batch_size

        for g in self.groups:
            shuffle(g)

        batches = []
        for g in self.groups:
            batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
        for _ in range(self.n_rand_batches):
            rand_group = choices(self.groups, self.probs)[0]
            batches.append(choices(rand_group, k=b))

        shuffle(batches)

        yield from batches


class PersonalizedDataLoader(DataLoader):
    def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
        super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
+2 −2
Original line number Diff line number Diff line
@@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
    next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
    next_size = next_size + ((h*d)-(next_size % (h*d)))

    data_np_low.resize(next_size)
    data_np_low = np.resize(data_np_low, next_size)
    data_np_low = data_np_low.reshape((h, -1, d))

    data_np_high.resize(next_size)
    data_np_high = np.resize(data_np_high, next_size)
    data_np_high = data_np_high.reshape((h, -1, d))

    edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+1 −1

File changed.

Contains only whitespace changes.