diff --git a/jars/javatools-1.0.0.jar b/jars/javatools-1.0.0.jar new file mode 100644 index 0000000..2664384 Binary files /dev/null and b/jars/javatools-1.0.0.jar differ diff --git a/jars/json-simple-1.1.1.jar b/jars/json-simple-1.1.1.jar new file mode 100644 index 0000000..66347a6 Binary files /dev/null and b/jars/json-simple-1.1.1.jar differ diff --git a/src/main/java/kb/howtokb/clustering/ActivityCachedSim.java b/src/main/java/kb/howtokb/clustering/ActivityCachedSim.java new file mode 100644 index 0000000..d0b9a84 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/ActivityCachedSim.java @@ -0,0 +1,37 @@ +package kb.howtokb.clustering; + +import edu.stanford.nlp.util.Pair; +import kb.howtokb.utils.BijectiveMap; + +public class ActivityCachedSim { + + private BijectiveMap> ids; + private ActivityComponentSim vSim, nSim; + + public ActivityCachedSim(BijectiveMap> ids, ActivityComponentSim vSim, + ActivityComponentSim nSim) { + this.ids = ids; + this.vSim = vSim; + this.nSim = nSim; + } + + private boolean sim(Pair e1, Pair e2) { +// return +// // v1 v2 are similar +// vSim.simFromCache(e1.first, e2.first) && +// // n1 v2 are similar +// nSim.simFromCache(e1.second, e2.second); + + if (!vSim.simFromCache(e1.first, e2.first) && + // n1 v2 are similar + !nSim.simFromCache(e1.second, e2.second)) + return false; + return true; + } + + public boolean sim(String a1, String a2) { + Pair e1 = this.ids.getValueFromKey(a1); + Pair e2 = this.ids.getValueFromKey(a2); + return sim(e1, e2); + } +} diff --git a/src/main/java/kb/howtokb/clustering/ActivityComponentSim.java b/src/main/java/kb/howtokb/clustering/ActivityComponentSim.java new file mode 100644 index 0000000..b6e778d --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/ActivityComponentSim.java @@ -0,0 +1,86 @@ +package kb.howtokb.clustering; + +import java.util.Map; + +import kb.howtokb.clustering.sim.ActivityWordCategorySim.SparseSims; +import kb.howtokb.utils.IDMap; + +public class ActivityComponentSim implements ISimilarity { + + private double threshold = 0.0; + private SparseSims sims; // cache. + private IDMap ids; + private ISimilarity word2vecSim; + + public ActivityComponentSim(double threshold, IDMap ids, ISimilarity word2vecSim) { + this.threshold = threshold; + this.sims = new SparseSims((float) this.threshold); + this.ids = ids; + this.word2vecSim = word2vecSim; + computeAllPairsSim(); + } + + private void computeAllPairsSim() { + //Progress p = new Progress(1); + Integer[] ids = this.ids.values().toArray(new Integer[0]); + System.out.println("\n" + ids.length + " activity components. (one dot per activity neighborhood)"); + for (int i = 0; i < ids.length; i++) { + //p.next(); + int e1 = ids[i]; + for (int j = i; j < ids.length; j++) { + // cache before returning the result. + int e2 = ids[j]; + sims.set(e1, e2, (float) sim(e1, e2)); + } + } + } + + @Override + public double sim(Integer e1, Integer e2) { + String word1 = ids.getKeyFromValue(e1); + String word2 = ids.getKeyFromValue(e2); + return word2vecSim.sim(word1, word2); + } + + @Override + public Map getNeighbors(Integer e) { + System.err.println( + "neighborhood for an integer " + "activity (noun or verb) not yet implemented."); + return null; + } + + @Override + public boolean simThreshold(Integer e1, Integer e2, double minthreshold) throws Exception { + return sim(e1, e2) >= minthreshold; + } + + // Use this function once the object is completely constructed. + /** + * Also checks for e2,e1 + * + * @param e1 + * @param e2 + * @return if (e1,e2) are similar, return true. otherwise returns false + */ + public boolean simFromCache(int e1, int e2) { + return sims.get(e1, e2); + } + + /** + * Also checks for e2,e1
+ * NOTE: ensure that word1 and word2 are either both verbs or both noun. + * And modified according to the ISim func. provided to the constructor + * + * @param word1 + * go away => v_go (we only lookup "go" in word2vec) + * @param word2 + * move => v_move + * @return if (e1,e2) are similar, return true. otherwise returns false + */ + public boolean simFromCache(String word1, String word2) { + int e1 = ids.getValueFromKey(word1); + int e2 = ids.getValueFromKey(word2); + return sims.get(e1, e2); + } + +} diff --git a/src/main/java/kb/howtokb/clustering/DataForClustering.java b/src/main/java/kb/howtokb/clustering/DataForClustering.java new file mode 100644 index 0000000..dda8e0c --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/DataForClustering.java @@ -0,0 +1,202 @@ +package kb.howtokb.clustering; + +import java.util.List; + +import kb.howtokb.clustering.sim.CategorySimilarity; +import kb.howtokb.clustering.sim.SimilarityComputation; +import kb.howtokb.taskframe.WikiHowTaskFrame; +import kb.howtokb.utils.AdjacencyBackedSparseMatrix; +import kb.howtokb.utils.SparseSimMatrix; + + +public class DataForClustering { + + /*Get similarity matrix of a cluster + * s(i,j) is similarity between i and j + * 0 at diagonal + */ + public static double[][] getSimilarityMatrix(List list, CategorySimilarity cs) throws Exception{ + double [][] res = new double[list.size()][list.size()]; + for (int i=0; i list, CategorySimilarity cs) throws Exception{ + double [][] res = new double[list.size()][list.size()]; + for (int i=0; i list, CategorySimilarity cs, double threshold) throws Exception{ + SparseSimMatrix res = new SparseSimMatrix((float) threshold); + for (int i=0; i list, CategorySimilarity cs, double threshold) throws Exception{ + int n = list.size(); + AdjacencyBackedSparseMatrix res = new AdjacencyBackedSparseMatrix((float) threshold, n); + for (int i=0; i= (U.length - 1) - k + 1; j--){ +// res[i][U.length - 1 - j] = U[i][j]; +// } +// } +// +//// //Normalize rows of Y to use for k-means on row of Y +//// for (int i=0; i list, boolean unnormalized, int k) throws Exception{ +// //Get similarity matrix +// double[][] simMatrix = getSimilarityMatrix(list, cs); +// //Get laplacian matrix +// double[][] laplacian; +// if (unnormalized){ +// laplacian = getUnnormalizedLaplacianMatrix(simMatrix); +// }else laplacian = getNormalizedLaplacianMatrix(simMatrix); +// +// return getMatrixForKMean(laplacian, k); +// } +} diff --git a/src/main/java/kb/howtokb/clustering/HeuristicBottomupClustering.java b/src/main/java/kb/howtokb/clustering/HeuristicBottomupClustering.java new file mode 100644 index 0000000..4c3fd59 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/HeuristicBottomupClustering.java @@ -0,0 +1,199 @@ +package kb.howtokb.clustering; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javatools.administrative.Announce; +import kb.howtokb.clustering.basicobj.ActivityWordsCategory; +import kb.howtokb.clustering.basicobj.CSKCluster; +import kb.howtokb.utils.AutoMap; + +// Merging clusters. +public class HeuristicBottomupClustering implements IBottomUpClustering, ActivityWordsCategory> { + + // As we do hard-clustering, each cluster can be part of at-most one super + // cluster. Therefore, while merging, maintain a visited flag + private Map, Boolean> visited; + private List> potentiallyMergeableClusters; + // TODO a very costly map. + private Map elems; + + public HeuristicBottomupClustering(String activityTb) throws IOException { + // just store the pointer, we don't need to modify this data. + elems = new HashMap<>(); + loadElemsFromDb(activityTb); + Announce.message("# initial instances: " + elems.size()); + // Group on the strong norm form (e.g. paint a wall and paint a long wall) + AutoMap> lexicalClusters = new AutoMap<>(); + for (Entry e : elems.entrySet()) + lexicalClusters.addArrayValue(e.getValue(), e.getValue()); + + // Construct clusters from automap. + this.potentiallyMergeableClusters = new ArrayList<>(); + for (Entry> e : lexicalClusters.entrySet()) { + CSKCluster smallCluster = new CSKCluster( + e.getKey().getId()); + for (ActivityWordsCategory clusterMember : e.getValue()) + smallCluster.addClusterMember(clusterMember.getId()); + potentiallyMergeableClusters.add(smallCluster); + } + + // by default none will be unvisited or visited because map is empty. + visited = new HashMap<>(); + for (CSKCluster smallCluster : potentiallyMergeableClusters) { + visited.put(smallCluster, false); + } + + } + + public HeuristicBottomupClustering(Map activityTb) throws IOException { + // just store the pointer, we don't need to modify this data. + elems = new HashMap<>(); + elems = activityTb; + Announce.message("# initial instances: " + elems.size()); + // Group on the strong norm form (e.g. paint a wall and paint a long wall) + AutoMap> lexicalClusters = new AutoMap<>(); + for (Entry e : elems.entrySet()) + lexicalClusters.addArrayValue(e.getValue(), e.getValue()); + + // Construct clusters from automap. + this.potentiallyMergeableClusters = new ArrayList<>(); + for (Entry> e : lexicalClusters.entrySet()) { + CSKCluster smallCluster = new CSKCluster( + e.getKey().getId()); + for (ActivityWordsCategory clusterMember : e.getValue()) + smallCluster.addClusterMember(clusterMember.getId()); + potentiallyMergeableClusters.add(smallCluster); + } + + // by default none will be unvisited or visited because map is empty. + visited = new HashMap<>(); + for (CSKCluster smallCluster : potentiallyMergeableClusters) { + visited.put(smallCluster, false); + } + + } + + private void loadElemsFromDb(String activityTb) throws IOException { + try (BufferedReader br = new BufferedReader(new FileReader(activityTb))) { + String sCurrentLine; + while ((sCurrentLine = br.readLine()) != null) { + String [] line = sCurrentLine.split("\t"); + elems.put(Integer.parseInt(line[0]), + new ActivityWordsCategory(Integer.parseInt(line[0]), Integer.parseInt(line[1]), line[2])); + } + } + } + + // Are these clusters brothers? + @Override + public boolean canMergeWith(CSKCluster c1, CSKCluster c2, + ISimilarity simFunc, double simThreshold) throws Exception { + // Cluster key = representative of the cluster. + return simFunc.simThreshold(elems.get(c1.getClusterKey()), elems.get(c2.getClusterKey()), simThreshold); + } + + // create a super cluster. + // visit and merge. + public List cluster(ISimilarity simFunc, double simThreshold) throws Exception { + List merged = new ArrayList<>(); + Announce.message("#initial clusters = " + potentiallyMergeableClusters.size()); + //Progress p = new Progress(10); + for (CSKCluster c1 : potentiallyMergeableClusters) { + //p.next(); + // already associated with another supercluster? + if (visited.get(c1)) + continue; + visited.put(c1, true); + + // initialize the cluster. + ActivitySuperCluster unmerged = new ActivitySuperCluster(); + unmerged.addCluster(c1); + + for (CSKCluster c2 : potentiallyMergeableClusters) { + if (visited.get(c2)) + continue; + + // Should c1 absorb the little cluster c2? + // TODO: think about a thresholding algorithm? + // TODO: think about early pruning to avoid canMergeWith for + // every pair. + if (canMergeWith(c1, c2, simFunc, simThreshold)) { + unmerged.addCluster(c2); + visited.put(c2, true); + } + } + + // We are done with c1. No other cluster wants to associate with it + merged.add(unmerged); + } + + return merged; + } + + public static class ActivitySuperCluster { + + // bunch of strongly normalized triples make the key. + private List superClusterKeys; + // a list of all members of the supercluster + private List superClusterMembers; + + public List getSuperClusterKeys() { + return superClusterKeys; + } + + public List getSuperClusterMembers() { + return superClusterMembers; + } + + public ActivitySuperCluster() { + this.superClusterKeys = new ArrayList<>(); + this.superClusterMembers = new ArrayList<>(); + } + + public void addCluster(CSKCluster c) { + this.superClusterKeys.add(c.getClusterKey()); + for (Integer smallClusterMember : c.getClusterMembers()) + this.superClusterMembers.add(smallClusterMember); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((superClusterKeys == null) ? 0 : superClusterKeys.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + ActivitySuperCluster other = (ActivitySuperCluster) obj; + if (superClusterKeys == null) { + if (other.superClusterKeys != null) + return false; + } else if (!superClusterKeys.equals(other.superClusterKeys)) + return false; + return true; + } + + @Override + public String toString() { + return superClusterKeys.size() + "\tCSKSuperCluster [superClusterKeys=" + superClusterKeys + + ", superClusterMembers=" + superClusterMembers + "]"; + } + + } + +} diff --git a/src/main/java/kb/howtokb/clustering/HeuristicTopDownClustering.java b/src/main/java/kb/howtokb/clustering/HeuristicTopDownClustering.java new file mode 100644 index 0000000..991e52b --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/HeuristicTopDownClustering.java @@ -0,0 +1,147 @@ +package kb.howtokb.clustering; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import kb.howtokb.clustering.basicobj.CSKSimpleCluster; +import kb.howtokb.clustering.basicobj.Instance; +import kb.howtokb.clustering.sim.CategorySimilarity; +import kb.howtokb.taskframe.WikiHowTaskFrame; + +public class HeuristicTopDownClustering + implements ITopDownClustering, Instance> { + + private IStopping> stopper; + private double[][] simMatrix; + private CSKSimpleCluster inputCluster; + public HeuristicTopDownClustering(List pts, boolean isNormalizedCut, double threshold, int k) throws SQLException, Exception { + List> ints = new ArrayList<>(); + for (int i=0; i(i, pts.get(i))); + inputCluster = new CSKSimpleCluster<>(-1, ints); + simMatrix = DataForClustering.getFullSimilarityMatrix(pts, new CategorySimilarity()); + this.stopper = new SimpleSimilarityStopping(threshold, simMatrix, isNormalizedCut); + } + + public CSKSimpleCluster getInputCluster() { + return inputCluster; + } + + @Override + public boolean canSplitFrom(CSKSimpleCluster c1, CSKSimpleCluster c2, + ISimilarity> simFunc, double simThreshold) { + // TODO Auto-generated method stub + return false; + } + + @Override + public List> splitACluster(CSKSimpleCluster c, int k) + throws Exception { + List> res = new ArrayList<>(); + + + List> members = new ArrayList<>(c.getClusterMembers()); +// c.getClusterMembers(); + Instance farthestPt = getFarthestPt(members); + + CSKSimpleCluster initCluster = new CSKSimpleCluster<>(-1, farthestPt); + members.remove(farthestPt); + CSKSimpleCluster leftCluster = new CSKSimpleCluster<>(-1, members); + + assignCluster(initCluster, leftCluster); + + List> children = new ArrayList<>(); + children.add(initCluster); children.add(leftCluster); + if (stopper.split(c, children)){ + for (int i=0; i 1) + res.addAll(splitACluster(children.get(i), k)); + else res.add(children.get(i)); + } + }else{ +// System.out.println(c.getClusterMembers().size()); + res.add(c); + } + return res; + } + + public Instance getFarthestPt(List> members){ + Instance farthestPt = new Instance<>(); + double min = Double.MAX_VALUE; + for (int i=0; i sum){ + min = sum; + farthestPt = members.get(i); + } + } + return farthestPt; + } + + public void assignCluster(CSKSimpleCluster initCluster, CSKSimpleCluster leftCluster){ + List> members = leftCluster.getClusterMembers(); + for (int i=0; i simPtToCluster(members.get(i), leftCluster)){ + initCluster.addClusterMember(members.get(i)); + leftCluster.removeMember(members.get(i)); + i--; + } + } + } + + public double simPtToCluster(Instance pt, CSKSimpleCluster cluster){ + List> members = cluster.getClusterMembers(); + double sum = 0.0; + for (int i=0; i + */ + public static class SimpleSimilarityStopping implements + IStopping> { + + private double threshold; + private double[][] simMatrix; + private boolean isNormalizedCut; + public SimpleSimilarityStopping(double thres, double [][] simMatrix, boolean isNormalizedCut) { + this.threshold = thres; + this.simMatrix = simMatrix; + this.isNormalizedCut = isNormalizedCut; + } + + + @Override + public boolean split(CSKSimpleCluster parent, List> children) { + if (parent.getClusterMembers().size() <= 1) return false; + if (isNormalizedCut) return split(children.get(0), children.get(1)); + return SimpleClusterSimilarity.averageIntraClusterSimilarity(parent, simMatrix) < this.threshold; + } + + //Similar to minimizing min cut + public boolean split(CSKSimpleCluster c1, CSKSimpleCluster c2) { + + double volC1 = SimpleClusterSimilarity.volOfCluster(c1, simMatrix); + double volC2 = SimpleClusterSimilarity.volOfCluster(c2, simMatrix); + double interC = SimpleClusterSimilarity.volOfInterClusters(c1, c2, simMatrix); + if (volC1 == 1 && volC2 == 1) return false; //Don't split into two singletons + if (volC1 == 1) return interC/volC2 < this.threshold; + if (volC2 == 1) return interC/volC1 < this.threshold; + System.out.println(volC1 + " " + volC2 + " " + interC); + return (interC/volC1 + interC/volC2) < this.threshold; + } + + } + +} diff --git a/src/main/java/kb/howtokb/clustering/HeuristicTopDownClusteringDynamicSparse.java b/src/main/java/kb/howtokb/clustering/HeuristicTopDownClusteringDynamicSparse.java new file mode 100644 index 0000000..fef5f97 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/HeuristicTopDownClusteringDynamicSparse.java @@ -0,0 +1,154 @@ +package kb.howtokb.clustering; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import kb.howtokb.clustering.basicobj.CSKSimpleCluster; +import kb.howtokb.clustering.basicobj.Instance; +import kb.howtokb.clustering.sim.CategorySimilarity; +import kb.howtokb.taskframe.WikiHowTaskFrame; +import kb.howtokb.utils.AdjacencyBackedSparseMatrix; + +public class HeuristicTopDownClusteringDynamicSparse + implements ITopDownClustering, Instance> { + + private IStopping> stopper; + private AdjacencyBackedSparseMatrix simMatrix; + private CSKSimpleCluster inputCluster; + public HeuristicTopDownClusteringDynamicSparse(List pts, boolean isNormalizedCut, double threshold, int k, double thresforSparse) throws SQLException, Exception { + List> ints = new ArrayList<>(); + for (int i=0; i(i, pts.get(i))); + inputCluster = new CSKSimpleCluster<>(-1, ints); + simMatrix = DataForClustering.getAdjacencyBackedSparseSimilarityMatrix(pts, new CategorySimilarity(), thresforSparse); + System.out.println("================================================"); + this.stopper = new SimpleSimilarityStopping(threshold, simMatrix, isNormalizedCut); + } + + public CSKSimpleCluster getInputCluster() { + return inputCluster; + } + + @Override + public boolean canSplitFrom(CSKSimpleCluster c1, CSKSimpleCluster c2, + ISimilarity> simFunc, double simThreshold) { + // TODO Auto-generated method stub + return false; + } + + @Override + public List> splitACluster(CSKSimpleCluster c, int k) + throws Exception { + List> res = new ArrayList<>(); + + + List> members = new ArrayList<>(c.getClusterMembers()); +// c.getClusterMembers(); + Instance farthestPt = getFarthestPt(members); + + CSKSimpleCluster initCluster = new CSKSimpleCluster<>(-1, farthestPt); + members.remove(farthestPt); + CSKSimpleCluster leftCluster = new CSKSimpleCluster<>(-1, members); + + assignCluster(initCluster, leftCluster); + //System.out.println(initCluster.getClusterMembers().size() + " " + leftCluster.getClusterMembers().size()); + + List> children = new ArrayList<>(); + children.add(initCluster); children.add(leftCluster); + if (stopper.split(c, children)){ + for (int i=0; i 1) + res.addAll(splitACluster(children.get(i), k)); + else res.add(children.get(i)); + } + }else{ +// System.out.println(c.getClusterMembers().size()); + res.add(c); + } + return res; + } + + public Instance getFarthestPt(List> members){ + System.out.println("Find the farthest Pt...."); + Instance farthestPt = new Instance<>(); + double min = Double.MAX_VALUE; + for (int i=0; i sum){ + min = sum; + farthestPt = members.get(i); + } + } + System.out.println("Finishing finding the farthest pt..."); + return farthestPt; + } + + public void assignCluster(CSKSimpleCluster initCluster, CSKSimpleCluster leftCluster){ + System.out.println("Assign cluster....."); + List> members = leftCluster.getClusterMembers(); + for (int i=0; i simPtToCluster(members.get(i), leftCluster)){ + initCluster.addClusterMember(members.get(i)); + leftCluster.removeMember(members.get(i)); + i--; + } + } + System.out.println("Done assign....."); + } + + public double simPtToCluster(Instance pt, CSKSimpleCluster cluster){ + List> members = cluster.getClusterMembers(); + double sum = 0.0; + for (int i=0; i + */ + public static class SimpleSimilarityStopping implements + IStopping> { + + private double threshold; + private AdjacencyBackedSparseMatrix simMatrix; + private boolean isNormalizedCut; + public SimpleSimilarityStopping(double thres, AdjacencyBackedSparseMatrix simMatrix, boolean isNormalizedCut) { + this.threshold = thres; + this.simMatrix = simMatrix; + this.isNormalizedCut = isNormalizedCut; + } + + + @Override + public boolean split(CSKSimpleCluster parent, List> children) { + if (parent.getClusterMembers().size() <= 1) return false; + if (isNormalizedCut) return split(children.get(0), children.get(1)); + return SimpleClusterSimilarity.averageIntraClusterSimilarity(parent, simMatrix) < this.threshold; + } + + //Similar to minimizing min cut + public boolean split(CSKSimpleCluster c1, CSKSimpleCluster c2) { + + double volC1 = SimpleClusterSimilarity.volOfCluster(c1, simMatrix); + double volC2 = SimpleClusterSimilarity.volOfCluster(c2, simMatrix); + double interC = SimpleClusterSimilarity.volOfInterClusters(c1, c2, simMatrix); + if (volC1 == 1 && volC2 == 1) return false; //Don't split into two singletons + if (volC1 == 1) return interC/volC2 < this.threshold; + if (volC2 == 1) return interC/volC1 < this.threshold; + System.out.println(volC1 + " " + volC2 + " " + interC); + return (interC/volC1 + interC/volC2) < this.threshold; + } + + } + +} diff --git a/src/main/java/kb/howtokb/clustering/IBottomUpClustering.java b/src/main/java/kb/howtokb/clustering/IBottomUpClustering.java new file mode 100644 index 0000000..514581f --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/IBottomUpClustering.java @@ -0,0 +1,8 @@ +package kb.howtokb.clustering; + + +public interface IBottomUpClustering { + + public boolean canMergeWith(ClusterType c1, ClusterType c2, + ISimilarity simFunc, double simThreshold) throws Exception; +} diff --git a/src/main/java/kb/howtokb/clustering/ISimilarity.java b/src/main/java/kb/howtokb/clustering/ISimilarity.java new file mode 100644 index 0000000..a720a47 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/ISimilarity.java @@ -0,0 +1,17 @@ +package kb.howtokb.clustering; + +import java.util.Map; + +/** Similarity between either pair of words (tiger, dog); +* or pair of senses (e.g. tiger#n#2, cat#n#1) +* +* @author ntandon +*/ +public interface ISimilarity { + + public double sim(T e1, T e2); + + public Map getNeighbors(T e); + + public boolean simThreshold(T e1, T e2, double minthreshold) throws Exception; +} diff --git a/src/main/java/kb/howtokb/clustering/ITopDownClustering.java b/src/main/java/kb/howtokb/clustering/ITopDownClustering.java new file mode 100644 index 0000000..08d7bf6 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/ITopDownClustering.java @@ -0,0 +1,28 @@ +package kb.howtokb.clustering; + +import java.util.List; + +public interface ITopDownClustering { + + /** + * Can we split a cluster into two. + * @param c + * @return + */ + public boolean canSplitFrom(ClusterType c1, ClusterType c2, + ISimilarity simFunc, double simThreshold); + + /** + * Supports k-way splits. This can be achieved via graph cut algorithms, + * or the simplest method being k-means. Question is how to compute the + * median in k-means over activity frames. + * @param c + * @return + */ + abstract public List splitACluster(ClusterType c, int k) throws Exception; + + public static interface IStopping { + + public boolean split(ClusterType parent, List children); + } +} diff --git a/src/main/java/kb/howtokb/clustering/SimpleClusterSimilarity.java b/src/main/java/kb/howtokb/clustering/SimpleClusterSimilarity.java new file mode 100644 index 0000000..d774ba2 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/SimpleClusterSimilarity.java @@ -0,0 +1,171 @@ +package kb.howtokb.clustering; + +import java.util.List; + +import kb.howtokb.clustering.basicobj.CSKSimpleCluster; +import kb.howtokb.clustering.basicobj.Instance; +import kb.howtokb.taskframe.WikiHowTaskFrame; +import kb.howtokb.utils.AdjacencyBackedSparseMatrix; +import kb.howtokb.utils.SparseSimMatrix; + +public class SimpleClusterSimilarity { + + /** + * Vol of cluster with input is a 2-D array + * @param c + * @param simMatrix + * @return + */ + public static double volOfCluster(CSKSimpleCluster c, double[][] simMatrix){ + List> members = c.getClusterMembers(); + double sum = 0.0; + if (members.size() == 1) return 1.0; + for (int i=0; i c, SparseSimMatrix simMatrix){ + List> members = c.getClusterMembers(); + double sum = 0.0; + if (members.size() == 1) return 1.0; + for (int i=0; i c, AdjacencyBackedSparseMatrix simMatrix){ + List> members = c.getClusterMembers(); + double sum = 0.0; + if (members.size() == 1) return 1.0; + for (int i=0; i c, double[][] simMatrix){ + int n = c.getClusterMembers().size(); + int total = n*(n-1)/2; + System.out.println(volOfCluster(c, simMatrix)/total); + return volOfCluster(c, simMatrix)/total; + } + /** + * SparseMatrix + * @param c + * @param simMatrix + * @return + */ + public static double averageIntraClusterSimilarity(CSKSimpleCluster c, SparseSimMatrix simMatrix){ + int n = c.getClusterMembers().size(); + int total = n*(n-1)/2; + double vol = volOfCluster(c, simMatrix)/total; + System.out.println(vol); + return vol; + } + + /** + * DynamicSparseMatrix + * @param c + * @param simMatrix + * @return + */ + public static double averageIntraClusterSimilarity(CSKSimpleCluster c, AdjacencyBackedSparseMatrix simMatrix){ + int n = c.getClusterMembers().size(); + int total = n*(n-1)/2; + double vol = volOfCluster(c, simMatrix)/total; + System.out.println(vol); + return vol; + } + + /** + * 2-d double array + * @param c1 + * @param c2 + * @param simMatrix + * @return + */ + public static double volOfInterClusters(CSKSimpleCluster c1, + CSKSimpleCluster c2, double[][] simMatrix){ + List> members1 = c1.getClusterMembers(); + List> members2 = c2.getClusterMembers(); + double sum = 0.0; + for (int i=0; i c1, + CSKSimpleCluster c2, SparseSimMatrix simMatrix){ + List> members1 = c1.getClusterMembers(); + List> members2 = c2.getClusterMembers(); + double sum = 0.0; + for (int i=0; i c1, + CSKSimpleCluster c2, AdjacencyBackedSparseMatrix simMatrix){ + List> members1 = c1.getClusterMembers(); + List> members2 = c2.getClusterMembers(); + double sum = 0.0; + for (int i=0; i c1, + CSKSimpleCluster c2, double[][] simMatrix){ + int total = c1.getClusterMembers().size() * c2.getClusterMembers().size(); + return volOfInterClusters(c1, c2, simMatrix)/total; + } +} diff --git a/src/main/java/kb/howtokb/clustering/SimplePruningSimilarity.java b/src/main/java/kb/howtokb/clustering/SimplePruningSimilarity.java new file mode 100644 index 0000000..0fe7f74 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/SimplePruningSimilarity.java @@ -0,0 +1,141 @@ +package kb.howtokb.clustering; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.sql.SQLException; +import java.util.Map; +import java.util.Set; + +import edu.stanford.nlp.util.Pair; +import kb.howtokb.clustering.basicobj.ActivityWordsCategory; +import kb.howtokb.clustering.sim.CategorySimilarity; +import kb.howtokb.clustering.sim.Coefficient; +import kb.howtokb.clustering.sim.w2v.Word2VecSimilarity; +import kb.howtokb.utils.BijectiveMap; +import kb.howtokb.utils.FileLines; +import kb.howtokb.utils.IDMap; + +public class SimplePruningSimilarity implements ISimilarity{ + + private static CategorySimilarity cate; + private double threshold; + private String model; + private String allAct; + private Set allActList; + private ActivityCachedSim aSim; + public SimplePruningSimilarity(double threshold, String model, String allAct) throws SQLException, IOException, ClassNotFoundException { + cate = new CategorySimilarity(); + + this.threshold = threshold; + this.model = model; + this.allAct = allAct; + ISimilarity word2vecSim = new Word2VecSimilarity(this.model, 0, false); + IDMap vIDs = new IDMap<>(0); + IDMap nIDs = new IDMap<>(0); + BijectiveMap> aIDs = new BijectiveMap<>(); + loadIDs(this.allAct, vIDs, nIDs, aIDs); + + //Using separate threshold for verb and noun v: 0.747, n: 0.67 + //Or can use same threshold, the bottom neck: vvnn: 0.5 + ActivityComponentSim vSim = new ActivityComponentSim(Coefficient.V_THRES, vIDs, word2vecSim); + ActivityComponentSim nSim = new ActivityComponentSim(Coefficient.O_THRES, nIDs, word2vecSim); + this.aSim = new ActivityCachedSim(aIDs, vSim, nSim); + } + + public SimplePruningSimilarity(double threshold, String model, Set allAct) throws SQLException, IOException, ClassNotFoundException { + cate = new CategorySimilarity(); + + this.threshold = threshold; + this.model = model; + this.allActList = allAct; + ISimilarity word2vecSim = new Word2VecSimilarity(this.model, 0, false); + IDMap vIDs = new IDMap<>(0); + IDMap nIDs = new IDMap<>(0); + BijectiveMap> aIDs = new BijectiveMap<>(); + loadIDs(this.allActList, vIDs, nIDs, aIDs); + + //Using separate threshold for verb and noun v: 0.747, n: 0.67 + //Or can use same threshold, the bottom neck: vvnn: 0.5 + ActivityComponentSim vSim = new ActivityComponentSim(Coefficient.V_THRES, vIDs, word2vecSim); + ActivityComponentSim nSim = new ActivityComponentSim(Coefficient.O_THRES, nIDs, word2vecSim); + this.aSim = new ActivityCachedSim(aIDs, vSim, nSim); + } + + @Override + public double sim(ActivityWordsCategory e1, ActivityWordsCategory e2) { + // TODO we don't have function to compute in this case + return 0; + } + + @Override + public Map getNeighbors(ActivityWordsCategory e) { + // TODO We don't need this function right now + return null; + } + /** + * Check similar between two activities in a simple way + * if similarity between two category is less than a threshold + * and similarity between two strong activities is less than a threshold + * then they are dissimilar, otw not sure and return true + * @param two ActivityWordsCategory objects + * @return true/false + */ + @Override + public boolean simThreshold(ActivityWordsCategory e1, ActivityWordsCategory e2, double minthreshold) throws Exception { + if (!cate.isSim(e1.getCatID(), e2.getCatID())) + if (!aSim.sim(e1.getActivityStrong(), e2.getActivityStrong())) + return false; + return true; + + } + /** + * loadID from a file including a list of activities + * @param activityList + * @param vIDs + * @param nIDs + * @param aIDs + * @throws FileNotFoundException + */ + private static void loadIDs(String activityList, IDMap vIDs, IDMap nIDs, + BijectiveMap> aIDs) throws FileNotFoundException { + for (String a : new FileLines(activityList)) { + // return,from,work + String[] vn = a.split(";"); + String verb = "v_" + (vn[0].contains(" ")?vn[0].split(" ")[0]:vn[0]); + String noun = "n_" + (vn[1].contains(" ")?vn[1].split(" ")[vn[1].split(" ").length - 1]:vn[1]); + int vid = vIDs.getAvailableGlobalID(); + vIDs.add(verb); + int nid = nIDs.getAvailableGlobalID(); + nIDs.add(noun); + aIDs.put(a, new Pair(vid, nid)); + } + System.out.println("Load map verb to id done! Size: " + vIDs.size()); + System.out.println("Load map noun to id done! Size: " + nIDs.size()); + } + + /** + * loadID from list of activities + * @param activityList + * @param vIDs + * @param nIDs + * @param aIDs + * @throws FileNotFoundException + */ + private static void loadIDs(Set activityList, IDMap vIDs, IDMap nIDs, + BijectiveMap> aIDs) throws FileNotFoundException { + for (String a : activityList) { + // return,from,work + String[] vn = a.split(";"); + String verb = "v_" + (vn[0].contains(" ")?vn[0].split(" ")[0]:vn[0]); + String noun = "n_" + (vn[1].contains(" ")?vn[1].split(" ")[vn[1].split(" ").length - 1]:vn[1]); + int vid = vIDs.getAvailableGlobalID(); + vIDs.add(verb); + int nid = nIDs.getAvailableGlobalID(); + nIDs.add(noun); + aIDs.put(a, new Pair(vid, nid)); + } + System.out.println("Load map verb to id done! Size: " + vIDs.size()); + System.out.println("Load map noun to id done! Size: " + nIDs.size()); + } + +} diff --git a/src/main/java/kb/howtokb/clustering/basicobj/ActivityWordsCategory.java b/src/main/java/kb/howtokb/clustering/basicobj/ActivityWordsCategory.java new file mode 100644 index 0000000..c2d75e0 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/basicobj/ActivityWordsCategory.java @@ -0,0 +1,58 @@ +package kb.howtokb.clustering.basicobj; + +public class ActivityWordsCategory { + + private int id; + private int catID; + private String activityStrong; + + public ActivityWordsCategory(int id, int catID, String activityStrong) { + super(); + this.id = id; + this.catID = catID; + this.activityStrong = activityStrong; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((activityStrong == null) ? 0 : activityStrong.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + ActivityWordsCategory other = (ActivityWordsCategory) obj; + if (activityStrong == null) { + if (other.activityStrong != null) + return false; + } else if (!activityStrong.equals(other.activityStrong)) + return false; + return true; + } + + public int getId() { + return id; + } + + public int getCatID() { + return catID; + } + + public String getActivityStrong() { + return activityStrong; + } + + @Override + public String toString() { + return "ActivityWordsCategory [id=" + id + ", activityStrong=" + activityStrong + "]"; + } + +} diff --git a/src/main/java/kb/howtokb/clustering/basicobj/BasicDataPt.java b/src/main/java/kb/howtokb/clustering/basicobj/BasicDataPt.java new file mode 100644 index 0000000..003d8ed --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/basicobj/BasicDataPt.java @@ -0,0 +1,6 @@ +package kb.howtokb.clustering.basicobj; + + +public interface BasicDataPt { + public int getID(); +} diff --git a/src/main/java/kb/howtokb/clustering/basicobj/CSKCluster.java b/src/main/java/kb/howtokb/clustering/basicobj/CSKCluster.java new file mode 100644 index 0000000..71c3f40 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/basicobj/CSKCluster.java @@ -0,0 +1,85 @@ +package kb.howtokb.clustering.basicobj; + +import java.util.ArrayList; +import java.util.List; + +public class CSKCluster { + + protected TClusterID clusterKey; + + public CSKCluster(TClusterID clusterKey) { + this.clusterKey = clusterKey; + this.clusterMembers = new ArrayList<>(); + } + + public CSKCluster(TClusterID id, List members){ + this.clusterKey = id; + this.clusterMembers = members; + } + + public TClusterID getClusterKey() { + return clusterKey; + } + + protected List clusterMembers; + + public List getClusterMembers() { + return clusterMembers; + } + + public void setClusterMembers(List clusterMembers) { + this.clusterMembers = clusterMembers; + } + + public CSKCluster addClusterMember( + TClusterMember clusterMember) { + clusterMembers.add(clusterMember); + return this; + } + + public CSKCluster addClusterMemberSet( + List clusterMember) { + clusterMembers.addAll(clusterMember); + return this; + } + + public CSKCluster removeMember(TClusterMember pt){ + clusterMembers.remove(pt); + return this; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = + prime * result + ((clusterKey == null) ? 0 : clusterKey.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + CSKCluster other = (CSKCluster) obj; + if (clusterKey == null) { + if (other.clusterKey != null) + return false; + } else if (!clusterKey.equals(other.clusterKey)) + return false; + return true; + } + + @Override + public String toString() { + return "CSKCluster [clusterKey=" + clusterKey + "]"; + } + + public void clear(){ + this.clusterMembers = new ArrayList<>(); + } +} diff --git a/src/main/java/kb/howtokb/clustering/basicobj/CSKSimpleCluster.java b/src/main/java/kb/howtokb/clustering/basicobj/CSKSimpleCluster.java new file mode 100644 index 0000000..3596a49 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/basicobj/CSKSimpleCluster.java @@ -0,0 +1,23 @@ +package kb.howtokb.clustering.basicobj; + +import java.util.ArrayList; +import java.util.List; + +public class CSKSimpleCluster extends CSKCluster>{ + + public CSKSimpleCluster(Integer clusterKey) { + super(clusterKey); + // TODO Auto-generated constructor stub + } + + public CSKSimpleCluster(Integer key, Instance pt){ + super(key); + this.clusterMembers = new ArrayList<>(); + this.clusterMembers.add(pt); + } + + public CSKSimpleCluster(Integer key, List> members){ + super(key, members); + } + +} diff --git a/src/main/java/kb/howtokb/clustering/basicobj/Instance.java b/src/main/java/kb/howtokb/clustering/basicobj/Instance.java new file mode 100644 index 0000000..c613294 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/basicobj/Instance.java @@ -0,0 +1,40 @@ +package kb.howtokb.clustering.basicobj; + +public class Instance implements BasicDataPt{ + + private int id; + private T frame; + + public Instance(int id, T frame) { + this.id = id; + this.frame = frame; + } + + public Instance(Instance tmp) { + this.id = tmp.getID(); + this.frame = tmp.getData(); + } + + public Instance() { + this.id = -1; + this.frame = null; + } + + public void setId(int id) { + this.id = id; + } + + public void setFrame(T frame) { + this.frame = frame; + } + + @Override + public int getID() { + return id; + } + + public T getData(){ + return frame; + } + +} diff --git a/src/main/java/kb/howtokb/clustering/sim/ActivityWordCategorySim.java b/src/main/java/kb/howtokb/clustering/sim/ActivityWordCategorySim.java new file mode 100644 index 0000000..4ae629c --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/sim/ActivityWordCategorySim.java @@ -0,0 +1,83 @@ +package kb.howtokb.clustering.sim; + +import gnu.trove.TLongHashSet; + +public class ActivityWordCategorySim { + + /* + public static void main(String[] args) throws Exception { + SparseSims distSparse = new SparseSims(0.5f); + // need: ids of activity frames and ActivityWordsCategory + //Test load all activity frame + ActivityWordsCategory[] arr = new ActivityWordsCategory[1292250]; + + String inputfile = "/var/tmp/cxchu/data-wordnet/act-frame.json"; + JSONParser parser = new JSONParser(); + + try (BufferedReader br = new BufferedReader(new FileReader(inputfile))) { + + String sCurrentLine; + int i=0; + + while ((sCurrentLine = br.readLine()) != null) { + + Object obj = parser.parse(sCurrentLine); + JSONObject jsonObject = (JSONObject) obj; + ActivityFrame newframe = JsonToActivityFrame.jsonToActivityFrame(jsonObject); + ActivityWordsCategory tmp = new ActivityWordsCategory(newframe.getID(), + Integer.parseInt(newframe.getActivity().getCategoryID()), newframe.getActivity().getVerb() + ";" + newframe.getActivity().getObject()); + System.out.println(i); + arr[i++] = tmp; + } + } + + for (int i = 0; i < arr.length; i++){ + ActivityWordsCategory e1 = arr[i]; + // Only store the upper triangular matrix. + for (int j = i + 1; j < arr.length; j++) + distSparse.set(i, j, (float) Word2VecRunner.simPair(e1.getActivityStrong(), arr[j].getActivityStrong())); + } + System.out.println("Done!"); + + + } + */ + + /** + * Only stores the upper triangular matrix. + * + * @author cxchu + * + */ + public static class SparseSims { + + private TLongHashSet simPairs; + private float threshold; + + public SparseSims(float threshold) { + simPairs = new TLongHashSet(); + this.threshold = threshold; + } + + public void set(int x, int y, float value) { + if (value < threshold) + return; + long key = intpairToLong(x, y); + simPairs.add(key); + } + + public boolean get(int x, int y) { + long key = intpairToLong(x, y); + if (simPairs.contains(key)) + return true; + + return simPairs.contains(intpairToLong(y, x)); + } + + private long intpairToLong(int l, int r) { + return ((long) l << 32) + r; + } + + } + +} diff --git a/src/main/java/kb/howtokb/clustering/sim/CategorySimilarity.java b/src/main/java/kb/howtokb/clustering/sim/CategorySimilarity.java new file mode 100644 index 0000000..d54d78e --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/sim/CategorySimilarity.java @@ -0,0 +1,174 @@ +package kb.howtokb.clustering.sim; + +import java.io.IOException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import edu.stanford.nlp.util.Pair; +import kb.howtokb.utils.AutoMap; +import kb.howtokb.utils.SQLiteJDBCConnector; +import kb.howtokb.wkhobject.Category_Json; + +public class CategorySimilarity { + + /** + * Compute similarity between two categories using + * sim (c1, c2) = (2*height(lca(c1,c2)) + 1)/(height(c1) + height(c2) + 1) + * with additive smoothing + * + */ + private static Map> parentChains; + private static Map, Double> preCate; + private static List allCate; + private static double threshold = Coefficient.CATE_THRES; + + public CategorySimilarity() throws SQLException, ClassNotFoundException, IOException{ + loadParentChains(); + loadAllCate(); + preComputedCate(); + } + + //Compute similarity between two categories + public double simWUP(int c1, int c2) { + int c0 = firstCommAncestor(c1, c2); + double pc0 = path2root(c0).size() -1; + double pc1 = path2root(c1).size() -1; + double pc2 = path2root(c2).size() -1; + //System.out.println(pc0 + ", " + pc1 + ", " + pc2); + return (2.0 * pc0 + 1) / (pc1 + pc2 + 1); + } + + //Compute similarity between two categories + public double sim(int c1, int c2) { + try{ + Pair pair = new Pair(c1, c2); + Pair ipair = new Pair(c2, c1); + if (preCate.containsKey(pair)) + return preCate.get(pair); + + return preCate.get(ipair); + + }catch(Exception e){ + return simWUP(c1,c2); + } + } + + /** + * Check if similarity between two categories is greater than a threshold + * @return + */ + public boolean isSim(int c1, int c2){ + Pair pair = new Pair(c1, c2); + if (preCate.containsKey(pair)) + return true; + + return preCate.containsKey(new Pair(c2,c1)); + } + + public Map, Double> getPreCate() { + return preCate; + } + + private void loadParentChains() throws SQLException, ClassNotFoundException, IOException { + parentChains = new AutoMap<>(); + // "rootpath":[57,54,52,150,1] + ResultSet rs = SQLiteJDBCConnector.q("select id, json from categoryjson"); + while (rs.next()) { + try { + parentChains.put(rs.getInt(1), Category_Json.fromJson(rs.getString(2)).getRootpath()); + } catch (Exception e) { + System.out.print("\n---- JSONException in category: " + rs.getInt(1)); + } + } + } + + private void loadAllCate(){ + allCate = new ArrayList<>(); + for (Entry> e: parentChains.entrySet()){ + allCate.add(e.getKey()); + } + } + + private void preComputedCate(){ + System.out.println("Pre-Computing similarity between two categories....."); + preCate = new HashMap<>(); + for (int i=0; i= threshold){ + preCate.put(new Pair(c1, c2), sim); + } + } + } + + System.out.println("Done! Number of pair: " + preCate.size()); + } + + //Get lowest common ancestor + private int firstCommAncestor(int c1, int c2) { + List r1 = path2root(c1); + List r2 = path2root(c2); + + // first common ancestor + for (int i : r1) + if (r2.contains(i)) + return i; + + return 1; // the root. + } + + @SuppressWarnings("unchecked") + private List path2root(int c1) { + return parentChains.containsKey(c1) ? parentChains.get(c1) : Collections.EMPTY_LIST; + } + + + /*private static void testSimCat(CategorySimilarity s) { + String input = ""; + try { + while (!input.equals("q")) { + input = Util.readStringFromUser("\nc1,c2:"); + String[] pair = input.replace(" ", "").split(","); + + if (pair == null || pair.length < 2) { + if (!input.equals("q")) + System.out.println("-- wrong input -- e.g. 12,28"); + continue; + } + + System.out.println(Util.format(s.sim(Integer.parseInt(pair[0]), Integer.parseInt(pair[1])))); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + DBConnector.closeConnections(); + } + }*/ + + /*public static void main(String[] args) throws Exception { + CategorySimilarity cs = new CategorySimilarity(); + //testSimCat(cs); + + + Writer cateout = new BufferedWriter(new OutputStreamWriter( + new FileOutputStream("/var/tmp/cxchu/clustering-pre-computation/preCate.txt"), "utf-8")); + + //Write preCate to file + Map, Double> preCate = cs.getPreCate(); + for (Entry, Double> e: preCate.entrySet()){ + cateout.write(e.getKey().first + ";" + e.getKey().second + + "\t" + Util.format(e.getValue()) + "\n"); + } + cateout.close(); + }*/ + + +} \ No newline at end of file diff --git a/src/main/java/kb/howtokb/clustering/sim/Coefficient.java b/src/main/java/kb/howtokb/clustering/sim/Coefficient.java new file mode 100644 index 0000000..4db709c --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/sim/Coefficient.java @@ -0,0 +1,53 @@ +package kb.howtokb.clustering.sim; + +public class Coefficient { + + /*=======Coefficient================ + * verb + * object + * category + * location + * time + * part object + * part agent + * weak verb + * weak object + * vvnn + */ + public static double[] ALL_COEF = new double[]{5.4, 0.16, 3.074, -1.45, -2.08, 0.893, + -0.586, 0.796, 2.388, 6.763}; + public static double INTERCEPT = -11.414; + public static double[] COEF_WITHOUT_WEAK = new double[]{5.289, 1.3, 3.173, -1.44, -2.047, 1.386, + -0.52, 7.011}; + public static double INTERCEPT_WITHOUT_WEAK = -10.638; + + //threshold + public static double CATE_THRES = 0.362; + public static double VVNN_TRHES = 0.5; + public static double V_THRES = 0.747; + public static double O_THRES = 0.67; + + + //Coefficient with parent, prev/next/ sub-act + /*=======Full Coefficient================ + * verb + * object + * category + * location + * time + * part object + * part agent + * weak verb + * weak object + * vvnn + * parent + * prev + * next + * sub-act + */ + + public static double[] FULL_COEF = new double[]{5.2, 0.41, 3.006, -2.12, -2.69, 1.45, + -0.328, 0.426, 2.452, 6.983, -0.706, -0.963, 0.378, -0.564}; + public static double INTERCEPT_FULL = -11.455; + +} diff --git a/src/main/java/kb/howtokb/clustering/sim/SimilarityComputation.java b/src/main/java/kb/howtokb/clustering/sim/SimilarityComputation.java new file mode 100644 index 0000000..6ebb455 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/sim/SimilarityComputation.java @@ -0,0 +1,312 @@ +package kb.howtokb.clustering.sim; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.List; + +import kb.howtokb.taskframe.WikiHowTaskFrame; +import kb.howtokb.tools.InformationExtraction; +import kb.howtokb.tools.NormalizationText; +import kb.howtokb.tools.StructureConverter; + +public class SimilarityComputation { + + //Similarity between two category + public static double simCategory(CategorySimilarity cs, int c1, int c2){ + //There are a few activity frame having new category, the cateID = -1 + //We transfer to GENERAL_CATEGORY + c1 = c1<1?1:c1; c2 = c2<1?1:c2; + return Double.parseDouble(NormalizationText.format(cs.sim(c1, c2))); + } + //similarity between two words: verb + public static double simVerb(String word1, String word2) throws NumberFormatException, IOException, SQLException{ + return Double.parseDouble(NormalizationText.format(StringSimilarity.simOfVerb(word1, word2))); + } + + //similarity between two words: object + public static double simNoun(String word1, String word2) throws NumberFormatException, IOException, SQLException{ + return Double.parseDouble(NormalizationText.format(StringSimilarity.simOfNoun(word1, word2))); + } + + //similarity between two activities using word2vec + public static double simActW2V(String word1, String word2) throws Exception{ + return Double.parseDouble(NormalizationText.format(StringSimilarity.simOfPairW2V(word1, word2))); + } + + //similarity between two list of word (don't have to pre-process data): agents + //jaccard + public static double simList(List l1, List l2) throws NumberFormatException, IOException, SQLException{ + return Double.parseDouble(NormalizationText.format(StringSimilarity.simOfListWord(l1, l2))); + } + + //similarity between two list of words: location, time, objects (have to pre-process data: pick head word) + //jaccard + public static double simLocationTimeAndObject(List l1, List l2) throws SQLException, IOException{ + l1 = NormalizationText.normList(l1); + l2 = NormalizationText.normList(l2); + return Double.parseDouble(NormalizationText.format(simList(l1, l2))); + } + + //similarity between two list of activity surface + //jaccard + public static double simActSurfaceList(List l1, List l2) throws Exception{ + return Double.parseDouble(NormalizationText.format(StringSimilarity.simOfListActivity(l1, l2))); + } + + //similarity between two phrases: ori_object, ori-verb + //Until now, using jaccard + public static double simPhrase(String s1, String s2) throws NumberFormatException, SQLException, IOException{ + return Double.parseDouble(NormalizationText.format(Double.parseDouble(NormalizationText.format(StringSimilarity.simOfPhrase(s1, s2))))); + } + + //========================Old part, with jaccard have threshold and lookup db + + /*//similarity between two list of word (don't have to pre-process data): agents + //jaccard + public static double simList(List l1, List l2, double threshold) throws NumberFormatException, IOException, SQLException{ + return Double.parseDouble(Util.format(StringSimilarity.simOfListNoun(l1, l2, threshold))); + } + + //similarity between two list of words: location, time, objects (have to pre-process data: pick head word) + //jaccard + public static double simLocationTimeAndObject(List l1, List l2, double threshold) throws SQLException, IOException{ + l1 = NormalizationText.normList(l1); + l2 = NormalizationText.normList(l2); + return Double.parseDouble(Util.format(simList(l1, l2, threshold))); + } + + //similarity between two phrases: ori_object + //Until now, using jaccard + public static double simNounPhrase(String s1, String s2, double threshold) throws NumberFormatException, SQLException, IOException{ + return Double.parseDouble(Util.format(Double.parseDouble(Util.format(StringSimilarity.simOfNounPhrase(s1, s2, threshold))))); + } + + //similarity between two phrases: ori_verb + //Until now, using jaccard + public static double simVerbPhrase(String s1, String s2, double threshold) throws NumberFormatException, SQLException, IOException{ + return Double.parseDouble(Util.format(Double.parseDouble(Util.format(StringSimilarity.simOfVerbPhrase(s1, s2, threshold))))); + }*/ + //=========================================================== + + //similarity vector between two activities + public static double[] getSimilarVector(CategorySimilarity cs, WikiHowTaskFrame f1, WikiHowTaskFrame f2) throws Exception{ + double[] res = new double[10]; + + String v1 = f1.getActivity().getVerb(); + + String v2 = f2.getActivity().getVerb(); + + double v2v = simVerb(v1, v2); + res[0] = v2v; + + String n1 = f1.getActivity().getObject(); + + String n2 = f2.getActivity().getObject(); + + double o2o = simNoun(n1, n2); + res[1] = o2o; + + double c2c = simCategory(cs,Integer.parseInt(f1.getActivity().getCategoryID()), + Integer.parseInt(f2.getActivity().getCategoryID())); + res[2] = c2c; + + double l2l = simLocationTimeAndObject(f1.getLocations(), f2.getLocations()); + res[3] = l2l; + + double t2t = simLocationTimeAndObject(f1.getTemporal(), f2.getTemporal()); + res[4] = t2t; + + double parto2parto = simLocationTimeAndObject(f1.getParticipatingObject(), f2.getParticipatingObject()); + res[5] = parto2parto; + + double parta2parta = simList(f1.getParticipatingAgent(), f2.getParticipatingAgent()); + res[6] = parta2parta; + + double ov2ov = simPhrase(f1.getActivity().getOriVerb(), f2.getActivity().getOriObject()); + res[7] = ov2ov; + + double oo2oo = simPhrase(f1.getActivity().getOriObject(), f2.getActivity().getOriObject()); + res[8] = oo2oo; + //v1*v2*n1*n2 + String a1 = f1.getActivity().getVerb() + ";"+ f1.getActivity().getObject(); + String a2 = f2.getActivity().getVerb() + ";" + f2.getActivity().getObject(); + double a1a2 = simActW2V(a1, a2); + res[9] = a1a2; + + return res; + } + + //similarity vector between two activities + //Include context: parent, sub, prev/next + public static double[] getFullSimilarVector(CategorySimilarity cs, WikiHowTaskFrame f1, WikiHowTaskFrame f2) throws Exception{ + double[] res = new double[14]; + + String v1 = f1.getActivity().getVerb(); + + String v2 = f2.getActivity().getVerb(); + + double v2v = simVerb(v1, v2); + res[0] = v2v; + + String n1 = f1.getActivity().getObject(); + + String n2 = f2.getActivity().getObject(); + + double o2o = simNoun(n1, n2); + res[1] = o2o; + + double c2c = simCategory(cs,Integer.parseInt(f1.getActivity().getCategoryID()), + Integer.parseInt(f2.getActivity().getCategoryID())); + res[2] = c2c; + + double l2l = simLocationTimeAndObject(f1.getLocations(), f2.getLocations()); + res[3] = l2l; + + double t2t = simLocationTimeAndObject(f1.getTemporal(), f2.getTemporal()); + res[4] = t2t; + + double parto2parto = simLocationTimeAndObject(f1.getParticipatingObject(), f2.getParticipatingObject()); + res[5] = parto2parto; + + double parta2parta = simList(f1.getParticipatingAgent(), f2.getParticipatingAgent()); + res[6] = parta2parta; + + double ov2ov = simPhrase(f1.getActivity().getOriVerb(), f2.getActivity().getOriObject()); + res[7] = ov2ov; + + double oo2oo = simPhrase(f1.getActivity().getOriObject(), f2.getActivity().getOriObject()); + res[8] = oo2oo; + //v1*v2*n1*n2 + String a1 = f1.getActivity().getVerb() + ";"+ f1.getActivity().getObject(); + String a2 = f2.getActivity().getVerb() + ";" + f2.getActivity().getObject(); + double a1a2 = simActW2V(a1, a2); + res[9] = a1a2; + + //How about parent, prev/next, sub-act + List parent1 = InformationExtraction.getListofActivitySurfaceFromDb(StructureConverter.stringToList(f1.getActivity().getParent())); + List parent2 = InformationExtraction.getListofActivitySurfaceFromDb(StructureConverter.stringToList(f2.getActivity().getParent())); + double parent = simActSurfaceList(parent1, parent2); + res[10] = parent; + + List prev1 = InformationExtraction.getListofActivitySurfaceFromDb(StructureConverter.stringToList(f1.getActivity().getPrev())); + List prev2 = InformationExtraction.getListofActivitySurfaceFromDb(StructureConverter.stringToList(f2.getActivity().getPrev())); + double prev = simActSurfaceList(prev1, prev2); + res[11] = prev; + + List next1 = InformationExtraction.getListofActivitySurfaceFromDb(StructureConverter.stringToList(f1.getActivity().getNext())); + List next2 = InformationExtraction.getListofActivitySurfaceFromDb(StructureConverter.stringToList(f2.getActivity().getNext())); + double next = simActSurfaceList(next1, next2); + res[12] = next; + + List sub1 = InformationExtraction.getListofActivitySurfaceFromDb(f1.getActivity().getSubActivities()); + List sub2 = InformationExtraction.getListofActivitySurfaceFromDb(f2.getActivity().getSubActivities()); + double sub = simActSurfaceList(sub1, sub2); + res[13] = sub; + + return res; + } + + //Get the final similarity value between two activity frames + public static double getSimilarity(CategorySimilarity cs, WikiHowTaskFrame f1, WikiHowTaskFrame f2) throws Exception{ + double [] vector = getSimilarVector(cs, f1, f2); + double res = 0; + for (int i=0; i l1 = new ArrayList<>(); + l1.add("book"); l1.add("dictionary"); l1.add("laptop"); + List l2 = new ArrayList<>(); + l2.add("book"); l2.add("bottle"); l2.add("computer"); l2.add("diary"); + System.out.println("Similarity of two lists of string: " + simList(l1, l2, 0.7)); + + //Test sim of two string + String s1 = "play football with friends"; + String s2 = "paint room with friends"; + System.out.println(simPhrase(s1, s2, 0.7)); + + }*/ + +} diff --git a/src/main/java/kb/howtokb/clustering/sim/StringSimilarity.java b/src/main/java/kb/howtokb/clustering/sim/StringSimilarity.java new file mode 100644 index 0000000..6ffebad --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/sim/StringSimilarity.java @@ -0,0 +1,224 @@ +package kb.howtokb.clustering.sim; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import kb.howtokb.clustering.sim.w2v.Word2VecRunner; +import kb.howtokb.tools.NormalizationText; + +public class StringSimilarity { + +// private static ILexicalDatabase db = new NictWordNet(); +// +// //Similarity using wordnet: WUP measure +// public static double simOfWord( String word1, String word2 ) { +// WS4JConfiguration.getInstance().setMFS(true); +// +// double s = new WuPalmer(db).calcRelatednessOfWords(word1, word2); +// if (s >=1) return 1; +// return s; +// } + + //Similarity using database/word2vec + public static double simOfVerb( String w1, String w2 ) throws IOException, SQLException { + + //Lookup db +// if (w1.equals(w2)) return 1.0; +// ResultSet rs = +// DBConnector.q("select sim from sim.v2v where (w1='" + w1 + "' and w2='" + w2 + "') or "+ +// "(w1='"+w2 + "' and w2='"+w1 + "') limit 1"); +// if (rs.next()){ +// double sim = rs.getDouble(1); +// return sim; +// } + + try { + return Word2VecRunner.simVerbs(w1,w2); + }catch(Exception e){ + return 0.0; + } + } + + //Similarity using database/word2vec + public static double simOfNoun( String w1, String w2 ) throws IOException, SQLException { +// //Lookup db +// if (w1.equals(w2)) return 1.0; +// ResultSet rs = +// DBConnector.q("select p_sim from sim.n2n where (w1='" + w1 + "' and w2='" + w2 + "') or "+ +// "(w1='"+w2 + "' and w2='"+w1 + "') limit 1"); +// if (rs.next()){ +// double sim = rs.getDouble(1); +// return sim; +// } +// return 0.0; + try { + return Word2VecRunner.simNouns(w1,w2); + }catch(Exception e){ + return 0.0; + } + + } + + //Similar using word2vec without knowing POS + public static double simOfWord(String w1, String w2){ + try { + return Word2VecRunner.getSim().simWithoutPOS(w1, w2); + }catch(Exception e){ + return 0.0; + } + } + + //Similarity using between two activity surfaces + public static double simOfPairW2V( String w1, String w2 ) throws Exception { + return Word2VecRunner.simPair(w1, w2); + } + + //Similarity of two list: weight jaccard + public static double simOfListWord(List l1, List l2) throws IOException, SQLException{ + + if (l1.size() == 0 || l2.size() == 0) return 0; + else{ + double sim = 0; + double total = 0; + for (int i=0; i l1, List l2) throws Exception{ + + if (l1.size() == 0 || l2.size() == 0) return 0; + else{ + double sim = 0; + double total = 0; + for (int i=0; i l1 = new ArrayList<>(); + List l2 = new ArrayList<>(); + for (int i=0; i l1, List l2, double threshold) throws IOException, SQLException{ + double total = l1.size() + l2.size(); + double inter = 0; + List temp = new ArrayList<>(); + if (l1.size() == 0 && l2.size() == 0) return 0; + else if (l1.size() == 0) return 0/(l2.size()+1); + else if (l2.size() == 0) return 0/(l1.size()+1); + else{ + for (int i=0; i= threshold){ + check = true; + } + } + for (int j=0; j= threshold){ + inter++; + check = true; + temp.add(l2.get(j)); + l2.remove(j); + } + } + if (check == true) inter ++; + } + } + + return inter/total; + }*/ + + /*//Jaccard + //Input: two lists of verb + public static double simOfListVerb(List l1, List l2, double threshold) throws IOException, SQLException{ + double total = l1.size() + l2.size(); + double inter = 0; + List temp = new ArrayList<>(); + if (l1.size() == 0 && l2.size() == 0) return 0; + else if (l1.size() == 0) return 0/(l2.size()+1); + else if (l2.size() == 0) return 0/(l1.size()+1); + else{ + for (int i=0; i= threshold){ + check = true; + } + } + for (int j=0; j= threshold){ + inter++; + check = true; + temp.add(l2.get(j)); + l2.remove(j); + } + } + if (check == true) inter ++; + } + } + + return inter/total; + }*/ + +// //Jaccard +// //Input: two arrays of noun +// public static double simOfListNoun(String[] s1, String[] s2, double threshold) throws IOException, SQLException{ +// List l1 = new ArrayList<>(); +// List l2 = new ArrayList<>(); +// for (int i=0; i, Pair> pairNeighbors; +// private static List> activities; + private static Word2VecSimilarity sim; +// private static final Map, Double> emptyPairMap = +// new HashMap<>(); + + //precomputation +// private static Set> simAct; +// private static Map, Integer> actToID = new HashMap<>(); + + private static Set> simVerbPair; + private static Map verbToID; + private static Set> simNounPair; + private static Map nounToID; + +// private static double threshold = Coefficient.VVNN_TRHES; + // //////////////////////////////////////////////////////// + // TODO: Cuong -- for POSLevelWord2vec code begins here. + // ////////////////////////////////////////////////////// +// private static int topK = 20; +// private static boolean isDesc = true; + public static void prepareData() throws Exception { + + //load all strong activities + //not necessary +// String input = "all-strong-activities.txt"; +// +// loadActivities(input); + + //load word2vec model + sim = + new Word2VecSimilarity( + "articles-word2vec-word-pos.model.txt", + 25, false); +// System.out.println("\n\n========================================== [" +// + activities.size() + " activities for neighborhood]\n"); + + //preComputeVerbs(); + //preComputeNouns(); + //preComputeActivity(sim); + +// pairNeighbors = +// new SortedMultiMap, Pair>(topK, +// isDesc); + + } + +// public static Set> getSimilarAct(String input) throws Exception{ +// if (sim == null) +// prepareData(); +// Pair a1 = activityToPair(input); +// Set> res = new HashSet<>(); +// for (Pair a2 : activities) { +// //String s = activity.first + ";" + activity.second; +// +// if (a1.equals(a2)) +// continue; +// +// double simScore = simPair(a1, a2, sim); +// if (simScore > 0) +// pairNeighbors.put(a1, a2, simScore); +// } +// for (Pair activity : pairNeighbors.keyset()) { +// String s = activity.first + ";" + activity.second; +// if (input.equals(s)){ +// +// for (Entry, Double> e1 : pairNeighbors +// .getAsMap(activity).entrySet()) { +// res.add(e1.getKey()); +// } +// } +// } +// return res; +// } +// +// public static Set getSimilarActString(String input) throws Exception{ +// Set> res = getSimilarAct(input); +// Set set = new HashSet<>(); +// for (Pair e: res){ +// set.add(e.first + ";" + e.second); +// } +// return set; +// } + +// private static void loadActivities(String input) throws IOException{ +// if (activities == null) +// activities = new ArrayList<>(); +// ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); +// InputStream inputs = classLoader.getResourceAsStream(input); +// try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputs, "UTF-8"))) { +// +// String sCurrentLine; +// while ((sCurrentLine = reader.readLine()) != null) { +// Pair act = activityToPair(sCurrentLine); +// //if (!allVerbs.contains(act.first)) allVerbs.add(act.first); +// //if (!allNouns.contains(act.second)) allNouns.add(act.second); +// activities.add(act); +// +// } +// } +// } + /** + * Load all (activities,id) from file + * @param input + * @throws IOException + */ +// private static void loadActToID(String input) throws IOException{ +// if (actToID == null) +// actToID = new HashMap<>(); +// try (BufferedReader br = new BufferedReader(new FileReader(input))) { +// +// String sCurrentLine; +// while ((sCurrentLine = br.readLine()) != null) { +// String [] act_id = sCurrentLine.split("\t"); +// Pair act = activityToPair(act_id[0]); +// actToID.put(act, Integer.parseInt(act_id[1])); +// } +// } +// } + + /** + * load all pair (verb, id) from file + * @param input + * @throws IOException + */ +// private static void loadVerbToID(String input) throws IOException{ +// System.out.println("Load all strong verb and id........"); +// if (verbToID == null) +// verbToID = new HashMap<>(); +// try (BufferedReader br = new BufferedReader(new FileReader(input))) { +// +// String sCurrentLine; +// while ((sCurrentLine = br.readLine()) != null) { +// String [] v_id = sCurrentLine.split("\t"); +// verbToID.put(v_id[0], Integer.parseInt(v_id[1])); +// } +// } +// System.out.println("Successfully! Number of verbs: " + verbToID.size()); +// } + + /** + * load all pair (noun, id) from file + * @param input + * @throws IOException + */ +// private static void loadNounToID(String input) throws IOException{ +// System.out.println("Load all strong noun and id........"); +// if (nounToID == null) +// nounToID = new HashMap<>(); +// try (BufferedReader br = new BufferedReader(new FileReader(input))) { +// +// String sCurrentLine; +// while ((sCurrentLine = br.readLine()) != null) { +// String [] n_id = sCurrentLine.split("\t"); +// nounToID.put(n_id[0], Integer.parseInt(n_id[1])); +// } +// } +// System.out.println("Successfully! Number of nouns: " + nounToID.size()); +// } + + private static Pair activityToPair(String activity) { + String[] vn = activity.split(";"); + return new Pair(vn[0], vn[1]); + } + +// private static Pair stringIntToPair(String s) { +// String[] vn = s.split(";"); +// return new Pair(Integer.parseInt(vn[0]), Integer.parseInt(vn[1])); +// } +// +// public static Map, Double> getPairNeighbors( +// Pair w) { +// return !pairNeighbors.containsKey(w) ? emptyPairMap : pairNeighbors +// .getAsMap(w); +// } + + /** + * Precompute similarity between two activities + * @return + * @throws Exception + */ +// public static void preComputeActivity(Word2VecSimilarity sim){ +// System.out.println("Pre-computing similar pair of activities......"); +// if (simAct == null) +// simAct = new HashSet<>(); +// int count = 0; +// for (int i=0; i= threshold){ +// simAct.add(new Pair(actToID.get(activities.get(i)), actToID.get(activities.get(j)))); +// count++; +// System.out.println(activities.get(i) + "\t" + activities.get(j)); +// } +// } +// } +// System.out.println("Total of similar activity pairs: " + count); +// } + + /** + * load all similar activity pairs + * @throws IOException + * @throws FileNotFoundException + * + */ +// private static void loadSimActPair(String input) throws FileNotFoundException, IOException{ +// if (simAct == null) +// simAct = new HashSet<>(); +// try (BufferedReader br = new BufferedReader(new FileReader(input))) { +// +// String sCurrentLine; +// while ((sCurrentLine = br.readLine()) != null) { +// simAct.add(stringIntToPair(sCurrentLine)); +// } +// } +// } + + /** + * load all similar verb pair + * @param input + * @throws FileNotFoundException + * @throws IOException + */ +// private static void loadSimVerbPair(String input) throws FileNotFoundException, IOException{ +// System.out.println("Load similar pair of verb............."); +// if (simVerbPair == null) +// simVerbPair = new HashSet<>(); +// try (BufferedReader br = new BufferedReader(new FileReader(input))) { +// +// String sCurrentLine; +// while ((sCurrentLine = br.readLine()) != null) { +// simVerbPair.add(stringIntToPair(sCurrentLine)); +// } +// } +// System.out.println("Done! Number of pairs: " + simVerbPair.size()); +// } + + /** + * load all similar noun pair + * @param input + * @throws FileNotFoundException + * @throws IOException + */ +// private static void loadSimNounPair(String input) throws FileNotFoundException, IOException{ +// System.out.println("Load similar pair of noun............."); +// if (simNounPair == null) +// simNounPair = new HashSet<>(); +// try (BufferedReader br = new BufferedReader(new FileReader(input))) { +// +// String sCurrentLine; +// while ((sCurrentLine = br.readLine()) != null) { +// simNounPair.add(stringIntToPair(sCurrentLine)); +// } +// } +// System.out.println("Done! Number of pairs: " + simNounPair.size()); +// } + + /** + * Check whether similarity between two activities are greater than a given threhold. + * @param two strings "paint;wall" and "color;ceiling" + * @return true/false + * @throws Exception + */ +// public static boolean isSim(String a1, String a2) throws Exception{ +// if (sim == null) +// prepareData(); +// Pair activity1 = activityToPair(a1); +// Pair activity2 = activityToPair(a2); +// int id1 = actToID.get(activity1); +// int id2 = actToID.get(activity2); +// if (simAct.contains(new Pair(id1, id2))) +// return true; +// return simAct.contains(new Pair(id2, id1)); +// } + + /** + * Check whether similarity between two activities are greater than a given threhold. + * in a simpler way to prune false negative + * if two verbs are dissimilar and two noun are dissimilar, then two acts are dissimilar + * else we are not sure and return true + * @param two strings "paint;wall" and "color;ceiling" + * @return true/false + * @throws Exception + */ + public static boolean isSim(String a1, String a2) throws Exception{ + if (sim == null) + prepareData(); + Pair activity1 = activityToPair(a1); + Pair activity2 = activityToPair(a2); + if (!isSimVerb(activity1.first, activity2.first)) + return false; + if (!isSimNoun(activity1.second, activity2.second)) + return false; + return true; + } + + /** + * check whether similarity between two verbs is greater than a given threshold + * @param v1 + * @param v2 + * @return + * @throws Exception + */ + public static boolean isSimVerb(String v1, String v2) throws Exception{ + if (sim == null) + prepareData(); + + int id1 = verbToID.get(v1); + int id2 = verbToID.get(v2); + if (simVerbPair.contains(new Pair(id1, id2))) + return true; + return simVerbPair.contains(new Pair(id2, id1)); + } + + /** + * check whether similarity between two verbs is greater than a given threshold + * @param v1 + * @param v2 + * @return + * @throws Exception + */ + public static boolean isSimNoun(String n1, String n2) throws Exception{ + if (sim == null) + prepareData(); + + int id1 = nounToID.get(n1); + int id2 = nounToID.get(n2); + if (simNounPair.contains(new Pair(id1, id2))) + return true; + return simNounPair.contains(new Pair(id2, id1)); + } + + public static double simVerbs(String v1, String v2) throws Exception{ + if (sim == null) + prepareData(); + v1 = v1.contains(" ")?v1.split(" ")[0]:v1; + v2 = v2.contains(" ")?v2.split(" ")[0]:v2; + return sim.sim("v_" + v1, "v_" + v2); + + } + + public static double simNouns(String n1, String n2) throws Exception{ + if (sim == null) + prepareData(); + String n1_tmp = n1.contains(" ")?n1.split(" ")[n1.split(" ").length - 1]:n1; + String n2_tmp = n2.contains(" ")?n2.split(" ")[n2.split(" ").length - 1]:n2; + return sim.sim("n_" + n1_tmp, "n_" + n2_tmp); + } + + public static double simPair(Pair activity1, + Pair activity2, Word2VecSimilarity sim) { + try { + + double vv = simVerbs(activity1.first, activity2.first); + if (vv == 0) + return 0.0; + double nn = simNouns(activity1.second, activity2.second); + return combinePairScore(vv, nn); + } catch (Exception e) { + return 0.0; + } + } + + public static double simPair(Pair activity1, + Pair activity2) throws Exception { + if (sim == null) + prepareData(); + try { + double vv = simVerbs(activity1.first, activity2.first); + if (vv == 0) + return 0.0; + double nn = simNouns(activity1.second, activity2.second); + return combinePairScore(vv, nn); + } catch (Exception e) { + return 0.0; + } + } + + public static double simPair(String a1, + String a2, Word2VecSimilarity sim) { + Pair activity1 = activityToPair(a1); + Pair activity2 = activityToPair(a2); + return simPair(activity1, activity2, sim); + } + + public static double simPair(String a1, + String a2) throws Exception { + if (sim == null) + prepareData(); + Pair activity1 = activityToPair(a1); + Pair activity2 = activityToPair(a2); + return simPair(activity1, activity2, sim); + } + + public static Word2VecSimilarity getSim() throws Exception { + if (sim == null) + prepareData(); + return sim; + } + + + private static double combinePairScore(double cosine1, double cosine2) { + return cosine1 * cosine2; + } + + public static void main(String[] args) throws Exception { + +// Word2VecSimilarity sim = +// new Word2VecSimilarity( +// "/var/tmp/cxchu/data-server/articles-word2vec-word-pos.model.txt", +// 50, false); +// +// System.out.println(simPair("watch;movie", "watch;film", sim)); +// System.out.println(simPair("watch;movie", "eat;popcorn", sim)); +// System.out.println(simPair("watch;film", "eat;popcorn", sim)); + prepareData(); + } + + +} diff --git a/src/main/java/kb/howtokb/clustering/sim/w2v/Word2VecSimilarity.java b/src/main/java/kb/howtokb/clustering/sim/w2v/Word2VecSimilarity.java new file mode 100644 index 0000000..3e865c3 --- /dev/null +++ b/src/main/java/kb/howtokb/clustering/sim/w2v/Word2VecSimilarity.java @@ -0,0 +1,174 @@ +package kb.howtokb.clustering.sim.w2v; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import kb.howtokb.clustering.ISimilarity; +import kb.howtokb.tools.NormalizationText; +import kb.howtokb.utils.SortedMultiMap; + + +public class Word2VecSimilarity implements ISimilarity { + + private final Map emptyMap = new HashMap<>(); + + Map wordvectors; + private SortedMultiMap neighbors; + + private String[] words; + + // Map beautiful -> [a_beautiful] + private static Map wordPOS; + + public Word2VecSimilarity(String modelFile, int maxK) throws IOException, SQLException { + this(modelFile, maxK, true); + } + + /** + * @param modelFile + * = "word2vec-data/phrase-norole.model.txt" + * @param maxK + * = k nearest neighbors for getNeighbors() function. + * @throws IOException + * : while reading the model. + * @throws SQLException + * : TODO: fix in SortedMultiMap, we don't need this. + */ + public Word2VecSimilarity(String modelFile, int maxK, boolean precomputeNeighborhood) + throws IOException, SQLException { + System.out.println("Word2Vec constructing model... " + modelFile); + load(modelFile); + if (precomputeNeighborhood) { + System.out.println("Word2Vec constructing neighborhood... " + words.length); + constructNeighborhood(maxK); + } + + System.out.println("Word2Vec object constructed"); + } + + private void load(String model) throws IOException { + if (wordvectors == null) { + + boolean isHeader = false; + int vocabElemIndex = 0; + int vecSize = 1; + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + InputStream inputs = classLoader.getResourceAsStream(model); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputs, "UTF-8"))) { + + String sCurrentLine; + while ((sCurrentLine = reader.readLine()) != null) { + if (!isHeader) { + // 1435 200 i.e. 1435 tokens, each 200 size vector. + int vocabSize = Integer.parseInt(sCurrentLine.split(" ")[0]); + words = new String[vocabSize]; + wordvectors = new LinkedHashMap<>(vocabSize); + wordPOS = new HashMap<>(); + vecSize = Integer.parseInt(sCurrentLine.split(" ")[1]); + isHeader = true; + continue; + } + + String[] w = sCurrentLine.split(" "); + double[] vec = new double[vecSize]; + for (int i = 1; i < w.length; i++) { + vec[i - 1] = Double.parseDouble(w[i]); + } + words[vocabElemIndex++] = w[0]; + wordvectors.put(w[0], vec); + if (w[0].indexOf("_") == 1) { + // System.out.println(w[0].substring(2)); + wordPOS.put(w[0].substring(2), w[0]); + } + } + } + } + } + + double cosine(double[] w1, double[] w2) { + double result = 0; + double t = 0; + double m1 = 0, m2 = 0; + for (int i = 0; i < w1.length; i++) { + m1 += w1[i] * w1[i]; + m2 += w2[i] * w2[i]; + t += w1[i] * w2[i]; + } + result = t / (Math.sqrt(m1) * Math.sqrt(m2)); + // result = (result > 1) ? 1 : result; + // Use angular distance to make similarity bounded between [0,1] + result = Math.acos(result) / Math.PI; + result = 1 - result; + return Double.parseDouble(NormalizationText.format(result)); + } + + private void constructNeighborhood(int maxK) throws SQLException { + if (neighbors == null) { + neighbors = new SortedMultiMap<>(maxK, true); + + // Upper triangular matrix is insufficient because + // we need to store all the m x m similarity matrix + // which is prohibitive, and so we use Heap + // which has a small memory footprint + // But, heap requires m x m computations not just upper triangular + // Anyways, cosine is very cheap so no need of futher optimizations + for (int i = 0; i < words.length; i++) { + for (int j = 0; j < words.length; j++) { + if (j != i) + neighbors.put(words[i], words[j], sim(words[i], words[j])); + } + } + } + } + + @Override + public double sim(String w1, String w2) { + if (w1.equals(w2)) + return 1.0; + if (!wordvectors.containsKey(w1)) { + System.out.println(w1 + " is out of dictionary"); + return 0; + } + if (!wordvectors.containsKey(w2)) { + System.out.println(w2 + " is out of dictionary"); + return 0; + } + return cosine(wordvectors.get(w1), wordvectors.get(w2)); + } + + public double simWithoutPOS(String w1, String w2) { + if (w1.equals(w2)) + return 1.0; + if (!wordPOS.containsKey(w1)) { + System.out.println(w1 + " is out of dictionary"); + return 0; + } + if (!wordPOS.containsKey(w2)) { + System.out.println(w2 + " is out of dictionary"); + return 0; + } + w1 = wordPOS.get(w1); + w2 = wordPOS.get(w2); + return cosine(wordvectors.get(w1), wordvectors.get(w2)); + } + + @Override + public Map getNeighbors(String w) { + return !neighbors.containsKey(w) ? emptyMap : neighbors.getAsMap(w); + } + + public String[] vocab() { + return words; + } + + @Override + public boolean simThreshold(String e1, String e2, double minthreshold) throws Exception { + return sim(e1, e2) >= minthreshold; + } +} \ No newline at end of file diff --git a/src/main/java/kb/howtokb/global/Global.java b/src/main/java/kb/howtokb/global/Global.java index 24f7889..f6308c6 100644 --- a/src/main/java/kb/howtokb/global/Global.java +++ b/src/main/java/kb/howtokb/global/Global.java @@ -5,71 +5,68 @@ public class Global { public static final String GENERAL_CATEGORY = "GENERAL_CATEGORY"; public static final String GENERAL_THING = "GENERAL_THING"; public static final String GENERAL_INGREDIENT = "GENERAL_INGREDIENT"; - public static final String REMOVE_LINK = "/home/cxchu/workspace/Github/data-extraction/DataExtraction/data/removelinks"; - public static final String NOUN_VERB = "/home/cxchu/workspace/Github/howtokb/data/noun-verb.txt"; public static final String DUMMY_SUBJECT = "You"; - public static final String VERBPHRASES_WN_FILE = "/var/tmp/cxchu/data-needed/db-text/verbphrases"; - public static final String VERBPHRASES_WN_DB = "wikihow.verbphrases"; - - public static final String WNTIME = "wikihow.wordnettime"; - public static final String WNVERB = "wikihow.wordnetverb"; - public static final String WNNOUN = "wikihow.wordnetnoun"; - public static final String WNAGENT = "wikihow.wordnetagent"; +// public static final String VERBPHRASES_WN_DB = "wikihow.verbphrases"; +// +// public static final String WNTIME = "wikihow.wordnettime"; +// public static final String WNVERB = "wikihow.wordnetverb"; +// public static final String WNNOUN = "wikihow.wordnetnoun"; +// public static final String WNAGENT = "wikihow.wordnetagent"; //Ground truth data - public static final String ALL_TRAIN_DATA_FILE = "/var/tmp/cxchu/groundtruth-data/New-Data/all-train-data"; - public static final String TRAIN_DATA = "/var/tmp/cxchu/groundtruth-data/New-Data/train"; - public static final String TEST_DATA = "/var/tmp/cxchu/groundtruth-data/New-Data/test"; - - public static final String DATA_GROUND_TRUTH_SIM_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-sim.json"; - public static final String DATA_GROUND_TRUTH_DISSIM_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-dissim.json"; - public static final String DATA_GROUND_TRUTH_SIM_FILE_CHECK = "/var/tmp/cxchu/groundtruth-data/act-frame-sim-check.txt"; - public static final String DATA_GROUND_TRUTH_DISSIM_FILE_CHECK = "/var/tmp/cxchu/groundtruth-data/act-frame-dissim-check.txt"; - - public static final String DATA_GROUND_TRUTH_CATE_W2V_SIM_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-cate-w2v-sim.json"; - public static final String DATA_GROUND_TRUTH_CATE_W2V_SIM_CHECK_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-cate-w2v-sim-check.txt"; - - public static final String DATA_TEST = "/var/tmp/cxchu/groundtruth-data/test.json"; - public static final String DATA_TEST_CHECK = "/var/tmp/cxchu/groundtruth-data/test.txt"; - -// //For server with check in wordnet - - public static final String JSON_FILE = "/var/tmp/cxchu/articles.json"; - public static final String DATA_FILE = "/var/tmp/cxchu/articles.txt"; - - - public static final String URL_FILE = "/var/tmp/cxchu/wikihow-url"; - public static final String CRAWL_DIRECTORY = "/var/tmp/cxchu/WikiHow2"; -// public static final String ID_CATEGORY_FILE_TEXT = "/var/tmp/cxchu/wikihow-id-category.txt"; - public static final String ID_CATEGORY_FILE_JSON = "/var/tmp/cxchu/wikihow-id-category.json"; - public static final String RAW_CATEGORY = "/var/tmp/cxchu/raw_categories"; - - public static final String ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame.json"; - public static final String ID_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/id-act-frame.json"; - public static final String AGGRE_WEAK_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame-aggre-weak.json"; - public static final String ID_AGGRE_WEAK_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/id-act-frame-aggre-weak.json"; - public static final String AGGRE_STRONG_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame-aggre-strong.json"; - public static final String ID_AGGRE_STRONG_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/id-act-frame-aggre-strong.json"; - - - public static final String WEAK_ID_ACT = "/var/tmp/cxchu/data-wordnet/weak-id-act"; - public static final String WEAK_ID_LINENUMBER = "/var/tmp/cxchu/data-wordnet/weak-id-line"; - public static final String STRONG_ID_ACT = "/var/tmp/cxchu/data-wordnet/strong-id-act"; - public static final String STRONG_ID_LINENUMBER = "/var/tmp/cxchu/data-wordnet/strong-id-line"; - public static final String ALL_ID_LINENUMBER = "/var/tmp/cxchu/data-wordnet/all-id-line"; - public static final String WEAK_OLD_ID_TO_ID = "/var/tmp/cxchu/data-wordnet/id-weak-id"; - public static final String STRONG_OLD_ID_TO_ID = "/var/tmp/cxchu/data-wordnet/id-strong-id"; - - //create data sample for testing - public static final String ACT_FRAME_DOMAIN_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame-domain.json"; - - public static final String ALL_STRONG_ACT_FILE = "/var/tmp/cxchu/data-wordnet/all-strong-activities.txt"; - public static final String ALL_WEAK_ACT_FILE = "/var/tmp/cxchu/data-wordnet/all-weak-activities.txt"; - - //Data for word2vec - public static final String DATA_TEXT_WORD2VEC_WORD_FILE = "/var/tmp/cxchu/articles-word2vec-word.txt"; - public static final String DATA_TEXT_WORD2VEC_PHRASE_FILE = "/var/tmp/cxchu/articles-word2vec-phrase.txt"; +// public static final String ALL_TRAIN_DATA_FILE = "/var/tmp/cxchu/groundtruth-data/New-Data/all-train-data"; +// public static final String TRAIN_DATA = "/var/tmp/cxchu/groundtruth-data/New-Data/train"; +// public static final String TEST_DATA = "/var/tmp/cxchu/groundtruth-data/New-Data/test"; +// +// public static final String DATA_GROUND_TRUTH_SIM_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-sim.json"; +// public static final String DATA_GROUND_TRUTH_DISSIM_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-dissim.json"; +// public static final String DATA_GROUND_TRUTH_SIM_FILE_CHECK = "/var/tmp/cxchu/groundtruth-data/act-frame-sim-check.txt"; +// public static final String DATA_GROUND_TRUTH_DISSIM_FILE_CHECK = "/var/tmp/cxchu/groundtruth-data/act-frame-dissim-check.txt"; +// +// public static final String DATA_GROUND_TRUTH_CATE_W2V_SIM_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-cate-w2v-sim.json"; +// public static final String DATA_GROUND_TRUTH_CATE_W2V_SIM_CHECK_FILE = "/var/tmp/cxchu/groundtruth-data/act-frame-cate-w2v-sim-check.txt"; +// +// public static final String DATA_TEST = "/var/tmp/cxchu/groundtruth-data/test.json"; +// public static final String DATA_TEST_CHECK = "/var/tmp/cxchu/groundtruth-data/test.txt"; +// +//// //For server with check in wordnet +// +// public static final String JSON_FILE = "/var/tmp/cxchu/articles.json"; +// public static final String DATA_FILE = "/var/tmp/cxchu/articles.txt"; +// +// +// public static final String URL_FILE = "/var/tmp/cxchu/wikihow-url"; +// public static final String CRAWL_DIRECTORY = "/var/tmp/cxchu/WikiHow2"; +//// public static final String ID_CATEGORY_FILE_TEXT = "/var/tmp/cxchu/wikihow-id-category.txt"; +// public static final String ID_CATEGORY_FILE_JSON = "/var/tmp/cxchu/wikihow-id-category.json"; +// public static final String RAW_CATEGORY = "/var/tmp/cxchu/raw_categories"; +// +// public static final String ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame.json"; +// public static final String ID_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/id-act-frame.json"; +// public static final String AGGRE_WEAK_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame-aggre-weak.json"; +// public static final String ID_AGGRE_WEAK_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/id-act-frame-aggre-weak.json"; +// public static final String AGGRE_STRONG_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame-aggre-strong.json"; +// public static final String ID_AGGRE_STRONG_ACT_FRAME_JSON_FILE = "/var/tmp/cxchu/data-wordnet/id-act-frame-aggre-strong.json"; +// +// +// public static final String WEAK_ID_ACT = "/var/tmp/cxchu/data-wordnet/weak-id-act"; +// public static final String WEAK_ID_LINENUMBER = "/var/tmp/cxchu/data-wordnet/weak-id-line"; +// public static final String STRONG_ID_ACT = "/var/tmp/cxchu/data-wordnet/strong-id-act"; +// public static final String STRONG_ID_LINENUMBER = "/var/tmp/cxchu/data-wordnet/strong-id-line"; +// public static final String ALL_ID_LINENUMBER = "/var/tmp/cxchu/data-wordnet/all-id-line"; +// public static final String WEAK_OLD_ID_TO_ID = "/var/tmp/cxchu/data-wordnet/id-weak-id"; +// public static final String STRONG_OLD_ID_TO_ID = "/var/tmp/cxchu/data-wordnet/id-strong-id"; +// +// //create data sample for testing +// public static final String ACT_FRAME_DOMAIN_JSON_FILE = "/var/tmp/cxchu/data-wordnet/act-frame-domain.json"; +// +// public static final String ALL_STRONG_ACT_FILE = "/var/tmp/cxchu/data-wordnet/all-strong-activities.txt"; +// public static final String ALL_WEAK_ACT_FILE = "/var/tmp/cxchu/data-wordnet/all-weak-activities.txt"; +// +// //Data for word2vec +// public static final String DATA_TEXT_WORD2VEC_WORD_FILE = "/var/tmp/cxchu/articles-word2vec-word.txt"; +// public static final String DATA_TEXT_WORD2VEC_PHRASE_FILE = "/var/tmp/cxchu/articles-word2vec-phrase.txt"; // //For server @@ -85,7 +82,7 @@ public class Global { // // public static final String URL_FILE = "/var/tmp/cxchu/wikihow-url"; // public static final String CRAWL_DIRECTORY = "/var/tmp/cxchu/WikiHow2"; - public static final String ID_CATEGORY_FILE_TEXT = "/var/tmp/cxchu/wikihow-id-category.txt"; +// public static final String ID_CATEGORY_FILE_TEXT = "/var/tmp/cxchu/wikihow-id-category.txt"; // public static final String ID_CATEGORY_FILE_JSON = "/var/tmp/cxchu/wikihow-id-category.json"; // public static final String RAW_CATEGORY = "/var/tmp/cxchu/raw_categories"; // diff --git a/src/main/java/kb/howtokb/taskframe/extractor/TextToOpenIEResult.java b/src/main/java/kb/howtokb/taskframe/extractor/TextToOpenIEResult.java index 8794bf1..30208c1 100644 --- a/src/main/java/kb/howtokb/taskframe/extractor/TextToOpenIEResult.java +++ b/src/main/java/kb/howtokb/taskframe/extractor/TextToOpenIEResult.java @@ -3,6 +3,8 @@ import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; import java.util.regex.Pattern; @@ -190,8 +192,11 @@ public static String nounToVerb(String noun) throws IOException{ String verb = noun; //Read dictionary file if (nvList.size() == 0){ - //nvList = new ArrayList<>(); - BufferedReader br = new BufferedReader(new FileReader(Global.NOUN_VERB)); + + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + InputStream inputs = classLoader.getResourceAsStream("noun-verb.txt"); + + BufferedReader br = new BufferedReader(new InputStreamReader(inputs, "UTF-8")); String sCurrentLine; diff --git a/src/main/java/kb/howtokb/taskframe/extractor/TextToActivity.java b/src/main/java/kb/howtokb/taskframe/extractor/TextToWikiHowTaskFrame.java similarity index 98% rename from src/main/java/kb/howtokb/taskframe/extractor/TextToActivity.java rename to src/main/java/kb/howtokb/taskframe/extractor/TextToWikiHowTaskFrame.java index 8e528ea..a778084 100644 --- a/src/main/java/kb/howtokb/taskframe/extractor/TextToActivity.java +++ b/src/main/java/kb/howtokb/taskframe/extractor/TextToWikiHowTaskFrame.java @@ -1,8 +1,9 @@ package kb.howtokb.taskframe.extractor; import java.io.BufferedReader; -import java.io.FileReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @@ -10,7 +11,6 @@ import edu.knowitall.openie.Extraction; import edu.stanford.nlp.util.Pair; -import kb.howtokb.global.Global; import kb.howtokb.taskframe.WikiHowTask; import kb.howtokb.taskframe.WikiHowTaskFrame; import kb.howtokb.wkhobject.Category; @@ -21,7 +21,7 @@ import kb.howtokb.wkhobject.Step; import kb.howtokb.wkhobject.Things; -public class TextToActivity { +public class TextToWikiHowTaskFrame { static Pattern pattern = Pattern.compile("[a-zA-Z]"); static TextToOpenIEResult txtOpenIE = new TextToOpenIEResult(); @@ -860,8 +860,11 @@ public int getNum_Article(){ public int getCategoryID(ArrayList cate) throws NumberFormatException, IOException{ if (catetoID == null){ catetoID = new HashMap<>(); - BufferedReader br = new BufferedReader(new FileReader(Global.ID_CATEGORY_FILE_TEXT)); - + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + InputStream inputs = classLoader.getResourceAsStream("wikihow-id-category.txt"); + + BufferedReader br = new BufferedReader(new InputStreamReader(inputs, "UTF-8")); + String sCurrentLine; while ((sCurrentLine = br.readLine()) != null) { @@ -888,7 +891,10 @@ public int getCategoryID(ArrayList cate) throws NumberFormatException, public int getLinkID(String url) throws NumberFormatException, IOException{ if (linktoID == null){ linktoID = new HashMap<>(); - BufferedReader br = new BufferedReader(new FileReader(Global.URL_FILE)); + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + InputStream inputs = classLoader.getResourceAsStream("wikihow-id-url"); + + BufferedReader br = new BufferedReader(new InputStreamReader(inputs, "UTF-8")); String sCurrentLine; diff --git a/src/main/java/kb/howtokb/tools/InformationExtraction.java b/src/main/java/kb/howtokb/tools/InformationExtraction.java index 2b2a840..8ab8a3d 100644 --- a/src/main/java/kb/howtokb/tools/InformationExtraction.java +++ b/src/main/java/kb/howtokb/tools/InformationExtraction.java @@ -5,6 +5,8 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.sql.ResultSet; +import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -16,15 +18,22 @@ import kb.howtokb.reader.TaskFrameReader; import kb.howtokb.taskframe.WikiHowTaskFrame; +import kb.howtokb.utils.SQLiteJDBCConnector; public class InformationExtraction { private static Map idtoCate; private static Map catetoID; - private static Map> parentChains; private static Map idtoWikiURL; // Get category id + /** + * Get category string + * @param category id + * @return category + * @throws NumberFormatException + * @throws IOException + */ public static String getCategory(int id) throws NumberFormatException, IOException { if (idtoCate == null) { idtoCate = new HashMap<>(); @@ -48,6 +57,13 @@ public static String getCategory(int id) throws NumberFormatException, IOExcepti } // Get category id + /** + * get category id + * @param cate + * @return id + * @throws NumberFormatException + * @throws IOException + */ public static int getCategoryID(String cate) throws NumberFormatException, IOException { if (catetoID == null) { catetoID = new HashMap<>(); @@ -70,6 +86,13 @@ public static int getCategoryID(String cate) throws NumberFormatException, IOExc } // Get url from id by reading file + /** + * get wiki url + * @param id + * @return url + * @throws NumberFormatException + * @throws IOException + */ public static String getWikiURLStringFromFile(int id) throws NumberFormatException, IOException { if (idtoWikiURL == null) { idtoWikiURL = new HashMap<>(); @@ -152,4 +175,22 @@ public static String linkToTitle(String s){ } return s; } + + /** + * get all children of a category + * @param ids + * @return list of task name in surface form + * @throws SQLException + * @throws ClassNotFoundException + * @throws IOException + */ + public static List getListofActivitySurfaceFromDb(List ids) throws SQLException, ClassNotFoundException, IOException { + List res = new ArrayList<>(); + for (int i=0; i normString(ArrayList s){ return result; } -// public static List normList(List l) throws SQLException, IOException{ -// List result = new ArrayList<>(); -// -// for (int i=0; i removeStopwordInString(String s) throws SQLException, IOException{ -// List result = new ArrayList<>(); -// String[] l = s.split(" "); -// for (int i=0; i normList(List l) throws SQLException, IOException{ + List result = new ArrayList<>(); + + for (int i=0; i removeStopwordInString(String s) throws SQLException, IOException{ + List result = new ArrayList<>(); + String[] l = s.split(" "); + for (int i=0; i list){ String res = ""; @@ -82,4 +85,12 @@ public static List stringToListInt(String s){ return res; } + private static DecimalFormat decim; + /** 32.535534534534; after formatting = 32.536 */ + public static String format(double x) { + if (decim == null) + decim = new DecimalFormat("#.###"); + return decim.format(x); + } + } diff --git a/src/main/java/kb/howtokb/tools/StructureConverter.java b/src/main/java/kb/howtokb/tools/StructureConverter.java new file mode 100644 index 0000000..d3524f7 --- /dev/null +++ b/src/main/java/kb/howtokb/tools/StructureConverter.java @@ -0,0 +1,26 @@ +package kb.howtokb.tools; + +import java.util.ArrayList; +import java.util.List; + +public class StructureConverter { + + // Convert a string to a list of integer + public static List stringToList(String s) { + List listInt = new ArrayList<>(); + if (s == null || s.length() == 0) + return listInt; + + if (s.contains(";")) { + String[] list = s.split(";"); + for (String l : list) { + if (!l.equals("-1")) + listInt.add(Integer.parseInt(l)); + } + } else { + if (!s.equals("-1")) + listInt.add(Integer.parseInt(s)); + } + return listInt; + } +} diff --git a/src/main/java/kb/howtokb/utils/AdjacencyBackedSparseMatrix.java b/src/main/java/kb/howtokb/utils/AdjacencyBackedSparseMatrix.java new file mode 100644 index 0000000..5ce98d4 --- /dev/null +++ b/src/main/java/kb/howtokb/utils/AdjacencyBackedSparseMatrix.java @@ -0,0 +1,31 @@ +package kb.howtokb.utils; + +import gnu.trove.TIntFloatHashMap; + +public class AdjacencyBackedSparseMatrix { + + TIntFloatHashMap[] matrix; + float threshold; + + public AdjacencyBackedSparseMatrix(float thres, int n) { + matrix = new TIntFloatHashMap[n]; + for (int i=0; i extends HashMap { + + private Map v2k; + + public BijectiveMap() { + super(); + v2k = new HashMap<>(); + } + + @Override + public V put(K key, V value) { + v2k.put(value, key); + return super.put(key, value); + } + + @Override + public V remove(Object key) { + + // removing a key, implies removing the corresponding 1:1 value + if (super.containsKey(key)) + v2k.remove(super.get(key)); + + return super.remove(key); + } + + public V getValueFromKey(K key) { + return super.get(key); + } + + public K getKeyFromValue(V key) { + return v2k.get(key); + } + + @Override + public void putAll(Map m) { + super.putAll(m); + } + + @Override + public void clear() { + super.clear(); + v2k.clear(); + } +} diff --git a/src/main/java/kb/howtokb/utils/IDMap.java b/src/main/java/kb/howtokb/utils/IDMap.java new file mode 100644 index 0000000..34142d3 --- /dev/null +++ b/src/main/java/kb/howtokb/utils/IDMap.java @@ -0,0 +1,25 @@ +package kb.howtokb.utils; + +public class IDMap extends BijectiveMap { + + private int globalID; + + public IDMap() { + this(0); + } + + public IDMap(int startingID) { + super(); + globalID = startingID; + } + + public int add(K key) { + if (!super.containsKey(key)) + super.put(key, (V) new Integer(globalID++)); + return get(key); + } + + public int getAvailableGlobalID() { + return globalID; + } +} \ No newline at end of file diff --git a/src/main/java/kb/howtokb/utils/SQLiteJDBCConnector.java b/src/main/java/kb/howtokb/utils/SQLiteJDBCConnector.java new file mode 100644 index 0000000..358214c --- /dev/null +++ b/src/main/java/kb/howtokb/utils/SQLiteJDBCConnector.java @@ -0,0 +1,91 @@ +package kb.howtokb.utils; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +public class SQLiteJDBCConnector { + + public static Connection c; + public static Statement st; + + public static boolean check = false; + + public static String db = "wikihowDB"; + + public static ResultSet q(String sql) throws SQLException, ClassNotFoundException, IOException { + try{ + if (check == false) + createDB(); + + System.out.println("Create db successfully!"); + try { + return st.executeQuery(sql); + } catch (Error | Exception e1) { + System.out.println("Exception (" + e1.getMessage() + ") " + "while initializing the DB. \n " + + "Trying to automatically resolve..."); + + // returning empty result set for problematic query + return q("SELECT NULL LIMIT 0;"); + } + } catch (Error | Exception e){ + System.out.println("Database is already created!"); + try { + return st.executeQuery(sql); + } catch (Error | Exception e1) { + System.out.println("Exception (" + e.getMessage() + ") " + "while initializing the DB. \n " + + "Trying to automatically resolve..."); + + // returning empty result set for problematic query + return q("SELECT NULL LIMIT 0;"); + } + } + } + + public static void createDB() throws SQLException, ClassNotFoundException, IOException { + try { + Class.forName("org.sqlite.JDBC"); + c = DriverManager.getConnection("jdbc:sqlite:" + db); + st = c.createStatement(); + + String sql = "CREATE TABLE frameidtostrongactsurface " + "(ID INT PRIMARY KEY NOT NULL," + + " task TEXT NOT NULL);"; + st.executeUpdate(sql); + + sql = "CREATE TABLE categoryjson " + "(ID INT PRIMARY KEY NOT NULL," + + " json TEXT NOT NULL);"; + st.executeUpdate(sql); + + String input = "/var/tmp/cxchu/clustering-result/for-database/frame-id-to-strong-surface"; + update(st, "frameidtostrongactsurface", input); + + input = "/var/tmp/cxchu/data-server/For-Database/wikihow-id-category.json"; + update(st, "categoryjson", input); + + check = true; + } catch (SQLException e) { + check = true; + + } + + } + + public static void update(Statement st, String table, String file) throws IOException, SQLException { + BufferedReader br = new BufferedReader(new FileReader(file)); + String line; + while ((line = br.readLine()) != null) { + String[] values = line.split("\t"); // your seperator + + // Convert String to right type. Integer, double, date etc. + st.executeUpdate("INSERT INTO " + table + " VALUES(" + values[0] + ",'" + values[1] + "');"); + // Use a PeparedStatemant, it´s easier and safer + } + br.close(); + } + +} diff --git a/src/main/java/kb/howtokb/utils/SortedMultiMap.java b/src/main/java/kb/howtokb/utils/SortedMultiMap.java new file mode 100644 index 0000000..de888cb --- /dev/null +++ b/src/main/java/kb/howtokb/utils/SortedMultiMap.java @@ -0,0 +1,356 @@ +package kb.howtokb.utils; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.PriorityQueue; +import java.util.Set; + +import kb.howtokb.tools.NormalizationText; + + + + /** Usage: + public static void main(String[] args) throws SQLException { + SortedMultiMap m = new SortedMultiMap<>(3, true); + m.put("a", "a1", 1.0); + m.put("b", "a2", 0.9); + m.put("a", "a3", 0.95); + m.put("a", "a1", 0.96); + m.put("a", "a3", 0.96); + + SortedMultiMap m2 = new SortedMultiMap<>(3, true); + m2.put("a", "a1", 0.9); + m2.put("b", "a1", 1.0); + m2.put("b", "a2", 0.95); + m2.put("a", "a3", 1.0); + + m.putAll(m2); + + for (String e : m.keyset()) { + System.out.println(e + "\t" + Arrays.toString(m.get(e))); + } + + }*/ + +public class SortedMultiMap { + + public static void main(String[] args) throws SQLException { + SortedMultiMap m = new SortedMultiMap<>(3, true); + m.put(1, 1, 1.0); + m.put(2, 2, 0.9); + m.put(3, 3, 0.95); + m.put(4, 4, 0.96); + m.put(5, 5, 0.96); + + + for (Integer e : m.keyset()) { + System.out.println(e + "\t" + Arrays.toString(m.get(e))); + } + } + + // ////////////////////////////////////////////////////////////// + // //////////////////// Functionality///////////////////// + // //////////////////////////////////////////////////////////// + + private int maxK; + private final boolean inDescOrder; + private AutoMap>> m; + + /** + * + * @param maxK How many top "k" elements are required + * @param inDescOrder if set to false sorts in ascending order + * @throws SQLException + */ + public SortedMultiMap(int maxK, boolean inDescOrder) throws SQLException { + this.maxK = maxK; + this.inDescOrder = inDescOrder; + m = new AutoMap<>(); + } + + /** + *
+     *         boolean isDesc = true;
+     *     
+     *     AutoMap>> m = new AutoMap<>();
+     * 
+     *     expander.putSorted("niket", "c", 40, m, isDesc);
+     * 
+     *     expander.putSorted("niket", "p", 10, m, isDesc);
+     *     expander.putSorted("niket", "b", 4, m, isDesc);
+     *     expander.putSorted("niket", "m", 50, m, isDesc);
+     * 
+     *     expander.putSorted("anjali", "p", 90, m, isDesc);
+     *     expander.putSorted("anjali", "c", 40, m, isDesc);
+     * 
+     *     expander.putSorted("anjali", "m", 20, m, isDesc);
+     *     expander.putSorted("anjali", "b", 100, m, isDesc);
+     * 
+     *     for (Entry>> e : m.entrySet())
+     *         for (QEntry e2 : e.getValue())
+     *         System.out.println(e.getKey() + "\t" + e2);
+     * @author ntandon
+     * 
+     * @param 
+     */
+    public static class QEntry {
+
+        private V v;
+
+        public V getV() {
+            return v;
+        }
+
+        public double getN() {
+            return n;
+        }
+
+        public void setN(double nNew) {
+            n = nNew;
+        }
+
+        public void addToN(double nNew) {
+            n += nNew;
+        }
+
+        private double n;
+
+        public QEntry(V v, double n) {
+            this.v = v;
+            this.n = n;
+        }
+
+        @Override
+        public String toString() {
+            return new StringBuilder().append(v.toString()).append('\t')
+                .append(NormalizationText.format(n)).toString();
+        }
+
+        @Override
+        public int hashCode() {
+            final int prime = 31;
+            int result = 1;
+            result = prime * result + ((v == null) ? 0 : v.hashCode());
+            return result;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj)
+                return true;
+            if (obj == null)
+                return false;
+            if (getClass() != obj.getClass())
+                return false;
+            QEntry other = (QEntry) obj;
+            if (v == null) {
+                if (other.v != null)
+                    return false;
+            } else if (!v.equals(other.v))
+                return false;
+            return true;
+        }
+
+    }
+
+    /**
+     * @deprecated because update on( "a 80, b 40" maxk=1) with (b, 50) would
+     *             have already dropped b in the first round.
+     * @param key
+     * @param value
+     * @param n
+     */
+    public void update(K key, V value, double n) {
+
+        if (!m.containsKey(key))
+            m.put(key, new PriorityQueue>(maxK,
+                new Comparator>() {
+
+                    @Override
+                    public int compare(QEntry o1, QEntry o2) {
+
+                        return inDescOrder ? (o1.n > o2.n ? 1 : -1)
+                            : (o2.n > o1.n ? 1 : -1);
+
+                    }
+                }));
+
+        // key=animal; value=tiger, n=10.
+        // m: animal=> dog;50, tiger;45, cow;10
+        // Update if existing==> m: animal=> dog;50, tiger;55, cow;10
+        QEntry qEntry = updatePriorityQueue(m.get(key), value, n);
+
+        PriorityQueue> qExisting = m.get(key);
+
+        /*
+         * Does the worst existing entry compare unfavorably to the entrytoadd?
+         */
+        qExisting.add(qEntry);
+        if (qExisting.size() > maxK)
+            qExisting.poll(); // remove the head (contains the worst value).
+    }
+
+    // key=animal; value=tiger, n=50.
+    // m: animal=> dog;50, tiger;45, cow;10
+    // Replace in existing==> m: animal=> dog;50, tiger;*50*, cow;10
+    private boolean replaceInPriorityQueue(PriorityQueue> q, V value,
+        double n) {
+
+        V entryToRemove = null;
+        double newVal = 0.0;
+
+        for (QEntry e : q) {
+            // if key already exists with a larger value, no operation
+            // if key already exists with a smaller value, replace
+            if (e.getV().equals(value) && e.getN() < n) {
+                entryToRemove = value;
+                newVal = Math.max(e.getN(), n);
+                break;
+            }
+        }
+
+        if (entryToRemove != null) {
+            q.remove(entryToRemove);
+            q.add(new QEntry(entryToRemove, newVal));
+            return true;
+        }
+
+        return false;
+
+    }
+
+    // key=animal; value=tiger, n=10.
+    // m: animal=> dog;50, tiger;45, cow;10
+    // Update if existing==> m: animal=> dog;50, tiger;55, cow;10
+    private QEntry updatePriorityQueue(PriorityQueue> q, V value,
+        double n) {
+        for (QEntry e : q) {
+            if (e.getV().equals(value)) {
+                double newVal = e.getN() + n;
+                q.remove(e);
+                return new QEntry(value, newVal);
+            }
+        }
+
+        return new QEntry(value, n);
+    }
+
+    public void put(K key, V value, double n) {
+
+        if (!m.containsKey(key))
+            m.put(key, new PriorityQueue>(maxK,
+                new Comparator>() {
+
+                    @Override
+                    public int compare(QEntry o1, QEntry o2) {
+
+                        return inDescOrder ? (o1.n > o2.n ? 1 : -1)
+                            : (o2.n > o1.n ? 1 : -1);
+
+                    }
+                }));
+
+        PriorityQueue> qExisting = m.get(key);
+
+        // if qExisting already has a QEntry containing the same key, retain the
+        // max value as n
+        boolean isValueUpdated = false;
+        for (QEntry entry : qExisting) {
+            if (entry.v.equals(value)) {
+                entry.n = n > entry.n ? n : entry.n;
+                isValueUpdated = true;
+                break;
+            }
+        }
+
+        if (!isValueUpdated) {
+            QEntry qEntry = new QEntry(value, n);
+            /*
+             * Does the worst existing entry compare unfavorably to the
+             * entrytoadd?
+             */
+            qExisting.add(qEntry);
+            if (qExisting.size() > maxK)
+                qExisting.poll(); // remove the head (contains the worst value).
+        }
+    }
+
+    public void remove(K key, V value) {
+        if (containsKey(key))
+            m.get(key).remove(value);
+    }
+
+    public boolean containsKey(K key) {
+        return m.containsKey(key);
+    }
+
+    public boolean containsKey(K key, V value) {
+        return value != null && containsKey(key) && m.get(key).contains(value);
+    }
+
+    public void remove(K key) {
+        if (m.containsKey(key))
+            m.remove(key);
+    }
+
+    private QEntry[] priorityQToOrderedArr(PriorityQueue> q) {
+        QEntry[] result = new QEntry[0];
+        if (q == null)
+            return null;
+        else
+            result = new QEntry[q.size()];
+
+        int reverseIndex = result.length;
+
+        QEntry e = q.poll();
+        while (e != null) {
+            result[--reverseIndex] = e;
+            e = q.poll();
+        }
+
+        return result;
+    }
+
+    public QEntry[] get(K key) {
+
+        PriorityQueue> q =
+            key == null || !m.containsKey(key) ? null : m.get(key);
+        return priorityQToOrderedArr(new PriorityQueue<>(q));
+    }
+    
+    public Map getAsMap(K key) {
+        Map mResult = new LinkedHashMap<>();
+        PriorityQueue> q =
+            key == null || !m.containsKey(key) ? null : m.get(key);
+        for (QEntry e : priorityQToOrderedArr(new PriorityQueue<>(q)))
+            mResult.put((V) e.getV(), e.getN());
+
+        return mResult;
+    }
+
+    public Set keyset() {
+        return m.keySet();
+    }
+
+    public Set>>> entrySet() {
+        return m.entrySet();
+    }
+
+//    public void putAll(SortedMultiMap addme) {
+//
+//        for (Entry>> e : Util.nullableIter(addme
+//            .entrySet()))
+//            for (QEntry e2 : e.getValue())
+//                put(e.getKey(), e2.getV(), e2.getN());
+//    }
+
+    public int size(){
+        return m!=null? m.size(): 0;
+    }
+}
+
+
diff --git a/src/main/java/kb/howtokb/utils/SparseSimMatrix.java b/src/main/java/kb/howtokb/utils/SparseSimMatrix.java
new file mode 100644
index 0000000..1483244
--- /dev/null
+++ b/src/main/java/kb/howtokb/utils/SparseSimMatrix.java
@@ -0,0 +1,34 @@
+package kb.howtokb.utils;
+
+import gnu.trove.TLongFloatHashMap;
+
+public class SparseSimMatrix {
+
+	TLongFloatHashMap matrix;
+	float threshold;
+
+	public SparseSimMatrix(float thres) {
+		matrix = new TLongFloatHashMap();
+		this.threshold = thres;
+	}
+
+	public void set(int x, int y, float value) {
+		if (value < threshold)
+			return;
+		long key = intpairToLong(x, y);
+		matrix.put(key, value);
+	}
+
+	//Matrix is symmetric then only store the upper triangle part
+	public float get(int x, int y) {
+		if (x <=y )
+			return matrix.get(intpairToLong(x, y));
+		return matrix.get(intpairToLong(y, x));
+	}
+
+	private long intpairToLong(int l, int r) {
+		return ((long) l << 32) + r;
+		// return (long) (l << 32) | (r & 0XFFFFFFFFL);
+	}
+
+}
\ No newline at end of file
diff --git a/src/test/java/kb/howtokb/TextToActivityTest.java b/src/test/java/kb/howtokb/TextToWikiHowTaskFrameTest.java
similarity index 59%
rename from src/test/java/kb/howtokb/TextToActivityTest.java
rename to src/test/java/kb/howtokb/TextToWikiHowTaskFrameTest.java
index 831293c..0c1de14 100644
--- a/src/test/java/kb/howtokb/TextToActivityTest.java
+++ b/src/test/java/kb/howtokb/TextToWikiHowTaskFrameTest.java
@@ -13,67 +13,67 @@
 
 import kb.howtokb.reader.WikiHowArticleReader;
 import kb.howtokb.taskframe.WikiHowTaskFrame;
-import kb.howtokb.taskframe.extractor.TextToActivity;
+import kb.howtokb.taskframe.extractor.TextToWikiHowTaskFrame;
 import kb.howtokb.wkhobject.Question;
 
-public class TextToActivityTest {
-	
+public class TextToWikiHowTaskFrameTest {
+
 	public static void main(String[] args) throws ClassNotFoundException, IOException, ParseException {
 		System.setOut(new PrintStream(new FileOutputStream("log.txt")));
-		
-		TextToActivity extract = new TextToActivity();
-		//Extract all question
+
+		TextToWikiHowTaskFrame extract = new TextToWikiHowTaskFrame();
+		// Extract all question
 		System.out.println("Reading json data file.....");
 		String input = "/var/tmp/cxchu/articles_test.json";
 		ArrayList allQuestions = WikiHowArticleReader.WikiHowArticleReaderFromJSONFile(input);
 		int frames = 0;
-		try{
+		try {
 			Writer textout = new BufferedWriter(new OutputStreamWriter(
-		              new FileOutputStream("/var/tmp/cxchu/data-for-test-code/act-frame.json"), "utf-8"));
+					new FileOutputStream("/var/tmp/cxchu/data-for-test-code/act-frame.json"), "utf-8"));
 			Writer idtextout = new BufferedWriter(new OutputStreamWriter(
-		              new FileOutputStream("/var/tmp/cxchu/data-for-test-code/id-act-frame.json"), "utf-8"));
-			
-			int i=1;
-			
-			for (Question ques: allQuestions){
-				try{
-				//if (i++ > 50){
-					//break;
-				//}
-				ArrayList listframe = extract.articleToListWikiHowTaskFrame(ques.setNormalized());
-				frames += listframe.size();
-				for (WikiHowTaskFrame frame: listframe){
-					//frame = frame.setNormalized();
-					textout.write(frame.toJsonObject().toJSONString() + "\n");
-					idtextout.write(frame.getID() + "\t" + frame.toJsonObject().toJSONString() + "\n");
-				}
-				
-				System.out.println(ques.getLink());
-				}catch (IOException e){
+					new FileOutputStream("/var/tmp/cxchu/data-for-test-code/id-act-frame.json"), "utf-8"));
+
+			int i = 1;
+
+			for (Question ques : allQuestions) {
+				try {
+					// if (i++ > 50){
+					// break;
+					// }
+					ArrayList listframe = extract.articleToListWikiHowTaskFrame(ques.setNormalized());
+					frames += listframe.size();
+					for (WikiHowTaskFrame frame : listframe) {
+						// frame = frame.setNormalized();
+						textout.write(frame.toJsonObject().toJSONString() + "\n");
+						idtextout.write(frame.getID() + "\t" + frame.toJsonObject().toJSONString() + "\n");
+					}
+
+					System.out.println(ques.getLink());
+				} catch (IOException e) {
 					e.printStackTrace();
 					continue;
-				}catch (NullPointerException e){
+				} catch (NullPointerException e) {
 					e.printStackTrace();
 					continue;
-				}catch (IndexOutOfBoundsException e){
+				} catch (IndexOutOfBoundsException e) {
 					e.printStackTrace();
 					continue;
-				}catch (PatternSyntaxException e){
+				} catch (PatternSyntaxException e) {
 					e.printStackTrace();
 					continue;
 				}
 			}
 			textout.close();
 			idtextout.close();
-		}catch (IOException e) {
-			
+		} catch (IOException e) {
+
 		}
 		System.out.println("Number of articles: " + extract.getNum_Article() + "\n");
 		System.out.println("Number of sentences: " + extract.getNum_Sent() + "\n");
 		System.out.println("Number of extractions: " + extract.getNum_Ext_Wt_Thres() + "\n");
 		System.out.println("Number of extractions with conf > 0.45: " + extract.getNum_Ext_Gt_Thres() + "\n");
 		System.out.println("Number of activity frames: " + frames + "\n");
-		
+
 	}
-	
+
 }
diff --git a/src/test/java/kb/howtokb/clustering/HeuristicBottomUpClusteringTest.java b/src/test/java/kb/howtokb/clustering/HeuristicBottomUpClusteringTest.java
new file mode 100644
index 0000000..10eae6c
--- /dev/null
+++ b/src/test/java/kb/howtokb/clustering/HeuristicBottomUpClusteringTest.java
@@ -0,0 +1,58 @@
+package kb.howtokb.clustering;
+
+import java.io.BufferedWriter;
+import java.io.FileOutputStream;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.util.List;
+
+import kb.howtokb.clustering.HeuristicBottomupClustering.ActivitySuperCluster;
+import kb.howtokb.clustering.sim.Coefficient;
+import kb.howtokb.taskframe.WikiHowTaskFrame;
+import kb.howtokb.tools.InformationExtraction;
+
+public class HeuristicBottomUpClusteringTest {
+	public static void main(String[] args) throws Exception {
+		
+		long startTime = System.currentTimeMillis();
+		
+		
+		String activityTb = "/var/tmp/cxchu/clustering-pre-computation/all-words-category.txt";
+		HeuristicBottomupClustering cluster = new HeuristicBottomupClustering(activityTb);
+		
+		double threshold = Coefficient.VVNN_TRHES;
+		String model = "/var/tmp/cxchu/w2v-model/articles-word2vec-word-pos.model.txt";
+		String allAct = "/var/tmp/cxchu/groundtruth-data/all-strong-activities.txt";
+		SimplePruningSimilarity simFunc = new SimplePruningSimilarity(threshold, model, allAct);
+		List results = cluster.cluster(simFunc, Coefficient.VVNN_TRHES);
+		System.out.println("Number of clusters: " + results.size());
+		String output = "/var/tmp/cxchu/clustering-result/bottom-up-cluster-";
+		
+		String input = ""; //original data point file
+		List allframe = InformationExtraction.getAllFrame(input);
+		int total = 0;
+		for (int i = 0; i < results.size(); i++) {
+			System.out.println("Cluster " + i + ": " + results.get(i).getSuperClusterMembers().size());
+			Writer out = new BufferedWriter(new OutputStreamWriter(
+					new FileOutputStream(output+i+".json"), "utf-8"));
+			 List actitiviesID =
+					 results.get(i).getSuperClusterMembers();
+			 for (int j=0; j allframe = InformationExtraction.getAllFrame(input);
+
+		System.out.println("Initializing simple topdown clustering.....");
+		
+		HeuristicTopDownClusteringDynamicSparse cluster = new HeuristicTopDownClusteringDynamicSparse(allframe, false, 0.9, 5, 0.5);
+		
+		System.out.println("Done! Start clustering.......");
+		int k=2;
+		List> res = cluster.splitACluster(cluster.getInputCluster(), k);
+		
+		System.out.println("Done! Results..............");
+		for (int i=0; i> frames = res.get(i).getClusterMembers();
+			for (int j=0; j allframe = InformationExtraction.getAllFrame(input);
+
+		System.out.println("Initializing simple topdown clustering.....");
+		
+//		InstanceInLeafSimpleStopping stopper = new InstanceInLeafSimpleStopping<>(5);
+		HeuristicTopDownClustering cluster = new HeuristicTopDownClustering(allframe, false, 0.9, 5);
+		
+		System.out.println("Done! Start clustering.......");
+		int k=2;
+		List> res = cluster.splitACluster(cluster.getInputCluster(), k);
+		
+		System.out.println("Done! Results..............");
+//		out.write("Number of clusters: " + res.size() + "\n");
+		for (int i=0; i> frames = res.get(i).getClusterMembers();
+			for (int j=0; j