k-means clustering

Given a set of observations (x1, x2, …, xn), where each observation is a d-dimensional real vector, k-means clustering aims to partition the n observations into k sets (kn) S = {S1, S2, …, Sk} so as to minimize the within-cluster sum of squares



where μi is the mean of points in Si.



Algorithm

The most common algorithm uses an iterative refinement technique. Due to its ubiquity it is often called the k-means algorithm; it is also referred to as Lloyd's algorithm, particularly in the computer science community.

Given an initial set of k means, the algorithm proceeds by alternating between two steps:


  • Assignment step
    Assign each observation to the cluster whose mean is closest to it (i.e. partition the observations according to the Voronoi diagram generated by the means).
  • Update step
    Calculate the new means to be the centroids of the observations in the new clusters.


The algorithm has converged when the assignments no longer change.

The "assignment" step is also referred to as expectation step, the "update step" as maximization step, making this algorithm a variant of the generalized expectation-maximization algorithm.



Demonstration of the standard algorithm

(1) k initial "means" (in this case k = 3) are randomly generated within the data domain.

(2) k clusters are created by associating every observation with the nearest mean. The partitions here represent the Voronoi diagram generated by the means.

(3) The centroid of each of the k clusters becomes the new mean.

(4) Steps 2 and 3 are repeated until convergence has been reached.



Example (k = 4)



Weak point of k-means

  • 크기가 매우 다른 군집을 찾는 데는 적절하지 않다.
  • 소수의 데이터(혼자 멀리 떨어진 데이터)가 평균값 계산 시 영향을 미쳐 잡음이나 이상치 데이터에 민감하다.
  • k값을 미리 정해야 한다.



Variations
  • Fuzzy C-Means Clustering is a soft version of K-means, where each data point has a fuzzy degree of belonging to each cluster.
  • Gaussian mixture models trained with expectation-maximization algorithm (EM algorithm) maintains probabilistic assignments to clusters, instead of deterministic assignments, and multivariate Gaussian distributions instead of means.
  • Several methods have been proposed to choose better starting clusters. One recent proposal is k-means++.
  • The filtering algorithm uses kd-trees to speed up each k-means step.
  • Some methods attempt to speed up each k-means step using coresets or the triangle inequality.
  • Escape local optima by swapping points between clusters.
  • The Spherical k-means clustering algorithm is suitable for directional data.
  • The Minkowski metric weighted k-means deals with the problem of noise features by assigning weights to each feature per cluster.



Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import java.util.*;
import java.math.*;
import java.io.*;
 
public class kmeans {
    static ArrayList<Vector> points = new ArrayList<Vector>();
    static ArrayList<Vector> seeds = new ArrayList<Vector>();
    static int nCluster;
    static float maxDistance;
     
    public static void printSeed() {
        for(int i=0; i<nCluster; i++)
            System.out.println("Seed #" + i + " = (" + seeds.get(i).elementAt(0) + ", " + seeds.get(i).elementAt(1) + ")");
        System.out.println();
    }
     
    public static void randomSeed() {
        Random random = new Random();
        float maxX = 0.0f;
        float maxY = 0.0f;
        float tempX, tempY;
        ArrayList<Float> randomList = new ArrayList<Float>();
         
        for(int i=0; i<points.size(); i++) {
            tempX = (float)points.get(i).elementAt(0);
            tempY = (float)points.get(i).elementAt(1);
            if(maxX < tempX) maxX = tempX;
            if(maxY < tempY) maxY = tempY;
        }
         
        maxDistance = (float)Math.pow(maxX, 2) + (float)Math.pow(maxY, 2);
         
        randomList.clear();
        seeds.clear();
        for(int i=0; i<nCluster; i++) {
            do {
                tempX = random.nextFloat() * maxX;
            }while(randomList.contains(tempX));
             
            do {
                tempY = random.nextFloat() * maxY;
            }while(randomList.contains(tempY));
             
            Vector<Float> seedPos = new Vector<Float>();
            seedPos.add(tempX);
            seedPos.add(tempY);
             
            seeds.add(seedPos);
        }
    }
     
    public static float calDistance(Vector<Object> a, Vector<Object> b) {
        float distanceX = (float)a.elementAt(0) - (float)b.elementAt(0);
        float distanceY = (float)a.elementAt(1) - (float)b.elementAt(1);
         
        return ((float)Math.pow(distanceX, 2) + (float)Math.pow(distanceY, 2));
    }
     
