Bringing SVGs to Life with K-Means Clustering: A Deep Dive into Intelligent Graphic Design
🎨

Bringing SVGs to Life with K-Means Clustering: A Deep Dive into Intelligent Graphic Design

Tags
Python
Pyodide
Computer Graphics
SVG
Thought Experiments
Parent item
Sub-item
Author

Introduction

During my undergraduate studies, I took a course on computer graphics that introduced me to concepts like aliasing, anti-aliasing, interpolation, and extrapolation. These topics fascinated me and sparked a deep interest in the field of graphics.
This interest led me to explore the world of visualizations professionally. As I delved into various graphic formats, I developed a particular love for SVGs. SVGs offer unique advantages, including resolution independence and ease of manipulation with CSS and JavaScript.
Logo I got from DALLE in WEBP format
Logo I got from DALLE in WEBP format
notion image
One day, while trying to modify an image online, I converted it to SVG format. However, the result was disappointing: no colors, and all parts were broken into small, ungrouped paths. Frustrated with the lack of software to efficiently color and group these paths, I recalled my work with bounding boxes in Geometry. Initially, I thought about applying the same concept, but then I realized there was a more sophisticated approach: using KMeans clustering to intelligently group and color the SVG paths.
This approach would not only make the SVGs more visually appealing but also more organized and easier to manipulate. Thus began my journey into the fascinating intersection of machine learning and graphic design, combining my love for graphics with the power of artificial intelligence. And the results are too fascinating
notion image
notion image
And finally -
notion image
 
Broadly here are a few steps to accomplish that -
  1. Calculate Path Centers:
      • Determine a representative point for each SVG path using methods like bounding box center, centroid, or visual center.
  1. Apply KMeans Clustering:
      • Use KMeans clustering to group paths based on their calculated centers.
      • This involves initializing K centroids, assigning each path center to the nearest centroid, and updating the centroid positions iteratively until they stabilize.
  1. Determine Optimal Number of Clusters:
      • Use metrics like silhouette score or the elbow method to find the optimal number of clusters (K).
      • This ensures that the grouping is efficient and meaningful.
  1. Assign Colors to Clusters:
      • Assign distinct colors to each cluster using a color map from libraries like matplotlib.
      • This visually differentiates the groups in the SVG.
  1. Update the SVG:
      • Apply the calculated transformations and colors to the SVG paths.
      • This involves modifying the 'transform' attributes and updating the color properties of each path based on its cluster assignment.
 
The rest of the blog is more of a what I did, how I did, and why I did what I did. 🤷🏻‍♂️ 

Understanding SVGs and KMeans Clustering

What are SVGs?

SVG stands for Scalable Vector Graphics. Unlike raster images (like JPEGs or PNGs) which are made up of a fixed grid of pixels, SVGs are defined by mathematical equations. This means they can be scaled to any size without losing quality - a crucial feature in our multi-device world. Also turns out its very widely used on the web! More about it here!
Here's a simple example of an SVG path:
Idea IconTheTechCruise.com Pyodide Terminal
<path d="M10 10 H 90 V 90 H 10 L 10 10"/>
This code defines a square path starting at point (10,10), moving horizontally to (90,10), then vertically to (90,90), back horizontally to (10,90), and finally closing the path at (10,10).

Why Use SVGs?

  1. Resolution Independence: SVGs look crisp at any size, from tiny icons to massive billboards.
  1. Small File Size: SVGs are often smaller in file size compared to high-quality raster images.
  1. Accessibility: SVGs can include metadata and are readable by screen readers, improving web accessibility.
  1. Interactivity: SVGs can be manipulated with CSS and JavaScript, allowing for dynamic and interactive graphics.

The Challenge: Unorganized SVG Paths

 
When converting a complex raster image to SVG, particularly using automated tools, we often end up with hundreds or even thousands of individual paths. These paths are typically unorganized and lack any meaningful grouping or coloring. This presents several challenges:
  1. Difficult to Edit: Manipulating individual paths is time-consuming and impractical.
  1. Lack of Visual Coherence: Related parts of the image may have different colors or styles.
  1. Performance Issues: Rendering and manipulating numerous ungrouped paths can be computationally expensive.

