Created
September 23, 2023 01:04
-
-
Save aredden/6d912c389c819f038dfb29573b26d21f to your computer and use it in GitHub Desktop.
pytorch function to perform wrapped casting for functions which require specific dtypes (flash attention 2)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
from typing import Any, Dict, Iterable, Optional, Tuple, TypeVar | |
import torch | |
T = TypeVar("T", bound=callable) | |
ref_map = { | |
torch.float64: [torch.float16,torch.float32,torch.bfloat16,torch.half], | |
torch.float32: [torch.float16,torch.bfloat16,torch.half], | |
torch.float16: [], | |
torch.bfloat16: [], | |
torch.half: [], | |
} | |
def dtype_larger_than_dtype(dtype, ref_dtype): | |
return ref_dtype in ref_map[dtype] | |
def dtype_smaller_than_dtype(dtype, ref_dtype): | |
return dtype in ref_map[ref_dtype] | |
class WrappedTensorDtype: | |
def __init__(self,dtype:torch.dtype,tensor: Optional[torch.Tensor] = None,float_index_args=None, float_kwarg_idx=None) -> None: | |
super().__init__() | |
assert dtype.is_floating_point, "WrappedTensorDtype only supports floating point dtypes" | |
self.is_complex = dtype.is_complex | |
self.is_floating_point = dtype.is_floating_point | |
self.is_signed = dtype.is_signed | |
self.itemsize = dtype.itemsize | |
self.to_complex = dtype.to_complex | |
self.to_real = dtype.to_real | |
self.__repr__ = dtype.__repr__ | |
self.__str__ = dtype.__str__ | |
self.__hash__ = dtype.__hash__ | |
self.__eq__ = dtype.__eq__ | |
self.__ne__ = dtype.__ne__ | |
self.inner_dtype = dtype | |
self.tensor = tensor | |
self.float_index_args = float_index_args | |
self.float_kwarg_idx = float_kwarg_idx | |
self.is_arg = float_index_args is not None | |
self.is_kwarg = float_kwarg_idx is not None | |
def __gt__(self, other): | |
assert other.is_floating_point and self.is_floating_point | |
return dtype_larger_than_dtype(self.inner_dtype, other.inner_dtype) | |
def __lt__(self, other): | |
assert other.is_floating_point and self.is_floating_point | |
return dtype_smaller_than_dtype(self.inner_dtype, other.inner_dtype) | |
def __ge__(self, other): | |
assert other.is_floating_point and self.is_floating_point | |
return dtype_larger_than_dtype(self.inner_dtype, other.inner_dtype) or self.inner_dtype == other.inner_dtype | |
def __le__(self, other): | |
assert other.is_floating_point and self.is_floating_point | |
return dtype_smaller_than_dtype(self.inner_dtype, other.inner_dtype) or self.inner_dtype == other.inner_dtype | |
@staticmethod | |
def from_dtype(dtype) -> "WrappedTensorDtype": | |
return WrappedTensorDtype(dtype) | |
@staticmethod | |
def from_tensor(tensor, float_index_args=None, float_kwarg_idx=None) -> "WrappedTensorDtype": | |
return WrappedTensorDtype(tensor.dtype, tensor, float_index_args, float_kwarg_idx) | |
def largest_float_of_list(tensor_list: Iterable, default=torch.float32) -> torch.dtype: | |
dtypes = [WrappedTensorDtype.from_dtype(tensor.dtype) for tensor in tensor_list if (isinstance(tensor, torch.Tensor) and tensor.is_floating_point())] | |
if len(dtypes) == 0: | |
return default | |
return max(dtypes).inner_dtype | |
def largest_float_of_args_kwargs_cast_to(*args, cast_to=None, **kwargs) -> Tuple[torch.dtype,Tuple,Dict[Any,Any]]: | |
dtypes = [WrappedTensorDtype.from_tensor(arg, float_index_args=arg_idx) for arg_idx,arg in enumerate(args) if (isinstance(arg, torch.Tensor) and arg.is_floating_point())] | |
dtypes += [WrappedTensorDtype.from_tensor(v, float_kwarg_idx=kwarg_idx) for kwarg_idx,v in enumerate(kwargs.values()) if (isinstance(v, torch.Tensor) and v.is_floating_point())] | |
if len(dtypes) == 0: | |
return torch.float32, args, kwargs | |
if cast_to is not None: | |
args = list(args) | |
kwarg_keys = list(kwargs.keys()) | |
for _dtype in dtypes: | |
if _dtype.is_arg: | |
args[_dtype.float_index_args] = args[_dtype.float_index_args].to(cast_to) | |
elif _dtype.is_kwarg: | |
kwargs[kwarg_keys[_dtype.float_kwarg_idx]] = kwargs[kwarg_keys[_dtype.float_kwarg_idx]].to(cast_to) | |
return max(dtypes).inner_dtype,tuple(args), kwargs | |
def wrap_to_dtype(func: T, dtype: str = "fp16") -> T: | |
""" | |
Wraps a function to cast all float torch.Tensor arguments to the specified dtype | |
and cast the output back to the largest float type of the input. | |
""" | |
wrap_dtype = { | |
"float32": torch.float32, | |
"float16": torch.float16, | |
"bfloat16": torch.bfloat16, | |
"float64": torch.float64, | |
"float": torch.float32, | |
"double": torch.float64, | |
"half": torch.float16, | |
"fp32": torch.float32, | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
"fp64": torch.float64, | |
}.get(dtype, torch.float16) | |
@wraps(func) | |
def wrapped(*args, **kwargs): | |
largest_dtype,args,kwargs = largest_float_of_args_kwargs_cast_to(*args, cast_to=wrap_dtype, **kwargs) | |
o = func(*args, **kwargs) | |
if largest_dtype != wrap_dtype: | |
o = o.to(largest_dtype) | |
return o | |
return wrapped |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment