001package org.clafer.ir.analysis;
002
003import gnu.trove.iterator.TIntIterator;
004import gnu.trove.set.TIntSet;
005import gnu.trove.set.hash.TIntHashSet;
006import java.util.Arrays;
007import java.util.HashMap;
008import java.util.Iterator;
009import java.util.Map;
010import java.util.Set;
011import org.clafer.collection.DisjointSets;
012import org.clafer.collection.Pair;
013import org.clafer.collection.Triple;
014import org.clafer.ir.IrAdd;
015import org.clafer.ir.IrAllDifferent;
016import org.clafer.ir.IrArrayToSet;
017import org.clafer.ir.IrBoolChannel;
018import org.clafer.ir.IrBoolDomain;
019import org.clafer.ir.IrBoolExpr;
020import org.clafer.ir.IrBoolExprVisitorAdapter;
021import org.clafer.ir.IrBoolVar;
022import org.clafer.ir.IrCard;
023import org.clafer.ir.IrCompare;
024import org.clafer.ir.IrDomain;
025import org.clafer.ir.IrElement;
026import org.clafer.ir.IrFilterString;
027import org.clafer.ir.IrIfOnlyIf;
028import org.clafer.ir.IrIntChannel;
029import org.clafer.ir.IrIntExpr;
030import org.clafer.ir.IrIntVar;
031import org.clafer.ir.IrJoinFunction;
032import org.clafer.ir.IrJoinRelation;
033import org.clafer.ir.IrMember;
034import org.clafer.ir.IrMinus;
035import org.clafer.ir.IrModule;
036import org.clafer.ir.IrNot;
037import org.clafer.ir.IrNotImplies;
038import org.clafer.ir.IrNotMember;
039import org.clafer.ir.IrNotWithin;
040import org.clafer.ir.IrOffset;
041import org.clafer.ir.IrRewriter;
042import org.clafer.ir.IrSelectN;
043import org.clafer.ir.IrSetExpr;
044import org.clafer.ir.IrSetTest;
045import org.clafer.ir.IrSetUnion;
046import org.clafer.ir.IrSetVar;
047import org.clafer.ir.IrSingleton;
048import org.clafer.ir.IrSortSets;
049import org.clafer.ir.IrSortStrings;
050import org.clafer.ir.IrSortStringsChannel;
051import org.clafer.ir.IrSubsetEq;
052import org.clafer.ir.IrUtil;
053import org.clafer.ir.IrWithin;
054import org.clafer.ir.Irs;
055import static org.clafer.ir.Irs.*;
056
057/**
058 * @author jimmy
059 */
060public class Coalescer {
061
062    private Coalescer() {
063    }
064
065    public static Triple<Map<IrIntVar, IrIntVar>, Map<IrSetVar, IrSetVar>, IrModule> coalesce(IrModule module) {
066        Pair<DisjointSets<IrIntVar>, DisjointSets<IrSetVar>> graphs = findEquivalences(module.getConstraints());
067        DisjointSets<IrIntVar> intGraph = graphs.getFst();
068        DisjointSets<IrSetVar> setGraph = graphs.getSnd();
069        Map<IrIntVar, IrIntVar> coalescedInts = new HashMap<>();
070        Map<IrSetVar, IrSetVar> coalescedSets = new HashMap<>();
071
072        for (Set<IrIntVar> component : intGraph.connectedComponents()) {
073            if (component.size() > 1) {
074                Iterator<IrIntVar> iter = component.iterator();
075                IrIntVar var = iter.next();
076                StringBuilder name = new StringBuilder().append(var.getName());
077                IrDomain domain = var.getDomain();
078                while (iter.hasNext()) {
079                    var = iter.next();
080                    name.append(';').append(var.getName());
081                    domain = IrUtil.intersection(domain, var.getDomain());
082                }
083                if (domain.isEmpty()) {
084                    // Model is unsatisfiable. Compile anyways?
085                } else {
086                    IrIntVar coalesced = domainInt(name.toString(), domain);
087                    for (IrIntVar coalesce : component) {
088                        if (!coalesced.equals(coalesce)) {
089                            coalescedInts.put(coalesce, coalesced);
090                        }
091                    }
092                }
093            }
094        }
095        for (Set<IrSetVar> component : setGraph.connectedComponents()) {
096            if (component.size() > 1) {
097                Iterator<IrSetVar> iter = component.iterator();
098                IrSetVar var = iter.next();
099                StringBuilder name = new StringBuilder().append(var.getName());
100                IrDomain env = var.getEnv();
101                IrDomain ker = var.getKer();
102                IrDomain card = var.getCard();
103                while (iter.hasNext()) {
104                    var = iter.next();
105                    name.append(';').append(var.getName());
106                    env = IrUtil.intersection(env, var.getEnv());
107                    ker = IrUtil.union(ker, var.getKer());
108                    card = IrUtil.intersection(card, var.getCard());
109                }
110                IrSetVar coalesced = newSet(name.toString(), env, ker, card);
111                if (coalesced != null) {
112                    for (IrSetVar coalesce : component) {
113                        if (!coalesced.equals(coalesce)) {
114                            coalescedSets.put(coalesce, coalesced);
115                        }
116                    }
117                }
118            }
119        }
120        return new Triple<>(
121                coalescedInts,
122                coalescedSets,
123                new CoalesceRewriter(coalescedInts, coalescedSets).rewrite(module, null));
124    }
125
126    private static IrSetVar newSet(String name, IrDomain env, IrDomain ker, IrDomain card) {
127        if (!IrUtil.isSubsetOf(ker, env) || ker.size() > env.size()) {
128            // Model is unsatisfiable. Compile anyways?
129        } else {
130            IrDomain boundCard = IrUtil.intersection(boundDomain(ker.size(), env.size()), card);
131            if (boundCard.isEmpty()) {
132                // Model is unsatisfiable. Compile anyways?
133            } else {
134                return set(name.toString(), env, ker, boundCard);
135            }
136        }
137        return null;
138    }
139
140    private static Pair<DisjointSets<IrIntVar>, DisjointSets<IrSetVar>> findEquivalences(Iterable<IrBoolExpr> constraints) {
141        DisjointSets<IrIntVar> intGraph = new DisjointSets<>();
142        DisjointSets<IrSetVar> setGraph = new DisjointSets<>();
143        EquivalenceFinder finder = new EquivalenceFinder(intGraph, setGraph);
144        for (IrBoolExpr constraint : constraints) {
145            constraint.accept(finder, null);
146        }
147        return new Pair<>(intGraph, setGraph);
148    }
149
150    private static class EquivalenceFinder extends IrBoolExprVisitorAdapter<Void, Void> {
151
152        private final DisjointSets<IrIntVar> intGraph;
153        private final DisjointSets<IrSetVar> setGraph;
154        private final Map<IrSetVar, IrIntVar> duplicates = new HashMap<>();
155
156        private EquivalenceFinder(DisjointSets<IrIntVar> intGraph, DisjointSets<IrSetVar> setGraph) {
157            this.intGraph = intGraph;
158            this.setGraph = setGraph;
159        }
160
161        @Override
162        public Void visit(IrBoolVar ir, Void a) {
163            if (IrBoolDomain.BoolDomain.equals(ir.getDomain())) {
164                intGraph.union(ir, True);
165            }
166            return null;
167        }
168
169        @Override
170        public Void visit(IrNot ir, Void a) {
171            propagateInt(FalseDomain, ir.getExpr());
172            return null;
173        }
174
175        @Override
176        public Void visit(IrNotImplies ir, Void a) {
177            propagateEqual(ir.getAntecedent(), One);
178            propagateEqual(ir.getConsequent(), Zero);
179            return null;
180        }
181
182        @Override
183        public Void visit(IrIfOnlyIf ir, Void a) {
184            propagateEqual(ir.getLeft(), ir.getRight());
185            return null;
186        }
187
188        @Override
189        public Void visit(IrWithin ir, Void a) {
190            propagateInt(ir.getRange(), ir.getValue());
191            return null;
192        }
193
194        @Override
195        public Void visit(IrNotWithin ir, Void a) {
196            propagateInt(IrUtil.difference(ir.getValue().getDomain(), ir.getRange()), ir.getValue());
197            return null;
198        }
199
200        @Override
201        public Void visit(IrCompare ir, Void a) {
202            IrIntExpr left = ir.getLeft();
203            IrIntExpr right = ir.getRight();
204            switch (ir.getOp()) {
205                case Equal:
206                    propagateEqual(left, right);
207                    break;
208                case NotEqual:
209                    propagateNotEqual(left, right);
210                    break;
211                case LessThan:
212                    propagateLessThan(left, right);
213                    break;
214                case LessThanEqual:
215                    propagateLessThanEqual(left, right);
216                    break;
217            }
218            return null;
219        }
220
221        @Override
222        public Void visit(IrSetTest ir, Void a) {
223            IrSetExpr left = ir.getLeft();
224            IrSetExpr right = ir.getRight();
225            if (IrSetTest.Op.Equal.equals(ir.getOp())) {
226                if (left instanceof IrSetVar && right instanceof IrSetVar) {
227                    setGraph.union((IrSetVar) left, (IrSetVar) right);
228                } else {
229                    propagateSet(new PartialSet(left.getEnv(), left.getKer(), left.getCard()), right);
230                    propagateSet(new PartialSet(right.getEnv(), right.getKer(), right.getCard()), left);
231                }
232            }
233            return null;
234        }
235
236        @Override
237        public Void visit(IrMember ir, Void a) {
238            IrIntExpr element = ir.getElement();
239            IrSetExpr set = ir.getSet();
240            propagateInt(set.getEnv(), element);
241            IrDomain ker = null;
242            Integer constant = IrUtil.getConstant(element);
243            if (constant != null && !set.getKer().contains(constant)) {
244                ker = IrUtil.add(set.getKer(), constant);
245            }
246            IrDomain card = null;
247            if (set.getCard().getLowBound() == 0) {
248                card = IrUtil.remove(set.getCard(), 0);
249            }
250            if (ker != null || card != null) {
251                propagateSet(new PartialSet(null, ker, card), set);
252            }
253            return null;
254        }
255
256        @Override
257        public Void visit(IrNotMember ir, Void a) {
258            IrIntExpr element = ir.getElement();
259            IrSetExpr set = ir.getSet();
260            IrDomain domain = IrUtil.difference(element.getDomain(), set.getKer());
261            propagateInt(domain, element);
262            Integer constant = IrUtil.getConstant(element);
263            if (constant != null && set.getEnv().contains(constant)) {
264                propagateEnv(IrUtil.remove(set.getEnv(), constant), set);
265            }
266            return null;
267        }
268
269        @Override
270        public Void visit(IrSubsetEq ir, Void a) {
271            IrSetExpr sub = ir.getSubset();
272            IrSetExpr sup = ir.getSuperset();
273            propagateSet(new PartialSet(sup.getEnv(), null, sup.getCard()), sub);
274            propagateSet(new PartialSet(null, sub.getKer(),
275                    IrUtil.boundLow(sup.getCard(), sub.getCard().getLowBound())), sub);
276            return null;
277        }
278
279        @Override
280        public Void visit(IrBoolChannel ir, Void a) {
281            IrBoolExpr[] bools = ir.getBools();
282            IrSetExpr set = ir.getSet();
283            IrDomain env = set.getEnv();
284            IrDomain ker = set.getKer();
285            TIntHashSet trues = new TIntHashSet(ker.size());
286            TIntHashSet notFalses = new TIntHashSet(env.size());
287            env.transferTo(notFalses);
288            boolean changed = false;
289            for (int i = 0; i < bools.length; i++) {
290                if (bools[i] instanceof IrBoolVar && !IrUtil.isConstant(bools[i])) {
291                    if (!env.contains(i)) {
292                        intGraph.union((IrBoolVar) bools[i], False);
293                    } else if (ker.contains(i)) {
294                        intGraph.union((IrBoolVar) bools[i], True);
295                    }
296                }
297                if (IrUtil.isTrue(ir)) {
298                    changed |= trues.add(i);
299                }
300                if (IrUtil.isFalse(bools[i])) {
301                    changed |= notFalses.remove(i);
302                }
303            }
304            if (changed) {
305                propagateSet(new PartialSet(enumDomain(notFalses), enumDomain(trues), null), set);
306            }
307            return null;
308        }
309
310        @Override
311        public Void visit(IrIntChannel ir, Void a) {
312            IrIntExpr[] ints = ir.getInts();
313            IrSetExpr[] sets = ir.getSets();
314
315            TIntSet kers = new TIntHashSet();
316
317            for (int i = 0; i < ints.length; i++) {
318                TIntSet domain = new TIntHashSet();
319                for (int j = 0; j < sets.length; j++) {
320                    if (sets[j].getEnv().contains(i)) {
321                        domain.add(j);
322                    }
323                }
324                propagateInt(enumDomain(domain), ints[i]);
325            }
326            int lowCards = 0;
327            int highCards = 0;
328            for (IrSetExpr set : sets) {
329                set.getKer().transferTo(kers);
330                lowCards += set.getCard().getLowBound();
331                highCards += set.getCard().getHighBound();
332            }
333            for (int i = 0; i < sets.length; i++) {
334                TIntSet env = new TIntHashSet();
335                TIntSet ker = new TIntHashSet();
336                for (int j = 0; j < ints.length; j++) {
337                    if (ints[j].getDomain().contains(i)) {
338                        env.add(j);
339                        if (ints[j].getDomain().size() == 1) {
340                            ker.add(j);
341                        }
342                    }
343                }
344                env.removeAll(kers);
345                sets[i].getKer().transferTo(env);
346                IrDomain card = boundDomain(
347                        ints.length - highCards + sets[i].getCard().getHighBound(),
348                        ints.length - lowCards + sets[i].getCard().getLowBound());
349                propagateSet(new PartialSet(enumDomain(env), enumDomain(ker), card), sets[i]);
350            }
351            return null;
352        }
353
354        @Override
355        public Void visit(IrSortStrings ir, Void a) {
356            IrIntExpr[][] strings = ir.getStrings();
357            for (int i = 0; i < strings.length - 1; i++) {
358                if (ir.isStrict()) {
359                    propagateLessThanString(strings[i], strings[i + 1]);
360                } else {
361                    propagateLessThanEqualString(strings[i], strings[i + 1]);
362                }
363            }
364            return null;
365        }
366
367        @Override
368        public Void visit(IrSortSets ir, Void a) {
369            IrSetExpr[] sets = ir.getSets();
370            int low = 0;
371            int high = 0;
372            for (int i = 0; i < sets.length; i++) {
373                IrSetExpr set = sets[i];
374                IrDomain card = set.getCard();
375                int newLow = low + card.getLowBound();
376                int newHigh = high + card.getHighBound();
377                IrDomain env = boundDomain(low, newHigh - 1);
378                IrDomain ker = set.getKer();
379                if (!ker.isEmpty() && !ker.isBounded()) {
380                    ker = Irs.boundDomain(ker.getLowBound(), ker.getHighBound());
381                }
382                if (high < newLow) {
383                    ker = IrUtil.union(ker, boundDomain(high, newLow - 1));
384                }
385                propagateSet(new PartialSet(env, ker, null), set);
386                low = newLow;
387                high = newHigh;
388            }
389            return null;
390        }
391
392        @Override
393        public Void visit(IrSortStringsChannel ir, Void a) {
394            IrIntExpr[][] strings = ir.getStrings();
395            IrIntExpr[] ints = ir.getInts();
396            for (int i = 0; i < strings.length; i++) {
397                for (int j = i + 1; j < strings.length; j++) {
398                    switch (IrUtil.compareString(strings[i], strings[j])) {
399                        case EQ:
400                            propagateEqual(ints[i], ints[j]);
401                            break;
402                        case LT:
403                            propagateLessThan(ints[i], ints[j]);
404                            break;
405                        case LE:
406                            propagateLessThanEqual(ints[i], ints[j]);
407                            break;
408                        case GT:
409                            propagateLessThan(ints[j], ints[i]);
410                            break;
411                        case GE:
412                            propagateLessThanEqual(ints[j], ints[i]);
413                            break;
414                    }
415                }
416            }
417            IrDomain dom = boundDomain(0, ints.length - 1);
418            for (int i = 0; i < ints.length; i++) {
419                propagateInt(dom, ints[i]);
420                for (int j = i + 1; j < ints.length; j++) {
421                    switch (IrUtil.compare(ints[i], ints[j])) {
422                        case EQ:
423                            propagateEqualString(strings[i], strings[j]);
424                            break;
425                        case LT:
426                            propagateLessThanString(strings[i], strings[j]);
427                            break;
428                        case LE:
429                            propagateLessThanEqualString(strings[i], strings[j]);
430                            break;
431                        case GT:
432                            propagateLessThanString(strings[j], strings[i]);
433                            break;
434                        case GE:
435                            propagateLessThanEqualString(strings[j], strings[i]);
436                            break;
437                    }
438                }
439            }
440            return null;
441        }
442
443        @Override
444        public Void visit(IrAllDifferent ir, Void a) {
445            IrIntExpr[] operands = ir.getOperands();
446            for (int i = 0; i < operands.length; i++) {
447                for (int j = i + 1; j > operands.length; j++) {
448                    propagateNotEqual(operands[i], operands[j]);
449                }
450            }
451            return null;
452        }
453
454        @Override
455        public Void visit(IrSelectN ir, Void a) {
456            IrBoolExpr[] bools = ir.getBools();
457            IrIntExpr n = ir.getN();
458            for (int i = 0; i < bools.length; i++) {
459                if (IrUtil.isTrue(bools[i]) && i >= n.getDomain().getLowBound()) {
460                    propagateInt(boundDomain(i + 1, bools.length), n);
461                } else if (IrUtil.isFalse(bools[i]) && i < n.getDomain().getHighBound()) {
462                    propagateInt(boundDomain(0, i), n);
463                }
464            }
465            for (int i = 0; i < n.getDomain().getLowBound(); i++) {
466                propagateInt(TrueDomain, bools[i]);
467            }
468            for (int i = n.getDomain().getHighBound(); i < bools.length; i++) {
469                propagateInt(FalseDomain, bools[i]);
470            }
471            return null;
472        }
473
474        @Override
475        public Void visit(IrFilterString ir, Void a) {
476            TIntIterator iter = ir.getSet().getEnv().iterator();
477            int i = 0;
478            IrDomain values = Irs.EmptyDomain;
479            while (iter.hasNext()) {
480                int env = iter.next();
481                if (!ir.getSet().getKer().contains(env)) {
482                    i = -1;
483                }
484                if (i >= 0) {
485                    IrIntExpr string = ir.getString()[env - ir.getOffset()];
486                    IrIntExpr result = ir.getResult()[i];
487                    propagateEqual(string, result);
488                    i++;
489                }
490                values = IrUtil.union(values, ir.getString()[env - ir.getOffset()].getDomain());
491            }
492            for (int j = 0; j < ir.getSet().getCard().getLowBound(); j++) {
493                propagateInt(values, ir.getResult()[j]);
494            }
495            for (int j = ir.getSet().getCard().getHighBound(); j < ir.getResult().length; j++) {
496                propagateInt(constantDomain(-1), ir.getResult()[j]);
497            }
498            return null;
499        }
500
501        private void propagateEqual(IrIntExpr left, IrIntExpr right) {
502            Pair<IrIntExpr, IrSetVar> cardinality = AnalysisUtil.getAssignCardinality(left, right);
503            if (cardinality != null) {
504                IrIntExpr cardExpr = cardinality.getFst();
505                IrSetVar setVar = cardinality.getSnd();
506
507                if (cardExpr instanceof IrIntVar) {
508                    IrIntVar cardVar = (IrIntVar) cardExpr;
509                    IrIntVar duplicate = duplicates.put(setVar, cardVar);
510                    if (duplicate != null) {
511                        intGraph.union(cardVar, duplicate);
512                        return;
513                    }
514                }
515            }
516            if (left instanceof IrIntVar && right instanceof IrIntVar) {
517                intGraph.union((IrIntVar) left, (IrIntVar) right);
518            } else {
519                propagateInt(left.getDomain(), right);
520                propagateInt(right.getDomain(), left);
521            }
522        }
523
524        private void propagateEqualString(IrIntExpr[] a, IrIntExpr[] b) {
525            for (int i = 0; i < a.length; i++) {
526                propagateEqual(a[i], b[i]);
527            }
528        }
529
530        private void propagateNotEqual(IrIntExpr left, IrIntExpr right) {
531            Integer constant = IrUtil.getConstant(left);
532            if (constant != null) {
533                IrDomain minus = IrUtil.remove(right.getDomain(), constant);
534                propagateInt(minus, right);
535            }
536            constant = IrUtil.getConstant(right);
537            if (constant != null) {
538                IrDomain minus = IrUtil.remove(left.getDomain(), constant);
539                propagateInt(minus, left);
540            }
541        }
542
543        private void propagateLessThan(IrIntExpr left, IrIntExpr right) {
544            IrDomain leftDomain = left.getDomain();
545            IrDomain rightDomain = right.getDomain();
546            if (leftDomain.getHighBound() >= rightDomain.getHighBound()) {
547                propagateInt(IrUtil.boundHigh(leftDomain, rightDomain.getHighBound() - 1), left);
548            }
549            if (rightDomain.getLowBound() <= leftDomain.getLowBound()) {
550                propagateInt(IrUtil.boundLow(rightDomain, leftDomain.getLowBound() + 1), right);
551            }
552        }
553
554        private void propagateLessThanString(IrIntExpr[] a, IrIntExpr[] b) {
555            propagateLessThanString(a, b, 0);
556        }
557
558        private void propagateLessThanString(IrIntExpr[] a, IrIntExpr[] b, int index) {
559            assert a.length == b.length;
560            if (index == a.length) {
561                // Model is unsatisfiable. Compile anyways?
562            }
563            switch (IrUtil.compare(a[index], b[index])) {
564                case EQ:
565                    propagateLessThanString(a, b, index + 1);
566                    return;
567                case LT:
568                    return;
569                case GT:
570                    // Model is unsatisfiable. Compile anyways?
571                    return;
572                case LE:
573                case GE:
574                case UNKNOWN:
575                    switch (IrUtil.compareString(a, b, index + 1)) {
576                        case EQ:
577                        case GT:
578                        case GE:
579                            propagateLessThan(a[index], b[index]);
580                            return;
581                        case LT:
582                        case LE:
583                        case UNKNOWN:
584                            propagateLessThanEqual(a[index], b[index]);
585                            return;
586                        default:
587                            throw new IllegalStateException();
588                    }
589                default:
590                    throw new IllegalStateException();
591            }
592        }
593
594        private void propagateLessThanEqual(IrIntExpr left, IrIntExpr right) {
595            IrDomain leftDomain = left.getDomain();
596            IrDomain rightDomain = right.getDomain();
597            if (leftDomain.getHighBound() > rightDomain.getHighBound()) {
598                propagateInt(IrUtil.boundHigh(left.getDomain(), right.getDomain().getHighBound()), left);
599            }
600            if (rightDomain.getLowBound() < leftDomain.getLowBound()) {
601                propagateInt(IrUtil.boundLow(right.getDomain(), left.getDomain().getLowBound()), right);
602            }
603        }
604
605        private void propagateLessThanEqualString(IrIntExpr[] a, IrIntExpr[] b) {
606            propagateLessThanEqualString(a, b, 0);
607        }
608
609        private void propagateLessThanEqualString(IrIntExpr[] a, IrIntExpr[] b, int index) {
610            assert a.length == b.length;
611            if (index == a.length) {
612                return;
613            }
614            switch (IrUtil.compare(a[index], b[index])) {
615                case EQ:
616                    propagateLessThanEqualString(a, b, index + 1);
617                    return;
618                case LT:
619                    return;
620                case GT:
621                    // Model is unsatisfiable. Compile anyways?
622                    return;
623                case LE:
624                case GE:
625                case UNKNOWN:
626                    switch (IrUtil.compareString(a, b, index + 1)) {
627                        case EQ:
628                        case LT:
629                        case LE:
630                        case GE:
631                        case UNKNOWN:
632                            propagateLessThanEqual(a[index], b[index]);
633                            return;
634                        case GT:
635                            propagateLessThan(a[index], b[index]);
636                            return;
637                        default:
638                            throw new IllegalStateException();
639                    }
640                default:
641                    throw new IllegalStateException();
642            }
643        }
644
645        private void propagateInt(IrDomain left, IrIntExpr right) {
646            if (IrUtil.isSubsetOf(right.getDomain(), left)) {
647                return;
648            }
649            if (right instanceof IrIntVar) {
650                IrDomain domain = IrUtil.intersection(left, right.getDomain());
651                if (domain.isEmpty()) {
652                    // Model is unsatisfiable. Compile anyways?
653                } else {
654                    intGraph.union((IrIntVar) right, domainInt("domain" + domain, domain));
655                }
656            } else if (right instanceof IrMinus) {
657                propagateInt(IrUtil.minus(left), ((IrMinus) right).getExpr());
658            } else if (right instanceof IrCard) {
659                propagateCard(left, ((IrCard) right).getSet());
660            } else if (right instanceof IrAdd) {
661                IrAdd add = (IrAdd) right;
662                IrIntExpr[] addends = add.getAddends();
663                if (addends.length == 1) {
664                    propagateInt(IrUtil.offset(left, -add.getOffset()), addends[0]);
665                } else {
666                    for (IrIntExpr addend : addends) {
667                        IrDomain domain = addend.getDomain();
668                        IrDomain bound =
669                                IrUtil.intersection(
670                                boundDomain(
671                                left.getLowBound() - right.getDomain().getHighBound() + domain.getHighBound(),
672                                left.getHighBound() - right.getDomain().getLowBound() + domain.getLowBound()),
673                                domain);
674                        propagateInt(bound, addend);
675                    }
676                }
677            } else if (right instanceof IrElement) {
678                IrElement element = (IrElement) right;
679                TIntHashSet domain = new TIntHashSet(element.getIndex().getDomain().size());
680                TIntIterator iter = element.getIndex().getDomain().iterator();
681                while (iter.hasNext()) {
682                    int val = iter.next();
683                    if (IrUtil.intersects(left, element.getArray()[val].getDomain())) {
684                        domain.add(val);
685                    }
686                }
687                propagateInt(enumDomain(domain), element.getIndex());
688            }
689        }
690
691        private void propagateSet(PartialSet left, IrSetExpr right) {
692            left.updateMask(right.getEnv(), right.getKer(), right.getCard());
693            if (left.hasMask()) {
694                if (right instanceof IrSetVar) {
695                    propagateSetVar(left, (IrSetVar) right);
696                } else if (right instanceof IrSingleton) {
697                    propagateSingleton(left, (IrSingleton) right);
698                } else if (right instanceof IrArrayToSet) {
699                    propagateArrayToSet(left, (IrArrayToSet) right);
700                } else if (right instanceof IrJoinRelation) {
701                    propagateJoinRelation(left, (IrJoinRelation) right);
702                } else if (right instanceof IrJoinFunction) {
703                    propagateJoinFunction(left, (IrJoinFunction) right);
704                } else if (right instanceof IrSetUnion) {
705                    propagateSetUnion(left, (IrSetUnion) right);
706                } else if (right instanceof IrOffset) {
707                    propagateOffset(left, (IrOffset) right);
708                }
709            }
710        }
711
712        private void propagateSetVar(PartialSet left, IrSetVar right) {
713            IrDomain env = right.getEnv();
714            IrDomain ker = right.getKer();
715            IrDomain card = right.getCard();
716            if (left.isEnvMask()) {
717                env = left.getEnv();
718            }
719            if (left.isKerMask()) {
720                ker = left.getKer();
721            }
722            if (left.isCardMask()) {
723                card = left.getCard();
724            }
725            IrSetVar set = newSet(left.toString(), env, ker, card);
726            if (set != null) {
727                setGraph.union(right, set);
728            }
729        }
730
731        private void propagateSingleton(PartialSet left, IrSingleton right) {
732            if (left.isKerMask()) {
733                IrDomain ker = left.getKer();
734                if (ker.size() == 1) {
735                    propagateInt(ker, right.getValue());
736                }
737            } else if (left.isEnvMask()) {
738                IrDomain env = left.getEnv();
739                propagateInt(env, right.getValue());
740            }
741        }
742
743        private void propagateArrayToSet(PartialSet left, IrArrayToSet right) {
744            if (left.isEnvMask()) {
745                IrDomain env = left.getEnv();
746                for (IrIntExpr child : right.getArray()) {
747                    propagateInt(env, child);
748                }
749            }
750            if (left.isKerMask()) {
751                TIntIterator iter = IrUtil.difference(left.getKer(), right.getKer()).iterator();
752                while (iter.hasNext()) {
753                    int val = iter.next();
754                    IrIntExpr index = null;
755                    for (IrIntExpr operand : right.getArray()) {
756                        if (operand.getDomain().contains(val)) {
757                            if (index != null) {
758                                index = null;
759                                break;
760                            }
761                            index = operand;
762                        }
763                    }
764                    if (index != null) {
765                        propagateInt(constantDomain(val), index);
766                    }
767                }
768            }
769        }
770
771        private void propagateJoinRelation(PartialSet left, IrJoinRelation right) {
772            if (right.isInjective()) {
773                if (left.isEnvMask() || left.isCardMask()) {
774                    IrDomain env = left.getEnv();
775                    IrDomain card = left.isCardMask() ? boundDomain(0, left.getCard().getHighBound()) : null;
776                    TIntIterator iter = right.getTake().getKer().iterator();
777                    PartialSet set = new PartialSet(env, null, card);
778                    while (iter.hasNext()) {
779                        propagateSet(set, right.getChildren()[iter.next()]);
780                    }
781                }
782                if (left.isKerMask()) {
783                    TIntIterator iter = IrUtil.difference(left.getKer(), right.getKer()).iterator();
784                    while (iter.hasNext()) {
785                        int val = iter.next();
786                        TIntIterator env = right.getTake().getEnv().iterator();
787                        int index = -1;
788                        while (env.hasNext()) {
789                            int j = env.next();
790                            if (right.getChildren()[j].getEnv().contains(val)) {
791                                if (index != -1) {
792                                    index = -1;
793                                    break;
794                                }
795                                index = j;
796                            }
797                        }
798                        if (index != -1) {
799                            propagateKer(constantDomain(index), right.getTake());
800                            propagateKer(constantDomain(val), right.getChildren()[index]);
801                        }
802                    }
803                }
804                if (left.isCardMask()) {
805                    IrSetExpr take = right.getTake();
806                    IrSetExpr[] children = right.getChildren();
807                    int lb = left.getCard().getLowBound();
808                    int ub = left.getCard().getHighBound();
809                    int[] envLbs = new int[take.getEnv().size() - take.getKer().size()];
810                    int[] envUbs = new int[envLbs.length];
811                    int kerMinCard = 0;
812                    int kerMaxCard = 0;
813                    int env = 0;
814                    TIntIterator iter = take.getEnv().iterator();
815                    while (iter.hasNext()) {
816                        int i = iter.next();
817                        if (take.getKer().contains(i)) {
818                            kerMinCard += children[i].getCard().getLowBound();
819                            kerMaxCard += children[i].getCard().getHighBound();
820                        } else {
821                            envLbs[env] = children[i].getCard().getLowBound();
822                            envUbs[env] = children[i].getCard().getHighBound();
823                            env++;
824                        }
825                    }
826                    Arrays.sort(envLbs);
827                    Arrays.sort(envUbs);
828                    int i;
829                    for (i = 0; i < envLbs.length && (kerMinCard < ub || envLbs[i] == 0); i++) {
830                        kerMinCard += envLbs[i];
831                    }
832                    int high = i + take.getKer().size();
833                    for (i = envUbs.length - 1; i >= 0 && kerMaxCard < lb; i--) {
834                        kerMaxCard += envUbs[i];
835                    }
836                    int low = envUbs.length - 1 - i + take.getKer().size();
837                    if (low > take.getCard().getLowBound() || high < take.getCard().getHighBound()) {
838                        propagateCard(boundDomain(low, high), take);
839                    }
840                }
841            }
842        }
843
844        private void propagateJoinFunction(PartialSet left, IrJoinFunction right) {
845            if (left.isEnvMask()) {
846                IrDomain env = left.getEnv();
847                TIntIterator iter = right.getTake().getKer().iterator();
848                while (iter.hasNext()) {
849                    propagateInt(env, right.getRefs()[iter.next()]);
850                }
851            }
852            if (left.isKerMask()) {
853                TIntIterator iter = IrUtil.difference(left.getKer(), right.getKer()).iterator();
854                while (iter.hasNext()) {
855                    int val = iter.next();
856                    TIntIterator env = right.getTake().getEnv().iterator();
857                    int index = -1;
858                    while (env.hasNext()) {
859                        int j = env.next();
860                        if (right.getRefs()[j].getDomain().contains(val)) {
861                            if (index != -1) {
862                                index = -1;
863                                break;
864                            }
865                            index = j;
866                        }
867                    }
868                    if (index != -1) {
869                        propagateKer(constantDomain(index), right.getTake());
870                        propagateInt(constantDomain(val), right.getRefs()[index]);
871                    }
872                }
873            }
874            if (left.isCardMask()) {
875                IrDomain card = left.getCard();
876                IrSetExpr take = right.getTake();
877                int low = Math.max(take.getKer().size(), card.getLowBound());
878                int high = Math.min(take.getEnv().size(),
879                        right.hasGlobalCardinality()
880                        ? card.getHighBound() * right.getGlobalCardinality()
881                        : take.getCard().getHighBound());
882                if (low > take.getCard().getLowBound() || high < take.getCard().getHighBound()) {
883                    propagateCard(boundDomain(low, high), take);
884                }
885            }
886        }
887
888        private void propagateSetUnion(PartialSet left, IrSetUnion right) {
889            if (left.isEnvMask() || left.isCardMask()) {
890                IrSetExpr[] operands = right.getOperands();
891                IrDomain env = left.getEnv();
892                if (right.isDisjoint() && left.isCardMask()) {
893                    int lowCards = 0;
894                    int highCards = 0;
895                    for (IrSetExpr operand : operands) {
896                        lowCards += operand.getCard().getLowBound();
897                        highCards += operand.getCard().getHighBound();
898                    }
899                    for (IrSetExpr operand : operands) {
900                        IrDomain card = boundDomain(
901                                left.getCard().getLowBound() - highCards + operand.getCard().getHighBound(),
902                                left.getCard().getHighBound() - lowCards + operand.getCard().getLowBound());
903                        PartialSet set = new PartialSet(env, null, card);
904                        propagateSet(set, operand);
905                    }
906                } else {
907                    IrDomain card = left.isCardMask() ? boundDomain(0, left.getCard().getHighBound()) : null;
908                    PartialSet set = new PartialSet(env, null, card);
909                    for (IrSetExpr operand : operands) {
910                        propagateSet(set, operand);
911                    }
912                }
913            }
914            if (left.isKerMask()) {
915                TIntIterator iter = IrUtil.difference(left.getKer(), right.getKer()).iterator();
916                while (iter.hasNext()) {
917                    int val = iter.next();
918                    IrSetExpr index = null;
919                    for (IrSetExpr operand : right.getOperands()) {
920                        if (operand.getEnv().contains(val)) {
921                            if (index != null) {
922                                index = null;
923                                break;
924                            }
925                            index = operand;
926                        }
927                    }
928                    if (index != null) {
929                        propagateKer(constantDomain(val), index);
930                    }
931                }
932            }
933        }
934
935        private void propagateOffset(PartialSet left, IrOffset right) {
936            int offset = right.getOffset();
937            IrDomain env = left.isEnvMask() ? IrUtil.offset(left.getEnv(), offset) : null;
938            IrDomain ker = left.isKerMask() ? IrUtil.offset(left.getKer(), offset) : null;
939            IrDomain card = left.isCardMask() ? left.getCard() : null;
940            propagateSet(new PartialSet(env, ker, card), right.getSet());
941        }
942
943        private void propagateEnv(IrDomain left, IrSetExpr right) {
944            propagateSet(env(left), right);
945        }
946
947        private void propagateKer(IrDomain left, IrSetExpr right) {
948            propagateSet(ker(left), right);
949        }
950
951        private void propagateCard(IrDomain left, IrSetExpr right) {
952            propagateSet(card(left), right);
953        }
954    }
955
956    private static class CoalesceRewriter extends IrRewriter<Void> {
957
958        private final Map<IrIntVar, IrIntVar> coalescedInts;
959        private final Map<IrSetVar, IrSetVar> coalescedSets;
960
961        CoalesceRewriter(Map<IrIntVar, IrIntVar> coalescedInts, Map<IrSetVar, IrSetVar> coalescedSets) {
962            this.coalescedInts = coalescedInts;
963            this.coalescedSets = coalescedSets;
964        }
965
966        @Override
967        public IrBoolVar visit(IrBoolVar ir, Void a) {
968            IrBoolVar var = (IrBoolVar) coalescedInts.get(ir);
969            return var == null ? ir : var;
970        }
971
972        @Override
973        public IrIntVar visit(IrIntVar ir, Void a) {
974            IrIntVar var = coalescedInts.get(ir);
975            return var == null ? ir : var;
976        }
977
978        @Override
979        public IrSetVar visit(IrSetVar ir, Void a) {
980            IrSetVar var = coalescedSets.get(ir);
981            return var == null ? ir : var;
982        }
983    }
984
985    private static PartialSet env(IrDomain env) {
986        return new PartialSet(env, null, null);
987    }
988
989    private static PartialSet ker(IrDomain ker) {
990        return new PartialSet(null, ker, null);
991    }
992
993    private static PartialSet card(IrDomain card) {
994        return new PartialSet(null, null, card);
995    }
996
997    private static class PartialSet {
998
999        private final IrDomain env;
1000        private final IrDomain ker;
1001        private final IrDomain card;
1002        private byte mask;
1003
1004        PartialSet(IrDomain env, IrDomain ker, IrDomain card) {
1005            assert env != null || ker != null || card != null;
1006            this.env = env;
1007            this.ker = ker;
1008            this.card = card;
1009        }
1010
1011        IrDomain getEnv() {
1012            return env;
1013        }
1014
1015        IrDomain getKer() {
1016            return ker;
1017        }
1018
1019        IrDomain getCard() {
1020            return card;
1021        }
1022
1023        boolean isEnvMask() {
1024            return (mask & 1) == 1;
1025        }
1026
1027        boolean isKerMask() {
1028            return (mask & 2) == 2;
1029        }
1030
1031        boolean isCardMask() {
1032            return (mask & 4) == 4;
1033        }
1034
1035        boolean hasMask() {
1036            return mask != 0;
1037        }
1038
1039        void updateMask(IrDomain env, IrDomain ker, IrDomain card) {
1040            if (this.env != null && !IrUtil.isSubsetOf(env, this.env)) {
1041                mask |= 1;
1042            }
1043            if (this.ker != null && !IrUtil.isSubsetOf(this.ker, ker)) {
1044                mask |= 2;
1045            }
1046            if (this.card != null && !IrUtil.isSubsetOf(card, this.card)) {
1047                mask |= 4;
1048            }
1049        }
1050
1051        @Override
1052        public String toString() {
1053            return (env != null ? "env=" + env : "")
1054                    + (ker != null ? "ker=" + ker : "")
1055                    + (card != null ? "card=" + card : "");
1056        }
1057    }
1058}