Created
May 20, 2019 12:58
-
-
Save SilvaEmerson/23b7d54f72fd8f470f1b3bfdd60b6a98 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Stratified K-Fold implementation | |
""" | |
def stratified_k_fold(arr, k=None, class_ratio=.5): | |
if k == None: | |
return None | |
number_elements_fold = len(arr) // k | |
folds = [[]] * k | |
zeros_ = [*filter(lambda el: not el[-1], arr)] | |
ones_ = [*filter(lambda el: el[-1], arr)] | |
ones_ratio = int(class_ratio * number_elements_fold) | |
zeros_ratio = number_elements_fold - ones_ratio | |
for fold_ind in range(k): | |
folds[fold_ind] = [*ones_[: ones_ratio], *zeros_[: zeros_ratio]] | |
del ones_[: ones_ratio] | |
del zeros_[: zeros_ratio] | |
if len(ones_) + len(zeros_) < number_elements_fold: | |
folds[fold_ind] = [*folds[fold_ind], *ones_, *zeros_] | |
return folds | |
if __name__ == '__main__': | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from random import shuffle | |
from functools import reduce | |
from collections import Counter | |
import StratifiedKFold as SKF | |
class MainTest(unittest.TestCase): | |
def setUp(self): | |
self.class_ratio = .5 | |
labels = [*[1] * 11, *[0] * 10] | |
shuffle(labels) | |
self.arr = [*zip(range(21), labels)] | |
def test_should_return_None(self): | |
self.assertIsNone(SKF.stratified_k_fold(self.arr)) | |
def test_should_return_4_folds(self): | |
result_len = len(SKF.stratified_k_fold(self.arr, k=4)) | |
self.assertEqual(result_len, 4) | |
def test_sould_not_return_even_one_empty_fold(self): | |
result = SKF.stratified_k_fold(self.arr, k=4) | |
self.assertTrue(all(result)) | |
def test_sould_return_same_amount_of_elements(self): | |
result = SKF.stratified_k_fold(self.arr, k=4) | |
total = reduce(lambda acc, curr: acc + len(curr), result, 0) | |
self.assertEqual(total, len(self.arr)) | |
def test_self_class_ratio_should_be_equal_as_passed(self): | |
result = SKF.stratified_k_fold(self.arr, k=4, class_ratio=self.class_ratio) | |
ratios = [] | |
for fold in result: | |
fold_class_ratio = sum([*map(lambda el: el[-1], fold)]) / len(fold) | |
ratios.append(round(fold_class_ratio, 1)) | |
most_common_ratio = Counter(ratios).most_common(1)[0][0] | |
self.assertLessEqual(most_common_ratio, self.class_ratio) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment