package ml.package4;
import org.junit.Test;
import java.util.*;
import static java.util.stream.Collectors.groupingBy;
/**
- Created by on 2017/12/8.
*/
public class test {
@Test
public void testId3() {
String[][] data = {
{"青年", "否", "否", "一般", "否"},
{"青年", "否", "否", "好", "否"},
{"青年", "是", "否", "好", "是"},
{"青年", "是", "是", "一般", "是"},
{"青年", "否", "否", "一般", "否"},
{"中年", "否", "否", "一般", "否"},
{"中年", "否", "否", "好", "否"},
{"中年", "是", "是", "好", "是"},
{"中年", "否", "是", "非常好", "是"},
{"中年", "否", "是", "非常好", "是"},
{"老年", "否", "是", "非常好", "是"},
{"老年", "否", "是", "好", "是"},
{"老年", "是", "否", "好", "是"},
{"老年", "是", "否", "非常好", "是"},
{"老年", "否", "否", "一般", "否"}
};
List<String> title = Arrays.asList("年龄", "有工作", "有自己的房子", "信贷情况", "类别");
ID3Node node = buildId3Tree(data, title);
System.out.println(node);
}
private String[][] calcData(String[][] data, int column, String val) {
Object[] tp = Arrays.stream(data).filter(o -> o[column].equals(val)).toArray();
String[][] rs = new String[tp.length][data[0].length - 1];
for (int i = 0; i < tp.length; i++) {
String[] row = (String[]) tp[i];
for (int j = 0, rsIndex = 0; j < row.length; j++) {
if (j != column) {
rs[i][rsIndex] = row[j];
rsIndex++;
}
}
}
return rs;
}
private ID3Node buildId3Tree(String[][] data, List<String> title) {
ID3Node id3Node = new ID3Node();
final double[] val = {0, 0, 0, 0, 0};
Arrays.stream(data).collect(groupingBy(o -> o[data[0].length - 1]))
.forEach((name, list) -> val[data[0].length - 1] -= 1d * list.size() / data.length * ln(1d * list.size() / data.length));
if (val[data[0].length - 1] == 0 || title.size() == 1) {
id3Node.title = title.get(title.size() - 1);
id3Node.val = data[0][data[0].length - 1];
return id3Node;
}
List<List<String>> ids = new ArrayList<>();
for (final int[] i = {0}; i[0] < data[0].length - 1; i[0]++) {
List<String> id = new ArrayList<>();
Arrays.stream(data).collect(groupingBy(o -> o[i[0]]))
.forEach((name, list) -> {
id.add(name);
final double[] tp = {0};
list.stream().collect(groupingBy(oo -> oo[data[0].length - 1]))
.forEach(
(_name, _list) -> tp[0] -= 1d * _list.size() / list.size() * ln(1d * _list.size() / list.size())
);
val[i[0]] += tp[0] * list.size() / data.length;
});
ids.add(id);
}
double v = -Double.MAX_VALUE;
int index = -1;
for (int i = 0; i < data[0].length - 1; i++) {
System.out.println(title.get(i) + "->" + (val[data[0].length - 1] - val[i]));
if (v < val[data[0].length - 1] - val[i]) {
v = val[data[0].length - 1] - val[i];
index = i;
}
}
System.out.println(title.get(index) + " ->" + ids.get(index));
id3Node.title = title.get(index);
if (id3Node.children == null)
id3Node.children = new HashMap<>();
List<String> _title = new ArrayList<>();
for (String str : title)
if (!str.equals(title.get(index)))
_title.add(str);
for (String str : ids.get(index))
id3Node.children.put(str, buildId3Tree(calcData(data, index, str), _title));
return id3Node;
}
private double ln(double v) {
if (v == 0d)
return 0d;
return Math.log(v) / Math.log(2d);
}
private class ID3Node {
String val;
String title;
Map<String, ID3Node> children;
}
}