k-means clustering

Given a set of observations (x1, x2, …, xn), where each observation is a d-dimensional real vector, k-means clustering aims to partition the n observations into k sets (kn) S = {S1, S2, …, Sk} so as to minimize the within-cluster sum of squares



where μi is the mean of points in Si.



Algorithm

The most common algorithm uses an iterative refinement technique. Due to its ubiquity it is often called the k-means algorithm; it is also referred to as Lloyd's algorithm, particularly in the computer science community.

Given an initial set of k means, the algorithm proceeds by alternating between two steps:


  • Assignment step
    Assign each observation to the cluster whose mean is closest to it (i.e. partition the observations according to the Voronoi diagram generated by the means).
  • Update step
    Calculate the new means to be the centroids of the observations in the new clusters.


The algorithm has converged when the assignments no longer change.

The "assignment" step is also referred to as expectation step, the "update step" as maximization step, making this algorithm a variant of the generalized expectation-maximization algorithm.



Demonstration of the standard algorithm

(1) k initial "means" (in this case k = 3) are randomly generated within the data domain.

(2) k clusters are created by associating every observation with the nearest mean. The partitions here represent the Voronoi diagram generated by the means.

(3) The centroid of each of the k clusters becomes the new mean.

(4) Steps 2 and 3 are repeated until convergence has been reached.



Example (k = 4)



Weak point of k-means

  • 크기가 매우 다른 군집을 찾는 데는 적절하지 않다.
  • 소수의 데이터(혼자 멀리 떨어진 데이터)가 평균값 계산 시 영향을 미쳐 잡음이나 이상치 데이터에 민감하다.
  • k값을 미리 정해야 한다.



Variations
  • Fuzzy C-Means Clustering is a soft version of K-means, where each data point has a fuzzy degree of belonging to each cluster.
  • Gaussian mixture models trained with expectation-maximization algorithm (EM algorithm) maintains probabilistic assignments to clusters, instead of deterministic assignments, and multivariate Gaussian distributions instead of means.
  • Several methods have been proposed to choose better starting clusters. One recent proposal is k-means++.
  • The filtering algorithm uses kd-trees to speed up each k-means step.
  • Some methods attempt to speed up each k-means step using coresets or the triangle inequality.
  • Escape local optima by swapping points between clusters.
  • The Spherical k-means clustering algorithm is suitable for directional data.
  • The Minkowski metric weighted k-means deals with the problem of noise features by assigning weights to each feature per cluster.



Code
import java.util.*;
import java.math.*;
import java.io.*;

public class kmeans {
	static ArrayList<Vector> points = new ArrayList<Vector>();
	static ArrayList<Vector> seeds = new ArrayList<Vector>();
	static int nCluster;
	static float maxDistance;
	
	public static void printSeed() {
		for(int i=0; i<nCluster; i++)
			System.out.println("Seed #" + i + " = (" + seeds.get(i).elementAt(0) + ", " + seeds.get(i).elementAt(1) + ")");
		System.out.println();
	}
	
	public static void randomSeed() {
		Random random = new Random();
		float maxX = 0.0f;
		float maxY = 0.0f;
		float tempX, tempY;
		ArrayList<Float> randomList = new ArrayList<Float>();
		
		for(int i=0; i<points.size(); i++) {
			tempX = (float)points.get(i).elementAt(0);
			tempY = (float)points.get(i).elementAt(1);
			if(maxX < tempX) maxX = tempX;
			if(maxY < tempY) maxY = tempY;
		}
		
		maxDistance = (float)Math.pow(maxX, 2) + (float)Math.pow(maxY, 2);
		
		randomList.clear();
		seeds.clear();
		for(int i=0; i<nCluster; i++) {
			do {
				tempX = random.nextFloat() * maxX;
			}while(randomList.contains(tempX));
			
			do {
				tempY = random.nextFloat() * maxY;
			}while(randomList.contains(tempY));
			
			Vector<Float> seedPos = new Vector<Float>();
			seedPos.add(tempX);
			seedPos.add(tempY);
			
			seeds.add(seedPos);
		}
	}
	
	public static float calDistance(Vector<Object> a, Vector<Object> b) {
		float distanceX = (float)a.elementAt(0) - (float)b.elementAt(0);
		float distanceY = (float)a.elementAt(1) - (float)b.elementAt(1);
		
		return ((float)Math.pow(distanceX, 2) + (float)Math.pow(distanceY, 2));
	}
	
