Unverified Commit 62e9fec3 authored by pepe10-gpu's avatar pepe10-gpu Committed by GitHub
Browse files

actual better fix

thanks C43H66N12O12S2
parent 29eff4a1
Loading
Loading
Loading
Loading
+2 −5
Original line number Original line Diff line number Diff line
@@ -39,11 +39,8 @@ def torch_gc():


def enable_tf32():
def enable_tf32():
    if torch.cuda.is_available():
    if torch.cuda.is_available():
        #TODO: make this better; find a way to check if it is a turing card
        turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"]
        for devid in range(0,torch.cuda.device_count()):
        for devid in range(0,torch.cuda.device_count()):
            for i in turing:
            if torch.cuda.get_device_capability(devid) == (7, 5):
                if i in torch.cuda.get_device_name(devid):
                shd = True
                shd = True
        if shd:
        if shd:
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.benchmark = True