Created September 23, 2023 01:04
pytorch function to perform wrapped casting for functions which require specific dtypes (flash attention 2)
from functools import wraps
from typing import Any, Dict, Iterable, Optional, Tuple, TypeVar
import torch
T = TypeVar("T", bound=callable)
ref_map = {
torch.float64: [torch.float16,torch.float32,torch.bfloat16,torch.half],
torch.float32: [torch.float16,torch.bfloat16,torch.half],
torch.float16: [],
torch.bfloat16: [],
torch.half: [],
def dtype_larger_than_dtype(dtype, ref_dtype):
return ref_dtype in ref_map[dtype]
def dtype_smaller_than_dtype(dtype, ref_dtype):
return dtype in ref_map[ref_dtype]
class WrappedTensorDtype:
def __init__(self,dtype:torch.dtype,tensor: Optional[torch.Tensor] = None,float_index_args=None, float_kwarg_idx=None) -> None:
assert dtype.is_floating_point, "WrappedTensorDtype only supports floating point dtypes"
self.is_complex = dtype.is_complex
self.is_floating_point = dtype.is_floating_point
self.is_signed = dtype.is_signed
self.itemsize = dtype.itemsize
self.to_complex = dtype.to_complex
self.to_real = dtype.to_real
self.__repr__ = dtype.__repr__
self.__str__ = dtype.__str__
self.__hash__ = dtype.__hash__
self.__eq__ = dtype.__eq__
self.__ne__ = dtype.__ne__
self.inner_dtype = dtype
self.tensor = tensor
self.float_index_args = float_index_args
self.float_kwarg_idx = float_kwarg_idx
self.is_arg = float_index_args is not None
self.is_kwarg = float_kwarg_idx is not None
def __gt__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_larger_than_dtype(self.inner_dtype, other.inner_dtype)
def __lt__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_smaller_than_dtype(self.inner_dtype, other.inner_dtype)
def __ge__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_larger_than_dtype(self.inner_dtype, other.inner_dtype) or self.inner_dtype == other.inner_dtype
def __le__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_smaller_than_dtype(self.inner_dtype, other.inner_dtype) or self.inner_dtype == other.inner_dtype
def from_dtype(dtype) -> "WrappedTensorDtype":
return WrappedTensorDtype(dtype)
def from_tensor(tensor, float_index_args=None, float_kwarg_idx=None) -> "WrappedTensorDtype":
return WrappedTensorDtype(tensor.dtype, tensor, float_index_args, float_kwarg_idx)
def largest_float_of_list(tensor_list: Iterable, default=torch.float32) -> torch.dtype:
dtypes = [WrappedTensorDtype.from_dtype(tensor.dtype) for tensor in tensor_list if (isinstance(tensor, torch.Tensor) and tensor.is_floating_point())]
if len(dtypes) == 0:
return default
return max(dtypes).inner_dtype
def largest_float_of_args_kwargs_cast_to(*args, cast_to=None, **kwargs) -> Tuple[torch.dtype,Tuple,Dict[Any,Any]]:
dtypes = [WrappedTensorDtype.from_tensor(arg, float_index_args=arg_idx) for arg_idx,arg in enumerate(args) if (isinstance(arg, torch.Tensor) and arg.is_floating_point())]
dtypes += [WrappedTensorDtype.from_tensor(v, float_kwarg_idx=kwarg_idx) for kwarg_idx,v in enumerate(kwargs.values()) if (isinstance(v, torch.Tensor) and v.is_floating_point())]
if len(dtypes) == 0:
return torch.float32, args, kwargs
if cast_to is not None:
args = list(args)
kwarg_keys = list(kwargs.keys())
for _dtype in dtypes:
if _dtype.is_arg:
args[_dtype.float_index_args] = args[_dtype.float_index_args].to(cast_to)
elif _dtype.is_kwarg:
kwargs[kwarg_keys[_dtype.float_kwarg_idx]] = kwargs[kwarg_keys[_dtype.float_kwarg_idx]].to(cast_to)
return max(dtypes).inner_dtype,tuple(args), kwargs
def wrap_to_dtype(func: T, dtype: str = "fp16") -> T:
Wraps a function to cast all float torch.Tensor arguments to the specified dtype
and cast the output back to the largest float type of the input.
wrap_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float64": torch.float64,
"float": torch.float32,
"double": torch.float64,
"half": torch.float16,
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp64": torch.float64,
}.get(dtype, torch.float16)
def wrapped(*args, **kwargs):
largest_dtype,args,kwargs = largest_float_of_args_kwargs_cast_to(*args, cast_to=wrap_dtype, **kwargs)
o = func(*args, **kwargs)
if largest_dtype != wrap_dtype:
o =
return o
return wrapped