    public static void doClustering() {
        float minDistance = maxDistance;
        float temp;
        int cluster;
         
        for(int j=0; j<points.size(); j++) {
            cluster = -1;
            minDistance = maxDistance;
            for(int i=0; i<nCluster; i++) {
                temp = calDistance(points.get(j), seeds.get(i));
                if(minDistance > temp) {
                    minDistance = temp;
                    cluster = i;
                }
            }
            points.get(j).set(2, cluster);
        }
    }
     
    public static void newSeed() {
        int cnt = 0;
        float newX = 0.0f;
        float newY = 0.0f;
         
        for(int i=0; i<nCluster; i++) {
            cnt = 0;
            newX = 0.0f;
            newY = 0.0f;
            for(int j=0; j<points.size(); j++) {
                if((int)points.get(j).elementAt(2) == i) {
                    newX += (float)points.get(j).elementAt(0);
                    newY += (float)points.get(j).elementAt(1);
                    cnt++;
                }
            }
             
            if(cnt == 0) cnt = 1;
            seeds.get(i).set(0, (newX / cnt));
            seeds.get(i).set(1, (newY / cnt));
        }
    }
     
    public static void setData(File data) {
        String oneline = "";
        float x, y;
         
        try {
            BufferedReader br = new BufferedReader(new FileReader(data));
             
            while((oneline = br.readLine()) != null) {
                StringTokenizer token = new StringTokenizer(oneline, "\t");
                 
                x = Float.parseFloat(token.nextToken());
                token.hasMoreTokens();
                y = Float.parseFloat(token.nextToken());
                 
                Vector<Object> v = new Vector<Object>();
                v.addElement(x);
                v.addElement(y);
                v.addElement(new Integer(-1));
                points.add(v);
            }
             
            br.close();
        } catch(FileNotFoundException e) {
            e.printStackTrace();
        } catch(IOException e) {
            e.printStackTrace();
        }
    }
 
    public static int unChanged(ArrayList<Vector> temp) {
        int result = 1;
         
        for(int i=0; i<nCluster; i++) {
            System.out.println(temp.get(i).elementAt(0) + ", " + temp.get(i).elementAt(1));
        }
         
        for(int i=0; i<nCluster; i++) {
            if((float)seeds.get(i).elementAt(0) != (float)temp.get(i).elementAt(0))
                result *= 0;
            if((float)seeds.get(i).elementAt(1) != (float)temp.get(i).elementAt(1))
                result *= 0;
        }
        return result;
    }
     
    public static void printPoints() {
        for(int i=0; i<points.size(); i++) {
            System.out.println(points.get(i).elementAt(0) + ", " + points.get(i).elementAt(1) + ", " + points.get(i).elementAt(2));
        }
    }
     
    public static void main(String[] args) {
        nCluster = Integer.parseInt(args[0]);
        float[] prevX = new float[nCluster];
        float[] prevY = new float[nCluster];
        int changed;
        String path = "C:\\Users\\Park\\workspace\\k-means\\bin\\";
        File data = new File(path + args[1]);
        String[] wfilename = new String[nCluster];
         
        setData(data);
        randomSeed();
        printSeed();
         
        while(true) {
            for(int i=0; i<nCluster; i++) {
                prevX[i] = (float)seeds.get(i).elementAt(0);
                prevY[i] = (float)seeds.get(i).elementAt(1);
            }
             
            doClustering();
            newSeed();
            printSeed();
             
            changed = 0;
            for(int i=0; i<nCluster; i++) {
                if(prevX[i] != (float)seeds.get(i).elementAt(0))
                    changed += 1;
                if(prevY[i] != (float)seeds.get(i).elementAt(1))
                    changed += 1;
            }
             
            if(changed == 0) break;
        }
         
        try{
            BufferedWriter bw[] = new BufferedWriter[nCluster];
            for(int i=0; i<nCluster; i++){
                wfilename[i] = "cluster" + (i+1) + ".txt";
                bw[i] = new BufferedWriter(new FileWriter(path+wfilename[i]));
                 
                for(int j=0; j<points.size(); j++) {
                    if((int)points.get(j).elementAt(2) == i) {
                        bw[i].write(points.get(j).elementAt(0) + "\t" + points.get(j).elementAt(1));
                        bw[i].newLine();
                    }
                }
                 
                bw[i].close();
            }
        } catch(IOException e) {
            e.printStackTrace();
        }
    }
}



References