Skip to content
Permalink
0d82ff1dc4
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
174 lines (151 sloc) 5.02 KB
package kb.howtokb.clustering.sim.w2v;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import kb.howtokb.clustering.ISimilarity;
import kb.howtokb.tools.NormalizationText;
import kb.howtokb.utils.SortedMultiMap;
public class Word2VecSimilarity implements ISimilarity<String> {
private final Map<String, Double> emptyMap = new HashMap<>();
Map<String, double[]> wordvectors;
private SortedMultiMap<String, String> neighbors;
private String[] words;
// Map beautiful -> [a_beautiful]
private static Map<String, String> wordPOS;
public Word2VecSimilarity(String modelFile, int maxK) throws IOException, SQLException {
this(modelFile, maxK, true);
}
/**
* @param modelFile
* = "word2vec-data/phrase-norole.model.txt"
* @param maxK
* = k nearest neighbors for getNeighbors() function.
* @throws IOException
* : while reading the model.
* @throws SQLException
* : TODO: fix in SortedMultiMap, we don't need this.
*/
public Word2VecSimilarity(String modelFile, int maxK, boolean precomputeNeighborhood)
throws IOException, SQLException {
System.out.println("Word2Vec constructing model... " + modelFile);
load(modelFile);
if (precomputeNeighborhood) {
System.out.println("Word2Vec constructing neighborhood... " + words.length);
constructNeighborhood(maxK);
}
System.out.println("Word2Vec object constructed");
}
private void load(String model) throws IOException {
if (wordvectors == null) {
boolean isHeader = false;
int vocabElemIndex = 0;
int vecSize = 1;
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
InputStream inputs = classLoader.getResourceAsStream(model);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputs, "UTF-8"))) {
String sCurrentLine;
while ((sCurrentLine = reader.readLine()) != null) {
if (!isHeader) {
// 1435 200 i.e. 1435 tokens, each 200 size vector.
int vocabSize = Integer.parseInt(sCurrentLine.split(" ")[0]);
words = new String[vocabSize];
wordvectors = new LinkedHashMap<>(vocabSize);
wordPOS = new HashMap<>();
vecSize = Integer.parseInt(sCurrentLine.split(" ")[1]);
isHeader = true;
continue;
}
String[] w = sCurrentLine.split(" ");
double[] vec = new double[vecSize];
for (int i = 1; i < w.length; i++) {
vec[i - 1] = Double.parseDouble(w[i]);
}
words[vocabElemIndex++] = w[0];
wordvectors.put(w[0], vec);
if (w[0].indexOf("_") == 1) {
// System.out.println(w[0].substring(2));
wordPOS.put(w[0].substring(2), w[0]);
}
}
}
}
}
double cosine(double[] w1, double[] w2) {
double result = 0;
double t = 0;
double m1 = 0, m2 = 0;
for (int i = 0; i < w1.length; i++) {
m1 += w1[i] * w1[i];
m2 += w2[i] * w2[i];
t += w1[i] * w2[i];
}
result = t / (Math.sqrt(m1) * Math.sqrt(m2));
// result = (result > 1) ? 1 : result;
// Use angular distance to make similarity bounded between [0,1]
result = Math.acos(result) / Math.PI;
result = 1 - result;
return Double.parseDouble(NormalizationText.format(result));
}
private void constructNeighborhood(int maxK) throws SQLException {
if (neighbors == null) {
neighbors = new SortedMultiMap<>(maxK, true);
// Upper triangular matrix is insufficient because
// we need to store all the m x m similarity matrix
// which is prohibitive, and so we use Heap
// which has a small memory footprint
// But, heap requires m x m computations not just upper triangular
// Anyways, cosine is very cheap so no need of futher optimizations
for (int i = 0; i < words.length; i++) {
for (int j = 0; j < words.length; j++) {
if (j != i)
neighbors.put(words[i], words[j], sim(words[i], words[j]));
}
}
}
}
@Override
public double sim(String w1, String w2) {
if (w1.equals(w2))
return 1.0;
if (!wordvectors.containsKey(w1)) {
System.out.println(w1 + " is out of dictionary");
return 0;
}
if (!wordvectors.containsKey(w2)) {
System.out.println(w2 + " is out of dictionary");
return 0;
}
return cosine(wordvectors.get(w1), wordvectors.get(w2));
}
public double simWithoutPOS(String w1, String w2) {
if (w1.equals(w2))
return 1.0;
if (!wordPOS.containsKey(w1)) {
System.out.println(w1 + " is out of dictionary");
return 0;
}
if (!wordPOS.containsKey(w2)) {
System.out.println(w2 + " is out of dictionary");
return 0;
}
w1 = wordPOS.get(w1);
w2 = wordPOS.get(w2);
return cosine(wordvectors.get(w1), wordvectors.get(w2));
}
@Override
public Map<String, Double> getNeighbors(String w) {
return !neighbors.containsKey(w) ? emptyMap : neighbors.getAsMap(w);
}
public String[] vocab() {
return words;
}
@Override
public boolean simThreshold(String e1, String e2, double minthreshold) throws Exception {
return sim(e1, e2) >= minthreshold;
}
}