Created
July 22, 2020 19:34
-
-
Save pushpendre/a17ce02e8bc2ee293f93360495070954 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/utils/tabular/features/abstract_feature_generator.py b/utils/tabular/features/abstract_feature_generator.py | |
index 155cc1f..2fbcdf8 100644 | |
--- a/utils/tabular/features/abstract_feature_generator.py | |
+++ b/utils/tabular/features/abstract_feature_generator.py | |
@@ -89,20 +89,23 @@ class AbstractFeatureGenerator: | |
self.features_to_remove += self.banned_features | |
X_index = copy.deepcopy(X.index) | |
X.columns = X.columns.astype(str) # Ensure all column names are strings | |
+ | |
+ # populate self.features_init, self.feature_type_family, self.features_to_remove | |
self.get_feature_types(X) | |
X = X.drop(self.features_to_remove, axis=1, errors='ignore') | |
self.features_init_to_keep = copy.deepcopy(list(X.columns)) | |
- self.features_init_types = X.dtypes.to_dict() | |
+ self.features_init_types = {featname: typ for typ, featname_list in self.feature_type_family.items() for featname in featname_list} | |
self.feature_type_family_init_raw = get_type_groups_df(X) | |
X.reset_index(drop=True, inplace=True) | |
X_features = self.generate_features(X) | |
+ object_column_set = set(self.feature_type_family.get('object', [])) | |
for column in X_features: | |
unique_value_count = len(X_features[column].unique()) | |
if unique_value_count == 1: | |
self.features_to_remove_post.append(column) | |
# TODO: Consider making 0.99 a parameter to FeatureGenerator | |
- elif 'object' in self.feature_type_family and column in self.feature_type_family['object'] and (unique_value_count / X_len > 0.99): | |
+ elif column in object_column_set and (unique_value_count / X_len > 0.99): | |
self.features_to_remove_post.append(column) | |
self.features_binned = list(set(self.features_binned) - set(self.features_to_remove_post)) | |
@@ -439,7 +442,7 @@ class AbstractFeatureGenerator: | |
# TODO: add option for user to specify dtypes on load | |
@staticmethod | |
def get_type_family(dtype): | |
- return get_type_family(dtype=dtype) | |
+ return get_type_family(dtype) | |
@staticmethod | |
def word_count(string): | |
diff --git a/utils/tabular/features/auto_ml_feature_generator.py b/utils/tabular/features/auto_ml_feature_generator.py | |
index 78ebf5c..ec6200f 100644 | |
--- a/utils/tabular/features/auto_ml_feature_generator.py | |
+++ b/utils/tabular/features/auto_ml_feature_generator.py | |
@@ -62,7 +62,7 @@ class AutoMLFeatureGenerator(AbstractFeatureGenerator): | |
self._compute_feature_transformations() | |
X_features = pd.DataFrame(index=X.index) | |
for column in X.columns: | |
- if X[column].dtype.name == 'object': | |
+ if self.features_init_types[column] == 'object': | |
X[column].fillna('', inplace=True) | |
else: | |
X[column].fillna(np.nan, inplace=True) | |
diff --git a/utils/tabular/features/utils.py b/utils/tabular/features/utils.py | |
index 86c9320..765e966 100644 | |
--- a/utils/tabular/features/utils.py | |
+++ b/utils/tabular/features/utils.py | |
@@ -6,27 +6,35 @@ import numpy as np | |
logger = logging.getLogger(__name__) | |
-def get_type_family(dtype): | |
+def get_type_family(dtype_toplevel): | |
"""From dtype, gets the dtype family.""" | |
+ # check if dtype is Sparse dtype extension from pandas | |
+ is_sparse = dtype_toplevel.name.startswith('Sparse[') | |
+ dtype = dtype_toplevel.subtype if is_sparse else dtype_toplevel | |
+ ret = None | |
try: | |
if dtype.name is 'category': | |
- return 'category' | |
+ ret = 'category' | |
if 'datetime' in dtype.name: | |
- return 'datetime' | |
+ ret = 'datetime' | |
elif np.issubdtype(dtype, np.integer): | |
- return 'int' | |
+ ret = 'int' | |
elif np.issubdtype(dtype, np.floating): | |
- return 'float' | |
+ ret = 'float' | |
except Exception as err: | |
logger.exception(f'Warning: dtype {dtype} is not recognized as a valid dtype by numpy! AutoGluon may incorrectly handle this feature...') | |
logger.exception(err) | |
- | |
- if dtype.name in ['bool', 'bool_']: | |
- return 'bool' | |
- elif dtype.name in ['str', 'string', 'object']: | |
- return 'object' | |
- else: | |
- return dtype.name | |
+ if ret is None: | |
+ if dtype.name in ['bool', 'bool_']: | |
+ ret = 'bool' | |
+ elif dtype.name in ['str', 'string', 'object']: | |
+ ret = 'object' | |
+ else: | |
+ ret = dtype.name | |
+ # forget that we are sparse because storage does not | |
+ # affect semantics. | |
+ # ~~f'Sparse[{ret}]' if is_sparse else ret~~ | |
+ return ret | |
def get_type_groups_df(df): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment