最近因为项目需求,需要时在weka上实现流形距离计算,因为weka没有提供流形学习的包,而smile提供了,于是我根据smile的等距离度量(Isomap)来重写了一个可在weka上使用的流形距离计算类。
欧式距离是最常用的距离度量,但是在数据集不具有全局线性结构是,欧氏距离就不是一种合理的数据距离度量,一般使用拓扑流形结构来度量高维度的非线线性数据。这种方法通常用了对数据进行降维,也被称为流形学习。
定义1:
流形两点间x1, x2的线段长度定义为 L(x1, x2) = exp(d(x1, x2) / σ) -1
定义2:
将数据点看作是无向有权图G=(V, E),V是顶点集合,E是边集P的集合,Pij表示图上数据点Xi, Xj的所有路径集合,则Xi,Xj的流形距离为 MD(xi, xj)=min∑L(pk, pk+1), 1≤k≤|p| - 1
算法流程:
for i = 1,2,3...m do
确定xi的k个最近邻
将xi与k个最近邻的距离设为定义的距离公式,与自己的距离设为0,与其他点距离设为-1
将这些数值添加进入邻接矩阵
end
根据邻接矩阵构建一个有权无向图的对象
使用dijkstra最短距离求出图上任意两点的最短距离
ManifoldDistance.java
import weka.core.EuclideanDistance;
import weka.core.Instances;
import java.util.*;
/**
* Created by Administrator on 2017/3/15.
*/
public class ManifoldDistance {
private final Instances data;
private final int k;
private final double sigma;
private double[][] matrix;
private Graph graph = new Graph();
/**
* 流形学习的距离计算类的构造方法
*
* @param data 要计算的instances类型的数据集
* @param k KNN需要指定的参数k
* @param sigma 距离公式需要的参数σ
*/
public ManifoldDistance(Instances data, int k, double sigma) {
this.data = data;
this.k = k;
this.sigma = sigma;
}
public Instances getData() {
return data;
}
public int getK() {
return k;
}
public double getSigma() {
return sigma;
}
public double[][] getMatrix() {
return matrix;
}
/**
* 构造数据data的邻接矩阵
*
* @return double[][]类型的邻接矩阵
*/
private double[][] constructWeightMatrix() {
int num = this.data.numInstances();
double[][] weight_matrix = new double[num][num];
EuclideanDistance calculateDistance = new EuclideanDistance(this.data);
for(int i = 0; i < num; i++){
HashMap<Integer, Double> temp = new HashMap<>();
for(int j = 0; j < num; j++){
if(i != j) {
double dist = calculateDistance.distance(this.data.instance(i), this.data.instance(j));
temp.put(j, Math.exp(dist / this.sigma) - 1);
}else{
temp.put(j, 0.0);
}
}
ArrayList<Integer> index = nearestNeighbor(temp);
for(int n = 0; n < num; n++){
if(index.contains(n)){
weight_matrix[i][n] = temp.get(n);
weight_matrix[n][i] = temp.get(n);
}else if(i == n){
weight_matrix[i][i] = 0.0;
}else{
if(weight_matrix[i][n] == 0.0) {
weight_matrix[i][n] = -1.0;
}
}
}
}
return weight_matrix;
}
/**
* 计算K个最近邻
*
* @param temp 当前向量i与其他所有向量的距离
* @return k个最近邻所在的位置索引
*/
private ArrayList<Integer> nearestNeighbor(HashMap<Integer, Double> temp){
ArrayList<Integer> index = new ArrayList<>();
ArrayList<Map.Entry<Integer, Double>> list = new ArrayList<>(temp.entrySet());
list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
int count = 0;
for (Map.Entry<Integer, Double> aList : list) {
if(count >= this.k){
break;
}else {
index.add(aList.getKey());
count++;
}
}
return index;
}
/**
* 生成邻接矩阵与对应的无向有权图
*/
public void build(){
this.matrix = constructWeightMatrix();
int num = this.matrix.length;
HashMap<String, List<Vertex>>edge = new HashMap<>();
for (int i = 0; i < num; i++){
edge.put(Integer.toString(i), new ArrayList<>());
}
for (int i = 0; i < num; i++){
for (int j = 0; j < num; j++){
if (this.matrix[i][j] > 0){
List<Vertex> iedge = edge.get(Integer.toString(i));
iedge.add(new Vertex(Integer.toString(j), this.matrix[i][j]));
edge.put(Integer.toString(i), iedge);
List<Vertex> jedge = edge.get(Integer.toString(j));
jedge.add(new Vertex(Integer.toString(i), this.matrix[i][j]));
edge.put(Integer.toString(j), jedge);
}
}
}
for(String i : edge.keySet()){
List<Vertex> toVertex = edge.get(i);
this.graph.addVertex(i, toVertex);
}
}
/**
* 获取图上两个向量的dijkstra最短距离
*
* @param start 起始点
* @param end 结束点
* @return 最短距离的数值
*/
public double getDistance(String start, String end){
List<String> path = this.graph.getShortestPath(start, end);
path.add(start);
Collections.reverse(path);
double mDist = 0.0;
for (int i = 0; i < path.size() - 1; i++){
int m = Integer.parseInt(path.get(i));
int n = Integer.parseInt(path.get(i + 1));
mDist += this.matrix[m][n];
}
System.out.println("shortest path:" + path);
return mDist;
}
}
Graph.java
import java.util.*;
/**
* Created by Administrator on 2017/3/14.
*/
class Graph {
private final Map<String, List<Vertex>> vertices;
public Graph() {
this.vertices = new HashMap<>();
}
public void addVertex(String character, List<Vertex> vertex) {
this.vertices.put(character, vertex);
}
public List<String> getShortestPath(String start, String finish) {
final Map<String, Double> distances = new HashMap<>();
final Map<String, Vertex> previous = new HashMap<>();
PriorityQueue<Vertex> nodes = new PriorityQueue<>();
for(String vertex : vertices.keySet()) {
if (Objects.equals(vertex, start)) {
distances.put(vertex, 0.0);
nodes.add(new Vertex(vertex, 0.0));
} else {
distances.put(vertex, Double.MAX_VALUE);
nodes.add(new Vertex(vertex, Double.MAX_VALUE));
}
previous.put(vertex, null);
}
while (!nodes.isEmpty()) {
Vertex smallest = nodes.poll();
if (Objects.equals(smallest.getId(), finish)) {
final List<String> path = new ArrayList<>();
while (previous.get(smallest.getId()) != null) {
path.add(smallest.getId());
smallest = previous.get(smallest.getId());
}
return path;
}
if (distances.get(smallest.getId()) == Integer.MAX_VALUE) {
break;
}
for (Vertex neighbor : vertices.get(smallest.getId())) {
Double alt = distances.get(smallest.getId()) + neighbor.getDistance();
if (alt < distances.get(neighbor.getId())) {
distances.put(neighbor.getId(), alt);
previous.put(neighbor.getId(), smallest);
for(Vertex n : nodes) {
if (Objects.equals(n.getId(), neighbor.getId())) {
nodes.remove(n);
n.setDistance(alt);
nodes.add(n);
break;
}
}
}
}
}
return new ArrayList<>(distances.keySet());
}
}
Vertex.java
/**
* Created by Administrator on 2017/3/14.
*/
class Vertex implements Comparable<Vertex> {
private String id;
private Double distance;
public Vertex(String id, Double distance) {
super();
this.id = id;
this.distance = distance;
}
public String getId() {
return id;
}
public Double getDistance() {
return distance;
}
public void setId(String id) {
this.id = id;
}
public void setDistance(Double distance) {
this.distance = distance;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((distance == null) ? 0 : distance.hashCode());
result = prime * result + ((id == null) ? 0 : id.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Vertex other = (Vertex) obj;
if (distance == null) {
if (other.distance != null)
return false;
} else if (!distance.equals(other.distance))
return false;
if (id == null) {
if (other.id != null)
return false;
} else if (!id.equals(other.id))
return false;
return true;
}
@Override
public String toString() {
return "Vertex [id=" + id + ", distance=" + distance + "]";
}
@Override
public int compareTo(Vertex o) {
if (this.distance < o.distance)
return -1;
else if (this.distance > o.distance)
return 1;
else
return this.getId().compareTo(o.getId());
}
}
Demo.java
import weka.core.Instances;
import java.io.FileReader;
import java.io.IOException;
/**
* Created by Administrator on 2017/3/15.
*/
public class Demo {
public static void main(String[] args) throws IOException {
Instances data = new Instances(new FileReader("Test/Manifold/cpu.arff"));
ManifoldDistance manifold = new ManifoldDistance(data, 20, 2);
manifold.build();
for (double[] aMtx : manifold.getMatrix()) {
for(double v : aMtx){
System.out.print(v + " ");
}
System.out.println();
}
System.out.println(manifold.getDistance("10", "71"));
System.out.println(manifold.getDistance("71", "10"));
}
}