-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
145 lines (117 loc) · 5.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
This script demonstrates how to use MINLP solvers to solve common computer Visio problems.
Here, we model the Maximum Consensus (MC) and Truncated Least Squares (TLS) problems as mixed-integer
linear (MILP) or nonlinear (MINLP) problem using Pyomo.
We can then evaluate some globally optimal branch-and-bound solvers, by default SCIP.
The instances of the optimization problems are written to files in GAMS and AMPL format, so that more
solvers can be evaluated. For example, these files can be uploaded to the NEOS server to evaluate commercial solvers such as BARON.
"""
import numpy as np
from generate_synthetic_instances import create_point_cloud_registration_problem, RegistrationProblemParams
import fire
from functools import partial
from pyomo.environ import *
def create_instance(rotation):
adversarial_suboptimality = 0.
N = 100
eps = .5
data_scale = 10.
outlier_rate = .1
problem_params = RegistrationProblemParams(N, adversarial_suboptimality, eps, outlier_rate, data_scale)
problem_instance = create_point_cloud_registration_problem(problem_params, dims=3, add_rotation=rotation, add_translation=not rotation)
return (problem_instance.P, problem_instance.Q, eps) if rotation else (problem_instance.P - problem_instance.Q, eps)
def create_max_consensus_problem(y, epsilon):
""" Create the maximum consensus problem as binary linear program (BLP)
as described in:
"Consensus Set Maximization with Guaranteed Global
Optimality for Robust Geometry Estimation", Hongdong Li, ICCV 2009
"""
N = y.shape[0]
d = y.shape[1]
model = ConcreteModel(name="ConsensusMaximizationBinary")
model.x = Var(range(d))
model.z = Var(range(N), within=Binary)
model.constraints = ConstraintList()
for i in range(N):
for j in range(d):
model.constraints.add(model.z[i] * abs(y[i, j] - model.x[j]) <= model.z[i] * epsilon)
model.objective = Objective(expr=sum(model.z[i] for i in range(N)), sense=maximize)
model.write(f"maximum_consensus_{d}d_model_binary.nl", io_options={"symbolic_solver_labels": True})
model.write(f"maximum_consensus_{d}d_model_binary.gms", io_options={"symbolic_solver_labels": True})
return model
def create_tls_translation_problem(y, epsilon):
"""
Creates the Truncated Least Squares (1D) translation estimation problem.
modeled without binary variables.
"""
N = y.shape[0]
d = y.shape[1]
model = ConcreteModel(name="TLS-SMU")
model.x = Var(range(1))
# SMU: Smooth approximation of max(x, 0) (i.e. ReLU) https://arxiv.org/pdf/2111.04682
def smu_relu(x): return (x + abs(x)) / 2
# Difference of Convex (DC) functions reformulation of TLS:
def tls_dc(r): return r - smu_relu(r - epsilon**2)
def objective(model):
return sum(tls_dc((y[i, 0] - model.x[0])**2) for i in range(N))
model.objective = Objective(rule=objective, sense=minimize)
model.write("tls_1d_dc_smu.nl", io_options={"symbolic_solver_labels": True})
model.write("tls_1d_dc_smu.gms", io_options={"symbolic_solver_labels": True})
return model
def quat_mul(a, b):
"""
Multiplies two quaternions in wxyz order.
See e.g. https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Quaternion.h as a reference.
"""
return [a[0] * b[0] - a[1] * b[1] - a[2] * b[2] - a[3] * b[3],
a[0] * b[1] + a[1] * b[0] + a[2] * b[3] - a[3] * b[2],
a[0] * b[2] + a[2] * b[0] + a[3] * b[1] - a[1] * b[3],
a[0] * b[3] + a[3] * b[0] + a[1] * b[2] - a[2] * b[1]]
def inv_quat(q): return [q[0], -q[1], -q[2], -q[3]]
def point_to_quat(a): return [0, a[0], a[1], a[2]]
def quat_to_point(a): return [a[1], a[2], a[3]]
def rotate_point_by_quat(q, p): return quat_to_point(quat_mul(quat_mul(q, point_to_quat(p)), inv_quat(q)))
def sq_norm(a): return a[0]**2 + a[1]**2 + a[2]**2
def sq_norm_q(a): return a[0]**2 + a[1]**2 + a[2]**2 + a[3]**2
def create_tls_rotation_problem(P, Q, epsilon):
"""
Creates a optimization problem for finding the rotation (parametrized as a quaternion)
between two 3D point cloud using the Tuncated Least Squares (TLS) objective.
Uses the TLS-cost from "TEASER" H. Yang et al., T-RO 2020.
"""
assert P.shape == Q.shape, "Both point clouds must have the same shape"
N = P.shape[0]
model = ConcreteModel(name="TLSRotationQuat")
model.x = Var(range(4)) # the quaternion
# SMU: Smooth approximation of max(x, 0) (i.e. ReLU) https://arxiv.org/pdf/2111.04682
def smu_relu(x): return (x + abs(x)) / 2
# Difference of Convex (DC) functions reformulation of TLS:
def min_dc(r): return r - smu_relu(r - epsilon**2)
model.constraints = ConstraintList()
# Unit-norm constraint for the quaternion
model.constraints.add(sq_norm_q(model.x) == 1)
def objective(model):
return sum(
min_dc(sq_norm(Q[i] - rotate_point_by_quat(model.x, P[i]))) for i in range(N))
model.objective = Objective(rule=objective, sense=minimize)
model.write("tls_rotation_quat.nl", io_options={
"symbolic_solver_labels": True})
model.write("tls_rotation_quat.gms", io_options={
"symbolic_solver_labels": True})
return model
def solve_with_pyomo(model):
solver = SolverFactory('scip')
solver.solve(model, tee=True)
# Display the solution
model.x.display()
def solve_problem(problem: str):
problem_factory = { "maximum_consensus": partial(create_max_consensus_problem, *create_instance(rotation=False)),
"tls_translation" : partial(create_tls_translation_problem, *create_instance(rotation=False)),
"tls_rotation" : partial(create_tls_rotation_problem, *create_instance(rotation=True)),
}
if problem not in problem_factory.keys():
print(f"Unknown problem: '{problem}', available ones: {list(problem_factory.keys())}")
return
solve_with_pyomo(problem_factory[problem]())
if __name__ == "__main__":
fire.Fire()