Skip to content

Instantly share code, notes, and snippets.

@aredden
Created September 23, 2023 01:04
Show Gist options
  • Save aredden/6d912c389c819f038dfb29573b26d21f to your computer and use it in GitHub Desktop.
Save aredden/6d912c389c819f038dfb29573b26d21f to your computer and use it in GitHub Desktop.
pytorch function to perform wrapped casting for functions which require specific dtypes (flash attention 2)
from functools import wraps
from typing import Any, Dict, Iterable, Optional, Tuple, TypeVar
import torch
T = TypeVar("T", bound=callable)
ref_map = {
torch.float64: [torch.float16,torch.float32,torch.bfloat16,torch.half],
torch.float32: [torch.float16,torch.bfloat16,torch.half],
torch.float16: [],
torch.bfloat16: [],
torch.half: [],
}
def dtype_larger_than_dtype(dtype, ref_dtype):
return ref_dtype in ref_map[dtype]
def dtype_smaller_than_dtype(dtype, ref_dtype):
return dtype in ref_map[ref_dtype]
class WrappedTensorDtype:
def __init__(self,dtype:torch.dtype,tensor: Optional[torch.Tensor] = None,float_index_args=None, float_kwarg_idx=None) -> None:
super().__init__()
assert dtype.is_floating_point, "WrappedTensorDtype only supports floating point dtypes"
self.is_complex = dtype.is_complex
self.is_floating_point = dtype.is_floating_point
self.is_signed = dtype.is_signed
self.itemsize = dtype.itemsize
self.to_complex = dtype.to_complex
self.to_real = dtype.to_real
self.__repr__ = dtype.__repr__
self.__str__ = dtype.__str__
self.__hash__ = dtype.__hash__
self.__eq__ = dtype.__eq__
self.__ne__ = dtype.__ne__
self.inner_dtype = dtype
self.tensor = tensor
self.float_index_args = float_index_args
self.float_kwarg_idx = float_kwarg_idx
self.is_arg = float_index_args is not None
self.is_kwarg = float_kwarg_idx is not None
def __gt__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_larger_than_dtype(self.inner_dtype, other.inner_dtype)
def __lt__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_smaller_than_dtype(self.inner_dtype, other.inner_dtype)
def __ge__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_larger_than_dtype(self.inner_dtype, other.inner_dtype) or self.inner_dtype == other.inner_dtype
def __le__(self, other):
assert other.is_floating_point and self.is_floating_point
return dtype_smaller_than_dtype(self.inner_dtype, other.inner_dtype) or self.inner_dtype == other.inner_dtype
@staticmethod
def from_dtype(dtype) -> "WrappedTensorDtype":
return WrappedTensorDtype(dtype)
@staticmethod
def from_tensor(tensor, float_index_args=None, float_kwarg_idx=None) -> "WrappedTensorDtype":
return WrappedTensorDtype(tensor.dtype, tensor, float_index_args, float_kwarg_idx)
def largest_float_of_list(tensor_list: Iterable, default=torch.float32) -> torch.dtype:
dtypes = [WrappedTensorDtype.from_dtype(tensor.dtype) for tensor in tensor_list if (isinstance(tensor, torch.Tensor) and tensor.is_floating_point())]
if len(dtypes) == 0:
return default
return max(dtypes).inner_dtype
def largest_float_of_args_kwargs_cast_to(*args, cast_to=None, **kwargs) -> Tuple[torch.dtype,Tuple,Dict[Any,Any]]:
dtypes = [WrappedTensorDtype.from_tensor(arg, float_index_args=arg_idx) for arg_idx,arg in enumerate(args) if (isinstance(arg, torch.Tensor) and arg.is_floating_point())]
dtypes += [WrappedTensorDtype.from_tensor(v, float_kwarg_idx=kwarg_idx) for kwarg_idx,v in enumerate(kwargs.values()) if (isinstance(v, torch.Tensor) and v.is_floating_point())]
if len(dtypes) == 0:
return torch.float32, args, kwargs
if cast_to is not None:
args = list(args)
kwarg_keys = list(kwargs.keys())
for _dtype in dtypes:
if _dtype.is_arg:
args[_dtype.float_index_args] = args[_dtype.float_index_args].to(cast_to)
elif _dtype.is_kwarg:
kwargs[kwarg_keys[_dtype.float_kwarg_idx]] = kwargs[kwarg_keys[_dtype.float_kwarg_idx]].to(cast_to)
return max(dtypes).inner_dtype,tuple(args), kwargs
def wrap_to_dtype(func: T, dtype: str = "fp16") -> T:
"""
Wraps a function to cast all float torch.Tensor arguments to the specified dtype
and cast the output back to the largest float type of the input.
"""
wrap_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float64": torch.float64,
"float": torch.float32,
"double": torch.float64,
"half": torch.float16,
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp64": torch.float64,
}.get(dtype, torch.float16)
@wraps(func)
def wrapped(*args, **kwargs):
largest_dtype,args,kwargs = largest_float_of_args_kwargs_cast_to(*args, cast_to=wrap_dtype, **kwargs)
o = func(*args, **kwargs)
if largest_dtype != wrap_dtype:
o = o.to(largest_dtype)
return o
return wrapped
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment