001package org.clafer.ir.analysis;
002
003import org.clafer.ir.IrBoolExpr;
004import org.clafer.ir.IrCompare;
005import static org.clafer.ir.IrCompare.Op.Equal;
006import static org.clafer.ir.IrCompare.Op.NotEqual;
007import org.clafer.ir.IrDomain;
008import org.clafer.ir.IrImplies;
009import org.clafer.ir.IrIntExpr;
010import org.clafer.ir.IrLone;
011import org.clafer.ir.IrModule;
012import org.clafer.ir.IrOffset;
013import org.clafer.ir.IrOr;
014import org.clafer.ir.IrRewriter;
015import org.clafer.ir.IrSetExpr;
016import org.clafer.ir.IrUtil;
017import static org.clafer.ir.Irs.*;
018
019/**
020 *
021 * @author jimmy
022 */
023public class Optimizer {
024
025    private Optimizer() {
026    }
027
028    /**
029     * Optimize the module.
030     *
031     * @param module the module to optimize
032     * @return the optimized module
033     */
034    public static IrModule optimize(IrModule module) {
035        return optimizer.rewrite(module, null);
036    }
037    private static final IrRewriter<Void> optimizer = new IrRewriter<Void>() {
038        @Override
039        public IrBoolExpr visit(IrLone ir, Void a) {
040            IrBoolExpr[] operands = rewrite(ir.getOperands(), a);
041            if (operands.length == 2) {
042                if (operands[0] instanceof IrCompare) {
043                    IrBoolExpr antecedent = operands[1];
044                    IrCompare compare = (IrCompare) operands[0];
045                    if (compare.getOp().isEquality()) {
046                        IrBoolExpr opt = optimizeLoneCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight());
047                        if (opt == null) {
048                            opt = optimizeLoneCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft());
049                        }
050                        if (opt != null) {
051                            return opt;
052                        }
053                    }
054                }
055                if (operands[1] instanceof IrCompare) {
056                    IrBoolExpr antecedent = operands[0];
057                    IrCompare compare = (IrCompare) operands[1];
058                    if (compare.getOp().isEquality()) {
059                        IrBoolExpr opt = optimizeLoneCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight());
060                        if (opt == null) {
061                            opt = optimizeLoneCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft());
062                        }
063                        if (opt != null) {
064                            return opt;
065                        }
066                    }
067                }
068            }
069            return changed(ir.getOperands(), operands)
070                    ? lone(operands)
071                    : ir;
072        }
073
074        @Override
075        public IrBoolExpr visit(IrOr ir, Void a) {
076            IrBoolExpr[] operands = rewrite(ir.getOperands(), a);
077            if (operands.length == 2) {
078                if (operands[0] instanceof IrCompare) {
079                    IrBoolExpr antecedent = operands[1];
080                    IrCompare compare = (IrCompare) operands[0];
081                    if (compare.getOp().isEquality()) {
082                        IrBoolExpr opt = optimizeOrCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight());
083                        if (opt == null) {
084                            opt = optimizeOrCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft());
085                        }
086                        if (opt != null) {
087                            return opt;
088                        }
089                    }
090                }
091                if (operands[1] instanceof IrCompare) {
092                    IrBoolExpr antecedent = operands[0];
093                    IrCompare compare = (IrCompare) operands[1];
094                    if (compare.getOp().isEquality()) {
095                        IrBoolExpr opt = optimizeOrCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight());
096                        if (opt == null) {
097                            opt = optimizeOrCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft());
098                        }
099                        if (opt != null) {
100                            return opt;
101                        }
102                    }
103                }
104            }
105            return changed(ir.getOperands(), operands)
106                    ? or(operands)
107                    : ir;
108        }
109
110        @Override
111        public IrBoolExpr visit(IrImplies ir, Void a) {
112            // Rewrite
113            //     !a => !b
114            // to
115            //     b => a
116            if (ir.getAntecedent().isNegative() && ir.getConsequent().isNegative()) {
117                return rewrite(implies(not(ir.getConsequent()), not(ir.getAntecedent())), a);
118            }
119            // Rewrite
120            //     !a => b
121            // to
122            //     a or b
123            if (ir.getAntecedent().isNegative()) {
124                return rewrite(or(not(ir.getAntecedent()), ir.getConsequent()), a);
125            }
126            // Rewrite
127            //     a => !b
128            // to
129            //     a + b <= 1
130            if (ir.getConsequent().isNegative()) {
131                return rewrite(lone(ir.getAntecedent(), not(ir.getConsequent())), a);
132            }
133            IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a);
134            IrBoolExpr consequent = rewrite(ir.getConsequent(), a);
135            if (consequent instanceof IrCompare) {
136                IrCompare compare = (IrCompare) consequent;
137                if (compare.getOp().isEquality()) {
138                    IrBoolExpr opt = optimizeImplicationCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight());
139                    if (opt == null) {
140                        opt = optimizeImplicationCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft());
141                    }
142                    if (opt != null) {
143                        return opt;
144                    }
145                }
146            }
147            return changed(ir.getAntecedent(), antecedent)
148                    || changed(ir.getConsequent(), consequent)
149                    ? implies(antecedent, consequent)
150                    : ir;
151        }
152
153        @Override
154        public IrSetExpr visit(IrOffset ir, Void a) {
155            if (ir.getSet() instanceof IrOffset) {
156                // Rewrite
157                //    offset(offset(set, a), b)
158                // to
159                //    offset(set, a + b)
160                // This optimization is important for going multiple steps up the
161                // hierarchy.
162                IrOffset innerOffset = (IrOffset) ir.getSet();
163                return rewrite(offset(innerOffset.getSet(),
164                        ir.getOffset() + innerOffset.getOffset()), a);
165            }
166            return super.visit(ir, a);
167        }
168    };
169
170    /**
171     * Optimize {@code lone(antecedent, left `op` right)} where `op` is = or !=.
172     */
173    private static IrBoolExpr optimizeLoneCompare(IrBoolExpr antecedent, IrIntExpr left, IrCompare.Op op, IrIntExpr right) {
174        IrDomain domain = left.getDomain();
175        Integer constant = IrUtil.getConstant(right);
176        if (domain.size() == 2 && constant != null) {
177            switch (op) {
178                case Equal:
179                    // Rewrite
180                    //     lone(bool, int = 888)
181                    //         where dom(int) = {-3, 888}
182                    // to
183                    //     asInt(bool) <= 888 - int
184                    //     asInt(bool) + int <= 888
185                    if (domain.getHighBound() == constant.intValue()) {
186                        return lessThanEqual(add(antecedent, left),
187                                domain.getHighBound());
188                    }
189                    // Rewrite
190                    //     lone(bool, int = -3)
191                    //         where dom(int) = {-3, 888}
192                    // to
193                    //     asInt(bool) <= int - (-3)
194                    if (domain.getLowBound() == constant.intValue()) {
195                        return lessThanEqual(antecedent,
196                                sub(left, domain.getLowBound()));
197                    }
198                    break;
199                case NotEqual:
200                    // Rewrite
201                    //     lone(bool, int != 888)
202                    //         where dom(int) = {-3, 888}
203                    // to
204                    //     asInt(bool) <= int - (-3)
205                    if (domain.getHighBound() == constant.intValue()) {
206                        return lessThanEqual(antecedent,
207                                sub(left, domain.getLowBound()));
208                    }
209                    // Rewrite
210                    //     lone(bool, int != -3)
211                    //         where dom(int) = {-3, 888}
212                    // to
213                    //     asInt(bool) <= 888 - int
214                    //     asInt(bool) + int <= 888
215                    if (domain.getLowBound() == constant.intValue()) {
216                        return lessThanEqual(add(antecedent, left),
217                                domain.getHighBound());
218                    }
219                    break;
220            }
221        }
222        return null;
223    }
224
225    /**
226     * Optimize {@code antecedent or (left `op` right)} where `op` is = or !=.
227     */
228    private static IrBoolExpr optimizeOrCompare(IrBoolExpr antecedent, IrIntExpr left, IrCompare.Op op, IrIntExpr right) {
229        IrDomain domain = left.getDomain();
230        Integer constant = IrUtil.getConstant(right);
231        if (domain.size() == 2 && constant != null) {
232            switch (op) {
233                case Equal:
234                    // Rewrite
235                    //     bool or int = 888
236                    //         where dom(int) = {-3, 888}
237                    // to
238                    //     asInt(bool) > (-3) - int
239                    //     asInt(bool) + int > (-3)
240                    if (domain.getHighBound() == constant.intValue()) {
241                        return greaterThan(add(antecedent, left),
242                                domain.getLowBound());
243                    }
244                    // Rewrite
245                    //     bool or int = -3
246                    //         where dom(int) = {-3, 888}
247                    // to
248                    //     asInt(bool) > int - 888
249                    if (domain.getLowBound() == constant.intValue()) {
250                        return greaterThan(antecedent,
251                                sub(left, domain.getHighBound()));
252                    }
253                    break;
254                case NotEqual:
255                    // Rewrite
256                    //     bool or int != 888
257                    //         where dom(int) = {-3, 888}
258                    // to
259                    //     asInt(bool) > int - 888
260                    if (domain.getHighBound() == constant.intValue()) {
261                        return greaterThan(antecedent,
262                                sub(left, domain.getHighBound()));
263                    }
264                    // Rewrite
265                    //     bool or int != -3
266                    //         where dom(int) = {-3, 888}
267                    // to
268                    //     asInt(bool) > (-3) - int
269                    //     asInt(bool) + int > (-3)
270                    if (domain.getLowBound() == constant.intValue()) {
271                        return greaterThan(add(antecedent, left),
272                                domain.getLowBound());
273                    }
274                    break;
275            }
276        }
277        return null;
278    }
279
280    /**
281     * Optimize {@code antecedent => (left `op` right)} where `op` is = or !=.
282     */
283    private static IrBoolExpr optimizeImplicationCompare(IrBoolExpr antecedent, IrIntExpr left, IrCompare.Op op, IrIntExpr right) {
284        IrDomain domain = left.getDomain();
285        Integer constant = IrUtil.getConstant(right);
286        if (domain.size() == 2 && constant != null) {
287            switch (op) {
288                case Equal:
289                    // Rewrite
290                    //     bool => int = 888
291                    //         where dom(int) = {-3, 888}
292                    // to
293                    //     asInt(bool) <= int - (-3)
294                    if (domain.getHighBound() == constant.intValue()) {
295                        return lessThanEqual(antecedent,
296                                sub(left, domain.getLowBound()));
297                    }
298                    // Rewrite
299                    //     bool => int = -3
300                    //         where dom(int) = {-3, 888}
301                    // to
302                    //     asInt(bool) <= 888 - int
303                    //     asInt(bool) + int <= 888
304                    if (domain.getLowBound() == constant.intValue()) {
305                        return lessThanEqual(add(antecedent, left),
306                                domain.getHighBound());
307                    }
308                    break;
309                case NotEqual:
310                    // Rewrite
311                    //     bool => int != 888
312                    //         where dom(int) = {-3, 888}
313                    // to
314                    //     asInt(bool) <= 888 - int
315                    //     asInt(bool) + int <= 888
316                    if (domain.getHighBound() == constant.intValue()) {
317                        return lessThanEqual(add(antecedent, left),
318                                domain.getHighBound());
319                    }
320                    // Rewrite
321                    //     bool => int != -3
322                    //         where dom(int) = {-3, 888}
323                    // to
324                    //     asInt(bool) <= int - (-3)
325                    if (domain.getLowBound() == constant.intValue()) {
326                        return lessThanEqual(antecedent,
327                                sub(left, domain.getLowBound()));
328                    }
329                    break;
330            }
331        }
332        return null;
333    }
334}