Spark Mlib 机器学习库集成了许多常用的机器学习算法,本文以K-Means算法为例结合图像压缩案例,简单介绍K-Means的应用。关于K-Means算法理论可以参考 → K-Means聚类算法(理论篇)
案例介绍
图像压缩
1)一张图由一系列像素组成,每个像素的颜色都由R、G、B值构成(不考虑Alpha),即R、G、B构成了颜色的三个基本特征,例如一个白色的像素点可以表示为(255,255,255)。
2)一张800×600的图片有480000个颜色数据,通过K-Means算法将这些颜色数据归类到K种颜色中,通过训练模型计算原始颜色对应的颜色分类,替换后生成新的图片。
Spark Mlib K-Means应用(Java + Python)
Java实现(详细解释见代码注释) 基于Spark1.3.1
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.rdd.RDD; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; /** * Spark MLib Demo */ public class Demo { public static void main(String[] args) throws IOException { //加载图片 BufferedImage bi = ImageIO.read(new File("/Users/felix/Desktop/1.jpg")); HashMap<int[], Vector> rgbs = new HashMap<>(); //提取图片像素 for (int x = 0; x < bi.getWidth(); x++) { for (int y = 0; y < bi.getHeight(); y++) { int[] pixel = bi.getRaster().getPixel(x, y, new int[3]); int[] point = new int[]{x, y}; int r = pixel[0]; int g = pixel[1]; int b = pixel[2]; //key为像素坐标, r,g,b特征构建密集矩阵 rgbs.put(point, Vectors.dense((double) r, (double) g, (double) b)); } } //初始化Spark SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[4]"); JavaSparkContext sc = new JavaSparkContext(conf); RDD<Vector> data = sc.parallelize(new ArrayList<>(rgbs.values())).rdd(); data.cache(); //聚类的K值 int k = 4; //runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果 int runs = 10; //K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束 int maxIterations = 20; //训练模型 KMeansModel model = KMeans.train(data, k, maxIterations, runs); //训练模型的聚类中心 ArrayList<double[]> cluster = new ArrayList<>(); for (Vector c : model.clusterCenters()) { cluster.add(c.toArray()); } //图像各个像素的新RGB值(即把原图的颜色值替换为对应所属聚类的聚类中心颜色) HashMap<int[], double[]> newRgbs = new HashMap<>(); for (Map.Entry rgb : rgbs.entrySet()) { int clusterKey = model.predict((Vector) rgb.getValue()); newRgbs.put((int[]) rgb.getKey(), cluster.get(clusterKey)); } //生成新的图片 BufferedImage image = new BufferedImage(bi.getWidth(), bi.getHeight(), BufferedImage.TYPE_INT_RGB); for (Map.Entry rgb : newRgbs.entrySet()) { int[] point = (int[]) rgb.getKey(); double[] vRGB = (double[]) rgb.getValue(); int RGB = ((int) vRGB[0] << 16) | ((int) vRGB[1] << 8) | ((int) vRGB[2]); image.setRGB(point[0], point[1], RGB); } ImageIO.write(image, "jpg", new File("/Users/felix/Desktop/" + k + "_" + runs + "pic.jpg")); sc.stop(); } }
Python实现(详细解释见代码注释) 基于Spark1.3.1
# encoding:utf-8 from pyspark import SparkContext, SparkConf from pyspark.mllib.clustering import KMeans from numpy import array, zeros from scipy import misc from PIL import Image, ImageDraw, ImageColor # 加载图片 img = misc.imread('/Users/felix/Desktop/1.JPG') height = len(img) width = len(img[0]) rgbs = [] points = [] # 记录每个像素对应的r,g,b值 for x in xrange(height): for y in xrange(width): r = img[x, y, 0] g = img[x, y, 1] b = img[x, y, 2] rgbs.extend([r, g, b]) points.append(([x, y], [r, g, b])) # 初始化spark conf = SparkConf().setAppName("KMeans").setMaster('local[4]') sc = SparkContext(conf=conf) data = sc.parallelize(array(rgbs).reshape(height * width, 3)) data.cache() # 聚类的K值 k = 5 # runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果 runs = 10 # K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束 max_iterations = 20 # 训练模型 model = KMeans.train(data, k, max_iterations, runs) new_points = {} # 将原图的颜色值替换为对应所属聚类的聚类中心颜色 for point in points: x = point[0][0] y = point[0][1] rgb = point[1] c = model.predict(rgb) key = str(x) + '-' + str(y) new_points[key] = model.centers[c] # 生成新图像 pixels = zeros((height, width, 3), 'uint8') for point, rgb in new_points.iteritems(): rgb = list(rgb) point = point.split('-') x = int(point[0]) y = int(point[1]) pixels[x, y, 0] = rgb[0] pixels[x, y, 1] = rgb[1] pixels[x, y, 2] = rgb[2] img = Image.fromarray(pixels) img.save('%d_%dpic_py.jpeg' % (k, runs)) sc.stop()
结果对比
以下列出不同参数的训练模型得出的压缩结果
原图 ↓
k=2 | runs=10 ↓
k=4 | runs=10 ↓
k=5 | runs=1 ↓
k=5 | runs=10 ↓
k=16 | runs=10 ↓
对比可见,随着K值的增加图片的细节越来越丰富(K相当于颜色数、颜色越多自然越接近原图),另外当K同为5时,runs越大图片效果越好(runs参数越大,spark并行训练得到的模型越多,最终会选取最优模型,自然图片效果越好)