KMeans Clustering: Bringing Order to Chaos

KMeans is an unsupervised machine learning algorithm that groups similar data points into clusters. In our case, we're using it to group SVG paths based on their spatial proximity.

How KMeans Works

  1. Initialization: The algorithm starts by randomly placing K centroids in the data space.
  1. Assignment: Each data point is assigned to the nearest centroid.
  1. Update: The centroids are moved to the average position of all points assigned to them.
  1. Repeat: Steps 2 and 3 are repeated until the centroids no longer move significantly.
In the context of our SVG paths, each "data point" is the center of a path, and we're grouping paths that are close to each other.

Why KMeans for SVG Path Grouping?

  1. Spatial Grouping: KMeans naturally groups paths that are close together, which often correspond to related parts of the image.
  1. Customizable Clustering: By adjusting the number of clusters (K), we can control the level of detail in our grouping.
  1. Efficiency: KMeans is relatively fast and can handle large numbers of paths effectively.

Implementation

Step 1: Calculate Path Centers

The first step is to calculate a representative point for each path. We have several options here:
  1. Bounding Box Center: Calculates the average of the minimum and maximum x and y coordinates of the path.
  1. Centroid: Uses the shoelace formula to calculate the geometric center of a closed path.
  1. Visual Center: Averages all points along the path.
Idea IconTheTechCruise.com Pyodide Terminal
def calculate_bounding_box_center(path):
    bbox = path.bbox()
    return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)

def calculate_centroid(path):
    points = path.segments()
    x = [p.start.real for p in points]
    y = [p.start.imag for p in points]
    area = 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
    cx = np.sum((x + np.roll(x, -1)) * (np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))) / (6.0 * area)
    cy = np.sum((y + np.roll(y, -1)) * (np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))) / (6.0 * area)
    return (cx, cy)

def calculate_visual_center(path):
    points = path.segments()
    x = [p.start.real for p in points]
    y = [p.start.imag for p in points]
    return (np.mean(x), np.mean(y))

Step 2: Apply KMeans Clustering

Once we have our path centers, we can apply KMeans clustering:
Idea IconTheTechCruise.com Pyodide Terminal
from sklearn.cluster import KMeans

