|
import numpy as np |
|
|
|
|
|
def approximate_mode(class_counts, n_draws, rng): |
|
"""Computes approximate mode of multivariate hypergeometric. |
|
This is an approximation to the mode of the multivariate |
|
hypergeometric given by class_counts and n_draws. |
|
It shouldn't be off by more than one. |
|
It is the mostly likely outcome of drawing n_draws many |
|
samples from the population given by class_counts. |
|
Args |
|
---------- |
|
class_counts : ndarray of int |
|
Population per class. |
|
n_draws : int |
|
Number of draws (samples to draw) from the overall population. |
|
rng : random state |
|
Used to break ties. |
|
Returns |
|
------- |
|
sampled_classes : ndarray of int |
|
Number of samples drawn from each class. |
|
np.sum(sampled_classes) == n_draws |
|
|
|
""" |
|
|
|
|
|
continuous = n_draws * class_counts / class_counts.sum() |
|
|
|
floored = np.floor(continuous) |
|
|
|
|
|
need_to_add = int(n_draws - floored.sum()) |
|
if need_to_add > 0: |
|
remainder = continuous - floored |
|
values = np.sort(np.unique(remainder))[::-1] |
|
|
|
|
|
for value in values: |
|
(inds,) = np.where(remainder == value) |
|
|
|
|
|
|
|
|
|
add_now = min(len(inds), need_to_add) |
|
inds = rng.choice(inds, size=add_now, replace=False) |
|
floored[inds] += 1 |
|
need_to_add -= add_now |
|
if need_to_add == 0: |
|
break |
|
return floored.astype(np.int64) |
|
|
|
|
|
def stratified_shuffle_split_generate_indices(y, n_train, n_test, rng, n_splits=10): |
|
""" |
|
|
|
Provides train/test indices to split data in train/test sets. |
|
It's reference is taken from StratifiedShuffleSplit implementation |
|
of scikit-learn library. |
|
|
|
Args |
|
---------- |
|
|
|
n_train : int, |
|
represents the absolute number of train samples. |
|
|
|
n_test : int, |
|
represents the absolute number of test samples. |
|
|
|
random_state : int or RandomState instance, default=None |
|
Controls the randomness of the training and testing indices produced. |
|
Pass an int for reproducible output across multiple function calls. |
|
|
|
n_splits : int, default=10 |
|
Number of re-shuffling & splitting iterations. |
|
""" |
|
classes, y_indices = np.unique(y, return_inverse=True) |
|
n_classes = classes.shape[0] |
|
class_counts = np.bincount(y_indices) |
|
if np.min(class_counts) < 2: |
|
raise ValueError("Minimum class count error") |
|
if n_train < n_classes: |
|
raise ValueError( |
|
"The train_size = %d should be greater or equal to the number of classes = %d" % (n_train, n_classes) |
|
) |
|
if n_test < n_classes: |
|
raise ValueError( |
|
"The test_size = %d should be greater or equal to the number of classes = %d" % (n_test, n_classes) |
|
) |
|
class_indices = np.split(np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]) |
|
for _ in range(n_splits): |
|
n_i = approximate_mode(class_counts, n_train, rng) |
|
class_counts_remaining = class_counts - n_i |
|
t_i = approximate_mode(class_counts_remaining, n_test, rng) |
|
|
|
train = [] |
|
test = [] |
|
|
|
for i in range(n_classes): |
|
permutation = rng.permutation(class_counts[i]) |
|
perm_indices_class_i = class_indices[i].take(permutation, mode="clip") |
|
train.extend(perm_indices_class_i[: n_i[i]]) |
|
test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]]) |
|
train = rng.permutation(train) |
|
test = rng.permutation(test) |
|
|
|
yield train, test |
|
|