	public static void doClustering() {
		float minDistance = maxDistance;
		float temp;
		int cluster;
		
		for(int j=0; j<points.size(); j++) {
			cluster = -1;
			minDistance = maxDistance;
			for(int i=0; i<nCluster; i++) {
				temp = calDistance(points.get(j), seeds.get(i));
				if(minDistance > temp) {
					minDistance = temp;
					cluster = i;
				}
			}
			points.get(j).set(2, cluster);
		}
	}
	
	public static void newSeed() {
		int cnt = 0;
		float newX = 0.0f;
		float newY = 0.0f;
		
		for(int i=0; i<nCluster; i++) {
			cnt = 0;
			newX = 0.0f;
			newY = 0.0f;
			for(int j=0; j<points.size(); j++) {
				if((int)points.get(j).elementAt(2) == i) {
					newX += (float)points.get(j).elementAt(0);
					newY += (float)points.get(j).elementAt(1);
					cnt++;
				}
			}
			
			if(cnt == 0) cnt = 1;
			seeds.get(i).set(0, (newX / cnt));
			seeds.get(i).set(1, (newY / cnt));
		}
	}
	
	public static void setData(File data) {
		String oneline = "";
		float x, y;
		
		try {
			BufferedReader br = new BufferedReader(new FileReader(data));
			
			while((oneline = br.readLine()) != null) {
				StringTokenizer token = new StringTokenizer(oneline, "\t");
				
				x = Float.parseFloat(token.nextToken());
				token.hasMoreTokens();
				y = Float.parseFloat(token.nextToken());
				
				Vector<Object> v = new Vector<Object>();
				v.addElement(x);
				v.addElement(y);
				v.addElement(new Integer(-1));
				points.add(v);
			}
			
			br.close();
		} catch(FileNotFoundException e) {
			e.printStackTrace();
		} catch(IOException e) {
			e.printStackTrace();
		}
	}

	public static int unChanged(ArrayList<Vector> temp) {
		int result = 1;
		
		for(int i=0; i<nCluster; i++) {
			System.out.println(temp.get(i).elementAt(0) + ", " + temp.get(i).elementAt(1));
		}
		
		for(int i=0; i<nCluster; i++) {
			if((float)seeds.get(i).elementAt(0) != (float)temp.get(i).elementAt(0))
				result *= 0;
			if((float)seeds.get(i).elementAt(1) != (float)temp.get(i).elementAt(1))
				result *= 0;
		}
		return result;
	}
	
	public static void printPoints() {
		for(int i=0; i<points.size(); i++) {
			System.out.println(points.get(i).elementAt(0) + ", " + points.get(i).elementAt(1) + ", " + points.get(i).elementAt(2));
		}
	}
	
	public static void main(String[] args) {
		nCluster = Integer.parseInt(args[0]);
		float[] prevX = new float[nCluster];
		float[] prevY = new float[nCluster];
		int changed;
		String path = "C:\\Users\\Park\\workspace\\k-means\\bin\\";
		File data = new File(path + args[1]);
		String[] wfilename = new String[nCluster];
		
		setData(data);
		randomSeed();
		printSeed();
		
		while(true) {
			for(int i=0; i<nCluster; i++) {
				prevX[i] = (float)seeds.get(i).elementAt(0);
				prevY[i] = (float)seeds.get(i).elementAt(1);
			}
			
			doClustering();
			newSeed();
			printSeed();
			
			changed = 0;
			for(int i=0; i<nCluster; i++) {
				if(prevX[i] != (float)seeds.get(i).elementAt(0))
					changed += 1;
				if(prevY[i] != (float)seeds.get(i).elementAt(1))
					changed += 1;
			}
			
			if(changed == 0) break;
		}
		
		try{
			BufferedWriter bw[] = new BufferedWriter[nCluster];
			for(int i=0; i<nCluster; i++){
				wfilename[i] = "cluster" + (i+1) + ".txt";
				bw[i] = new BufferedWriter(new FileWriter(path+wfilename[i]));
				
				for(int j=0; j<points.size(); j++) {
					if((int)points.get(j).elementAt(2) == i) {
						bw[i].write(points.get(j).elementAt(0) + "\t" + points.get(j).elementAt(1));
						bw[i].newLine();
					}
				}
				
				bw[i].close();
			}
		} catch(IOException e) {
			e.printStackTrace();
		}
	}
}



References