Skip to content
Commit 27fe3eb6 authored by brkirch's avatar brkirch
Browse files

Add workaround for MPS layer_norm on PyTorch 2.0

On PyTorch 2.0, with MPS layer_norm only accepts float32 inputs. This was fixed shortly after 2.0 was finalized so the workaround can be applied with an exact version match.
parent c5142e2f
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment