001package org.clafer.ir.analysis;
002
003import java.util.HashSet;
004import java.util.Set;
005import org.clafer.collection.Pair;
006import org.clafer.ir.IrBoolExpr;
007import org.clafer.ir.IrCompare;
008import org.clafer.ir.IrIntExpr;
009import org.clafer.ir.IrModule;
010import org.clafer.ir.IrVar;
011import org.clafer.ir.Irs;
012
013/**
014 * Remove constraints that are duplicated. One example is to remove weaker
015 * constraints that are already enforced by stronger ones. Another example is
016 * due to coalescing, it is possible to end up with multiple cardinality
017 * variables for the same set. Remove the duplicates.
018 *
019 * @author jimmy
020 */
021public class DuplicateConstraints {
022
023    private DuplicateConstraints() {
024    }
025
026    /**
027     * Reduce.
028     *
029     * <ol>
030     * <li>a = b & a ≤ b to a = b</li>
031     * <li>a ≠ b & a &lt; b to a &lt; b</li>
032     * <li>a ≠ b & a ≤ b to a &lt; b</li>
033     * <li>a &lt; b & a ≤ b to a &lt; b</li>
034     * <li>a ≤ b & b ≤ a to a = b</li>
035     * </ol>
036     *
037     * @param module the module to remove duplicates
038     * @return the optimized module
039     */
040    public static IrModule removeDuplicates(IrModule module) {
041        HashSet<Pair<IrIntExpr, IrIntExpr>> equal = new HashSet<>();
042        HashSet<Pair<IrIntExpr, IrIntExpr>> notEqual = new HashSet<>();
043        HashSet<Pair<IrIntExpr, IrIntExpr>> lessThanEqual = new HashSet<>();
044        HashSet<Pair<IrIntExpr, IrIntExpr>> lessThan = new HashSet<>();
045
046        Set<IrBoolExpr> constraints = new HashSet<>(module.getConstraints().size());
047
048        for (IrBoolExpr constraint : module.getConstraints()) {
049            if (constraint instanceof IrCompare) {
050                IrCompare compare = (IrCompare) constraint;
051                IrIntExpr left = compare.getLeft();
052                IrIntExpr right = compare.getRight();
053                Pair<IrIntExpr, IrIntExpr> pair = new Pair<>(left, right);
054                Pair<IrIntExpr, IrIntExpr> converse = new Pair<>(right, left);
055
056                switch (compare.getOp()) {
057                    case Equal:
058                        lessThanEqual.remove(pair);
059                        lessThanEqual.remove(converse);
060                        equal.add(pair);
061                        equal.remove(converse);
062                        break;
063                    case NotEqual:
064                        if (lessThanEqual.remove(pair)) {
065                            lessThan.add(pair);
066                        } else if (lessThanEqual.remove(converse)) {
067                            lessThan.add(converse);
068                        } else if (!lessThan.contains(pair)
069                                && !lessThan.contains(converse)) {
070                            notEqual.add(pair);
071                            notEqual.remove(converse);
072                        }
073                        break;
074                    case LessThan:
075                        notEqual.remove(pair);
076                        notEqual.remove(converse);
077                        lessThanEqual.remove(pair);
078                        lessThan.add(pair);
079                        break;
080                    case LessThanEqual:
081                        if (notEqual.remove(pair)) {
082                            lessThan.add(pair);
083                        } else if (notEqual.remove(converse)) {
084                            lessThan.add(converse);
085                        } else if (lessThanEqual.remove(converse)) {
086                            equal.add(pair);
087                        } else if (!equal.contains(pair)
088                                && !equal.contains(converse)
089                                && !lessThan.contains(pair)) {
090                            lessThanEqual.add(pair);
091                        }
092                        break;
093                    default:
094                        throw new IllegalStateException();
095                }
096            } else {
097                constraints.add(constraint);
098            }
099        }
100
101        for (Pair<IrIntExpr, IrIntExpr> p : equal) {
102            constraints.add(Irs.equal(p.getFst(), p.getSnd()));
103        }
104        for (Pair<IrIntExpr, IrIntExpr> p : notEqual) {
105            constraints.add(Irs.notEqual(p.getFst(), p.getSnd()));
106        }
107        for (Pair<IrIntExpr, IrIntExpr> p : lessThan) {
108            constraints.add(Irs.lessThan(p.getFst(), p.getSnd()));
109        }
110        for (Pair<IrIntExpr, IrIntExpr> p : lessThanEqual) {
111            constraints.add(Irs.lessThanEqual(p.getFst(), p.getSnd()));
112        }
113
114        HashSet<IrVar> variables = new HashSet<>(module.getVariables());
115
116        return new IrModule().addVariables(variables).addConstraints(constraints);
117    }
118}