Messing with T-distributed stochastic neighbor embedding

12 minute read

Recently, Kaggle launch the scripts project, which is a board of scripts released by competitioners with possibilites of evaluations by peers.

One of the script is a scatterplot obtained through t-distributed stochastic neighbor embedding which summarise the information of a huge data set.

The scatterplot was so self explaining that I wanted to explore that method.

Theory

The theory is not that complicated after you catch the idea behind. The original paper is this one.

t-SNE is mainly used to visualise huge data set into scatterplot. It reduces the dimensionality of data to 2 or 3 dimensions, allowing to do 2d or 3d plot.

t-SNE converts distances between datapoint in the original space into conditionnal probabilities \(p_{j|i}\).

\[p_{j|i} = \frac{\exp{(-d(\boldsymbol{x}_i, \boldsymbol{x}_j) / (2 \sigma_i^2)})}{\sum_{i \neq k} \exp{(-d(\boldsymbol{x}_i, \boldsymbol{x}_k) / (2 \sigma_i^2)})}, \quad p_{i|i} = 0,\]

\(\sigma_i\) is the variance of the gaussian which is centered on \(\boldsymbol{x}_i\). The perplexity parameter of the algorithm can influence this value.

If two points are close, \(p_{j|i}\) will be high. If two points are far, \(p_{j|i}\) will be low.

The conditionnal probabilities are used to define the joint probabilities: \[p_{ij} = \frac{p_{j|i} + p_{i|j}}{2N}.\]

The distances in the embedded space could be describe the same way:

\[q_{ij} = \frac{(1 + ||\boldsymbol{y}_i - \boldsymbol{y}_j)||^2)^{-1}}{\sum_{k \neq l} (1 + ||\boldsymbol{y}_k - \boldsymbol{y}_l)||^2)^{-1}},\]

The idea now is, for a good visualisation in the embedded space, \(q_{ij}\) and \(p_{ij}\) should be equal.

The Kullback-Leibler divergence is the measure used to calculate the mismatch between \(q_{ij}\) and \(p_{ij}\).

\[KL(P|Q) = \sum_{i \neq j} p_{ij} \log \frac{p_{ij}}{q_{ij}}\]

A gradient descent is used to minimise this mismatch.

The function

The package Rtsne have one function, Rtsne(). To note:

  • The function does not allow duplicates
  • The SNE result is fairly robust to change in perplexity. Classic values are between 5 and 50
  • By defaut, an initial pca is made to reduce the number of dimensions before to do the SNE.
  • There is no normalisation made. Consequently, a variable with huge value will appear well separated.
  • The algorithm accept only numeric variables.
    • Consequently, I personnaly like to divide my dummie variable by the number of modality.
    • When I want a specific variable to appear well organised, I increase artificially the value.

Exemple

Libraries used:

library(data.table) 
library(ggplot2) library(Rtsne)

Diamond data set

# data table format: diamonds.dt <- data.table(diamonds)  
# We transform ordinal variable into numeric one: diamonds.dt[, cut2 := as.numeric(cut)] 
diamonds.dt[, clarity2 := as.numeric(clarity)] diamonds.dt[, color2 := as.numeric(color)]  
# Normalization of each variable: 
diamonds.dt2 <- diamonds.dt[, list(lapply(.SD, function(x) (x - min(x) ) / (max(x) - min(x))), color, cut, clarity), .SDcols = c("carat", "cut2", "color2", "clarity2", "depth", "table", "price", "x", "y", "z")] 
 # deduplication: 
diamonds.dt3 <- diamonds.dt2[, list(count = .N), by = c("carat", "cut2", "color2", "clarity2", "depth", "table", "price", "x", "y", "z", "cut", "color", "clarity")] 
 
diamonds.dt4 <-  diamonds.dt3[, c("carat", "cut2", "color2", "clarity2", "depth", "table", "price", "x", "y", "z"), with = F] 
 # Embedding the data set: 
diamonds.2d <- Rtsne(diamonds.dt4, dims = 2, initial_dims = 50, perplexity = 30, max_iter = 600, verbose = T) 
save(diamonds.2d, file = "./DATA/diamonds_2d.rdata")  
diamonds.dt3[, x.rtsne := diamonds.2d$Y[, 1]] 
diamonds.dt3[, y.rtsne := diamonds.2d$Y[, 2]]
# colorless 
ggplot(data = diamonds.dt4, aes(x = x.rtsne, y = y.rtsne)) + 
  geom_point(color = "black") +   ggtitle("Raw plot")

# clarity 
ggplot(data = diamonds.dt4, aes(x = x.rtsne, y = y.rtsne, color = as.factor(clarity2)))+ 
  geom_point() +   theme_classic() +  ggtitle("Clarity")

# cut 
ggplot(data = diamonds.dt4, aes(x = x.rtsne, y = y.rtsne, color = as.factor(cut2))) + 
  geom_point() +   theme_classic() +  ggtitle("Cut")

# price 
ggplot(data = diamonds.dt4, aes(x = x.rtsne, y = y.rtsne, color = price)) + 
  geom_point() +   theme_classic() +  ggtitle("Price")