Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
non-constant function ensured, non-overlapping noise checked
  • Loading branch information
kbudhath committed Sep 25, 2017
1 parent c40fce0 commit 46b9efe
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions test_synthetic.py
Expand Up @@ -44,6 +44,11 @@ def map_randomly(Xd, fXd):
for x in Xd:
y = random.choice(fXd)
f[x] = y

# ensure that f is not a constant function
if len(set(f.values())) == 1:
f = map_randomly(Xd, fXd)
assert len(set(f.values())) != 1
return f


Expand Down Expand Up @@ -181,30 +186,55 @@ def _decision_rate(srcX):
# "CISC", "DC", "DR"], "decision rate", "accuracy", "decision rate versus accuracy", "dec_rate_%sX.png" % srcX)


def are_disjoint(sets):
disjoint = True
union = set()
for s in sets:
for x in s:
if x in union:
disjoint = False
break
union.add(x)
return disjoint


def test_accuracy():
nsim = 1000
size = 5000
size = 1000
level = 0.01
suppfX = range(-7, 8)
srcsX = ["uniform", "binomial", "negativeBinomial",
"geometric", "hypergeometric", "poisson", "multinomial"]
print "-" * 64
print "-" * 70
print "%18s%10s%10s%10s%10s%10s" % ("TYPE_X", "DC", "DR", "CISC", "CRISPE", "CRISP")
print "-" * 64
print "-" * 70
sys.stdout.flush()

fp = open("results/acc-dtype.dat", "w")
fp.write("%s\t%s\t%s\t%s\t%s\t%s\n" %
("dtype", "dc", "dr", "cisc", "crispe", "crisp"))
for srcX in srcsX:
nsamples = 0
nc_dc, nc_dr, nc_cisc, nc_crispe, nc_crisp = 0, 0, 0, 0, 0
for k in xrange(nsim):
while nsamples < nsim:
X = generate_X(srcX, size)
suppX = list(set(X))
f = map_randomly(suppX, suppfX)
N = generate_additive_N(size)
Y = [f[X[i]] + N[i] for i in xrange(size)]

# check if f(x) + supp N are disjoint for x in domx
suppN = set(N)
decomps = []
for x in suppX:
fx = f[x]
sum_fx_suppN = set([fx + n for n in suppN])
decomps.append(sum_fx_suppN)

non_overlapping_noise = are_disjoint(decomps)
if non_overlapping_noise:
continue

nsamples += 1
dc_score = dc(dc_compat(X), dc_compat(Y))
dr_score = dr(X, Y, level)
cisc_score = cisc(X, Y)
Expand All @@ -217,6 +247,8 @@ def test_accuracy():
nc_crispe += int(crispe_score[0] < crispe_score[1])
nc_crisp += int(crisp_score[0] < crisp_score[1])

assert nsamples == nsim

acc_dc = nc_dc * 100 / nsim
acc_dr = nc_dr * 100 / nsim
acc_cisc = nc_cisc * 100 / nsim
Expand All @@ -226,7 +258,7 @@ def test_accuracy():
sys.stdout.flush()
fp.write("%s\t%.2f\t%.2f\t%.2f\t%.2f\t%.2f\n" %
(srcX, acc_dc, acc_dr, acc_cisc, acc_crispe, acc_crisp))
print "-" * 58
print "-" * 70
sys.stdout.flush()
fp.close()

Expand Down Expand Up @@ -529,18 +561,17 @@ def test_hypercompression():


def test_sample_size():
nsim = 5000
nsim = 500
level = 0.05
sizes = [50, 100, 500, 1000, 2500, 5000]
sizes = [100, 500, 1000, 2500, 5000]
suppfX = range(-7, 8)
srcX = "geometric"

fp = open("results/acc-size.dat", "w")
diffs = []
fp.write("%s\t%s\t%s\t%s\t%s\t%s\n" %
("size", "dc", "dr", "cisc", "crispe", "crisp"))
print "%s\t%s\t%s\t%s\t%s\t%s\n" %
("size", "dc", "dr", "cisc", "crispe", "crisp")
print "%s\t%s\t%s\t%s\t%s\t%s" % ("size", "dc", "dr", "cisc", "crispe", "crisp")
sys.stdout.flush()
# progress(0, len(sizes))
for k, size in enumerate(sizes):
Expand Down

0 comments on commit 46b9efe

Please sign in to comment.