Commit 4d5f1691 authored by brkirch's avatar brkirch
Browse files

Use devices.autocast instead of torch.autocast

parent 21effd62
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -495,7 +495,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                if shared.state.interrupted:
                    break

                with torch.autocast("cuda"):
                with devices.autocast():
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
+1 −2
Original line number Diff line number Diff line
@@ -148,8 +148,7 @@ class InterrogateModels:

            clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)

            precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
            with torch.no_grad(), precision_scope("cuda"):
            with torch.no_grad(), devices.autocast():
                image_features = self.clip_model.encode_image(clip_image).type(self.dtype)

                image_features /= image_features.norm(dim=-1, keepdim=True)
+1 −5
Original line number Diff line number Diff line
@@ -13,10 +13,6 @@ from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData

precision_scope = (
    torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)


class UpscalerSwinIR(Upscaler):
    def __init__(self, dirname):
@@ -112,7 +108,7 @@ def upscale(
    img = np.moveaxis(img, 2, 0) / 255
    img = torch.from_numpy(img).float()
    img = img.unsqueeze(0).to(devices.device_swinir)
    with torch.no_grad(), precision_scope("cuda"):
    with torch.no_grad(), devices.autocast():
        _, _, h_old, w_old = img.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
+2 −2
Original line number Diff line number Diff line
@@ -82,7 +82,7 @@ class PersonalizedBase(Dataset):
            torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
            latent_sample = None

            with torch.autocast("cuda"):
            with devices.autocast():
                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))

            if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
@@ -101,7 +101,7 @@ class PersonalizedBase(Dataset):
                entry.cond_text = self.create_text(filename_text)

            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
                with torch.autocast("cuda"):
                with devices.autocast():
                    entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)

            self.dataset.append(entry)
+1 −1
Original line number Diff line number Diff line
@@ -316,7 +316,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                if shared.state.interrupted:
                    break

                with torch.autocast("cuda"):
                with devices.autocast():
                    # c = stack_conds(batch.cond).to(devices.device)
                    # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
                    # print(mask)