Unverified Commit 5abecea3 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #10259 from AUTOMATIC1111/ruff

Ruff
parents f5ea1e9d 3ec7b705
Loading
Loading
Loading
Loading
+25 −18
Original line number Diff line number Diff line
@@ -18,22 +18,29 @@ jobs:
    steps:
      - name: Checkout Code
        uses: actions/checkout@v3
      - name: Set up Python 3.10
        uses: actions/setup-python@v4
      - uses: actions/setup-python@v4
        with:
          python-version: 3.10.6
          cache: pip
          cache-dependency-path: |
            **/requirements*txt
      - name: Install PyLint
        run: | 
          python -m pip install --upgrade pip
          pip install pylint
      # This lets PyLint check to see if it can resolve imports
      - name: Install dependencies
        run: |
          export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
          python launch.py
      - name: Analysing the code with pylint
        run: |
          pylint $(git ls-files '*.py')
          python-version: 3.11
          # NB: there's no cache: pip here since we're not installing anything
          #     from the requirements.txt file(s) in the repository; it's faster
          #     not to have GHA download an (at the time of writing) 4 GB cache
          #     of PyTorch and other dependencies.
      - name: Install Ruff
        run: pip install ruff==0.0.265
      - name: Run Ruff
        run: ruff .

# The rest are currently disabled pending fixing of e.g. installing the torch dependency.

#      - name: Install PyLint
#        run: |
#          python -m pip install --upgrade pip
#          pip install pylint
#      # This lets PyLint check to see if it can resolve imports
#      - name: Install dependencies
#        run: |
#          export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
#          python launch.py
#      - name: Analysing the code with pylint
#        run: |
#          pylint $(git ls-files '*.py')
+4 −5
Original line number Diff line number Diff line
@@ -88,7 +88,7 @@ class LDSR:

        x_t = None
        logs = None
        for n in range(n_runs):
        for _ in range(n_runs):
            if custom_shape is not None:
                x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
                x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
@@ -110,7 +110,6 @@ class LDSR:
        diffusion_steps = int(steps)
        eta = 1.0

        down_sample_method = 'Lanczos'

        gc.collect()
        if torch.cuda.is_available:
@@ -158,7 +157,7 @@ class LDSR:


def get_cond(selected_path):
    example = dict()
    example = {}
    up_f = 4
    c = selected_path.convert('RGB')
    c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
                              corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
    log = dict()
    log = {}

    z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
                                        return_first_stage_outputs=True,
@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
        x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
        log["sample_noquant"] = x_sample_noquant
        log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
    except:
    except Exception:
        pass

    log["sample"] = x_sample
+2 −1
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
import sd_hijack_autoencoder  # noqa: F401
import sd_hijack_ddpm_v1  # noqa: F401


class UpscalerLDSR(Upscaler):
+17 −11
Original line number Diff line number Diff line
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder

import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager

from torch.optim.lr_scheduler import LambdaLR

from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config

import ldm.models.autoencoder
from packaging import version

class VQModel(pl.LightningModule):
    def __init__(self,
@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
                 n_embed,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 ignore_keys=None,
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
        self.scheduler_config = scheduler_config
        self.lr_g_factor = lr_g_factor

@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list()):
    def init_from_ckpt(self, path, ignore_keys=None):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
            for ik in ignore_keys or []:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
            self._validation_step(batch, batch_idx, suffix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, suffix=""):
@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
        return self.decoder.conv_out.weight

    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
        log = dict()
        log = {}
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if only_inputs:
@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
        if plot_ema:
            with self.ema_scope():
                xrec_ema, _ = self(x)
                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
                if x.shape[1] > 3:
                    xrec_ema = self.to_rgb(xrec_ema)
                log["reconstructions_ema"] = xrec_ema
        return log

@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):

class VQModelInterface(VQModel):
    def __init__(self, embed_dim, *args, **kwargs):
        super().__init__(embed_dim=embed_dim, *args, **kwargs)
        super().__init__(*args, embed_dim=embed_dim, **kwargs)
        self.embed_dim = embed_dim

    def encode(self, x):
@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
        dec = self.decoder(quant)
        return dec

setattr(ldm.models.autoencoder, "VQModel", VQModel)
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
ldm.models.autoencoder.VQModel = VQModel
ldm.models.autoencoder.VQModelInterface = VQModelInterface
+27 −33
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
                 beta_schedule="linear",
                 loss_type="l2",
                 ckpt_path=None,
                 ignore_keys=[],
                 ignore_keys=None,
                 load_only_unet=False,
                 monitor="val/loss",
                 use_ema=True,
@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
        if monitor is not None:
            self.monitor = monitor
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)

        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
    def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
            for ik in ignore_keys or []:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
        log = dict()
        log = {}
        x = self.get_input(batch, self.first_stage_key)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
        log["inputs"] = x

        # get diffusion row
        diffusion_row = list()
        diffusion_row = []
        x_start = x[:n_row]

        for t in range(self.num_timesteps):
@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
            conditioning_key = None
        ckpt_path = kwargs.pop("ckpt_path", None)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
        super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        try:
            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
        except:
        except Exception:
            self.num_downs = 0
        if not scale_by_std:
            self.scale_factor = scale_factor
@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, c, t, *args, **kwargs)

    def _rescale_annotations(self, bboxes, crop_coordinates):  # TODO: move to dataset
        def rescale_bbox(bbox):
            x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
            y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
            w = min(bbox[2] / crop_coordinates[2], 1 - x0)
            h = min(bbox[3] / crop_coordinates[3], 1 - y0)
            return x0, y0, w, h

        return [rescale_bbox(b) for b in bboxes]

    def apply_model(self, x_noisy, t, cond, return_ids=False):

        if isinstance(cond, dict):
@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
                [x[:batch_size] for x in cond[key]] for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]

@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(x0_partial)
            if callback: callback(i)
            if img_callback: img_callback(img, i)
            if callback:
                callback(i)
            if img_callback:
                img_callback(img, i)
        return img, intermediates

    @torch.no_grad()
@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(img)
            if callback: callback(i)
            if img_callback: img_callback(img, i)
            if callback:
                callback(i)
            if img_callback:
                img_callback(img, i)

        if return_intermediates:
            return img, intermediates
@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
                [x[:batch_size] for x in cond[key]] for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
        return self.p_sample_loop(cond,
@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):

        use_ddim = ddim_steps is not None

        log = dict()
        log = {}
        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
                                           return_first_stage_outputs=True,
                                           force_c_encode=True,
@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            diffusion_row = []
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):

            if inpaint:
                # make a simple center square
                b, h, w = z.shape[0], z.shape[2], z.shape[3]
                h, w = z.shape[2], z.shape[3]
                mask = torch.ones(N, h, w).to(self.device)
                # zeros will be filled in
                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
    # TODO: move all layout-specific hacks to this class
    def __init__(self, cond_stage_key, *args, **kwargs):
        assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
        super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)

    def log_images(self, batch, N=8, *args, **kwargs):
        logs = super().log_images(batch=batch, N=N, *args, **kwargs)
        logs = super().log_images(*args, batch=batch, N=N, **kwargs)

        key = 'train' if self.training else 'validation'
        dset = self.trainer.datamodule.datasets[key]
@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
        logs['bbox_image'] = cond_img
        return logs

setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
Loading