-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
693 additions
and
612 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
"""This module implements the paper titled `Accurate Causal Inference on | ||
Discrete Data`. We can also compute the total information content in the | ||
sample by encoding the function and using the stochastic complexity on top of | ||
regression model. For more detail, please refer to the manuscript at | ||
http://people.mpi-inf.mpg.de/~kbudhath/manuscript/acid.pdf | ||
""" | ||
from collections import Counter | ||
from math import log | ||
import sys | ||
|
||
from formatter import stratify | ||
from measures import DependenceMeasure, DMType | ||
|
||
|
||
def choose(n, k): | ||
"""Computes the binomial coefficient `n choose k`. | ||
""" | ||
if 0 <= k <= n: | ||
ntok = 1 | ||
ktok = 1 | ||
for t in range(1, min(k, n - k) + 1): | ||
ntok *= n | ||
ktok *= t | ||
n -= 1 | ||
return ntok // ktok | ||
else: | ||
return 0 | ||
|
||
|
||
def univ_enc(n): | ||
"""Computes the universal code length of the given integer. | ||
Reference: J. Rissanen. A Universal Prior for Integers and Estimation by | ||
Minimum Description Length. Annals of Statistics 11(2) pp.416-431, 1983. | ||
""" | ||
ucl = log(2.86504, 2) | ||
previous = n | ||
while True: | ||
previous = log(previous, 2) | ||
if previous < 1.0: | ||
break | ||
ucl += previous | ||
return ucl | ||
|
||
|
||
def encode_func(f): | ||
"""Encodes the function by enumerating the set of all possible functions. | ||
Args: | ||
ndom (int): number of elements in the domain of the function | ||
nimg (int): number of elements in the image of the function | ||
Returns: | ||
(float): encoded size of the function | ||
""" | ||
# nones = len(set(f.values())) | ||
# return univ_enc(nones) + log(choose(ndom * nimg, nones), 2) | ||
ndom = len(f.keys()) | ||
nimg = len(set(f.values())) | ||
return univ_enc(ndom) + univ_enc(nimg) + log(ndom ** nimg, 2) | ||
|
||
|
||
def map_to_majority(X, Y): | ||
"""Creates a function that maps y to the frequently co-occuring x. | ||
Args: | ||
X (sequence): sequence of discrete outcomes | ||
Y (sequence): sequence of discrete outcomes | ||
Returns: | ||
(dict): map from Y-values to frequently co-occuring X-values | ||
""" | ||
f = dict() | ||
Y_grps = stratify(X, Y) | ||
for x, Ys in Y_grps.items(): | ||
frequent_y, _ = Counter(Ys).most_common(1)[0] | ||
f[x] = frequent_y | ||
return f | ||
|
||
|
||
def regress(X, Y, dep_measure, max_niterations, enc_func=False): | ||
"""Performs discrete regression with Y as a dependent variable and X as | ||
an independent variable. | ||
Args: | ||
X (sequence): sequence of discrete outcomes | ||
Y (sequence): sequence of discrete outcomes | ||
dep_measure (DependenceMeasure): subclass of DependenceMeasure | ||
max_niterations (int): maximum number of iterations | ||
enc_func (bool): whether to encode the function or not | ||
Returns: | ||
(float): p-value (or information content) after fitting ANM from X->Y | ||
""" | ||
# todo: make it work with chi-squared test of independence or G^2 test | ||
supp_X = list(set(X)) | ||
supp_Y = list(set(Y)) | ||
f = map_to_majority(X, Y) | ||
|
||
pair = list(zip(X, Y)) | ||
res = [y - f[x] for x, y in pair] | ||
cur_res_inf = dep_measure.measure(res, X) | ||
|
||
j = 0 | ||
minimized = True | ||
while j < max_niterations and minimized: | ||
minimized = False | ||
|
||
for x_to_map in supp_X: | ||
best_res_inf = sys.float_info.max | ||
best_y = None | ||
|
||
for cand_y in supp_Y: | ||
if cand_y == f[x_to_map]: | ||
continue | ||
|
||
res = [y - f[x] if x != x_to_map else y - | ||
cand_y for x, y in pair] | ||
res_inf = dep_measure.measure(res, X) | ||
|
||
if res_inf < best_res_inf: | ||
best_res_inf = res_inf | ||
best_y = cand_y | ||
|
||
if best_res_inf < cur_res_inf: | ||
cur_res_inf = best_res_inf | ||
f[x_to_map] = best_y | ||
minimized = True | ||
j += 1 | ||
|
||
if dep_measure.type == DMType.INFO and not enc_func: | ||
return dep_measure.measure(X) + cur_res_inf | ||
elif dep_measure.type == DMType.INFO and enc_func: | ||
return dep_measure.measure(X) + encode_func(f) + cur_res_inf | ||
else: | ||
_, p_value = dep_measure.nhst([y - f[x] for x, y in pair], X) | ||
return p_value | ||
|
||
|
||
def anm(X, Y, dep_measure, max_niterations=1000, enc_func=False): | ||
"""Fits the Additive Noise Model from X to Y and vice versa. | ||
Args: | ||
X (sequence): sequence of discrete outcomes | ||
Y (sequence): sequence of discrete outcomes | ||
dep_measure (DependenceMeasure): subclass of DependenceMeasure | ||
max_niterations (int): maximum number of iterations | ||
enc_func (bool): whether to encode the function or not | ||
Returns: | ||
(float, float): p-value (or information content) after fitting ANM | ||
from X->Y and vice versa. | ||
""" | ||
assert issubclass(dep_measure, DependenceMeasure), "dependence measure "\ | ||
"must be a subclass of DependenceMeasure abstract class" | ||
xtoy = regress(X, Y, dep_measure, max_niterations, enc_func) | ||
ytox = regress(Y, X, dep_measure, max_niterations, enc_func) | ||
return (xtoy, ytox) | ||
|
||
|
||
if __name__ == "__main__": | ||
import numpy as np | ||
from measures import Entropy, StochasticComplexity, ChiSquaredTest | ||
|
||
X = np.random.choice([1, 2, 4, -1], 1000) | ||
Y = np.random.choice([-2, -1, 0, 1, 2], 1000) | ||
|
||
print(anm(X, Y, Entropy)) | ||
print(anm(X, Y, StochasticComplexity)) | ||
print(anm(X, Y, StochasticComplexity, enc_func=True)) | ||
print(anm(X, Y, ChiSquaredTest)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,58 @@ | ||
#!/usr/bin/env python | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
"""Causal inference on discrete data using stochastic complexity of multinomial. | ||
"""This module implements the paper titled `MDL for Causal Inference on | ||
Discrete Data`. For more detail, please refer to the manuscript at | ||
http://people.mpi-inf.mpg.de/~kbudhath/manuscript/cisc.pdf | ||
""" | ||
from math import log | ||
from formatter import stratify | ||
from sc import sc | ||
|
||
|
||
from collections import defaultdict | ||
from sc import stochastic_complexity | ||
def cisc(X, Y, plain=False): | ||
"""Computes the total stochastic complexity from X to Y and vice versa. | ||
Args: | ||
X (sequence): sequence of discrete outcomes | ||
Y (sequence): sequence of discrete outcomes | ||
plain (bool): whether to compute the plain conditional stochastic | ||
complexity or not. If not provided, we compute the weighted one. | ||
def marginals(X, Y): | ||
Ys = defaultdict(list) | ||
for i, x in enumerate(X): | ||
Ys[x].append(Y[i]) | ||
return Ys | ||
Returns: | ||
(float, float): the total multinomial stochastic complexity of X and Y | ||
in the direction from X to Y, and vice versa. | ||
""" | ||
assert len(X) == len(Y) | ||
|
||
n = len(X) | ||
|
||
def cisc(X, Y): | ||
scX = stochastic_complexity(X) | ||
scY = stochastic_complexity(Y) | ||
scX = sc(X) | ||
scY = sc(Y) | ||
|
||
mYgX = marginals(X, Y) | ||
mXgY = marginals(Y, X) | ||
YgX = stratify(X, Y) | ||
XgY = stratify(Y, X) | ||
|
||
domX = mYgX.keys() | ||
domY = mXgY.keys() | ||
domX = YgX.keys() | ||
domY = XgY.keys() | ||
|
||
# plain one | ||
# scYgX = sum(stochastic_complexity(Z, len(domY)) for Z in mYgX.itervalues()) | ||
# scXgY = sum(stochastic_complexity(Z, len(domX)) for Z in mXgY.itervalues()) | ||
ndomX = len(domX) | ||
ndomY = len(domY) | ||
|
||
# weighted one | ||
scYgX = sum((len(Z) * 1.0) / len(X) * stochastic_complexity(Z, len(domY)) | ||
for Z in mYgX.itervalues()) | ||
scXgY = sum((len(Z) * 1.0) / len(X) * stochastic_complexity(Z, len(domX)) | ||
for Z in mXgY.itervalues()) | ||
if plain: | ||
scYgX = sum(sc(Yp, ndomY) for Yp in YgX.values()) | ||
scXgY = sum(sc(Xp, ndomX) for Xp in XgY.values()) | ||
else: | ||
scYgX = sum(len(Yp) / n * sc(Yp, ndomY) for Yp in YgX.values()) | ||
scXgY = sum(len(Xp) / n * sc(Xp, ndomX) for Xp in XgY.values()) | ||
|
||
ciscXtoY = scX + scYgX | ||
ciscYtoX = scY + scXgY | ||
# print "X=%.2f Ygx=%.2f" % (scX, scYgX) | ||
# print "Y=%.2f XgY=%.2f" % (scY, scXgY) | ||
|
||
return (ciscXtoY, ciscYtoX) | ||
|
||
|
||
if __name__ == "__main__": | ||
import random | ||
from test_synthetic import map_randomly | ||
n = 1000 | ||
Xd = range(1, 4) | ||
fXd = range(1, 4) | ||
f = map_randomly(Xd, fXd) | ||
N = range(-2, 3) | ||
|
||
X = [random.choice(Xd) for i in xrange(n)] | ||
Y = [f[X[i]] + random.choice(N) for i in xrange(n)] | ||
|
||
print cisc(X, Y) | ||
n = 100 | ||
X = [random.randint(0, 10) for i in range(n)] | ||
Y = [random.randint(0, 10) for i in range(n)] | ||
print(cisc(X, Y)) |
Oops, something went wrong.