日志

K-Means聚类算法(实践篇)– 基于Spark Mlib的图像压缩案例

Spark Mlib 机器学习库集成了许多常用的机器学习算法,本文以K-Means算法为例结合图像压缩案例,简单介绍K-Means的应用。关于K-Means算法理论可以参考 → K-Means聚类算法(理论篇)

案例介绍

图像压缩

1)一张图由一系列像素组成,每个像素的颜色都由R、G、B值构成(不考虑Alpha),即R、G、B构成了颜色的三个基本特征,例如一个白色的像素点可以表示为(255,255,255)。

2)一张800×600的图片有480000个颜色数据,通过K-Means算法将这些颜色数据归类到K种颜色中,通过训练模型计算原始颜色对应的颜色分类,替换后生成新的图片。

Spark Mlib K-Means应用(Java + Python)

Java实现(详细解释见代码注释) 基于Spark1.3.1

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * Spark MLib Demo
 */
public class Demo {
    public static void main(String[] args) throws IOException {
        //加载图片
        BufferedImage bi = ImageIO.read(new File("/Users/felix/Desktop/1.jpg"));
        HashMap<int[], Vector> rgbs = new HashMap<>();
        //提取图片像素
        for (int x = 0; x < bi.getWidth(); x++) {
            for (int y = 0; y < bi.getHeight(); y++) {
                int[] pixel = bi.getRaster().getPixel(x, y, new int[3]);
                int[] point = new int[]{x, y};
                int r = pixel[0];
                int g = pixel[1];
                int b = pixel[2];
                //key为像素坐标, r,g,b特征构建密集矩阵
                rgbs.put(point, Vectors.dense((double) r, (double) g, (double) b));
            }
        }
        //初始化Spark
        SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[4]");
        JavaSparkContext sc = new JavaSparkContext(conf);
        RDD<Vector> data = sc.parallelize(new ArrayList<>(rgbs.values())).rdd();
        data.cache();

        //聚类的K值
        int k = 4;
        //runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
        int runs = 10;
        //K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
        int maxIterations = 20;
        //训练模型
        KMeansModel model = KMeans.train(data, k, maxIterations, runs);
        
        //训练模型的聚类中心
        ArrayList<double[]> cluster = new ArrayList<>();
        for (Vector c : model.clusterCenters()) {
            cluster.add(c.toArray());
        }
        //图像各个像素的新RGB值(即把原图的颜色值替换为对应所属聚类的聚类中心颜色)
        HashMap<int[], double[]> newRgbs = new HashMap<>();
        
        for (Map.Entry rgb : rgbs.entrySet()) {
            int clusterKey = model.predict((Vector) rgb.getValue());
            newRgbs.put((int[]) rgb.getKey(), cluster.get(clusterKey));
        }
        
        //生成新的图片
        BufferedImage image = new BufferedImage(bi.getWidth(), bi.getHeight(), BufferedImage.TYPE_INT_RGB);
        
        for (Map.Entry rgb : newRgbs.entrySet()) {
            int[] point = (int[]) rgb.getKey();
            double[] vRGB = (double[]) rgb.getValue();
            int RGB = ((int) vRGB[0] << 16) | ((int) vRGB[1] << 8) | ((int) vRGB[2]);
            image.setRGB(point[0], point[1], RGB);
        }
        
        ImageIO.write(image, "jpg", new File("/Users/felix/Desktop/" + k + "_" + runs + "pic.jpg"));
        sc.stop();
    }
}

Python实现(详细解释见代码注释) 基于Spark1.3.1

# encoding:utf-8
from pyspark import SparkContext, SparkConf
from pyspark.mllib.clustering import KMeans
from numpy import array, zeros
from scipy import misc
from PIL import Image, ImageDraw, ImageColor

# 加载图片
img = misc.imread('/Users/felix/Desktop/1.JPG')
height = len(img)
width = len(img[0])
rgbs = []
points = []

# 记录每个像素对应的r,g,b值
for x in xrange(height):
    for y in xrange(width):
        r = img[x, y, 0]
        g = img[x, y, 1]
        b = img[x, y, 2]
        rgbs.extend([r, g, b])
        points.append(([x, y], [r, g, b]))
# 初始化spark
conf = SparkConf().setAppName("KMeans").setMaster('local[4]')
sc = SparkContext(conf=conf)
data = sc.parallelize(array(rgbs).reshape(height * width, 3))
data.cache()
# 聚类的K值
k = 5
# runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
runs = 10
# K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
max_iterations = 20
# 训练模型
model = KMeans.train(data, k, max_iterations, runs)
new_points = {}

# 将原图的颜色值替换为对应所属聚类的聚类中心颜色
for point in points:
    x = point[0][0]
    y = point[0][1]
    rgb = point[1]
    c = model.predict(rgb)
    key = str(x) + '-' + str(y)
    new_points[key] = model.centers[c]

# 生成新图像
pixels = zeros((height, width, 3), 'uint8')
for point, rgb in new_points.iteritems():
    rgb = list(rgb)
    point = point.split('-')
    x = int(point[0])
    y = int(point[1])
    pixels[x, y, 0] = rgb[0]
    pixels[x, y, 1] = rgb[1]
    pixels[x, y, 2] = rgb[2]

img = Image.fromarray(pixels)
img.save('%d_%dpic_py.jpeg' % (k, runs))
sc.stop()

结果对比

以下列出不同参数的训练模型得出的压缩结果

原图 ↓

k=2 | runs=10 ↓

k=4 | runs=10 ↓

k=5 | runs=1 ↓

k=5 | runs=10 ↓

k=16 | runs=10 ↓

对比可见,随着K值的增加图片的细节越来越丰富(K相当于颜色数、颜色越多自然越接近原图),另外当K同为5时,runs越大图片效果越好(runs参数越大,spark并行训练得到的模型越多,最终会选取最优模型,自然图片效果越好)

转载请注明出处:

© http://hejunhao.me