ML Basics: Linear Regression

This is part of the Machine Learning Basics series.

Based on a Twitter Poll I’ll be doing a series on the basics of Machine Learning over the next few months. I plan to have a post out on the first Thursday of every month covering some algorithm or conceptual aspect of Machine Learning.

Machine Learning and Statistics have a lot in common. Both are interested in summarizing and finding patterns in data. To understand most Machine Learning textbooks it is necessary to have a solid grounding in stats. Sadly, most of the folks I know who were educated in the US learned calculus instead of statistics. Because of this I’m going to start at the ground level where Machine Learning and Statistics overlap and move forward from there. And if all my talk of math is freaking you out a bit my goal is that this series will not require anything beyond high school Algebra 1.

Descriptive Statistics Library

Several of the algorithms I hope to show use mean, standard deviation, and other basic descriptive statistics. Statisticians define some of these slightly differently than we learned in school. To simplify things I use the (descriptive_statistics)[] gem in Ruby. It extends Enumerable with the common descriptive statistics methods so they are readily available and easy to use. All of my Ruby examples will use this library. In Python I’ll use the scipy library.

Linear Regression

One of the basic things that Machine Learning tries to do is anwser the question “What is the relationship between these pieces of data?” or perhaps more exactly “If I know part of the data can I figure out the rest?”. And the most basic technique for doing this is linear regression.

Regression is finding a line (not necessarily a straight line) that matches up to, or fits, some data relatively well. Once you have that line and an equation for it you can use algebra to to predict one value based on the other(s). This technique is frequently used in math, the sciences, economics, psychology, and shows up in newspapers and academic journals. I used one of the BigQuery public datasets to retrieve the number of babies born each year between 1969 and 2005 in the US. Here is the data in a basic plot.

Births (in millions) by year

The chart above shows a linear regression. That means that I’ve assumed the relationship between the x and y can be modeled with a straight line. Lines are represented with equations of the form y = Mx + B where M and B are numbers. The equation above shows TODO.

Least Squares

Least squares is a straight forward way to calculate

Multiple Linear Regression

Linear regression isn’t limited to two variables. While it is harder to visualize the exact same techniques work for