# Balancing the training dataset to a reported positive-negative class ratio, in the unseen dataset

Authors

A python gist for balancing (re-sampling) a training dataset to match a reported positive-negative class ratio, in the unseen dataset

When we know the unseen’s pos-neg class ratio (or “guess” it from the LB..) we should give a try at balancing the training dataset, to reflect it.

I wrote a python gist for it, using pandas here, here’s the code:

``````import pandas as pd

def balance_train_ds(df_train, unseen_pos_rate, train_y_field):
df_train_pos = df_train[df_train[train_y_field] == 1]
df_train_neg = df_train[df_train[train_y_field] == 0]
p = df_train_pos.shape
n = df_train_neg.shape
train_pos_rate = float(p) / float(df_train.shape)
print 'train ds pos rate {0}, unseen ds reported pos rate {1}'.format(train_pos_rate, unseen_pos_rate)

# pos_rate = r1 = (p / (p + n))
# solving for r2, where r2 > r1, or r2 < r1, and we'd like to only add samples (pos or neg), not losing any
# using pands "sample" function with "replace=True" allows to sample more than the ds current size, if needed
r1 = train_pos_rate
r2 = unseen_pos_rate

if r2 < r1:
# solving balance for r2, where r2 < r1
# p / (p + n + balance) = r2
balance = int( (p - (r2 * p)- (r2 * n)) / r2 )
print 'duplicating {0} random negatives'.format(balance)
df_train = pd.concat([df_train, df_train_neg.sample(n=balance, replace=True)])
elif r2 > r1:
# solving balance for r2, where r2 > r1
# (p + x) / (p + x + n) = r2
balance = int( ((r2 * p) - p + (r2 * n)) / (1 - r2) )
print 'duplicating {0} random positives'.format(balance)
df_train = pd.concat([df_train, df_train_pos.sample(n=balance, replace=True)])

# re-check
df_train_pos = df_train[df_train[train_y_field] == 1]
df_train_neg = df_train[df_train[train_y_field] == 0]
train_pos_rate = float(df_train_pos.shape) / float(df_train.shape)
print 'train ds re-balanced to {0}'.format(train_pos_rate)
return df_train

if __name__ == '__main__':
# set the reported unseen positive ratio, ant try
unseen_pos_rate = 0.12