Belief networks

In this project, you will implement a Bayesian belief network, and an MCMC technique that uses Gibbs sampling to infer the parameters of the network.

  1. Skim Section 3.4 of the book.


  2. Implement a belief network with support for categorical distributions. Here are some optional implementation suggestions:
    • A graph is made up of nodes. Here is a little Java code for some basic Node structures that you may want to use to get started:
      class Node
      {
      	boolean observed;
      	double observedValue; // only valid when "observed" is true
      	ArrayList<Categorical> categoricalParents;
      	ArrayList<Node> children;
      
      	// General-purpose constructor
      	Node()
      	{
      		categoricalParents = new ArrayList<Categorical>();
      		children = new ArrayList<Node>();
      	}
      
      	// Computes the conditional likelihood of this node
      	// given the current values of all parent nodes.
      	abstract double likelihood(double x);
      
      	// Draws a new sample value for this node, given
      	// the current values of all nodes within its Markov blanket.
      	// (For categorical nodes, this implements Gibbs Sampling.
      	// For continuous nodes, it uses Metropolis.)
      	abstract void resample();
      
      	// If this node is observed, returns the observed value.
      	// Otherwise, returns the value obtained from the most recent
      	// call to "resample".
      	abstract double getCurrentValue();
      }
      
      
      /// A simple node with a value that is always observed.
      class Constant extends Node
      {
      	Constant(double val)
      	{
      		observed = true;
      		observedValue = val;
      	}
      
      	double likelihood(double x)
      	{
      		return (x == observedValue ? 1.0 : 0.0);
      	}
      
      	void resample()
      	{
      	}
      
      	double getCurrentValue()
      	{
      		return observedValue;
      	}
      }
      
      
      /// A node that supports a finite number of categorical values.
      class Categorical extends Node
      {
      	// todo: write me
      	//
      	// You will need a two-dimensional array of Node references to
      	// store the probabilities of this distribution.
      	// The number of columns should equal the number of categories
      	// in this distribution. The number of rows should equal the
      	// number of permutations of parent nodes.
      	// In other words, each row is a separate case of a categorical
      	// distribution.
      	// (For example, if there are three categorical parent nodes
      	// that respectively support 2, 3, and 5 values, then you will
      	// need 30 rows. If there are no categorical parent nodes,
      	// then you will need only one row.)
      }
      


    • Implement Gibbs sampling for categorical nodes. Pseudo-code for this step is given in the book. Implement the simple test in Section 3.4.3.1 to make sure your Gibbs sampler works as expected.


    • Add a class for continuous distributions that inherits from Node. Implement its sample method with the Metropolis technique. Classes that inherit from this one should implement the "likelihood" method, so you should leave that method unimplemented in this class.


    • Implement nodes for the Normal distribution, and Inverse-gamma distribution that inherit from your continuous distribution node. Implement the likelihood functions for these nodes by looking up their probability density functions on Wikipedia. (If you want your belief network to be really useful, you can implement support for several common distributions, but these two will be sufficient for this project.)

      The PDF of the inverse gamma distribution uses the gamma function. I had a little trouble finding an implementation of the gamma function on Windows. So, in case you need it, here is a cross-platform approximation of the gamma function:
      double gamma_function(double x)
      {
      #ifdef WINDOWS
      	int i, k, m;
      	double ga, gr, z;
      	double r = 0;
      
      	static double g[] =
      	{
              1.0,
              0.5772156649015329,
             -0.6558780715202538,
             -0.420026350340952e-1,
              0.1665386113822915,
             -0.421977345555443e-1,
             -0.9621971527877e-2,
              0.7218943246663e-2,
             -0.11651675918591e-2,
             -0.2152416741149e-3,
              0.1280502823882e-3,
             -0.201348547807e-4,
             -0.12504934821e-5,
              0.1133027232e-5,
             -0.2056338417e-6,
              0.6116095e-8,
              0.50020075e-8,
             -0.11812746e-8,
              0.1043427e-9,
              0.77823e-11,
             -0.36968e-11,
              0.51e-12,
             -0.206e-13,
             -0.54e-14,
              0.14e-14
      	};
      
      	if(x > 171.0)
      		return 1e308; // This value is an overflow flag.
      	if(x == (int)x)
      	{
      		if(x > 0.0)
      		{
      			ga = 1.0; // use factorial
      			for (i = 2; i < x; i++)
      				ga *= i;
      		}
      		else
      			ga = 1e308;
      	}
      	else
      	{
      		if(fabs(x) > 1.0)
      		{
      			z = fabs(x);
      			m = (int)z;
      			r = 1.0;
      			for (k = 1; k <= m; k++)
      				r *= (z - k);
      			z -= m;
      		}
      		else
      			z = x;
      		gr = g[24];
      		for (k = 23; k >= 0; k--)
      			gr = gr * z + g[k];
      		ga = 1.0 / (gr*z);
      		if(fabs(x) > 1.0)
      		{
      			ga *= r;
      			if (x < 0.0)
      				ga = -M_PI / (x * ga * sin(M_PI * x));
      		}
      	}
      	return ga;
      #else
      	return tgamma(x);
      #endif
      }
      
      and here's some Java code to do it:
      public class Gamma {
         static double logGamma(double x) {
            double tmp = (x - 0.5) * Math.log(x + 4.5) - (x + 4.5);
            double ser = 1.0 + 76.18009173    / (x + 0)   - 86.50532033    / (x + 1)
                             + 24.01409822    / (x + 2)   -  1.231739516   / (x + 3)
                             +  0.00120858003 / (x + 4)   -  0.00000536382 / (x + 5);
            return tmp + Math.log(ser * Math.sqrt(2 * Math.PI));
         }
      
         static double gamma(double x) { return Math.exp(logGamma(x)); }
      
         public static void main(String[] args) {
            double x = Double.parseDouble(args[0]);
            StdOut.println("Gamma(" + x + ") = " + gamma(x));
            StdOut.println("log Gamma(" + x + ") = " + logGamma(x));
         }
      
      }


    • Download this Golf dataset. It reports the number of swings used by 604 professional golfers at 42 tournaments. Your task is to rank the golfers according to their demonstrated abilities (gretest ability first). Naive Bayes is not sufficient for this problem. The reason this problem is interesting is because not all tournaments are equally difficult, and not all golfers participated in every tournament. Hence, simply counting the total number of swings does not tell you much about the golfer's ability. Here is a belief network that someone who knows more about golf than I do designed for this purpose:

      It explains the number of swings by each golfer as being the sum of tournament difficulty and golfer error. The magic-number priors for the Normal distributions are derived from measurements of the data. I am not sure where the priors for the inverse Gamma distributions came from. Presumably, someone who understands the domain used his or her intuition to come up with those values.

      Since this network is rather large, it will take more iterations to get good results. I used 50000 burn-in iterations, followed by 50000 samples. (When debugging, you can get reasonable approximate rankings with far fewer iterations.) I collected all the samples for the 604 "Golfer error" nodes, then computed the median error for each golfer from these samples. (50000 samples of 604 double-precision values will occupy about 0.25 GB of RAM--pfft, my phone could easily handle that.)

      (In C++, you can use the "nth_element" function to compute a median in O(n) time. The Java class library does not provide an equivalent function by default, so here's one I made:
      class Main
      {
      	static void nth_element_helper2(double[] arr, int beg, int end)
      	{
      		for(int i = beg + 1; i < end; i++)
      		{
      			for(int j = i; j > beg; j--)
      			{
      				if(arr[j - 1] < arr[j])
      					break;
      				double t = arr[j];
      				arr[j] = arr[j - 1];
      				arr[j - 1] = t;
      			}
      		}
      	}
      
      	static void nth_element_helper(double[] arr, int beg, int end,
      	int index)
      	{
      		if(beg + 4 >= end)
      		{
      			nth_element_helper2(arr, beg, end);
      			return;
      		}
      		int initial_beg = beg;
      		int initial_end = end;
      
      		// Pick a pivot (using the median of 3 technique)
      		double pivA = arr[beg];
      		double pivB = arr[(beg + end) / 2];
      		double pivC = arr[end - 1];
      		double pivot;
      		if(pivA < pivB)
      		{
      			if(pivB < pivC)
      				pivot = pivB;
      			else if(pivA < pivC)
      				pivot = pivC;
      			else
      				pivot = pivA;
      		}
      		else
      		{
      			if(pivA < pivC)
      				pivot = pivA;
      			else if(pivB < pivC)
      				pivot = pivC;
      			else
      				pivot = pivB;
      		}
      
      		// Divide values about the pivot
      		while(true)
      		{
      			while(beg + 1 < end && arr[beg] < pivot)
      				beg++;
      			while(end > beg + 1 && arr[end - 1] > pivot)
      				end--;
      			if(beg + 1 >= end)
      				break;
      
      			// Swap values
      			double t = arr[beg];
      			arr[beg] = arr[end - 1];
      			arr[end - 1] = t;
      
      			beg++;
      			end--;
      		}
      		if(arr[beg] < pivot)
      			beg++;
      
      		// Recurse
      		if(beg == initial_beg || end == initial_end)
      			throw new RuntimeException("No progress. Bad pivot");
      		if(index < beg) // recurse on only one side
      			nth_element_helper(arr, initial_beg, beg, index);
      		else
      			nth_element_helper(arr, beg, initial_end, index);
      	}
      
      	static double nth_element(double[] arr, int index)
      	{
      		nth_element_helper(arr, 0, arr.length, index);
      		return arr[index];
      	}
      
      	public static void main(String[] args)
      	{
       		double[] arr = { 9, 7, 1, 5, 6, 4, 3, 2, 8, 0, 10 };
      		if(nth_element(arr, 5) == 5)
      			System.out.println("seems to work");
      		else
      			System.out.println("broken");
      	}
      }
      
      Basically, it works just like QuickSort, except you only recurse on the side of the pivot that contains the nth element, and you leave the other side unsorted. When amortized, this works out to be an O(n) algorithm, which means you can compute a median proportionally as fast as you can compute a mean.)

      I sorted the golfers by their median error (smallest first) and printed each name and median error. Here are the first 10 lines in my results:
      VijaySingh	-3.91185
      TigerWoods	-3.78926
      ErnieEls	-3.40822
      PhilMickelson	-3.3914
      StewartCink	-3.06153
      JayHaas	-2.9455
      SergioGarcia	-2.89813
      ScottVerplank	-2.83843
      RetiefGoosen	-2.78884
      PadraigHarrington	-2.77575
      ...
      
      Some Wikipedia searches will confirm that these guys are pretty-good golfers.


    • Make two text files: simple.txt should contain some output verifying that the simple categorical test works. golf.txt should contain your ranked ordering of all the golfers with their median error values. Zip these up with your source code and submit in the usual manner.