001package org.clafer.ir;
002
003import org.clafer.common.Util;
004import static org.clafer.ir.Irs.*;
005
006/**
007 *
008 * @param <T> the parameter type
009 * @author jimmy
010 */
011public abstract class IrRewriter<T>
012        implements IrIntExprVisitor<T, IrIntExpr>, IrSetExprVisitor<T, IrSetExpr> {
013
014    protected static <T> boolean changed(T t1, T t2) {
015        if (t1 == t2) {
016            return false;
017        }
018        assert !t1.equals(t2) : "Likely not optimized, the rewriter duplicated an object. Possible false negative.";
019        return true;
020    }
021
022    protected static <T> boolean changed(T[] t1, T[] t2) {
023        if (t1 == t2) {
024            return false;
025        }
026        if (t1.length != t2.length) {
027            return true;
028        }
029        for (int i = 0; i < t1.length; i++) {
030            if (changed(t1[i], t2[i])) {
031                return true;
032            }
033        }
034        return false;
035    }
036
037    protected static <T> boolean changed(T[][] t1, T[][] t2) {
038        if (t1 == t2) {
039            return false;
040        }
041        if (t1.length != t2.length) {
042            return true;
043        }
044        for (int i = 0; i < t1.length; i++) {
045            if (changed(t1[i], t2[i])) {
046                return true;
047            }
048        }
049        return false;
050    }
051
052    public IrModule rewrite(IrModule module, T t) {
053        IrModule optModule = new IrModule(module.getVariables().size(), module.getConstraints().size());
054        for (IrVar variable : module.getVariables()) {
055            if (variable instanceof IrBoolVar) {
056                optModule.addVariable(visit((IrBoolVar) variable, t));
057            } else if (variable instanceof IrIntVar) {
058                optModule.addVariable(visit((IrIntVar) variable, t));
059            } else {
060                optModule.addVariable(visit((IrSetVar) variable, t));
061            }
062        }
063        for (IrBoolExpr constraint : module.getConstraints()) {
064            optModule.addConstraint(rewrite(constraint, t));
065        }
066        return optModule;
067    }
068
069    public IrBoolExpr rewrite(IrBoolExpr expr, T t) {
070        return (IrBoolExpr) expr.accept(this, t);
071    }
072
073    public IrBoolExpr[] rewrite(IrBoolExpr[] exprs, T t) {
074        IrBoolExpr[] rewritten = new IrBoolExpr[exprs.length];
075        for (int i = 0; i < rewritten.length; i++) {
076            rewritten[i] = rewrite(exprs[i], t);
077        }
078        return rewritten;
079    }
080
081    public IrIntExpr rewrite(IrIntExpr expr, T t) {
082        return expr.accept(this, t);
083    }
084
085    public IrIntExpr[] rewrite(IrIntExpr[] exprs, T t) {
086        IrIntExpr[] rewritten = new IrIntExpr[exprs.length];
087        for (int i = 0; i < rewritten.length; i++) {
088            rewritten[i] = rewrite(exprs[i], t);
089        }
090        return rewritten;
091    }
092
093    public IrIntExpr[][] rewrite(IrIntExpr[][] exprs, T t) {
094        IrIntExpr[][] rewritten = new IrIntExpr[exprs.length][];
095        for (int i = 0; i < rewritten.length; i++) {
096            rewritten[i] = rewrite(exprs[i], t);
097        }
098        return rewritten;
099    }
100
101    public IrSetExpr rewrite(IrSetExpr expr, T t) {
102        return expr.accept(this, t);
103    }
104
105    public IrSetExpr[] rewrite(IrSetExpr[] exprs, T t) {
106        IrSetExpr[] rewritten = new IrSetExpr[exprs.length];
107        for (int i = 0; i < rewritten.length; i++) {
108            rewritten[i] = rewrite(exprs[i], t);
109        }
110        return rewritten;
111    }
112
113    public IrSetExpr[][] rewrite(IrSetExpr[][] exprs, T t) {
114        IrSetExpr[][] rewritten = new IrSetExpr[exprs.length][];
115        for (int i = 0; i < rewritten.length; i++) {
116            rewritten[i] = rewrite(exprs[i], t);
117        }
118        return rewritten;
119    }
120
121    @Override
122    public IrBoolVar visit(IrBoolVar ir, T a) {
123        return ir;
124    }
125
126    @Override
127    public IrBoolExpr visit(IrNot ir, T a) {
128        IrBoolExpr expr = rewrite(ir.getExpr(), a);
129        return changed(ir.getExpr(), expr)
130                ? not(expr)
131                : ir;
132    }
133
134    @Override
135    public IrBoolExpr visit(IrAnd ir, T a) {
136        IrBoolExpr[] operands = rewrite(ir.getOperands(), a);
137        return changed(ir.getOperands(), operands)
138                ? and(operands)
139                : ir;
140    }
141
142    @Override
143    public IrBoolExpr visit(IrLone ir, T a) {
144        IrBoolExpr[] operands = rewrite(ir.getOperands(), a);
145        return changed(ir.getOperands(), operands)
146                ? lone(operands)
147                : ir;
148    }
149
150    @Override
151    public IrBoolExpr visit(IrOne ir, T a) {
152        IrBoolExpr[] operands = rewrite(ir.getOperands(), a);
153        return changed(ir.getOperands(), operands)
154                ? one(operands)
155                : ir;
156    }
157
158    @Override
159    public IrBoolExpr visit(IrOr ir, T a) {
160        IrBoolExpr[] operands = rewrite(ir.getOperands(), a);
161        return changed(ir.getOperands(), operands)
162                ? or(operands)
163                : ir;
164    }
165
166    @Override
167    public IrBoolExpr visit(IrImplies ir, T a) {
168        IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a);
169        IrBoolExpr consequent = rewrite(ir.getConsequent(), a);
170        return changed(ir.getAntecedent(), antecedent)
171                || changed(ir.getConsequent(), consequent)
172                ? implies(antecedent, consequent)
173                : ir;
174    }
175
176    @Override
177    public IrBoolExpr visit(IrNotImplies ir, T a) {
178        IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a);
179        IrBoolExpr consequent = rewrite(ir.getConsequent(), a);
180        return changed(ir.getAntecedent(), antecedent)
181                || changed(ir.getConsequent(), consequent)
182                ? notImplies(antecedent, consequent)
183                : ir;
184    }
185
186    @Override
187    public IrBoolExpr visit(IrIfThenElse ir, T a) {
188        IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a);
189        IrBoolExpr consequent = rewrite(ir.getConsequent(), a);
190        IrBoolExpr alternative = rewrite(ir.getAlternative(), a);
191        return changed(ir.getAntecedent(), antecedent)
192                || changed(ir.getConsequent(), consequent)
193                || changed(ir.getAlternative(), alternative)
194                ? ifThenElse(antecedent, consequent, alternative)
195                : ir;
196    }
197
198    @Override
199    public IrBoolExpr visit(IrIfOnlyIf ir, T a) {
200        IrBoolExpr left = rewrite(ir.getLeft(), a);
201        IrBoolExpr right = rewrite(ir.getRight(), a);
202        return changed(ir.getLeft(), left) || changed(ir.getRight(), right)
203                ? ifOnlyIf(left, right)
204                : ir;
205    }
206
207    @Override
208    public IrBoolExpr visit(IrXor ir, T a) {
209        IrBoolExpr left = rewrite(ir.getLeft(), a);
210        IrBoolExpr right = rewrite(ir.getRight(), a);
211        return changed(ir.getLeft(), left) || changed(ir.getRight(), right)
212                ? xor(left, right)
213                : ir;
214    }
215
216    @Override
217    public IrBoolExpr visit(IrWithin ir, T a) {
218        IrIntExpr value = rewrite(ir.getValue(), a);
219        return changed(ir.getValue(), value)
220                ? within(value, ir.getRange())
221                : ir;
222    }
223
224    @Override
225    public IrBoolExpr visit(IrNotWithin ir, T a) {
226        IrIntExpr value = rewrite(ir.getValue(), a);
227        return changed(ir.getValue(), value)
228                ? notWithin(value, ir.getRange())
229                : ir;
230    }
231
232    @Override
233    public IrBoolExpr visit(IrCompare ir, T a) {
234        IrIntExpr left = rewrite(ir.getLeft(), a);
235        IrIntExpr right = rewrite(ir.getRight(), a);
236        return changed(ir.getLeft(), left) || changed(ir.getRight(), right)
237                ? compare(left, ir.getOp(), right)
238                : ir;
239    }
240
241    @Override
242    public IrBoolExpr visit(IrSetTest ir, T a) {
243        IrSetExpr left = rewrite(ir.getLeft(), a);
244        IrSetExpr right = rewrite(ir.getRight(), a);
245        return changed(ir.getLeft(), left) || changed(ir.getRight(), right)
246                ? equality(left, ir.getOp(), right)
247                : ir;
248    }
249
250    @Override
251    public IrBoolExpr visit(IrMember ir, T a) {
252        IrIntExpr element = rewrite(ir.getElement(), a);
253        IrSetExpr set = rewrite(ir.getSet(), a);
254        return changed(ir.getElement(), element) || changed(ir.getSet(), set)
255                ? member(element, set)
256                : ir;
257    }
258
259    @Override
260    public IrBoolExpr visit(IrNotMember ir, T a) {
261        IrIntExpr element = rewrite(ir.getElement(), a);
262        IrSetExpr set = rewrite(ir.getSet(), a);
263        return changed(ir.getElement(), element) || changed(ir.getSet(), set)
264                ? notMember(element, set)
265                : ir;
266    }
267
268    @Override
269    public IrBoolExpr visit(IrSubsetEq ir, T a) {
270        IrSetExpr subset = rewrite(ir.getSubset(), a);
271        IrSetExpr superset = rewrite(ir.getSuperset(), a);
272        return changed(ir.getSubset(), subset) || changed(ir.getSuperset(), superset)
273                ? subsetEq(subset, superset)
274                : ir;
275    }
276
277    @Override
278    public IrBoolExpr visit(IrBoolChannel ir, T a) {
279        IrBoolExpr[] bools = rewrite(ir.getBools(), a);
280        IrSetExpr set = rewrite(ir.getSet(), a);
281        return changed(ir.getBools(), bools) || changed(ir.getSet(), set)
282                ? boolChannel(bools, set)
283                : ir;
284    }
285
286    @Override
287    public IrBoolExpr visit(IrIntChannel ir, T a) {
288        IrIntExpr[] ints = rewrite(ir.getInts(), a);
289        IrSetExpr[] sets = rewrite(ir.getSets(), a);
290        return changed(ir.getInts(), ints) || changed(ir.getSets(), sets)
291                ? intChannel(ints, sets)
292                : ir;
293    }
294
295    @Override
296    public IrBoolExpr visit(IrSortStrings ir, T a) {
297        IrIntExpr[][] strings = rewrite(ir.getStrings(), a);
298        return changed(ir.getStrings(), strings)
299                ? (ir.isStrict() ? sortStrict(strings) : sort(strings))
300                : ir;
301    }
302
303    @Override
304    public IrBoolExpr visit(IrSortSets ir, T a) {
305        IrSetExpr[] sets = rewrite(ir.getSets(), a);
306        return changed(ir.getSets(), sets)
307                ? sort(sets)
308                : ir;
309    }
310
311    @Override
312    public IrBoolExpr visit(IrSortStringsChannel ir, T a) {
313        IrIntExpr[][] strings = rewrite(ir.getStrings(), a);
314        IrIntExpr[] ints = rewrite(ir.getInts(), a);
315        return changed(ir.getStrings(), strings) || changed(ir.getInts(), ints)
316                ? sortChannel(strings, ints)
317                : ir;
318    }
319
320    @Override
321    public IrBoolExpr visit(IrAllDifferent ir, T a) {
322        IrIntExpr[] operands = rewrite(ir.getOperands(), a);
323        return changed(ir.getOperands(), operands)
324                ? allDifferent(operands)
325                : ir;
326    }
327
328    @Override
329    public IrBoolExpr visit(IrSelectN ir, T a) {
330        IrBoolExpr[] bools = rewrite(ir.getBools(), a);
331        IrIntExpr n = rewrite(ir.getN(), a);
332        return changed(ir.getBools(), bools) || changed(ir.getN(), n)
333                ? selectN(bools, n)
334                : ir;
335    }
336
337    @Override
338    public IrIntExpr visit(IrAcyclic ir, T a) {
339        IrIntExpr[] edges = rewrite(ir.getEdges(), a);
340        return changed(ir.getEdges(), edges)
341                ? acyclic(edges)
342                : ir;
343    }
344
345    @Override
346    public IrIntExpr visit(IrUnreachable ir, T a) {
347        IrIntExpr[] edges = rewrite(ir.getEdges(), a);
348        return changed(ir.getEdges(), edges)
349                ? unreachable(edges, ir.getFrom(), ir.getTo())
350                : ir;
351    }
352
353    @Override
354    public IrBoolExpr visit(IrFilterString ir, T a) {
355        IrSetExpr set = rewrite(ir.getSet(), a);
356        IrIntExpr[] string = rewrite(ir.getString(), a);
357        IrIntExpr[] result = rewrite(ir.getResult(), a);
358        return changed(ir.getSet(), set)
359                || changed(ir.getString(), string)
360                || changed(ir.getResult(), result)
361                ? filterString(set, ir.getOffset(), string, result)
362                : ir;
363    }
364
365    @Override
366    public IrIntVar visit(IrIntVar ir, T a) {
367        return ir;
368    }
369
370    @Override
371    public IrIntExpr visit(IrMinus ir, T a) {
372        IrIntExpr expr = rewrite(ir.getExpr(), a);
373        return changed(ir.getExpr(), expr)
374                ? minus(expr)
375                : ir;
376    }
377
378    @Override
379    public IrIntExpr visit(IrCard ir, T a) {
380        IrSetExpr set = rewrite(ir.getSet(), a);
381        return changed(ir.getSet(), set)
382                ? card(set)
383                : ir;
384    }
385
386    @Override
387    public IrIntExpr visit(IrAdd ir, T a) {
388        IrIntExpr[] addends = rewrite(ir.getAddends(), a);
389        return changed(ir.getAddends(), addends)
390                ? add(Util.cons(constant(ir.getOffset()), addends))
391                : ir;
392    }
393
394    @Override
395    public IrIntExpr visit(IrMul ir, T a) {
396        IrIntExpr multiplicand = rewrite(ir.getMultiplicand(), a);
397        IrIntExpr multiplier = rewrite(ir.getMultiplier(), a);
398        return changed(ir.getMultiplicand(), multiplicand) || changed(ir.getMultiplier(), multiplier)
399                ? mul(multiplicand, multiplier)
400                : ir;
401    }
402
403    @Override
404    public IrIntExpr visit(IrDiv ir, T a) {
405        IrIntExpr dividend = rewrite(ir.getDividend(), a);
406        IrIntExpr divisor = rewrite(ir.getDivisor(), a);
407        return changed(ir.getDividend(), dividend) || changed(ir.getDivisor(), divisor)
408                ? div(dividend, divisor)
409                : ir;
410    }
411
412    @Override
413    public IrIntExpr visit(IrElement ir, T a) {
414        IrIntExpr[] array = rewrite(ir.getArray(), a);
415        IrIntExpr index = rewrite(ir.getIndex(), a);
416        return changed(ir.getArray(), array) || changed(ir.getIndex(), index)
417                ? element(array, index)
418                : ir;
419    }
420
421    @Override
422    public IrIntExpr visit(IrCount ir, T a) {
423        IrIntExpr[] array = rewrite(ir.getArray(), a);
424        return changed(ir.getArray(), array)
425                ? count(ir.getValue(), array)
426                : ir;
427    }
428
429    @Override
430    public IrIntExpr visit(IrSetSum ir, T a) {
431        IrSetExpr set = rewrite(ir.getSet(), a);
432        return changed(ir.getSet(), set)
433                ? sum(set)
434                : ir;
435    }
436
437    @Override
438    public IrIntExpr visit(IrTernary ir, T a) {
439        IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a);
440        IrIntExpr consequent = rewrite(ir.getConsequent(), a);
441        IrIntExpr alternative = rewrite(ir.getAlternative(), a);
442        return changed(ir.getAntecedent(), antecedent)
443                || changed(ir.getConsequent(), consequent)
444                || changed(ir.getAlternative(), alternative)
445                ? ternary(antecedent, consequent, alternative)
446                : ir;
447    }
448
449    @Override
450    public IrSetVar visit(IrSetVar ir, T a) {
451        return ir;
452    }
453
454    @Override
455    public IrSetExpr visit(IrSingleton ir, T a) {
456        IrIntExpr value = rewrite(ir.getValue(), a);
457        return changed(ir.getValue(), value)
458                ? singleton(value)
459                : ir;
460    }
461
462    @Override
463    public IrSetExpr visit(IrArrayToSet ir, T a) {
464        IrIntExpr[] array = rewrite(ir.getArray(), a);
465        return changed(ir.getArray(), array)
466                ? arrayToSet(array, ir.getGlobalCardinality())
467                : ir;
468    }
469
470    @Override
471    public IrSetExpr visit(IrJoinRelation ir, T a) {
472        IrSetExpr take = rewrite(ir.getTake(), a);
473        IrSetExpr[] children = rewrite(ir.getChildren(), a);
474        return changed(ir.getTake(), take) || changed(ir.getChildren(), children)
475                ? joinRelation(take, children, ir.isInjective())
476                : ir;
477    }
478
479    @Override
480    public IrSetExpr visit(IrJoinFunction ir, T a) {
481        IrSetExpr take = rewrite(ir.getTake(), a);
482        IrIntExpr[] refs = rewrite(ir.getRefs(), a);
483        return changed(ir.getTake(), take) || changed(ir.getRefs(), refs)
484                ? joinFunction(take, refs, ir.getGlobalCardinality())
485                : ir;
486    }
487
488    @Override
489    public IrSetExpr visit(IrSetDifference ir, T a) {
490        IrSetExpr minuend = rewrite(ir.getMinuend(), a);
491        IrSetExpr subtrahend = rewrite(ir.getSubtrahend(), a);
492        return changed(ir.getMinuend(), minuend) || changed(ir.getSubtrahend(), subtrahend)
493                ? difference(minuend, subtrahend)
494                : ir;
495    }
496
497    @Override
498    public IrSetExpr visit(IrSetIntersection ir, T a) {
499        IrSetExpr[] operands = rewrite(ir.getOperands(), a);
500        return changed(ir.getOperands(), operands)
501                ? intersection(operands)
502                : ir;
503    }
504
505    @Override
506    public IrSetExpr visit(IrSetUnion ir, T a) {
507        IrSetExpr[] operands = rewrite(ir.getOperands(), a);
508        return changed(ir.getOperands(), operands)
509                ? union(operands, ir.isDisjoint())
510                : ir;
511    }
512
513    @Override
514    public IrSetExpr visit(IrOffset ir, T a) {
515        IrSetExpr set = rewrite(ir.getSet(), a);
516        return changed(ir.getSet(), set)
517                ? offset(set, ir.getOffset())
518                : ir;
519    }
520
521    @Override
522    public IrSetExpr visit(IrMask ir, T a) {
523        IrSetExpr set = rewrite(ir.getSet(), a);
524        return changed(ir.getSet(), set)
525                ? mask(set, ir.getFrom(), ir.getTo())
526                : ir;
527    }
528
529    @Override
530    public IrSetExpr visit(IrSetTernary ir, T a) {
531        IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a);
532        IrSetExpr consequent = rewrite(ir.getConsequent(), a);
533        IrSetExpr alternative = rewrite(ir.getAlternative(), a);
534        return changed(ir.getAntecedent(), antecedent)
535                || changed(ir.getConsequent(), consequent)
536                || changed(ir.getAlternative(), alternative)
537                ? ternary(antecedent, consequent, alternative)
538                : ir;
539    }
540}