Spark Mlib 机器学习库集成了许多常用的机器学习算法,本文以K-Means算法为例结合图像压缩案例,简单介绍K-Means的应用。关于K-Means算法理论可以参考 → K-Means聚类算法(理论篇)
案例介绍
图像压缩
1)一张图由一系列像素组成,每个像素的颜色都由R、G、B值构成(不考虑Alpha),即R、G、B构成了颜色的三个基本特征,例如一个白色的像素点可以表示为(255,255,255)。
2)一张800×600的图片有480000个颜色数据,通过K-Means算法将这些颜色数据归类到K种颜色中,通过训练模型计算原始颜色对应的颜色分类,替换后生成新的图片。
Spark Mlib K-Means应用(Java + Python)
Java实现(详细解释见代码注释) 基于Spark1.3.1
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* Spark MLib Demo
*/
public class Demo {
public static void main(String[] args) throws IOException {
//加载图片
BufferedImage bi = ImageIO.read(new File("/Users/felix/Desktop/1.jpg"));
HashMap<int[], Vector> rgbs = new HashMap<>();
//提取图片像素
for (int x = 0; x < bi.getWidth(); x++) {
for (int y = 0; y < bi.getHeight(); y++) {
int[] pixel = bi.getRaster().getPixel(x, y, new int[3]);
int[] point = new int[]{x, y};
int r = pixel[0];
int g = pixel[1];
int b = pixel[2];
//key为像素坐标, r,g,b特征构建密集矩阵
rgbs.put(point, Vectors.dense((double) r, (double) g, (double) b));
}
}
//初始化Spark
SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[4]");
JavaSparkContext sc = new JavaSparkContext(conf);
RDD<Vector> data = sc.parallelize(new ArrayList<>(rgbs.values())).rdd();
data.cache();
//聚类的K值
int k = 4;
//runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
int runs = 10;
//K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
int maxIterations = 20;
//训练模型
KMeansModel model = KMeans.train(data, k, maxIterations, runs);
//训练模型的聚类中心
ArrayList<double[]> cluster = new ArrayList<>();
for (Vector c : model.clusterCenters()) {
cluster.add(c.toArray());
}
//图像各个像素的新RGB值(即把原图的颜色值替换为对应所属聚类的聚类中心颜色)
HashMap<int[], double[]> newRgbs = new HashMap<>();
for (Map.Entry rgb : rgbs.entrySet()) {
int clusterKey = model.predict((Vector) rgb.getValue());
newRgbs.put((int[]) rgb.getKey(), cluster.get(clusterKey));
}
//生成新的图片
BufferedImage image = new BufferedImage(bi.getWidth(), bi.getHeight(), BufferedImage.TYPE_INT_RGB);
for (Map.Entry rgb : newRgbs.entrySet()) {
int[] point = (int[]) rgb.getKey();
double[] vRGB = (double[]) rgb.getValue();
int RGB = ((int) vRGB[0] << 16) | ((int) vRGB[1] << 8) | ((int) vRGB[2]);
image.setRGB(point[0], point[1], RGB);
}
ImageIO.write(image, "jpg", new File("/Users/felix/Desktop/" + k + "_" + runs + "pic.jpg"));
sc.stop();
}
}
Python实现(详细解释见代码注释) 基于Spark1.3.1
# encoding:utf-8
from pyspark import SparkContext, SparkConf
from pyspark.mllib.clustering import KMeans
from numpy import array, zeros
from scipy import misc
from PIL import Image, ImageDraw, ImageColor
# 加载图片
img = misc.imread('/Users/felix/Desktop/1.JPG')
height = len(img)
width = len(img[0])
rgbs = []
points = []
# 记录每个像素对应的r,g,b值
for x in xrange(height):
for y in xrange(width):
r = img[x, y, 0]
g = img[x, y, 1]
b = img[x, y, 2]
rgbs.extend([r, g, b])
points.append(([x, y], [r, g, b]))
# 初始化spark
conf = SparkConf().setAppName("KMeans").setMaster('local[4]')
sc = SparkContext(conf=conf)
data = sc.parallelize(array(rgbs).reshape(height * width, 3))
data.cache()
# 聚类的K值
k = 5
# runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
runs = 10
# K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
max_iterations = 20
# 训练模型
model = KMeans.train(data, k, max_iterations, runs)
new_points = {}
# 将原图的颜色值替换为对应所属聚类的聚类中心颜色
for point in points:
x = point[0][0]
y = point[0][1]
rgb = point[1]
c = model.predict(rgb)
key = str(x) + '-' + str(y)
new_points[key] = model.centers[c]
# 生成新图像
pixels = zeros((height, width, 3), 'uint8')
for point, rgb in new_points.iteritems():
rgb = list(rgb)
point = point.split('-')
x = int(point[0])
y = int(point[1])
pixels[x, y, 0] = rgb[0]
pixels[x, y, 1] = rgb[1]
pixels[x, y, 2] = rgb[2]
img = Image.fromarray(pixels)
img.save('%d_%dpic_py.jpeg' % (k, runs))
sc.stop()
结果对比
以下列出不同参数的训练模型得出的压缩结果
原图 ↓

k=2 | runs=10 ↓

k=4 | runs=10 ↓

k=5 | runs=1 ↓

k=5 | runs=10 ↓

k=16 | runs=10 ↓

对比可见,随着K值的增加图片的细节越来越丰富(K相当于颜色数、颜色越多自然越接近原图),另外当K同为5时,runs越大图片效果越好(runs参数越大,spark并行训练得到的模型越多,最终会选取最优模型,自然图片效果越好)