Skip to content

Commit

Permalink
acid added
Browse files Browse the repository at this point in the history
  • Loading branch information
kbudhath committed Apr 25, 2018
1 parent 23673ee commit 5566b44
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
91 changes: 91 additions & 0 deletions acid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import Counter, defaultdict
from copy import copy
import random
import sys
import time

from entropy import entropy


__author__ = "Kailash Budhathoki"
__email__ = "kbudhath@mpi-inf.mpg.de"
__copyright__ = "Copyright (c) 2018"
__license__ = "MIT"


def marginals(X, Y):
Ys = defaultdict(list)
for i, x in enumerate(X):
Ys[x].append(Y[i])
return Ys


def map_to_majority(X, Y):
f = dict()
subgroups_y = defaultdict(list)

for i, x in enumerate(X):
subgroups_y[x].append(Y[i])

for x, subgroup_y in subgroups_y.iteritems():
freq_y, _ = Counter(subgroup_y).most_common(1)[0]
f[x] = freq_y

return f


def regress(X, Y):
# target Y, feature X
max_iterations = 10000
hx = entropy(X)
len_dom_y = len(set(Y))
f = map_to_majority(X, Y)

supp_x = list(set(X))
supp_y = list(set(Y))

pair = zip(X, Y)
res = [y - f[x] for x, y in pair]
cur_res_codelen = entropy(res)

j = 0
minimized = True
while j < max_iterations and minimized:
minimized = False

for x_to_map in supp_x:
best_res_codelen = sys.float_info.max
best_cand_y = None

for cand_y in supp_y:
if cand_y == f[x_to_map]:
continue

res = [y - f[x] if x != x_to_map else y -
cand_y for x, y in pair]
res_codelen = entropy(res)

if res_codelen < best_res_codelen:
best_res_codelen = res_codelen
best_cand_y = cand_y

if best_res_codelen < cur_res_codelen:
cur_res_codelen = best_res_codelen
f[x_to_map] = best_cand_y
minimized = True
j += 1
return hx + cur_res_codelen


def acid(X, Y):
hxtoy = regress(X, Y)
hytox = regress(Y, X)
return (hxtoy, hytox)


if __name__ == "__main__":
from test_benchmark import load_pair
X, Y = load_pair(99)
print acid(X, Y)
8 changes: 7 additions & 1 deletion entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from math import log


__author__ = "Kailash Budhathoki"
__email__ = "kbudhath@mpi-inf.mpg.de"
__copyright__ = "Copyright (c) 2018"
__license__ = "MIT"


def entropy(sequence):
res = 0
n = len(sequence)
Expand All @@ -15,7 +21,7 @@ def entropy(sequence):
return res


if __name__=="__main__":
if __name__ == "__main__":
print entropy([1, 2, 1, 1, 1, 1])
print entropy([1, 1, 1, 1, 1, 1])
print entropy([1, 1, 1, 2, 2, 2])

0 comments on commit 5566b44

Please sign in to comment.