Source code for jycm.km_matcher

# from: https://github.com/mayorx/hungarian-algorithm

import numpy as np


# max weight assignment
[docs]class KMMatcher: # weights : nxm weight matrix (numpy , float), n <= m def __init__(self, weights): weights = np.array(weights).astype(np.float32) n, m = weights.shape self.reverted = False if n > m: self.reverted = True weights = weights.transpose() n, m = m, n self.weights = weights self.n, self.m = n, m assert self.n <= self.m # init label self.label_x = np.max(weights, axis=1) self.label_y = np.zeros((self.m,), dtype=np.float32) self.max_match = 0 self.xy = -np.ones((self.n,), dtype=int) self.yx = -np.ones((self.m,), dtype=int)
[docs] def do_augment(self, x, y): self.max_match += 1 while x != -2: self.yx[y] = x ty = self.xy[x] self.xy[x] = y x, y = self.prev[x], ty
[docs] def find_augment_path(self): self.S = np.zeros((self.n,), bool) self.T = np.zeros((self.m,), bool) self.slack = np.zeros((self.m,), dtype=np.float32) self.slackyx = -np.ones((self.m,), dtype=int) # l[slackyx[y]] + l[y] - w[slackx[y], y] == slack[y] self.prev = -np.ones((self.n,), int) queue, st = [], 0 root = -1 for x in range(self.n): if self.xy[x] == -1: queue.append(x) root = x self.prev[x] = -2 self.S[x] = True break self.slack = self.label_y + self.label_x[root] - self.weights[root] self.slackyx[:] = root while True: while st < len(queue): x = queue[st] st += 1 is_in_graph = np.isclose(self.weights[x], self.label_x[x] + self.label_y) nonzero_inds = np.nonzero(np.logical_and(is_in_graph, np.logical_not(self.T)))[0] for y in nonzero_inds: if self.yx[y] == -1: return x, y self.T[y] = True queue.append(self.yx[y]) self.add_to_tree(self.yx[y], x) self.update_labels() queue, st = [], 0 is_in_graph = np.isclose(self.slack, 0) nonzero_inds = np.nonzero(np.logical_and(is_in_graph, np.logical_not(self.T)))[0] for y in nonzero_inds: x = self.slackyx[y] if self.yx[y] == -1: return x, y self.T[y] = True if not self.S[self.yx[y]]: queue.append(x) self.add_to_tree(self.yx[y], x)
[docs] def solve(self, verbose=False): while self.max_match < self.n: x, y = self.find_augment_path() self.do_augment(x, y) sum_ = 0. pairs = [] for x in range(self.n): if verbose: print('match {} to {}, weight {:.4f}'.format(x, self.xy[x], self.weights[x, self.xy[x]])) pairs.append((x, self.xy[x])) sum_ += self.weights[x, self.xy[x]] self.best = sum_ if verbose: print('ans: {:.4f}'.format(sum_)) if self.reverted: return sum_, [(y, x) for x, y in pairs] return sum_, pairs
[docs] def add_to_tree(self, x, prevx): self.S[x] = True self.prev[x] = prevx better_slack_idx = self.label_x[x] + self.label_y - self.weights[x] < self.slack self.slack[better_slack_idx] = self.label_x[x] + self.label_y[better_slack_idx] - self.weights[ x, better_slack_idx] self.slackyx[better_slack_idx] = x
[docs] def update_labels(self): delta = self.slack[np.logical_not(self.T)].min() self.label_x[self.S] -= delta self.label_y[self.T] += delta self.slack[np.logical_not(self.T)] -= delta
if __name__ == '__main__': matcher = KMMatcher([ [2., 3., 0., 3.], [0., 4., 0., 100.], [5., 6., 0., 0.], ]) _, _pairs = matcher.solve(verbose=True) print(_pairs) matcher = KMMatcher([ [2., 0., 5., ], [3., 4., 6., ], [0., 0., 0., ], [3., 100., 0., ], ]) _, _pairs = matcher.solve(verbose=True) print(_pairs)