def cluster_paths(centers, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    labels = kmeans.fit_predict(centers)
    return labels

Step 3: Determine Optimal Number of Clusters

To find the optimal number of clusters, we can use metrics like the silhouette score or the elbow method:
Idea IconTheTechCruise.com Pyodide Terminal
from sklearn.metrics import silhouette_score

def find_optimal_clusters(centers, max_clusters):
    silhouette_scores = []
    for k in range(2, max_clusters + 1):
        labels = cluster_paths(centers, k)
        score = silhouette_score(centers, labels)
        silhouette_scores.append(score)
    return silhouette_scores.index(max(silhouette_scores)) + 2

Step 4: Assign Colors to Clusters

With our paths grouped, we can now assign colors to each cluster. We'll use a color map from matplotlib to ensure visually distinct colors:
Idea IconTheTechCruise.com Pyodide Terminal
import matplotlib.pyplot as plt

def assign_colors(labels, n_clusters):
    cmap = plt.get_cmap('tab20')
    colors = [matplotlib.colors.rgb2hex(cmap(i / n_clusters)) for i in range(n_clusters)]
    return [colors[label] for label in labels]

Step 5: Update the SVG

Finally, we update our SVG with the new groupings and colors:
Idea IconTheTechCruise.com Pyodide Terminal
def update_svg(svg, labels, colors):
    paths = svg.find_all('path')
    for path, label, color in zip(paths, labels, colors):
        path['fill'] = color
        path['stroke'] = 'black'
        path['stroke-width'] = '0.5'
    return svg

Results and Demo

The Process: From Unorganized Paths to Beautifully Grouped SVGs

Let's walk through the step-by-step process of using KMeans to group and color our SVG paths.

Step 1: Calculate Path Centers

The first step is to calculate a representative point for each path. The bounding box center is simple and works well for most paths. The centroid is mathematically accurate for closed paths but can be outside the path for complex shapes. The visual center often gives a good approximation of where the "center of mass" of the path appears to be.

Step 2: Apply KMeans Clustering

Once we have our path centers, we can apply KMeans clustering:
Idea IconTheTechCruise.com Pyodide Terminal
from sklearn.cluster import KMeans

def cluster_paths(centers, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    labels = kmeans.fit_predict(centers)
    return labels

Step 3: Determine Optimal Number of Clusters

To find the optimal number of clusters, we can use metrics like the silhouette score or the elbow method:
Idea IconTheTechCruise.com Pyodide Terminal
from sklearn.metrics import silhouette_score

def find_optimal_clusters(centers, max_clusters):
    silhouette_scores = []
    for k in range(2, max_clusters + 1):
        labels = cluster_paths(centers, k)
        score = silhouette_score(centers, labels)
        silhouette_scores.append(score)
    return silhouette_scores.index(max(silhouette_scores)) + 2

Step 4: Assign Colors to Clusters

With our paths grouped, we can now assign colors to each cluster. We'll use a color map from matplotlib to ensure visually distinct colors:
Idea IconTheTechCruise.com Pyodide Terminal
import matplotlib.pyplot as plt

def assign_colors(labels, n_clusters):
    cmap = plt.get_cmap('tab20')
    colors = [matplotlib.colors.rgb2hex(cmap(i / n_clusters)) for i in range(n_clusters)]
    return [colors[label] for label in labels]

Step 5: Update the SVG

Finally, we update our SVG with the new groupings and colors:
Idea IconTheTechCruise.com Pyodide Terminal
def update_svg(svg, labels, colors):
    paths = svg.find_all('path')
    for path, label, color in zip(paths, labels, colors):
        path['fill'] = color
        path['stroke'] = 'black'
        path

['stroke-width'] = '0.5'
    return svg

Bringing It All Together with Pyodide

To make this tool accessible and easy to use, I implemented it as a web application using Pyodide. Pyodide allows us to run Python code directly in the browser, eliminating the need for server-side processing.
Here's a simplified version of the main processing function:
Idea IconTheTechCruise.com Pyodide Terminal
 def process_svg(svg_content):
    # Parse SVG
    soup = BeautifulSoup(svg_content, 'xml')
    paths = soup.find_all('path')

    # Calculate centers
    centers = [calculate_center(path) for path in paths]

    # Find optimal number of clusters
    optimal_k = find_optimal_clusters(centers, max_clusters=20)

    # Cluster paths
    labels = cluster_paths(centers, optimal_k)

    # Assign colors
    colors = assign_colors(labels, optimal_k)

    # Update SVG
    updated_svg = update_svg(soup, labels, colors)

    return updated_svg.prettify()

The Results: Before and After

The results of this process can be quite striking. Let's look at a before and after example:
Before Image: Complex SVG with many ungrouped paths
notion image
After Image: Same SVG with paths grouped and colored
notion image
In the "before" image, we see a complex SVG with numerous paths, all in the same color. It's difficult to discern the structure of the image or edit specific parts.
In the "after" image, we see the same SVG, but now the paths are grouped and colored based on their spatial relationships. This makes the structure of the image immediately apparent and makes editing much more manageable.

Conclusion

The application of KMeans clustering to SVG processing has proven to be an interesting and practical approach to organizing and coloring complex vector graphics. This project demonstrates how we can take an unstructured SVG and transform it into a more visually coherent and manageable graphic using machine learning techniques.
While this approach has its limitations and may not be suitable for all use cases, it provides an interesting example of how machine learning concepts can be applied to graphic design tasks. The combination of KMeans clustering with SVG processing opens up new possibilities for handling complex vector graphics.
For those interested in the intersection of data science and design, this project offers a practical example of how these fields can complement each other. Whether you're a designer looking for new ways to handle complex SVGs, a developer interested in optimization techniques, or simply curious about novel applications of machine learning, this project might provide some useful insights.
I enjoyed working on this project and exploring the potential of applying clustering techniques to SVG processing. It's been a fun learning experience, and I hope that sharing it might inspire others to experiment with similar ideas in their own work. As we continue to explore the possibilities of applying machine learning to graphic design, who knows what amazing tools and techniques we'll discover? The canvas is blank, and the possibilities are endless. Happy clustering!
Python snippet
Idea IconTheTechCruise.com Pyodide Terminal
import argparse
import time
import pandas as pd
import matplotlib.colors
import xml.etree.ElementTree as ET
import random
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, davies_bouldin_score
import matplotlib.pyplot as plt
from svgpathtools import svg2paths, wsvg, Line, CubicBezier, QuadraticBezier


class SVGProcessor:
    """Class to process SVG files and calculate centers of paths."""
    def __init__(self, config):
        self.config = config

    def calculate_center(self, path):
        """Calculate the center of a path based on the specified method."""
        center_type = self.config['center_type']
        if center_type == 'bounding_box':
            return self.calculate_bounding_box_center(path)
        elif center_type == 'centroid':
            return self.calculate_polygon_centroid(path)
        elif center_type == 'visual':
            return self.calculate_visual_center(path)
        else:
            return self.calculate_path_center(path)

    def calculate_bounding_box_center(self, path):
        """Calculate the center of the bounding box of a path."""
        all_points = [
            seg.start for seg in path] + [
            seg.end for seg in path if isinstance(
                seg, (Line, CubicBezier, QuadraticBezier))]
        x_coords = [p.real for p in all_points]
        y_coords = [p.imag for p in all_points]
        min_x, max_x = min(x_coords), max(x_coords)
        min_y, max_y = min(y_coords), max(y_coords)
        center_x = (min_x + max_x) / 2
        center_y = (min_y + max_y) / 2
        return center_x, center_y

    def calculate_polygon_centroid(self, path):
        """Calculate the centroid of a polygon using the shoelace formula."""
        points = [seg.start for seg in path]
        x = [p.real for p in points]
        y = [p.imag for p in points]
        area = 0.5 * np.abs(np.dot(x, np.roll(y, 1)) -
                            np.dot(y, np.roll(x, 1)))
        cx = np.sum((x + np.roll(x, -1)) * (np.dot(x, np.roll(y, 1)
                                                   ) - np.dot(y, np.roll(x, 1)))) / (6.0 * area)
        cy = np.sum((y + np.roll(y, -1)) * (np.dot(x, np.roll(y, 1)
                                                   ) - np.dot(y, np.roll(x, 1)))) / (6.0 * area)
        return cx, cy

    def calculate_visual_center(self, path):
        """Calculate the visual center of a path based on the average of all points."""
        all_points = [
            seg.start for seg in path] + [
            seg.end for seg in path if isinstance(
                seg, (Line, CubicBezier, QuadraticBezier))]
        x_coords = [p.real for p in all_points]
        y_coords = [p.imag for p in all_points]
        center_x = np.mean(x_coords)
        center_y = np.mean(y_coords)
        return center_x, center_y

    def calculate_path_center(self, path):
        """Calculate the center of a path based on the average of all points."""
        all_points = [
            seg.start for seg in path] + [
            seg.end for seg in path if isinstance(
                seg, (Line, CubicBezier, QuadraticBezier))]
        x_coords = [p.real for p in all_points]
        y_coords = [p.imag for p in all_points]
        center_x = np.mean(x_coords)
        center_y = np.mean(y_coords)
        return center_x, center_y

    def process_svg(self):
        """Process the input SVG file and calculate centers of paths."""
        input_filepath = self.config['input_filepath']
        paths, attributes = svg2paths(input_filepath)
        centers_data = []

        for i, path in enumerate(paths):
            center = self.calculate_center(path)
            centers_data.append({'Index': i, 'x': center[0], 'y': center[1]})

        # Save to DataFrame and then to CSV
        centers_file = f"{self.config['center_type']}_centers.csv"
        df_centers = pd.DataFrame(centers_data)
        # df_centers.to_csv(centers_file, index=False)
        # print(f"Centers saved to {centers_file}.")
        return df_centers

    def evaluate_clusters(self, X):
        max_centers = self.config['max_centers']
        inertia = []
        silhouette = []
        db_index = []
        k_values = range(10, min(max_centers, X.shape[0]) + 1)

        for k in k_values:
            kmeans = KMeans(n_clusters=k, random_state=10).fit(X)
            labels = kmeans.labels_
            inertia.append(kmeans.inertia_)
            silhouette.append(silhouette_score(X, labels))
            db_index.append(davies_bouldin_score(X, labels))

        # Plotting the evaluation metrics
        plt.figure(figsize=(15, 5))
        plt.subplot(131)
        plt.plot(k_values, inertia, 'bo-')
        plt.title('Elbow Method (Inertia)')
        plt.xlabel('Number of clusters')
        plt.ylabel('Inertia')

        plt.subplot(132)
        plt.plot(k_values, silhouette, 'go-')
        plt.title('Silhouette Score')
        plt.xlabel('Number of clusters')
        plt.ylabel('Silhouette Score')

        plt.subplot(133)
        plt.plot(k_values, db_index, 'ro-')
        plt.title('Davies-Bouldin Index')
        plt.xlabel('Number of clusters')
        plt.ylabel('Davies-Bouldin Index')
        plt.gca().invert_yaxis()  # Lower is better for DB index

        plt.tight_layout()
        plt.savefig(f'evaluation_metrics_{time.time()}.png')
        return k_values, inertia, silhouette, db_index

    def perform_k_means(self, centers_df, optimal_k):
        X = centers_df[['x', 'y']].values
        kmeans = KMeans(n_clusters=optimal_k, random_state=X.shape[0]).fit(X)
        centers_df['Center'] = kmeans.labels_

        # Saving the results
        result_df = pd.DataFrame({
            'Original Index': centers_df['Index'],
            'x': centers_df['x'],
            'y': centers_df['y'],
            'Center': centers_df['Center']
        })
        output_filepath = 'clustered_data.csv'
        result_df.to_csv(output_filepath, index=False)
        print(f"Clustering results saved to '{output_filepath}'.")
        return output_filepath


def process_svg_and_group_by_center(
        csv_path,
        svg_path,
        output_svg_path,
        color_scheme):
    """Reads the centers data from a CSV file and groups paths in the SVG by center."""
    def get_color(args):
        colors = [matplotlib.colors.to_hex(
            c) for c in plt.cm.get_cmap(color_scheme).colors]
        return colors[random.randint(0, len(colors) - 1)]

    # Read centers data from CSV
    centers_data = pd.read_csv(csv_path)

    # Map centers to colors
    centers_data['color'] = centers_data['Center'].apply(get_color)
    center_color_map = dict(zip(centers_data['Center'], centers_data['color']))

    # Load the SVG file
    ET.register_namespace('', "http://www.w3.org/2000=/svg")
    original_paths, attributes = svg2paths(svg_path)

    # Calculate the bounding box of the original SVG
    min_x = min_y = float('inf')
    max_x = max_y = float('-inf')
    for path in original_paths:
        bbox = path.bbox()
        # Update min and max coordinates
        min_x = min(min_x, bbox[0])  # Leftmost point
        min_y = min(min_y, bbox[1])  # Topmost point
        max_x = max(max_x, bbox[2])  # Rightmost point
        max_y = max(max_y, bbox[3])  # Bottommost point

    # Define the dimensions and scaling of the output SVG
    output_width, output_height = 600, 600
    width, height = max_x - min_x, max_y - min_y
    scale = min(output_width / width, output_height / height)

    # Calculate the translation needed to center the SVG
    translate_x = output_width / 2 - (min_x + width / 2) * scale
    translate_y = output_height / 2 - (min_y + height / 2) * scale

    # Prepare a new SVG root element
    root = ET.Element('svg', attrib={
        'xmlns': "http://www.w3.org/2000/svg",
        'viewBox': f"{0} {0} {output_width} {output_height}",
        'width': str(output_width),
        'height': str(output_height)
    })

    # Group paths by center and apply transformations
    center_groups = {}
    for i, (path, attr) in enumerate(zip(original_paths, attributes)):
        center = centers_data.loc[i, 'Center']
        if center not in center_groups:
            center_groups[center] = ET.SubElement(
                root,
                'g',
                attrib={
                    'fill': center_color_map[center],
                    'stroke': 'black',
                    'stroke-width': '1',
                    'transform': f'translate({translate_x}, {translate_y}) scale({scale}, {scale})'})
        ET.SubElement(
            center_groups[center],
            'path',
            attrib={
                'd': path.d(),
                **attr})

    # Write the new SVG file
    tree = ET.ElementTree(root)
    tree.write(output_svg_path)
    print(f"SVG processing complete. Output saved to {output_svg_path}")



def flip_svg_vertically(input_svg_path, output_svg_path):
    """Flip an SVG vertically by scaling it with a negative factor on the Y axis."""
    tree = ET.parse(input_svg_path)
    root = tree.getroot()

    # Check if there is an existing 'transform' attribute
    transform = root.get('transform', '')

    # Set up the vertical flip transformation (scale by -1 on the Y axis)
    # We assume the height of the SVG is known or can be obtained (replace
    # 'height' with the actual height value)
    height = 100  # Set this to the actual height of your SVG canvas
    vertical_flip_transform = f"translate(0, {height}) scale(1, -1)"

    # Append or modify the existing transform attribute
    if transform:
        # Combine with existing transforms
        new_transform = f"{transform} {vertical_flip_transform}"
    else:
        new_transform = vertical_flip_transform

    # Update the transform attribute
    root.set('transform', new_transform)

    # Write the modified SVG to a new file
    tree.write(output_svg_path)

    print(f"SVG flipped vertically and saved to '{output_svg_path}'.")


def main(config):
    """Main function to process the input SVG file and cluster the paths."""
    svg_processor = SVGProcessor(config)
    centers_df = svg_processor.process_svg()

    k_values, inertia, silhouette, db_index = svg_processor.evaluate_clusters(
        centers_df[['x', 'y']])
    # +10 because range starts from 10
    optimal_k = silhouette.index(max(silhouette)) + 10
    print(f"Optimal number of clusters: {optimal_k}")

    # Perform K-Means clustering
    clustered_data_path = svg_processor.perform_k_means(centers_df, optimal_k)

    # Color the SVG by cluster and save the output
    process_svg_and_group_by_center(
        clustered_data_path,
        config['input_filepath'],
        config['output_filepath'],
        config['color_scheme'])
    print("Processing complete.")

    # Update viewBox of the output SVG
    # set_svg_viewbox(config['output_filepath'])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Cluster SVG paths based on their centers.')
    parser.add_argument(
        '--center_type',
        type=str,
        default='bounding_box',
        help='Type of center to calculate (centroid or bounding_box)')
    parser.add_argument(
        '--max_centers',
        type=int,
        default=12,
        help='Maximum number of centers to evaluate')
    parser.add_argument(
        '--input_filepath',
        type=str,
        help='Input SVG file path')
    parser.add_argument(
        '--output_filepath',
        type=str,
        help='Output SVG file path')
    parser.add_argument(
        '--color_scheme',
        type=str,
        default='tab20',
        help='Color scheme for the output SVG')

    args = parser.parse_args()

    config = {
        'center_type': args.center_type,
        'max_centers': args.max_centers,
        'input_filepath': args.input_filepath or '/Users/srinivasvaddi/Downloads/img.svg',
        'output_filepath': args.output_filepath or f'output{time.time()}.svg',
        'color_scheme': args.color_scheme or 'tab20',
    }

    main(config)
    flip_svg_vertically(config['output_filepath'], config['output_filepath'])
 
Feel free to try it!
Buy us a coffeeBuy us a coffee