Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
changes to synthetic data generation
  • Loading branch information
kbudhath committed Feb 14, 2017
1 parent 54991d8 commit fea61e4
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 138 deletions.
104 changes: 84 additions & 20 deletions test_benchmark.py
Expand Up @@ -11,7 +11,8 @@

from cisc import cisc
from dc import dc
from test_anm import dc_compat
from dr import dr
from utils import dc_compat, plot_multiline, progress, reverse_argsort
from discretizer import *


Expand Down Expand Up @@ -67,15 +68,19 @@ def load_tubingen_pairs():

def test_tubingen_pairs():
epsilon = 0.0
level = 0.05
truths = get_ground_truths_of_tubingen_pairs()
multivariate_pairs = [52, 53, 54, 55, 71]
num_pairs = len(truths) - len(multivariate_pairs)

num_correct = 0
num_wrong = 0
nsample = 0
num_indecisive = 0
res_cisc, res_dc, res_dr = [], [], []
diffs_cisc, diffs_dc, diffs_dr = [], [], []

print "#num\tfound\ttruth\tdelta"
progress(0, 95)
for i, data in enumerate(load_tubingen_pairs()):
if i + 1 in multivariate_pairs:
continue
Expand All @@ -85,32 +90,91 @@ def test_tubingen_pairs():
# discretizer = UnivariateIPDiscretizer(X, Y)
# aX, Xd, aY, Yd = discretizer.discretize()
# cisc_score = cisc(Xd, Yd)
nsample += 1
cisc_score = cisc(X, Y)
# cisc_score = dc(dc_compat(Xd), dc_compat(Yd))
delta = abs(cisc_score[0] - cisc_score[1])
dc_score = dc(dc_compat(X), dc_compat(Y))
dr_score = dr(X.tolist(), Y.tolist(), level)

diffs_cisc.append(abs(cisc_score[0] - cisc_score[1]))
diffs_dc.append(abs(dc_score[0] - dc_score[1]))
diffs_dr.append(abs(dr_score[0] - dr_score[1]))

if cisc_score[0] < cisc_score[1]:
cause = "X"
found = "X → Y"
cause_cisc = "X"
elif cisc_score[0] > cisc_score[1]:
cause = "Y"
found = "Y → X"
cause_cisc = "Y"
else:
cause_cisc = ""

if dc_score[0] > dc_score[1]:
cause_dc = "X"
elif dc_score[0] < dc_score[1]:
cause_dc = "Y"
else:
cause_dc = ""

if dr_score[0] > level and dr_score[1] < level:
cause_dr = "X"
elif dr_score[0] < level and dr_score[1] > level:
cause_dr = "Y"
else:
cause = ""
found = "X ~ Y"

truth_cause = truths[i][0]
truth = "X → Y" if truth_cause == "X" else "Y → X"
if cause == "":
num_indecisive += 1
elif cause == truth_cause:
num_correct += 1
cause_dr = ""

true_cause = truths[i][0]
if cause_cisc == "":
res_cisc.append(random.choice([True, False]))
elif cause_cisc == true_cause:
res_cisc.append(True)
else:
num_wrong += 1
res_cisc.append(False)

print "%4d\t%s\t%s\t%.4f" % (i + 1, found, truth, delta)
if cause_dc == "":
res_dc.append(random.choice([True, False]))
elif cause_dc == true_cause:
res_dc.append(True)
else:
res_dc.append(False)

print "✓ = %3d ✗ = %3d ~ = %3d" % (num_correct, num_wrong, num_indecisive)
if cause_dr == "":
res_dr.append(random.choice([True, False]))
elif cause_dr == true_cause:
res_dr.append(True)
else:
res_dr.append(False)

progress(nsample, 95)

# print "✓ = %3d ✗ = %3d ~ = %3d" % (num_correct, num_wrong,
# num_indecisive)
indices_cisc = reverse_argsort(diffs_cisc)
indices_dc = reverse_argsort(diffs_dc)
indices_dr = reverse_argsort(diffs_dr)

diffs_cisc = [diffs_cisc[i] for i in indices_cisc]
diffs_dc = [diffs_dc[i] for i in indices_dc]
diffs_dr = [diffs_dr[i] for i in indices_dr]

res_cisc = [res_cisc[i] for i in indices_cisc]
res_dc = [res_dc[i] for i in indices_dc]
res_dr = [res_dr[i] for i in indices_dr]

dec_rate = np.arange(0.02, 1.01, 0.01)
accs_cisc, accs_dc, accs_dr = [], [], []
fp = open("results/dec_rate_benchmark.dat", "w")
for r in dec_rate:
maxIdx = int(r * nsample)
rcisc = res_cisc[:maxIdx]
rdc = res_dc[:maxIdx]
rdr = res_dr[:maxIdx]

accs_cisc.append(sum(rcisc) / len(rcisc))
accs_dc.append(sum(rdc) / len(rdc))
accs_dr.append(sum(rdr) / len(rdr))
fp.write("%.2f %.2f %.2f %.2f" % (r, sum(rdc) / len(rdc),
sum(rdr) / len(rdr), sum(rcisc) / len(rcisc)))
fp.close()
plot_multiline([accs_cisc, accs_dc, accs_dr], dec_rate, [
"CISC", "DC", "DR"], "decision rate", "accuracy", "decision rate versus accuracy")


if __name__ == "__main__":
Expand Down

0 comments on commit fea61e4

Please sign in to comment.