日志

K-Means聚类算法(实践篇)– 基于Spark Mlib的图像压缩案例

Spark Mlib 机器学习库集成了许多常用的机器学习算法,本文以K-Means算法为例结合图像压缩案例,简单介绍K-Means的应用。关于K-Means算法理论可以参考 → K-Means聚类算法(理论篇)

案例介绍

图像压缩

1)一张图由一系列像素组成,每个像素的颜色都由R、G、B值构成(不考虑Alpha),即R、G、B构成了颜色的三个基本特征,例如一个白色的像素点可以表示为(255,255,255)。

2)一张800×600的图片有480000个颜色数据,通过K-Means算法将这些颜色数据归类到K种颜色中,通过训练模型计算原始颜色对应的颜色分类,替换后生成新的图片。

Spark Mlib K-Means应用(Java + Python)

Java实现(详细解释见代码注释) 基于Spark1.3.1

  1. import org.apache.spark.SparkConf;
  2. import org.apache.spark.api.java.JavaSparkContext;
  3. import org.apache.spark.mllib.clustering.KMeans;
  4. import org.apache.spark.mllib.clustering.KMeansModel;
  5. import org.apache.spark.mllib.linalg.Vector;
  6. import org.apache.spark.mllib.linalg.Vectors;
  7. import org.apache.spark.rdd.RDD;
  8.  
  9. import javax.imageio.ImageIO;
  10. import java.awt.image.BufferedImage;
  11. import java.io.File;
  12. import java.io.IOException;
  13. import java.util.ArrayList;
  14. import java.util.HashMap;
  15. import java.util.Map;
  16.  
  17. /**
  18. * Spark MLib Demo
  19. */
  20. public class Demo {
  21. public static void main(String[] args) throws IOException {
  22. //加载图片
  23. BufferedImage bi = ImageIO.read(new File("/Users/felix/Desktop/1.jpg"));
  24. HashMap<int[], Vector> rgbs = new HashMap<>();
  25. //提取图片像素
  26. for (int x = 0; x < bi.getWidth(); x++) {
  27. for (int y = 0; y < bi.getHeight(); y++) {
  28. int[] pixel = bi.getRaster().getPixel(x, y, new int[3]);
  29. int[] point = new int[]{x, y};
  30. int r = pixel[0];
  31. int g = pixel[1];
  32. int b = pixel[2];
  33. //key为像素坐标, r,g,b特征构建密集矩阵
  34. rgbs.put(point, Vectors.dense((double) r, (double) g, (double) b));
  35. }
  36. }
  37. //初始化Spark
  38. SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[4]");
  39. JavaSparkContext sc = new JavaSparkContext(conf);
  40. RDD<Vector> data = sc.parallelize(new ArrayList<>(rgbs.values())).rdd();
  41. data.cache();
  42.  
  43. //聚类的K值
  44. int k = 4;
  45. //runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
  46. int runs = 10;
  47. //K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
  48. int maxIterations = 20;
  49. //训练模型
  50. KMeansModel model = KMeans.train(data, k, maxIterations, runs);
  51. //训练模型的聚类中心
  52. ArrayList<double[]> cluster = new ArrayList<>();
  53. for (Vector c : model.clusterCenters()) {
  54. cluster.add(c.toArray());
  55. }
  56. //图像各个像素的新RGB值(即把原图的颜色值替换为对应所属聚类的聚类中心颜色)
  57. HashMap<int[], double[]> newRgbs = new HashMap<>();
  58. for (Map.Entry rgb : rgbs.entrySet()) {
  59. int clusterKey = model.predict((Vector) rgb.getValue());
  60. newRgbs.put((int[]) rgb.getKey(), cluster.get(clusterKey));
  61. }
  62. //生成新的图片
  63. BufferedImage image = new BufferedImage(bi.getWidth(), bi.getHeight(), BufferedImage.TYPE_INT_RGB);
  64. for (Map.Entry rgb : newRgbs.entrySet()) {
  65. int[] point = (int[]) rgb.getKey();
  66. double[] vRGB = (double[]) rgb.getValue();
  67. int RGB = ((int) vRGB[0] << 16) | ((int) vRGB[1] << 8) | ((int) vRGB[2]);
  68. image.setRGB(point[0], point[1], RGB);
  69. }
  70. ImageIO.write(image, "jpg", new File("/Users/felix/Desktop/" + k + "_" + runs + "pic.jpg"));
  71. sc.stop();
  72. }
  73. }

Python实现(详细解释见代码注释) 基于Spark1.3.1

  1. # encoding:utf-8
  2. from pyspark import SparkContext, SparkConf
  3. from pyspark.mllib.clustering import KMeans
  4. from numpy import array, zeros
  5. from scipy import misc
  6. from PIL import Image, ImageDraw, ImageColor
  7.  
  8. # 加载图片
  9. img = misc.imread('/Users/felix/Desktop/1.JPG')
  10. height = len(img)
  11. width = len(img[0])
  12. rgbs = []
  13. points = []
  14.  
  15. # 记录每个像素对应的r,g,b值
  16. for x in xrange(height):
  17. for y in xrange(width):
  18. r = img[x, y, 0]
  19. g = img[x, y, 1]
  20. b = img[x, y, 2]
  21. rgbs.extend([r, g, b])
  22. points.append(([x, y], [r, g, b]))
  23. # 初始化spark
  24. conf = SparkConf().setAppName("KMeans").setMaster('local[4]')
  25. sc = SparkContext(conf=conf)
  26. data = sc.parallelize(array(rgbs).reshape(height * width, 3))
  27. data.cache()
  28. # 聚类的K值
  29. k = 5
  30. # runs参数代表并行训练runs个模型,返回聚类效果最好的那个模型作为最终的训练结果
  31. runs = 10
  32. # K-Means算法最大迭代次数,当迭代次数达到最大值时不管聚类中心是否收敛算法也会结束
  33. max_iterations = 20
  34. # 训练模型
  35. model = KMeans.train(data, k, max_iterations, runs)
  36. new_points = {}
  37.  
  38. # 将原图的颜色值替换为对应所属聚类的聚类中心颜色
  39. for point in points:
  40. x = point[0][0]
  41. y = point[0][1]
  42. rgb = point[1]
  43. c = model.predict(rgb)
  44. key = str(x) + '-' + str(y)
  45. new_points[key] = model.centers[c]
  46.  
  47. # 生成新图像
  48. pixels = zeros((height, width, 3), 'uint8')
  49. for point, rgb in new_points.iteritems():
  50. rgb = list(rgb)
  51. point = point.split('-')
  52. x = int(point[0])
  53. y = int(point[1])
  54. pixels[x, y, 0] = rgb[0]
  55. pixels[x, y, 1] = rgb[1]
  56. pixels[x, y, 2] = rgb[2]
  57.  
  58. img = Image.fromarray(pixels)
  59. img.save('%d_%dpic_py.jpeg' % (k, runs))
  60. sc.stop()

结果对比

以下列出不同参数的训练模型得出的压缩结果

原图 ↓

k=2 | runs=10 ↓

k=4 | runs=10 ↓

k=5 | runs=1 ↓

k=5 | runs=10 ↓

k=16 | runs=10 ↓

对比可见,随着K值的增加图片的细节越来越丰富(K相当于颜色数、颜色越多自然越接近原图),另外当K同为5时,runs越大图片效果越好(runs参数越大,spark并行训练得到的模型越多,最终会选取最优模型,自然图片效果越好)

转载请注明出处:

© http://hejunhao.me