Spark Mlib 机器学习库集成了许多常用的机器学习算法,本文以K-Means算法为例结合图像压缩案例,简单介绍K-Means的应用。关于K-Means算法理论可以参考 → K-Means聚类算法(理论篇)
案例介绍
图像压缩
1)一张图由一系列像素组成,每个像素的颜色都由R、G、B值构成(不考虑Alpha),即R、G、B构成了颜色的三个基本特征,例如一个白色的像素点可以表示为(255,255,255)。
2)一张800×600的图片有480000个颜色数据,通过K-Means算法将这些颜色数据归类到K种颜色中,通过训练模型计算原始颜色对应的颜色分类,替换后生成新的图片。
Spark Mlib K-Means应用(Java + Python)
Java实现(详细解释见代码注释) 基于Spark1.3.1
- import org.apache.spark.SparkConf;
- import org.apache.spark.api.java.JavaSparkContext;
- import org.apache.spark.mllib.clustering.KMeans;
- import org.apache.spark.mllib.clustering.KMeansModel;
- import org.apache.spark.mllib.linalg.Vector;
- import org.apache.spark.mllib.linalg.Vectors;
- import org.apache.spark.rdd.RDD;
- import javax.imageio.ImageIO;
- import java.awt.image.BufferedImage;
- import java.io.File;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Map;
- /**
- * Spark MLib Demo
- */
- public class Demo {
- public static void main(String[] args) throws IOException {
- //加载图片
- BufferedImage bi = ImageIO.read(new File("/Users/felix/Desktop/1.jpg"));
- HashMap<int[], Vector> rgbs = new HashMap<>();
- //提取图片像素
- for (int x = 0; x < bi.getWidth(); x++) {
- for (int y = 0; y < bi.getHeight(); y++) {
- int[] pixel = bi.getRaster().getPixel(x, y, new int[3]);
- int[] point = new int[]{x, y};
- int r = pixel[0];
- int g = pixel[1];
- int b = pixel[2];
- //key为像素坐标, r,g,b特征构建密集矩阵
- rgbs.put(point, Vectors.dense((double) r, (double) g, (double) b));
- }
- }
- //初始化Spark
- SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[4]");
- JavaSparkContext sc = new JavaSparkContext(conf);
- RDD<Vector> data = sc.parallelize(new ArrayList<>(rgbs.values())).rdd();
- data.cache();
- //聚类的K值
- int k = 4;
- //runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
- int runs = 10;
- //K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
- int maxIterations = 20;
- //训练模型
- KMeansModel model = KMeans.train(data, k, maxIterations, runs);
- //训练模型的聚类中心
- ArrayList<double[]> cluster = new ArrayList<>();
- for (Vector c : model.clusterCenters()) {
- cluster.add(c.toArray());
- }
- //图像各个像素的新RGB值(即把原图的颜色值替换为对应所属聚类的聚类中心颜色)
- HashMap<int[], double[]> newRgbs = new HashMap<>();
- for (Map.Entry rgb : rgbs.entrySet()) {
- int clusterKey = model.predict((Vector) rgb.getValue());
- newRgbs.put((int[]) rgb.getKey(), cluster.get(clusterKey));
- }
- //生成新的图片
- BufferedImage image = new BufferedImage(bi.getWidth(), bi.getHeight(), BufferedImage.TYPE_INT_RGB);
- for (Map.Entry rgb : newRgbs.entrySet()) {
- int[] point = (int[]) rgb.getKey();
- double[] vRGB = (double[]) rgb.getValue();
- int RGB = ((int) vRGB[0] << 16) | ((int) vRGB[1] << 8) | ((int) vRGB[2]);
- image.setRGB(point[0], point[1], RGB);
- }
- ImageIO.write(image, "jpg", new File("/Users/felix/Desktop/" + k + "_" + runs + "pic.jpg"));
- sc.stop();
- }
- }
Python实现(详细解释见代码注释) 基于Spark1.3.1
- # encoding:utf-8
- from pyspark import SparkContext, SparkConf
- from pyspark.mllib.clustering import KMeans
- from numpy import array, zeros
- from scipy import misc
- from PIL import Image, ImageDraw, ImageColor
- # 加载图片
- img = misc.imread('/Users/felix/Desktop/1.JPG')
- height = len(img)
- width = len(img[0])
- rgbs = []
- points = []
- # 记录每个像素对应的r,g,b值
- for x in xrange(height):
- for y in xrange(width):
- r = img[x, y, 0]
- g = img[x, y, 1]
- b = img[x, y, 2]
- rgbs.extend([r, g, b])
- points.append(([x, y], [r, g, b]))
- # 初始化spark
- conf = SparkConf().setAppName("KMeans").setMaster('local[4]')
- sc = SparkContext(conf=conf)
- data = sc.parallelize(array(rgbs).reshape(height * width, 3))
- data.cache()
- # 聚类的K值
- k = 5
- # runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
- runs = 10
- # K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
- max_iterations = 20
- # 训练模型
- model = KMeans.train(data, k, max_iterations, runs)
- new_points = {}
- # 将原图的颜色值替换为对应所属聚类的聚类中心颜色
- for point in points:
- x = point[0][0]
- y = point[0][1]
- rgb = point[1]
- c = model.predict(rgb)
- key = str(x) + '-' + str(y)
- new_points[key] = model.centers[c]
- # 生成新图像
- pixels = zeros((height, width, 3), 'uint8')
- for point, rgb in new_points.iteritems():
- rgb = list(rgb)
- point = point.split('-')
- x = int(point[0])
- y = int(point[1])
- pixels[x, y, 0] = rgb[0]
- pixels[x, y, 1] = rgb[1]
- pixels[x, y, 2] = rgb[2]
- img = Image.fromarray(pixels)
- img.save('%d_%dpic_py.jpeg' % (k, runs))
- sc.stop()
结果对比
以下列出不同参数的训练模型得出的压缩结果
原图 ↓
k=2 | runs=10 ↓
k=4 | runs=10 ↓
k=5 | runs=1 ↓
k=5 | runs=10 ↓
k=16 | runs=10 ↓
对比可见,随着K值的增加图片的细节越来越丰富(K相当于颜色数、颜色越多自然越接近原图),另外当K同为5时,runs越大图片效果越好(runs参数越大,spark并行训练得到的模型越多,最终会选取最优模型,自然图片效果越好)