Skip to content

Commit 3b7d08c

Browse files
crmymhcopybara-github
authored andcommitted
Update KMeans n_init to be backwards compatible with 1.0.2.
PiperOrigin-RevId: 744022645 Change-Id: Ief51379ef54fe943c6d7b6651fff73bd6cdb29ac
1 parent 2198dc3 commit 3b7d08c

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

qkeras/codebook.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def activation_compression(model, compile_config, activation_indexes, bits,
6969
"""
7070
assert len(activation_indexes) > 0
7171
assert 0.0 < sample_size <= 1.0
72-
km_models = [KMeans(2**bits)] * len(activation_indexes)
72+
# n_init=10 maintains the same behavior as legacy versions of sklearn. This
73+
# was changed to "auto" in sklearn 1.4.
74+
km_models = [KMeans(2**bits, n_init=10)] * len(activation_indexes)
7375
cb_tables = [[]] * len(activation_indexes)
7476
models = []
7577
x = x_in = model.layers[0].output
@@ -139,7 +141,7 @@ def weight_compression(weights, bits, axis=0, quantizer=None):
139141
for i, w in tqdm(enumerate(np.split(weights, weights.shape[axis], axis))):
140142
original_shape = w.shape
141143
w = w.ravel()
142-
km = KMeans(n)
144+
km = KMeans(n, n_init=10)
143145
km.fit(w.reshape(-1, 1))
144146
if quantizer:
145147
km.cluster_centers_ = quantizer(km.cluster_centers_).numpy()
@@ -178,7 +180,7 @@ def two_tier_embedding_compression(embeddings, bits, quantizer=None):
178180
cluster_index_table = np.zeros(index_table.shape[0], dtype=np.uint8)
179181
codebook_table = np.zeros((n, n))
180182

181-
km1 = KMeans(n)
183+
km1 = KMeans(n, n_init=10)
182184
km1.fit(embeddings)
183185
tier1 = km1.predict(embeddings)
184186

@@ -188,7 +190,7 @@ def two_tier_embedding_compression(embeddings, bits, quantizer=None):
188190
mask = block_label == tier1
189191
indices = np.arange(embeddings.shape[0])[mask]
190192
block = embeddings[mask]
191-
km2 = KMeans(n)
193+
km2 = KMeans(n, n_init=10)
192194
km2.fit(block.flatten().reshape(-1, 1))
193195
if quantizer:
194196
km2.cluster_centers_ = quantizer(km2.cluster_centers_).numpy()

0 commit comments

Comments
 (0)