Package Bio :: Module kNN
[hide private]
[frames] | no frames]

Source Code for Module Bio.kNN

  1  #!/usr/bin/env python 
  2   
  3  """ 
  4  This module provides code for doing k-nearest-neighbors classification. 
  5   
  6  k Nearest Neighbors is a supervised learning algorithm that classifies 
  7  a new observation based the classes in its surrounding neighborhood. 
  8   
  9  Glossary: 
 10  distance   The distance between two points in the feature space. 
 11  weight     The importance given to each point for classification.  
 12   
 13   
 14  Classes: 
 15  kNN           Holds information for a nearest neighbors classifier. 
 16   
 17   
 18  Functions: 
 19  train        Train a new kNN classifier. 
 20  calculate    Calculate the probabilities of each class, given an observation. 
 21  classify     Classify an observation into a class. 
 22   
 23      Weighting Functions: 
 24  equal_weight    Every example is given a weight of 1. 
 25   
 26  """ 
 27   
 28  #TODO - Remove this work around once we drop python 2.3 support 
 29  try: 
 30      set = set 
 31  except NameError: 
 32      from sets import Set as set 
 33   
 34  import numpy 
 35   
36 -class kNN:
37 """Holds information necessary to do nearest neighbors classification. 38 39 Members: 40 classes Set of the possible classes. 41 xs List of the neighbors. 42 ys List of the classes that the neighbors belong to. 43 k Number of neighbors to look at. 44 45 """
46 - def __init__(self):
47 """kNN()""" 48 self.classes = set() 49 self.xs = [] 50 self.ys = [] 51 self.k = None
52
53 -def equal_weight(x, y):
54 """equal_weight(x, y) -> 1""" 55 # everything gets 1 vote 56 return 1
57
58 -def train(xs, ys, k, typecode=None):
59 """train(xs, ys, k) -> kNN 60 61 Train a k nearest neighbors classifier on a training set. xs is a 62 list of observations and ys is a list of the class assignments. 63 Thus, xs and ys should contain the same number of elements. k is 64 the number of neighbors that should be examined when doing the 65 classification. 66 67 """ 68 knn = kNN() 69 knn.classes = set(ys) 70 knn.xs = numpy.asarray(xs, typecode) 71 knn.ys = ys 72 knn.k = k 73 return knn
74
75 -def calculate(knn, x, weight_fn=equal_weight, distance_fn=None):
76 """calculate(knn, x[, weight_fn][, distance_fn]) -> weight dict 77 78 Calculate the probability for each class. knn is a kNN object. x 79 is the observed data. weight_fn is an optional function that 80 takes x and a training example, and returns a weight. distance_fn 81 is an optional function that takes two points and returns the 82 distance between them. If distance_fn is None (the default), the 83 Euclidean distance is used. Returns a dictionary of the class to 84 the weight given to the class. 85 86 """ 87 x = numpy.asarray(x) 88 89 order = [] # list of (distance, index) 90 if distance_fn: 91 for i in range(len(knn.xs)): 92 dist = distance_fn(x, knn.xs[i]) 93 order.append((dist, i)) 94 else: 95 # Default: Use a fast implementation of the Euclidean distance 96 temp = numpy.zeros(len(x)) 97 # Predefining temp allows reuse of this array, making this 98 # function about twice as fast. 99 for i in range(len(knn.xs)): 100 temp[:] = x - knn.xs[i] 101 dist = numpy.sqrt(numpy.dot(temp,temp)) 102 order.append((dist, i)) 103 order.sort() 104 105 # first 'k' are the ones I want. 106 weights = {} # class -> number of votes 107 for k in knn.classes: 108 weights[k] = 0.0 109 for dist, i in order[:knn.k]: 110 klass = knn.ys[i] 111 weights[klass] = weights[klass] + weight_fn(x, knn.xs[i]) 112 113 return weights
114
115 -def classify(knn, x, weight_fn=equal_weight, distance_fn=None):
116 """classify(knn, x[, weight_fn][, distance_fn]) -> class 117 118 Classify an observation into a class. If not specified, weight_fn will 119 give all neighbors equal weight. distance_fn is an optional function 120 that takes two points and returns the distance between them. If 121 distance_fn is None (the default), the Euclidean distance is used. 122 """ 123 weights = calculate( 124 knn, x, weight_fn=weight_fn, distance_fn=distance_fn) 125 126 most_class = None 127 most_weight = None 128 for klass, weight in weights.items(): 129 if most_class is None or weight > most_weight: 130 most_class = klass 131 most_weight = weight 132 return most_class
133