你的位置:首页 > Java教程

[Java教程]Coursera Algorithms Programming Assignment 5: Kd

题目地址:http://coursera.cs.princeton.edu/algs4/assignments/kdtree.html

分析:

Brute-force implementation. 蛮力实现的方法比较简单,就是逐个遍历每个point进行比较,实现下述API就可以了,没有什么难度。

 

 1 import java.util.ArrayList; 2 import java.util.TreeSet; 3 import edu.princeton.cs.algs4.Point2D; 4 import edu.princeton.cs.algs4.RectHV; 5 import edu.princeton.cs.algs4.StdDraw; 6 /** 7  * @author evasean www.cnblogs.com/evasean/ 8 */ 9 public class PointSET {10   private TreeSet<Point2D> points;11   public PointSET() {12     // construct an empty set of points13     points = new TreeSet<Point2D>();14   }15 16   public boolean isEmpty() {17     // is the set empty?18     return points.isEmpty();19   }20 21   public int size() {22     // number of points in the set23     return points.size();24   }25 26   public void insert(Point2D p) {27     // add the point to the set (if it is not already in the set)28     if(p==null)29       throw new IllegalArgumentException("Point2D p is not illegal!");30     if(!points.contains(p))31       points.add(p);32   }33 34   public boolean contains(Point2D p) {35     // does the set contain point p?36     if(p==null)37       throw new IllegalArgumentException("Point2D p is not illegal!");38     return points.contains(p);39   }40 41   public void draw() {42     // draw all points to standard draw43     for (Point2D p : points) {44       p.draw();45     }46     StdDraw.show();47   }48 49   public Iterable<Point2D> range(RectHV rect) {50     // all points that are inside the rectangle (or on the boundary)51     if(rect==null)52       throw new IllegalArgumentException("RectHV rect is not illegal!");53     ArrayList<Point2D> list = new ArrayList<Point2D>();54     for(Point2D point : points){55       if(rect.contains(point)) list.add(point);56     }57     return list;58   }59 60   public Point2D nearest(Point2D p) {61     // a nearest neighbor in the set to point p; null if the set is empty62     if(p==null)63       throw new IllegalArgumentException("Point2D p is not illegal!");64     if(points.size() == 0) return null;65     double neareatDistance = Double.POSITIVE_INFINITY;66     Point2D nearest = null;67     for(Point2D point : points){68       double tmp = p.distanceTo(point);69       if(Double.compare(neareatDistance, tmp) == 1){70         neareatDistance = tmp;71         nearest = point;72       }73         74     }75     return nearest;76   }77 78   public static void main(String[] args) {79     // unit testing of the methods (optional)80   }81 }

2d-tree implementation.

kd-tree插入时要交替以x坐标和y坐标作为判断依据,比如root节点处比较依据为x坐标,那么当要查找或插入一个新节点point时,比较root节点的x坐标和point的x坐标,如果后者比前者小,那么下一次要比较的就是root->left, 相反下一次要比较的就是root->right。进入下一层级之后,就要使用y坐标作为比较依据。示例如下图:

 

区域搜索:查找落在给定矩阵区域范围内的所有points。从root开始递归查找,如果给定的矩阵不与当前节点的相关矩阵相交,那么就没有必要继续查找该节点及其子树了。

最近节点搜索:查找与给定point距离最近的节点。从root开始递归查找其左右子树,如果给定节点point和已经查找到的最近节点的距离比该point与当前遍历节点的相关矩阵距离还近,那么就没必要遍历这个当前节点及其子树了。

 1 import java.util.ArrayList; 2 import edu.princeton.cs.algs4.Point2D; 3 import edu.princeton.cs.algs4.RectHV; 4 import edu.princeton.cs.algs4.StdDraw; 5 /** 6  * @author evasean www.cnblogs.com/evasean/ 7 */ 8 public class KdTree { 9   private Node root; 10   private class Node { 11     private Point2D p; 12     /* 13      * 节点的value就是包含该节点的矩形空间 其左右子树的矩形空间,就是通过该节点进行水平切分或垂直切分的两个子空间 14     */ 15     private RectHV rect; 16     private Node left, right; 17     private int size; 18     private boolean xCoordinate; // 标识该节点是否是以x坐标垂直切分 19  20     public Node(Point2D point, RectHV rect, int size, boolean xCoordinate) { 21       this.p = point; 22       this.rect = rect; 23       this.size = size; 24       this.xCoordinate = xCoordinate; 25     } 26   } 27  28   public KdTree() { 29     // construct an empty set of points 30   } 31  32   public boolean isEmpty() { 33     // is the set empty? 34     return size() == 0; 35   } 36  37   public int size() { 38     // number of points in the set 39     return size(root); 40   } 41  42   private int size(Node x) { 43     if (x == null) 44       return 0; 45     else 46       return x.size; 47   } 48  49   public void insert(Point2D p) { 50     // add the point to the set (if it is not already in the set) 51     if (p == null) 52       throw new IllegalArgumentException("Point2D p is not illegal!"); 53     if (root == null) 54       root = new Node(p, new RectHV(0.0, 0.0, 1.0, 1.0), 1, true); 55     else 56       insert(root, p); 57     // System.out.println("size="+root.size); 58   } 59  60   private void insert(Node x, Point2D p) { 61     if (x.xCoordinate == true) { // x的切分标志是x坐标 62       int cmp = Double.compare(p.x(), x.p.x()); 63       if (cmp == -1) { 64         if (x.left != null) 65           insert(x.left, p); 66         else { 67           RectHV parent = x.rect; 68           // 将节点x的矩形空间进行垂直切分后的左侧部分 69           double newXmin = parent.xmin(); 70           double newYmin = parent.ymin(); 71           double newXmax = x.p.x(); 72           double newYmax = parent.ymax(); 73           x.left = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, false); 74         } 75       } else if (cmp == 1) { 76         if (x.right != null) 77           insert(x.right, p); 78         else { 79           RectHV parent = x.rect; 80           // 将节点x的矩形空间进行垂直切分后的右侧部分 81           double newXmin = x.p.x(); 82           double newYmin = parent.ymin(); 83           double newXmax = parent.xmax(); 84           double newYmax = parent.ymax(); 85           x.right = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, false); 86         } 87       } else { // x.key.x() 与 p.x() 相等 88         int cmp2 = Double.compare(p.y(), x.p.y()); 89         if (cmp2 == -1) { 90           if (x.left != null) 91             insert(x.left, p); 92           else { 93             x.left = new Node(p, x.rect, 1, false); 94           } 95         } else if (cmp2 == 1) { 96           if (x.right != null) 97             insert(x.right, p); 98           else { 99             x.right = new Node(p, x.rect, 1, false);100           }101         }102       }103     } else { // x的切分标志是y坐标104       int cmp = Double.compare(p.y(), x.p.y());105       if (cmp == -1) {106         if (x.left != null)107           insert(x.left, p);108         else {109           RectHV parent = x.rect;110           // 将节点x的矩形空间进行垂直切分后的左侧部分111           double newXmin = parent.xmin();112           double newYmin = parent.ymin();113           double newXmax = parent.xmax();114           double newYmax = x.p.y();115           x.left = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, true);116         }117       } else if (cmp == 1) {118         if (x.right != null)119           insert(x.right, p);120         else {121           RectHV parent = x.rect;122           // 将节点x的矩形空间进行垂直切分后的左侧部分123           double newXmin = parent.xmin();124           double newYmin = x.p.y();125           double newXmax = parent.xmax();126           double newYmax = parent.ymax();127           x.right = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, true);128         }129       } else { // x.key.y() 与 p.y()相等130         int cmp2 = Double.compare(p.x(), x.p.x());131         if (cmp2 == -1) {132           if (x.left != null)133             insert(x.left, p);134           else {135             x.left = new Node(p, x.rect, 1, true);136           }137         } else if (cmp2 == 1) {138           if (x.right != null)139             insert(x.right, p);140           else {141             x.right = new Node(p, x.rect, 1, true);142           }143         }144       }145     }146     x.size = 1 + size(x.left) + size(x.right);147   }148 149   public boolean contains(Point2D p) {150     // does the set contain point p?151     if (p == null)152       throw new IllegalArgumentException("Point2D p is not illegal!");153     return contains(root, p);154   }155 156   private boolean contains(Node x, Point2D p) {157     if(x == null ) return false;158     if (x.p.equals(p))159       return true;160     else {161       if(x.xCoordinate == true){162         int cmp = Double.compare(p.x(), x.p.x());163         if(cmp == -1) return contains(x.left,p);164         else if(cmp == 1 ) return contains(x.right,p);165         else{166           int cmp2 = Double.compare(p.y(), x.p.y());167           if(cmp2 == -1) return contains(x.left,p);168           else if(cmp2 == 1 ) return contains(x.right,p);169           else return true;170         }171       }else{172         int cmp = Double.compare(p.y(), x.p.y());173         if(cmp == -1) return contains(x.left,p);174         else if(cmp == 1 ) return contains(x.right,p);175         else{176           int cmp2 = Double.compare(p.x(), x.p.x());177           if(cmp2 == -1) return contains(x.left,p);178           else if(cmp2 == 1 ) return contains(x.right,p);179           else return true;180         }  181       }182     }183   }184 185   public void draw() {186     // draw all points to standard draw187     StdDraw.setXscale(0, 1);188     StdDraw.setYscale(0, 1);189     draw(root);190   }191 192   private void draw(Node x) {193     if (x == null)194       return;195     StdDraw.setPenColor(StdDraw.BLACK);196     StdDraw.setPenRadius(0.01);197     x.p.draw();198     if (x.xCoordinate == true) {199       StdDraw.setPenColor(StdDraw.RED);200       StdDraw.setPenRadius();201       Point2D start = new Point2D(x.p.x(), x.rect.ymin());202       Point2D end = new Point2D(x.p.x(), x.rect.ymax());203       start.drawTo(end);204     } else {205       StdDraw.setPenColor(StdDraw.BLUE);206       StdDraw.setPenRadius();207       Point2D start = new Point2D(x.rect.xmin(), x.p.y());208       Point2D end = new Point2D(x.rect.xmax(), x.p.y());209       start.drawTo(end);210     }211     draw(x.left);212     draw(x.right);213   }214 215   public Iterable<Point2D> range(RectHV rect) {216     // all points that are inside the rectangle (or on the boundary)217     if (rect == null)218       throw new IllegalArgumentException("RectHV rect is not illegal!");219     if (root != null)220       return range(root, rect);221     else222       return new ArrayList<Point2D>();223   }224 225   private ArrayList<Point2D> range(Node x, RectHV rect) {226     ArrayList<Point2D> points = new ArrayList<Point2D>();227     if (x.rect.intersects(rect)) {228       if (rect.contains(x.p))229         points.add(x.p);230       if (x.left != null)231         points.addAll(range(x.left, rect));232       if (x.right != null)233         points.addAll(range(x.right, rect));234     }235     return points;236   }237 238   public Point2D nearest(Point2D p) {239     // a nearest neighbor in the set to point p; null if the set is empty240     if (p == null)241       throw new IllegalArgumentException("Point2D p is not illegal!");242     if (root != null)243       return nearest(root, p, root.p);244     return null;245   }246 247   /**248    * 作业提交提示nearest的时间复杂度偏高,导致作业只有98分,我觉得这样写比较清晰明了,就懒得继续优化249    * @param x250    * @param p251    * @param currNearPoint252    * @return253   */254   private Point2D nearest(Node x, Point2D p, Point2D currNearPoint) {255     if(x.p.equals(p)) return x.p;256     double currMinDistance = currNearPoint.distanceTo(p);257     if (Double.compare(x.rect.distanceTo(p), currMinDistance) >= 0)258       return currNearPoint;259     else {260       double distance = x.p.distanceTo(p);261       if (Double.compare(x.p.distanceTo(p), currMinDistance) == -1) {262         currNearPoint = x.p;263         currMinDistance = distance;264       }265       if (x.left != null)266         currNearPoint = nearest(x.left, p, currNearPoint);267       if (x.right != null)268         currNearPoint = nearest(x.right, p, currNearPoint);269     }270     return currNearPoint;271   }272 273   public static void main(String[] args) {274     // unit testing of the methods (optional)275   }276 }