Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
//
// main.cpp
// non-neg-mat-fac
//
// Created by Peter Arndt on 04/04/2016.
// Copyright © 2016 Peter Arndt. All rights reserved.
//
#include <iostream>
#include "common.h"
#include "random.h"
#include "matrix.h"
using namespace std;
MatDoub generate_random_matrix(int I,int J,int max)
{
Ran ran(initrand());
MatDoub M(I,J,0.0);
for (int i=0;i<I;i++)
for(int j=0; j<J; j++)
M[i][j]=1+ran.int64() % (max);
return M;
}
double distance(MatDoub_I &A, MatDoub_I &B)
{
double d=0;
for (int a=0; a<A.nrows(); a++)
for (int b=0; b<A.ncols(); b++)
d+= SQR(A[a][b]-B[a][b]);
return d;
}
double distance(MatDoub_I &V, MatDoub_I &W, MatDoub_I &H)
{
MatDoub WH(W.nrows(),H.ncols(),0.0);
NRmatrix_dgemm(false, false, 1, W, H, 0, WH);
return distance(V,WH);
}
double divergence(MatDoub_I &V, MatDoub_I &W, MatDoub_I &H)
{
MatDoub WH(W.nrows(),H.ncols(),0.0);
NRmatrix_dgemm(false, false, 1, W, H, 0, WH);
double d=0;
for (int a=0; a<V.nrows(); a++)
for (int b=0; b<V.ncols(); b++)
d+= V[a][b] * log( V[a][b] / WH[a][b] ) - V[a][b] + WH[a][b];
return d;
}
void perturbe_matrix(MatDoub &A,double max_factor=2)
{
Ran ran(initrand());
for (int a=0; a<A.nrows(); a++)
for (int b=0; b<A.ncols(); b++){
double factor = exp( log(max_factor) * (2*ran.doub()-1) );
A[a][b] *= factor;
}
}
pair<double,double> NMF_update1(MatDoub_I &V, MatDoub &W, MatDoub &H)
{
MatDoub WD(W.nrows(),W.ncols(),1.0);
MatDoub HD(H.nrows(),H.ncols(),1.0);
MatDoub WH(W.nrows(),H.ncols(),0.0);
NRmatrix_dgemm(false, false, 1, W, H, 0, WH);
MatDoub WtV(W.ncols(),V.ncols(),0.0);
NRmatrix_dgemm(true, false, 1, W, V, 0, WtV);
MatDoub WtWH(W.ncols(),WH.ncols(),0.0);
NRmatrix_dgemm(true, false, 1, W, WH, 0, WtWH);
MatDoub VHt(V.nrows(),H.nrows(),0.0);
NRmatrix_dgemm(false, true, 1, V, H, 0, VHt);
MatDoub WHHt(WH.nrows(),H.nrows(),0.0);
NRmatrix_dgemm(false, true, 1, WH, H, 0, WHHt);
cout <<WtWH<<WHHt<<endl;
// H
for (int a=0; a<HD.nrows(); a++){
for (int b=0; b<HD.ncols(); b++){
HD[a][b]=WtV[a][b]/WtWH[a][b];
}
}
// W
for (int a=0; a<WD.nrows(); a++){
for (int b=0; b<WD.ncols(); b++){
WD[a][b]=VHt[a][b]/WHHt[a][b];
}
}
double HDmin,HDmax;
NRmatrix_minmax(HD,HDmin,HDmax);
double WDmin,WDmax;
NRmatrix_minmax(WD,WDmin,WDmax);
NRmatrix_elementwise_mult(W,WD);
NRmatrix_elementwise_mult(H,HD);
cout << "dist:"<<distance(V, W,H)<<endl;
// use scaling freedom to let the cols of W sum up to 1
for (int b=0; b<W.ncols(); b++){
double lambda_b=0;
for (int a=0; a<W.nrows(); a++)
lambda_b += W[a][b];
for (int a=0; a<W.nrows(); a++)
W[a][b] /= lambda_b;
for (int a=0; a<H.ncols(); a++)
H[b][a] *= lambda_b;
}
cout << "dist:"<<distance(V, W,H)<<endl;
return make_pair(min(HDmin,WDmin),max(HDmax,WDmax));
}
pair<double,double> NMF_update2(MatDoub_I &V, MatDoub &W, MatDoub &H)
{
MatDoub WD(W.nrows(),W.ncols(),1.0);
MatDoub HD(H.nrows(),H.ncols(),1.0);
MatDoub WH(W.nrows(),H.ncols(),0.0);
NRmatrix_dgemm(false, false, 1, W, H, 0, WH);
// H
for (int a=0; a<HD.nrows(); a++){
for (int b=0; b<HD.ncols(); b++){
double num=0,den=0;
for (int i=0;i<W.nrows();i++){
num+=W[i][a]*V[i][b]/WH[i][b];
den+=W[i][a];
}
HD[a][b]=num/den;
}
}
// W
for (int i=0; i<WD.nrows(); i++){
for (int a=0; a<WD.ncols(); a++){
double num=0,den=0;
for (int k=0;k<H.ncols();k++){
num+=H[a][k]*V[i][k]/WH[i][k];
den+=H[a][k];
}
WD[i][a]=num/den;
}
}
double HDmin,HDmax;
NRmatrix_minmax(HD,HDmin,HDmax);
double WDmin,WDmax;
NRmatrix_minmax(WD,WDmin,WDmax);
NRmatrix_elementwise_mult(W,WD);
NRmatrix_elementwise_mult(H,HD);
return make_pair(min(HDmin,WDmin),max(HDmax,WDmax));
}
pair<double,double> NMF_update3(MatDoub_I &V, MatDoub &W, MatDoub &H)
{
MatDoub WD(W.nrows(),W.ncols(),1.0);
MatDoub HD(H.nrows(),H.ncols(),1.0);
MatDoub WH(W.nrows(),H.ncols(),0.0);
NRmatrix_dgemm(false, false, 1, W, H, 0, WH);
// H
for (int a=0; a<HD.nrows(); a++){
for (int b=0; b<HD.ncols(); b++){
double num=0;
for (int i=0;i<W.nrows();i++){
num+=W[i][a]*V[i][b]/WH[i][b];
}
HD[a][b]=num;
}
}
// W
for (int i=0; i<WD.nrows(); i++){
for (int a=0; a<WD.ncols(); a++){
double num=0;
for (int k=0;k<H.ncols();k++){
num+=H[a][k]*V[i][k]/WH[i][k];
}
WD[i][a]=num;
}
}
double HDmin,HDmax;
NRmatrix_minmax(HD,HDmin,HDmax);
double WDmin,WDmax;
NRmatrix_minmax(WD,WDmin,WDmax);
NRmatrix_elementwise_mult(W,WD);
NRmatrix_elementwise_mult(H,HD);
for (int a=0; a<W.ncols(); a++){
double sum=0;
for (int i=0; i<W.nrows(); i++)
sum += W[i][a];
for (int i=0; i<W.nrows(); i++)
W[i][a] /= sum;
}
return make_pair(min(HDmin,WDmin),max(HDmax,WDmax));
}
pair<MatDoub,MatDoub> NMF(MatDoub_I &V,int K)
{
int N = V.nrows();
int M = V.ncols();
MatDoub W=generate_random_matrix(N, K,10);
MatDoub H=generate_random_matrix(K, M,10);
for (int j=0; j<H.ncols();j++){
double h=0;
for (int i=0;i<H.nrows();i++)
h+=H[i][j];
for (int i=0;i<H.nrows();i++)
H[i][j]/=h;
}
// cout << W << H;
double last_d=divergence(V,W,H);
int iter=0;
int max_iter=2000;
while(iter < max_iter){
iter++;
auto minmax = NMF_update3(V, W, H);
// cout << iter;
// cout << "\t"<< distance(V,W,H)<<"\t"<<divergence(V,W,H);
// cout <<endl;
double d=divergence(V,W,H);
if (fabs(d-last_d)<last_d*0.001)
break;
//cout << W << H<<endl;
}
return make_pair(W, H);
}
int main(int argc, const char * argv[]) {
StartUp(argc,argv);
cout << "data dir:" << data_dir(parent_dir(__FILE__)+"data") << endl;
int A=1;
auto V = NRmatrix_read_tsv(data_dir()+"V"+toa(A)+".tsv");
auto W = NRmatrix_read_tsv(data_dir()+"W"+toa(A)+".tsv");
auto H = NRmatrix_read_tsv(data_dir()+"H"+toa(A)+".tsv");
int K = 2;
MatDoub bestW,bestH;
double bestDist=-1;
double bestDiv=-1;
MatDoub Wi=W;
MatDoub Hi=H;
MatDoub Vi=V;
long N=10000;
long P=500;
double f=1.1;
perturbe_matrix(Hi,2);
perturbe_matrix(Wi,2);
fstream fout(data_dir()+"conv"+toa(A)+".tsv",ios::out);
fout <<"i\tVdist\tWdist\tHdist\n";
for (int i=0;i<N;i++){
NMF_update3(V, Wi, Hi);
double dist=distance(V,Wi,Hi);
cout << i;
cout << "\tVdist="<<dist;
fout << i;
fout << "\t"<<dist;
fout << "\t"<<distance(W,Wi);
fout << "\t"<<distance(H,Hi);
fout << endl;
if ((bestDist<0) || (bestDist > dist)){
bestW=Wi;
bestH=Hi;
bestDist=dist;
cout << " * ";
}
if (i% P ==0){
perturbe_matrix(Hi,f);
cout << "P";
}
cout << endl;
}
fout.close();
cout << bestW<<endl;
return 0;
}