All Articles

Color palette extraction with K-means clustering | Machine Learning from Scratch (Part IV)

TL;DR Build K-Means clustering model using Python from Scratch. Use your model to find dominant colors from UI mobile design screenshots.

Choosing a color palette for your next big mobile app (re)design can be a daunting task, especially when you don’t know what the heck you’re doing. How can you make it easier (asking for a friend)?

One approach is to head over to a place where the PROs share their work. Pages like Dribbble, uplabs and Behance have the goods.

After finding mockups you like, you might want to extract colors from them and use those. This might require opening specialized software, manually picking color with some tool(s) and other over-the-counter hacks. Let’s use Machine Learning to make your life easier.

Complete source code notebook on Google Colaboratory

Unsupervised Learning

Up until now we only looked at models that require training data in the form of features and labels. In other words, for each example, we need to have the correct answer, too.

Usually, such training data is hard to obtain and requires many hours of manual work done by humans (yes, we’re already serving “The Terminators). Can we skip all that?

Yes, at least for some problems we can use example data without knowing the correct answer for it. One such problem is clustering.

What is clustering?

Given some vector of data points XX, clustering methods allow you to put each point in a group. In other words, you can categorize a set of entities, based on their properties, automatically.

Yes, that is very useful in practice. Usually, you run the algorithm on a bunch of data points and specify how much groups you want. For example, your inbox contains two main groups of e-mail: spam and not-spam (were you waiting for something else?). You can let the clustering algorithm create two groups from your emails and use your beautiful brain to classify which is which.

More applications of clustering algorithms:

  • Customer segmentation - find groups of users that spend/behave the same way
  • Fraudulent transactions - find bank transactions that belong to different clusters and identify them as fraudulent
  • Document analysis - group similar documents

The Data

source: various authors on https://www.uplabs.com/

This time, our data doesn’t come from some predefined or well-known dataset. Since Unsupervised learning does not require labeled data, the Internet can be your oyster.

Here, we’ll use 3 mobile UI designs from various authors. Our model will run on each shot and try to extract the color palette for each.

What is K-Means Clustering?

K-Means Clustering is defined by Wikipedia as:

k-means clustering is a method of vector quantization, originally from signal processing, that is popular for cluster analysis in data mining. k-means clustering aims to partition n observations into k clusters in which each observation belongs to the cluster with the nearest mean, serving as a prototype of the cluster. This results in a partitioning of the data space into Voronoi cells.

Wikipedia also tells us that solving K-Means clustering is difficult (in fact, NP-hard) but we can find local optimum solutions using some heuristics.

But how do K-Means Clustering works?

Let’s say you have some vector XX that contains nn data points. Running our algorithm consists of the following steps:

  1. Take random kk points (called centroids) from XX
  2. Assign every point to the closest centroid. The newly formed bunch of points is called cluster.
  3. For each cluster, find new centroid by calculating a new center from the points
  4. Repeat steps 2-3 until centroids stop changing

Let’s see how can we use it to extract color palettes from mobile UI screenshots.

Data Preprocessing

Given our data is stored in raw pixels (called images), we need a way to convert it to points that our clustering algorithm can use.

Let’s first define two classes that represent a point and cluster:

class Point:

  def __init__(self, coordinates):
    self.coordinates = coordinates

Our Point is just a holder to the coordinates for each dimension in our space.

class Cluster:

  def __init__(self, center, points):
    self.center = center
    self.points = points

The Cluster is defined by its center and all other points it contains.

Given a path to image file we can create the points as follows:

def get_points(image_path):
  img = Image.open(image_path)
  img.thumbnail((200, 400))
  img = img.convert("RGB")
  w, h = img.size

  points = []
  for count, color in img.getcolors(w * h):
    for _ in range(count):
      points.append(Point(color))

  return points

A couple of things are happening here:

  • load the image into memory
  • resize it to smaller image (mobile UX requires big elements on the screen, so we aren’t losing much color information)
  • drop the alpha (transparency) information

