Commit 550256db authored by AUTOMATIC's avatar AUTOMATIC
Browse files

ruff manual fixes

parent 028d3f64
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ class VQModel(pl.LightningModule):
                 n_embed,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 ignore_keys=None,
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
@@ -62,7 +62,7 @@ class VQModel(pl.LightningModule):
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
        self.scheduler_config = scheduler_config
        self.lr_g_factor = lr_g_factor

@@ -81,11 +81,11 @@ class VQModel(pl.LightningModule):
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list()):
    def init_from_ckpt(self, path, ignore_keys=None):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
            for ik in ignore_keys or []:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
@@ -270,7 +270,7 @@ class VQModel(pl.LightningModule):

class VQModelInterface(VQModel):
    def __init__(self, embed_dim, *args, **kwargs):
        super().__init__(embed_dim=embed_dim, *args, **kwargs)
        super().__init__(*args, embed_dim=embed_dim, **kwargs)
        self.embed_dim = embed_dim

    def encode(self, x):
+7 −7
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
                 beta_schedule="linear",
                 loss_type="l2",
                 ckpt_path=None,
                 ignore_keys=[],
                 ignore_keys=None,
                 load_only_unet=False,
                 monitor="val/loss",
                 use_ema=True,
@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
        if monitor is not None:
            self.monitor = monitor
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)

        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
    def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
            for ik in ignore_keys or []:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
@@ -444,7 +444,7 @@ class LatentDiffusionV1(DDPMV1):
            conditioning_key = None
        ckpt_path = kwargs.pop("ckpt_path", None)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
        super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
@@ -1418,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
    # TODO: move all layout-specific hacks to this class
    def __init__(self, cond_stage_key, *args, **kwargs):
        assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
        super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)

    def log_images(self, batch, N=8, *args, **kwargs):
        logs = super().log_images(batch=batch, N=N, *args, **kwargs)
        logs = super().log_images(*args, batch=batch, N=N, **kwargs)

        key = 'train' if self.training else 'validation'
        dset = self.trainer.datamodule.datasets[key]
+5 −1
Original line number Diff line number Diff line
@@ -644,13 +644,17 @@ class SwinIR(nn.Module):
    """

    def __init__(self, img_size=64, patch_size=1, in_chans=3,
                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
                 embed_dim=96, depths=None, num_heads=None,
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                 **kwargs):
        super(SwinIR, self).__init__()

        depths = depths or [6, 6, 6, 6]
        num_heads = num_heads or [6, 6, 6, 6]

        num_in_ch = in_chans
        num_out_ch = in_chans
        num_feat = 64
+9 −2
Original line number Diff line number Diff line
@@ -74,9 +74,12 @@ class WindowAttention(nn.Module):
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0]):
                 pretrained_window_size=None):

        super().__init__()

        pretrained_window_size = pretrained_window_size or [0, 0]

        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
@@ -698,13 +701,17 @@ class Swin2SR(nn.Module):
    """

    def __init__(self, img_size=64, patch_size=1, in_chans=3,
                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
                 embed_dim=96, depths=None, num_heads=None,
                 window_size=7, mlp_ratio=4., qkv_bias=True, 
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                 **kwargs):
        super(Swin2SR, self).__init__()

        depths = depths or [6, 6, 6, 6]
        num_heads = num_heads or [6, 6, 6, 6]

        num_in_ch = in_chans
        num_out_ch = in_chans
        num_feat = 64
+12 −6
Original line number Diff line number Diff line
@@ -34,14 +34,16 @@ import piexif.helper
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except Exception:
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e


def script_name_to_index(name, scripts):
    try:
        return [script.title().lower() for script in scripts].index(name.lower())
    except Exception:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
    except Exception as e:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e


def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
@@ -50,20 +52,23 @@ def validate_sampler_name(name):

    return name


def setUpscalers(req: dict):
    reqDict = vars(req)
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
    return reqDict


def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
    try:
        image = Image.open(BytesIO(base64.b64decode(encoding)))
        return image
    except Exception:
        raise HTTPException(status_code=500, detail="Invalid encoded image")
    except Exception as e:
        raise HTTPException(status_code=500, detail="Invalid encoded image") from e


def encode_pil_to_base64(image):
    with io.BytesIO() as output_bytes:
@@ -94,6 +99,7 @@ def encode_pil_to_base64(image):

    return base64.b64encode(bytes_data)


def api_middleware(app: FastAPI):
    rich_available = True
    try:
Loading