"""Utility functions for working with PyTorch."""

import torch
from typing import Callable, Union


def get_activation(fn: Union[Callable, str]):
  """Get a PyTorch activation function, specified either directly or as a string.

    This function simplifies allowing users to specify activation functions by name.
    If a function is provided, it is simply returned unchanged.  If a string is provided,
    the corresponding function in torch.nn.functional is returned.
    """
  if isinstance(fn, str):
    return getattr(torch.nn.functional, fn)
  return fn