Note that we’re creating a Point for each pixel in our image.

Alright! You can now extract points from an image. But how can we calculate the distance between points in our clusters?

Distance function

Similar to the cost function in our supervised algorithm examples, we need a function that tells us how well we’re doing. The objective of our algorithm is to minimize the distance between the points in each centroid.

One of the simplest distance functions we can use is the Euclidean distance, defined by:

d(p,q)=i=1n(qipi)2d(p, q) = \sqrt{\sum_{i=1}^n{(q_i - p_i)^2}}

where pp and qq are two points from our space.

Note that while Euclidean distance is simple to implement it might not be the best way to calculate the color difference.

Here is the Python implementation:

def euclidean(p, q):
  n_dim = len(p.coordinates)
  return sqrt(sum([
      (p.coordinates[i] - q.coordinates[i]) ** 2 for i in range(n_dim)
  ]))

Implementing K-Means clustering

Now that you have all pieces of the puzzle you can implement the K-Means clustering algorithm. Let’s start with the method that finds the center for a set of points:

def calculate_center(self, points):
  n_dim = len(points[0].coordinates)
  vals = [0.0 for i in range(n_dim)]
  for p in points:
    for i in range(n_dim):
      vals[i] += p.coordinates[i]
  coords = [(v / len(points)) for v in vals]
  return Point(coords)

To find the center of a set of points, we add the values for each dimension and divide by the number of points.

Now for finding the actual clusters:

def assign_points(self, clusters, points):
  plists = [[] for i in range(self.n_clusters)]

  for p in points:
    smallest_distance = float('inf')

    for i in range(self.n_clusters):
      distance = euclidean(p, clusters[i].center)
      if distance < smallest_distance:
        smallest_distance = distance
        idx = i

    plists[idx].append(p)

  return plists

def fit(self, points):
  clusters = [Cluster(center=p, points=[p]) for p in random.sample(points, self.n_clusters)]

  while True:

    plists = self.assign_points(clusters, points)

    diff = 0

    for i in range(self.n_clusters):
      if not plists[i]:
        continue
      old = clusters[i]
      center = self.calculate_center(plists[i])
      new = Cluster(center, plists[i])
      clusters[i] = new
      diff = max(diff, euclidean(old.center, new.center))

    if diff < self.min_diff:
      break

  return clusters

The implementation follows the description of the algorithm given above. Note that we exit the training loop when the color difference is lower than the one set by us.

Evaluation

Now that you have an implementation of K-Means clustering you can use it on the UI screenshots. We need a little more glue code to make it easier to extract color palettes:

def rgb_to_hex(rgb):
  return '#%s' % ''.join(('%02x' % p for p in rgb))

def get_colors(filename, n_colors=3):
  points = get_points(filename)
  clusters = KMeans(n_clusters=n_colors).fit(points)
  clusters.sort(key=lambda c: len(c.points), reverse = True)
  rgbs = [map(int, c.center.coordinates) for c in clusters]
  return list(map(rgb_to_hex, rgbs))

The get_colors function takes a path to an image file and the number of colors you want to extract from the image. We sort the clusters obtained from our algorithm based on the points in each (in descending order). Finally, we convert the RGB colors to hexadecimal values.

Let’s extract the palette for the first UI screenshot from our data:

colors = get_colors('ui1.png', n_colors=5)
", ".join(colors)

and here are the hexadecimal color values:

#f9f8fa, #4252a6, #fc6693, #bdbcd0, #374698

And for a visual representation of our results:

Running the clustering on the next two images, we obtain the following:

The last screenshot:

Well, that looks cool, right? Go on try it on your own images!

Complete source code notebook on Google Colaboratory

Conclusion

Congratulations, you just implemented your first unsupervised algorithm from scratch! It also appears that it obtains good results when trying to extract a color palette from images.

Next up, we’re going to do sentiment analysis on short phrases and learn a bit more how we can handle text data.