001package org.clafer.ir; 002 003import org.clafer.common.Util; 004import static org.clafer.ir.Irs.*; 005 006/** 007 * 008 * @param <T> the parameter type 009 * @author jimmy 010 */ 011public abstract class IrRewriter<T> 012 implements IrIntExprVisitor<T, IrIntExpr>, IrSetExprVisitor<T, IrSetExpr> { 013 014 protected static <T> boolean changed(T t1, T t2) { 015 if (t1 == t2) { 016 return false; 017 } 018 assert !t1.equals(t2) : "Likely not optimized, the rewriter duplicated an object. Possible false negative."; 019 return true; 020 } 021 022 protected static <T> boolean changed(T[] t1, T[] t2) { 023 if (t1 == t2) { 024 return false; 025 } 026 if (t1.length != t2.length) { 027 return true; 028 } 029 for (int i = 0; i < t1.length; i++) { 030 if (changed(t1[i], t2[i])) { 031 return true; 032 } 033 } 034 return false; 035 } 036 037 protected static <T> boolean changed(T[][] t1, T[][] t2) { 038 if (t1 == t2) { 039 return false; 040 } 041 if (t1.length != t2.length) { 042 return true; 043 } 044 for (int i = 0; i < t1.length; i++) { 045 if (changed(t1[i], t2[i])) { 046 return true; 047 } 048 } 049 return false; 050 } 051 052 public IrModule rewrite(IrModule module, T t) { 053 IrModule optModule = new IrModule(module.getVariables().size(), module.getConstraints().size()); 054 for (IrVar variable : module.getVariables()) { 055 if (variable instanceof IrBoolVar) { 056 optModule.addVariable(visit((IrBoolVar) variable, t)); 057 } else if (variable instanceof IrIntVar) { 058 optModule.addVariable(visit((IrIntVar) variable, t)); 059 } else { 060 optModule.addVariable(visit((IrSetVar) variable, t)); 061 } 062 } 063 for (IrBoolExpr constraint : module.getConstraints()) { 064 optModule.addConstraint(rewrite(constraint, t)); 065 } 066 return optModule; 067 } 068 069 public IrBoolExpr rewrite(IrBoolExpr expr, T t) { 070 return (IrBoolExpr) expr.accept(this, t); 071 } 072 073 public IrBoolExpr[] rewrite(IrBoolExpr[] exprs, T t) { 074 IrBoolExpr[] rewritten = new IrBoolExpr[exprs.length]; 075 for (int i = 0; i < rewritten.length; i++) { 076 rewritten[i] = rewrite(exprs[i], t); 077 } 078 return rewritten; 079 } 080 081 public IrIntExpr rewrite(IrIntExpr expr, T t) { 082 return expr.accept(this, t); 083 } 084 085 public IrIntExpr[] rewrite(IrIntExpr[] exprs, T t) { 086 IrIntExpr[] rewritten = new IrIntExpr[exprs.length]; 087 for (int i = 0; i < rewritten.length; i++) { 088 rewritten[i] = rewrite(exprs[i], t); 089 } 090 return rewritten; 091 } 092 093 public IrIntExpr[][] rewrite(IrIntExpr[][] exprs, T t) { 094 IrIntExpr[][] rewritten = new IrIntExpr[exprs.length][]; 095 for (int i = 0; i < rewritten.length; i++) { 096 rewritten[i] = rewrite(exprs[i], t); 097 } 098 return rewritten; 099 } 100 101 public IrSetExpr rewrite(IrSetExpr expr, T t) { 102 return expr.accept(this, t); 103 } 104 105 public IrSetExpr[] rewrite(IrSetExpr[] exprs, T t) { 106 IrSetExpr[] rewritten = new IrSetExpr[exprs.length]; 107 for (int i = 0; i < rewritten.length; i++) { 108 rewritten[i] = rewrite(exprs[i], t); 109 } 110 return rewritten; 111 } 112 113 public IrSetExpr[][] rewrite(IrSetExpr[][] exprs, T t) { 114 IrSetExpr[][] rewritten = new IrSetExpr[exprs.length][]; 115 for (int i = 0; i < rewritten.length; i++) { 116 rewritten[i] = rewrite(exprs[i], t); 117 } 118 return rewritten; 119 } 120 121 @Override 122 public IrBoolVar visit(IrBoolVar ir, T a) { 123 return ir; 124 } 125 126 @Override 127 public IrBoolExpr visit(IrNot ir, T a) { 128 IrBoolExpr expr = rewrite(ir.getExpr(), a); 129 return changed(ir.getExpr(), expr) 130 ? not(expr) 131 : ir; 132 } 133 134 @Override 135 public IrBoolExpr visit(IrAnd ir, T a) { 136 IrBoolExpr[] operands = rewrite(ir.getOperands(), a); 137 return changed(ir.getOperands(), operands) 138 ? and(operands) 139 : ir; 140 } 141 142 @Override 143 public IrBoolExpr visit(IrLone ir, T a) { 144 IrBoolExpr[] operands = rewrite(ir.getOperands(), a); 145 return changed(ir.getOperands(), operands) 146 ? lone(operands) 147 : ir; 148 } 149 150 @Override 151 public IrBoolExpr visit(IrOne ir, T a) { 152 IrBoolExpr[] operands = rewrite(ir.getOperands(), a); 153 return changed(ir.getOperands(), operands) 154 ? one(operands) 155 : ir; 156 } 157 158 @Override 159 public IrBoolExpr visit(IrOr ir, T a) { 160 IrBoolExpr[] operands = rewrite(ir.getOperands(), a); 161 return changed(ir.getOperands(), operands) 162 ? or(operands) 163 : ir; 164 } 165 166 @Override 167 public IrBoolExpr visit(IrImplies ir, T a) { 168 IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a); 169 IrBoolExpr consequent = rewrite(ir.getConsequent(), a); 170 return changed(ir.getAntecedent(), antecedent) 171 || changed(ir.getConsequent(), consequent) 172 ? implies(antecedent, consequent) 173 : ir; 174 } 175 176 @Override 177 public IrBoolExpr visit(IrNotImplies ir, T a) { 178 IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a); 179 IrBoolExpr consequent = rewrite(ir.getConsequent(), a); 180 return changed(ir.getAntecedent(), antecedent) 181 || changed(ir.getConsequent(), consequent) 182 ? notImplies(antecedent, consequent) 183 : ir; 184 } 185 186 @Override 187 public IrBoolExpr visit(IrIfThenElse ir, T a) { 188 IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a); 189 IrBoolExpr consequent = rewrite(ir.getConsequent(), a); 190 IrBoolExpr alternative = rewrite(ir.getAlternative(), a); 191 return changed(ir.getAntecedent(), antecedent) 192 || changed(ir.getConsequent(), consequent) 193 || changed(ir.getAlternative(), alternative) 194 ? ifThenElse(antecedent, consequent, alternative) 195 : ir; 196 } 197 198 @Override 199 public IrBoolExpr visit(IrIfOnlyIf ir, T a) { 200 IrBoolExpr left = rewrite(ir.getLeft(), a); 201 IrBoolExpr right = rewrite(ir.getRight(), a); 202 return changed(ir.getLeft(), left) || changed(ir.getRight(), right) 203 ? ifOnlyIf(left, right) 204 : ir; 205 } 206 207 @Override 208 public IrBoolExpr visit(IrXor ir, T a) { 209 IrBoolExpr left = rewrite(ir.getLeft(), a); 210 IrBoolExpr right = rewrite(ir.getRight(), a); 211 return changed(ir.getLeft(), left) || changed(ir.getRight(), right) 212 ? xor(left, right) 213 : ir; 214 } 215 216 @Override 217 public IrBoolExpr visit(IrWithin ir, T a) { 218 IrIntExpr value = rewrite(ir.getValue(), a); 219 return changed(ir.getValue(), value) 220 ? within(value, ir.getRange()) 221 : ir; 222 } 223 224 @Override 225 public IrBoolExpr visit(IrNotWithin ir, T a) { 226 IrIntExpr value = rewrite(ir.getValue(), a); 227 return changed(ir.getValue(), value) 228 ? notWithin(value, ir.getRange()) 229 : ir; 230 } 231 232 @Override 233 public IrBoolExpr visit(IrCompare ir, T a) { 234 IrIntExpr left = rewrite(ir.getLeft(), a); 235 IrIntExpr right = rewrite(ir.getRight(), a); 236 return changed(ir.getLeft(), left) || changed(ir.getRight(), right) 237 ? compare(left, ir.getOp(), right) 238 : ir; 239 } 240 241 @Override 242 public IrBoolExpr visit(IrSetTest ir, T a) { 243 IrSetExpr left = rewrite(ir.getLeft(), a); 244 IrSetExpr right = rewrite(ir.getRight(), a); 245 return changed(ir.getLeft(), left) || changed(ir.getRight(), right) 246 ? equality(left, ir.getOp(), right) 247 : ir; 248 } 249 250 @Override 251 public IrBoolExpr visit(IrMember ir, T a) { 252 IrIntExpr element = rewrite(ir.getElement(), a); 253 IrSetExpr set = rewrite(ir.getSet(), a); 254 return changed(ir.getElement(), element) || changed(ir.getSet(), set) 255 ? member(element, set) 256 : ir; 257 } 258 259 @Override 260 public IrBoolExpr visit(IrNotMember ir, T a) { 261 IrIntExpr element = rewrite(ir.getElement(), a); 262 IrSetExpr set = rewrite(ir.getSet(), a); 263 return changed(ir.getElement(), element) || changed(ir.getSet(), set) 264 ? notMember(element, set) 265 : ir; 266 } 267 268 @Override 269 public IrBoolExpr visit(IrSubsetEq ir, T a) { 270 IrSetExpr subset = rewrite(ir.getSubset(), a); 271 IrSetExpr superset = rewrite(ir.getSuperset(), a); 272 return changed(ir.getSubset(), subset) || changed(ir.getSuperset(), superset) 273 ? subsetEq(subset, superset) 274 : ir; 275 } 276 277 @Override 278 public IrBoolExpr visit(IrBoolChannel ir, T a) { 279 IrBoolExpr[] bools = rewrite(ir.getBools(), a); 280 IrSetExpr set = rewrite(ir.getSet(), a); 281 return changed(ir.getBools(), bools) || changed(ir.getSet(), set) 282 ? boolChannel(bools, set) 283 : ir; 284 } 285 286 @Override 287 public IrBoolExpr visit(IrIntChannel ir, T a) { 288 IrIntExpr[] ints = rewrite(ir.getInts(), a); 289 IrSetExpr[] sets = rewrite(ir.getSets(), a); 290 return changed(ir.getInts(), ints) || changed(ir.getSets(), sets) 291 ? intChannel(ints, sets) 292 : ir; 293 } 294 295 @Override 296 public IrBoolExpr visit(IrSortStrings ir, T a) { 297 IrIntExpr[][] strings = rewrite(ir.getStrings(), a); 298 return changed(ir.getStrings(), strings) 299 ? (ir.isStrict() ? sortStrict(strings) : sort(strings)) 300 : ir; 301 } 302 303 @Override 304 public IrBoolExpr visit(IrSortSets ir, T a) { 305 IrSetExpr[] sets = rewrite(ir.getSets(), a); 306 return changed(ir.getSets(), sets) 307 ? sort(sets) 308 : ir; 309 } 310 311 @Override 312 public IrBoolExpr visit(IrSortStringsChannel ir, T a) { 313 IrIntExpr[][] strings = rewrite(ir.getStrings(), a); 314 IrIntExpr[] ints = rewrite(ir.getInts(), a); 315 return changed(ir.getStrings(), strings) || changed(ir.getInts(), ints) 316 ? sortChannel(strings, ints) 317 : ir; 318 } 319 320 @Override 321 public IrBoolExpr visit(IrAllDifferent ir, T a) { 322 IrIntExpr[] operands = rewrite(ir.getOperands(), a); 323 return changed(ir.getOperands(), operands) 324 ? allDifferent(operands) 325 : ir; 326 } 327 328 @Override 329 public IrBoolExpr visit(IrSelectN ir, T a) { 330 IrBoolExpr[] bools = rewrite(ir.getBools(), a); 331 IrIntExpr n = rewrite(ir.getN(), a); 332 return changed(ir.getBools(), bools) || changed(ir.getN(), n) 333 ? selectN(bools, n) 334 : ir; 335 } 336 337 @Override 338 public IrIntExpr visit(IrAcyclic ir, T a) { 339 IrIntExpr[] edges = rewrite(ir.getEdges(), a); 340 return changed(ir.getEdges(), edges) 341 ? acyclic(edges) 342 : ir; 343 } 344 345 @Override 346 public IrIntExpr visit(IrUnreachable ir, T a) { 347 IrIntExpr[] edges = rewrite(ir.getEdges(), a); 348 return changed(ir.getEdges(), edges) 349 ? unreachable(edges, ir.getFrom(), ir.getTo()) 350 : ir; 351 } 352 353 @Override 354 public IrBoolExpr visit(IrFilterString ir, T a) { 355 IrSetExpr set = rewrite(ir.getSet(), a); 356 IrIntExpr[] string = rewrite(ir.getString(), a); 357 IrIntExpr[] result = rewrite(ir.getResult(), a); 358 return changed(ir.getSet(), set) 359 || changed(ir.getString(), string) 360 || changed(ir.getResult(), result) 361 ? filterString(set, ir.getOffset(), string, result) 362 : ir; 363 } 364 365 @Override 366 public IrIntVar visit(IrIntVar ir, T a) { 367 return ir; 368 } 369 370 @Override 371 public IrIntExpr visit(IrMinus ir, T a) { 372 IrIntExpr expr = rewrite(ir.getExpr(), a); 373 return changed(ir.getExpr(), expr) 374 ? minus(expr) 375 : ir; 376 } 377 378 @Override 379 public IrIntExpr visit(IrCard ir, T a) { 380 IrSetExpr set = rewrite(ir.getSet(), a); 381 return changed(ir.getSet(), set) 382 ? card(set) 383 : ir; 384 } 385 386 @Override 387 public IrIntExpr visit(IrAdd ir, T a) { 388 IrIntExpr[] addends = rewrite(ir.getAddends(), a); 389 return changed(ir.getAddends(), addends) 390 ? add(Util.cons(constant(ir.getOffset()), addends)) 391 : ir; 392 } 393 394 @Override 395 public IrIntExpr visit(IrMul ir, T a) { 396 IrIntExpr multiplicand = rewrite(ir.getMultiplicand(), a); 397 IrIntExpr multiplier = rewrite(ir.getMultiplier(), a); 398 return changed(ir.getMultiplicand(), multiplicand) || changed(ir.getMultiplier(), multiplier) 399 ? mul(multiplicand, multiplier) 400 : ir; 401 } 402 403 @Override 404 public IrIntExpr visit(IrDiv ir, T a) { 405 IrIntExpr dividend = rewrite(ir.getDividend(), a); 406 IrIntExpr divisor = rewrite(ir.getDivisor(), a); 407 return changed(ir.getDividend(), dividend) || changed(ir.getDivisor(), divisor) 408 ? div(dividend, divisor) 409 : ir; 410 } 411 412 @Override 413 public IrIntExpr visit(IrElement ir, T a) { 414 IrIntExpr[] array = rewrite(ir.getArray(), a); 415 IrIntExpr index = rewrite(ir.getIndex(), a); 416 return changed(ir.getArray(), array) || changed(ir.getIndex(), index) 417 ? element(array, index) 418 : ir; 419 } 420 421 @Override 422 public IrIntExpr visit(IrCount ir, T a) { 423 IrIntExpr[] array = rewrite(ir.getArray(), a); 424 return changed(ir.getArray(), array) 425 ? count(ir.getValue(), array) 426 : ir; 427 } 428 429 @Override 430 public IrIntExpr visit(IrSetSum ir, T a) { 431 IrSetExpr set = rewrite(ir.getSet(), a); 432 return changed(ir.getSet(), set) 433 ? sum(set) 434 : ir; 435 } 436 437 @Override 438 public IrIntExpr visit(IrTernary ir, T a) { 439 IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a); 440 IrIntExpr consequent = rewrite(ir.getConsequent(), a); 441 IrIntExpr alternative = rewrite(ir.getAlternative(), a); 442 return changed(ir.getAntecedent(), antecedent) 443 || changed(ir.getConsequent(), consequent) 444 || changed(ir.getAlternative(), alternative) 445 ? ternary(antecedent, consequent, alternative) 446 : ir; 447 } 448 449 @Override 450 public IrSetVar visit(IrSetVar ir, T a) { 451 return ir; 452 } 453 454 @Override 455 public IrSetExpr visit(IrSingleton ir, T a) { 456 IrIntExpr value = rewrite(ir.getValue(), a); 457 return changed(ir.getValue(), value) 458 ? singleton(value) 459 : ir; 460 } 461 462 @Override 463 public IrSetExpr visit(IrArrayToSet ir, T a) { 464 IrIntExpr[] array = rewrite(ir.getArray(), a); 465 return changed(ir.getArray(), array) 466 ? arrayToSet(array, ir.getGlobalCardinality()) 467 : ir; 468 } 469 470 @Override 471 public IrSetExpr visit(IrJoinRelation ir, T a) { 472 IrSetExpr take = rewrite(ir.getTake(), a); 473 IrSetExpr[] children = rewrite(ir.getChildren(), a); 474 return changed(ir.getTake(), take) || changed(ir.getChildren(), children) 475 ? joinRelation(take, children, ir.isInjective()) 476 : ir; 477 } 478 479 @Override 480 public IrSetExpr visit(IrJoinFunction ir, T a) { 481 IrSetExpr take = rewrite(ir.getTake(), a); 482 IrIntExpr[] refs = rewrite(ir.getRefs(), a); 483 return changed(ir.getTake(), take) || changed(ir.getRefs(), refs) 484 ? joinFunction(take, refs, ir.getGlobalCardinality()) 485 : ir; 486 } 487 488 @Override 489 public IrSetExpr visit(IrSetDifference ir, T a) { 490 IrSetExpr minuend = rewrite(ir.getMinuend(), a); 491 IrSetExpr subtrahend = rewrite(ir.getSubtrahend(), a); 492 return changed(ir.getMinuend(), minuend) || changed(ir.getSubtrahend(), subtrahend) 493 ? difference(minuend, subtrahend) 494 : ir; 495 } 496 497 @Override 498 public IrSetExpr visit(IrSetIntersection ir, T a) { 499 IrSetExpr[] operands = rewrite(ir.getOperands(), a); 500 return changed(ir.getOperands(), operands) 501 ? intersection(operands) 502 : ir; 503 } 504 505 @Override 506 public IrSetExpr visit(IrSetUnion ir, T a) { 507 IrSetExpr[] operands = rewrite(ir.getOperands(), a); 508 return changed(ir.getOperands(), operands) 509 ? union(operands, ir.isDisjoint()) 510 : ir; 511 } 512 513 @Override 514 public IrSetExpr visit(IrOffset ir, T a) { 515 IrSetExpr set = rewrite(ir.getSet(), a); 516 return changed(ir.getSet(), set) 517 ? offset(set, ir.getOffset()) 518 : ir; 519 } 520 521 @Override 522 public IrSetExpr visit(IrMask ir, T a) { 523 IrSetExpr set = rewrite(ir.getSet(), a); 524 return changed(ir.getSet(), set) 525 ? mask(set, ir.getFrom(), ir.getTo()) 526 : ir; 527 } 528 529 @Override 530 public IrSetExpr visit(IrSetTernary ir, T a) { 531 IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a); 532 IrSetExpr consequent = rewrite(ir.getConsequent(), a); 533 IrSetExpr alternative = rewrite(ir.getAlternative(), a); 534 return changed(ir.getAntecedent(), antecedent) 535 || changed(ir.getConsequent(), consequent) 536 || changed(ir.getAlternative(), alternative) 537 ? ternary(antecedent, consequent, alternative) 538 : ir; 539 } 540}