Commit f261a4a5 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

use selected device instead of always cuda for UniPC sampler

parent a11ce2b9
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -3,7 +3,8 @@
import torch

from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
from modules import shared
from modules import shared, devices


class UniPCSampler(object):
    def __init__(self, model, **kwargs):
@@ -16,8 +17,8 @@ class UniPCSampler(object):

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
            if attr.device != devices.device:
                attr = attr.to(devices.device)
        setattr(self, name, attr)

    def set_hooks(self, before_sample, after_sample, after_update):