你的位置:首页 > ASP.net教程

[ASP.net教程]数据挖掘之决策树ID3算法(C#实现)


决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏。比如猜一个动物:

问:这个动物是陆生动物吗?

答:是的。

问:这个动物有鳃吗?

答:没有。

这样的两个问题顺序就有些颠倒,因为一般来说陆生动物是没有鳃的(记得应该是这样的,如有错误欢迎指正)。所以玩这种游戏,提问的顺序很重要,争取每次都能够获得尽可能多的信息量。

AllElectronics顾客数据库标记类的训练元组
RIDageincomestudentcredit_ratingClass: buys_computer
1youthhighnofairno
2youthhighnoexcellentno
3middle_agedhighnofairyes
4seniormediumnofairyes
5seniorlowyesfairyes
6seniorlowyesexcellentno
7middle_agedlowyesexcellentyes
8youthmediumnofairno
9youthlowyesfairyes
10seniormediumyesfairyes
11youthmediumyesexcellentyes
12middle_agedmediumnoexcellentyes
13middle_agedhighyesfairyes
14seniormediumnoexcellentno

以AllElectronics顾客数据库标记类的训练元组为例。我们想要以这些样本为训练集,训练我们的决策树模型,以此来挖掘出顾客是否会购买电脑的决策模式。

在决策树ID3算法中,计算信息度的公式如下:

$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$

计算信息增益的公式如下:

$$Gain(A) = Info(D) - Info_A(D)$$

按照公式,在要进行分类的类别变量中,有5个“no”和9个“yes”,因此期望信息为:

$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$

首先计算特征age的期望信息:

$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$

因此,如果按照age进行划分,则获得的信息增益为:

$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$

依次计算以income、student和credit_rating来分裂的信息增益,由此选择能够带来最大信息增益的变量,在当

前结点选择以以该变量的取值进行分裂。递归地进行执行即可生成决策树。更加详细的内容可以参考:

https://en.wikipedia.org/wiki/Decision_tree

C#代码的实现如下:

 1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 namespace MachineLearning.DecisionTree 5 { 6   public class DecisionTreeID3<T> where T : IEquatable<T> 7   { 8     T[,] Data; 9     string[] Names; 10     int Category; 11     T[] CategoryLabels; 12     DecisionTreeNode<T> Root; 13     public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels) 14     { 15       Data = data; 16       Names = names; 17       Category = data.GetLength(1) - 1;//类别变量需要放在最后一列 18       CategoryLabels = categoryLabels; 19     } 20     public void Learn() 21     { 22       int nRows = Data.GetLength(0); 23       int nCols = Data.GetLength(1); 24       int[] rows = new int[nRows]; 25       int[] cols = new int[nCols]; 26       for (int i = 0; i < nRows; i++) rows[i] = i; 27       for (int i = 0; i < nCols; i++) cols[i] = i; 28       Root = new DecisionTreeNode<T>(-1, default(T)); 29       Learn(rows, cols, Root); 30       DisplayNode(Root); 31     } 32     public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0) 33     { 34       if (Node.Label != -1) 35         Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value); 36       foreach (var item in Node.Children) 37         DisplayNode(item, depth + 1); 38     } 39     private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root) 40     { 41       var categoryValues = GetAttribute(Data, Category, pnRows); 42       var categoryCount = categoryValues.Distinct().Count(); 43       if (categoryCount == 1) 44       { 45         var node = new DecisionTreeNode<T>(Category, categoryValues.First()); 46         Root.Children.Add(node); 47       } 48       else 49       { 50         if (pnRows.Length == 0) return; 51         else if (pnCols.Length == 1) 52         { 53           //投票~ 54           //多数票表决制 55           var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First(); 56           var node = new DecisionTreeNode<T>(Category, Vote.First()); 57           Root.Children.Add(node); 58         } 59         else 60         { 61           var maxCol = MaxEntropy(pnRows, pnCols); 62           var attributes = GetAttribute(Data, maxCol, pnRows).Distinct(); 63           string currentPrefix = Names[maxCol]; 64           foreach (var attr in attributes) 65           { 66             int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray(); 67             int[] cols = pnCols.Where(i => i != maxCol).ToArray(); 68             var node = new DecisionTreeNode<T>(maxCol, attr); 69             Root.Children.Add(node); 70             Learn(rows, cols, node);//递归生成决策树 71           } 72         } 73       } 74     } 75     public double AttributeInfo(int attrCol, int[] pnRows) 76     { 77       var tuples = AttributeCount(attrCol, pnRows); 78       var sum = (double)pnRows.Length; 79       double Entropy = 0.0; 80       foreach (var tuple in tuples) 81       { 82         int[] count = new int[CategoryLabels.Length]; 83         foreach (var irow in pnRows) 84           if (Data[irow, attrCol].Equals(tuple.Item1)) 85           { 86             int index = Array.IndexOf(CategoryLabels, Data[irow, Category]); 87             count[index]++; 88           } 89         double k = 0.0; 90         for (int i = 0; i < count.Length; i++) 91         { 92           double frequency = count[i] / (double)tuple.Item2; 93           double t = -frequency * Log2(frequency); 94           k += t; 95         } 96         double freq = tuple.Item2 / sum; 97         Entropy += freq * k; 98       } 99       return Entropy;100     }101     public double CategoryInfo(int[] pnRows)102     {103       var tuples = AttributeCount(Category, pnRows);104       var sum = (double)pnRows.Length;105       double Entropy = 0.0;106       foreach (var tuple in tuples)107       {108         double frequency = tuple.Item2 / sum;109         double t = -frequency * Log2(frequency);110         Entropy += t;111       }112       return Entropy;113     }114     private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)115     {116       foreach (var irow in pnRows)117         yield return data[irow, col];118     }119     private static double Log2(double x)120     {121       return x == 0.0 ? 0.0 : Math.Log(x, 2.0);122     }123     public int MaxEntropy(int[] pnRows, int[] pnCols)124     {125       double cateEntropy = CategoryInfo(pnRows);126       int maxAttr = 0;127       double max = double.MinValue;128       foreach (var icol in pnCols)129         if (icol != Category)130         {131           double Gain = cateEntropy - AttributeInfo(icol, pnRows);132           if (max < Gain)133           {134             max = Gain;135             maxAttr = icol;136           }137         }138       return maxAttr;139     }140     public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)141     {142       var tuples = from n in GetAttribute(Data, col, pnRows)143              group n by n into i144             select Tuple.Create(i.First(), i.Count());145       return tuples;146     }147   }148 }

调用方法如下:

 1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 using System.Text; 5 using System.Threading.Tasks; 6 using MachineLearning.DecisionTree; 7 namespace MachineLearning 8 { 9   class Program10   {11     static void Main(string[] args)12     {13       var data = new string[,]14       {15         {"youth","high","no","fair","no"},16         {"youth","high","no","excellent","no"},17         {"middle_aged","high","no","fair","yes"},18         {"senior","medium","no","fair","yes"},19         {"senior","low","yes","fair","yes"},20         {"senior","low","yes","excellent","no"},21         {"middle_aged","low","yes","excellent","yes"},22         {"youth","medium","no","fair","no"},23         {"youth","low","yes","fair","yes"},24         {"senior","medium","yes","fair","yes"},25         {"youth","medium","yes","excellent","yes"},26         {"middle_aged","medium","no","excellent","yes"},27         {"middle_aged","high","yes","fair","yes"},28         {"senior","medium","no","excellent","no"}29       };30       var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };31       var tree = new DecisionTreeID3<string>(data, names, new string[] { "yes", "no" });32       tree.Learn();33       Console.ReadKey();34     }35   }36 }

 

运行结果:


 

注:作者本人也在学习中,能力有限,如有错漏还请不吝指正。转载请注明作者。