001package org.clafer.ir.analysis; 002 003import org.clafer.ir.IrBoolExpr; 004import org.clafer.ir.IrCompare; 005import static org.clafer.ir.IrCompare.Op.Equal; 006import static org.clafer.ir.IrCompare.Op.NotEqual; 007import org.clafer.ir.IrDomain; 008import org.clafer.ir.IrImplies; 009import org.clafer.ir.IrIntExpr; 010import org.clafer.ir.IrLone; 011import org.clafer.ir.IrModule; 012import org.clafer.ir.IrOffset; 013import org.clafer.ir.IrOr; 014import org.clafer.ir.IrRewriter; 015import org.clafer.ir.IrSetExpr; 016import org.clafer.ir.IrUtil; 017import static org.clafer.ir.Irs.*; 018 019/** 020 * 021 * @author jimmy 022 */ 023public class Optimizer { 024 025 private Optimizer() { 026 } 027 028 /** 029 * Optimize the module. 030 * 031 * @param module the module to optimize 032 * @return the optimized module 033 */ 034 public static IrModule optimize(IrModule module) { 035 return optimizer.rewrite(module, null); 036 } 037 private static final IrRewriter<Void> optimizer = new IrRewriter<Void>() { 038 @Override 039 public IrBoolExpr visit(IrLone ir, Void a) { 040 IrBoolExpr[] operands = rewrite(ir.getOperands(), a); 041 if (operands.length == 2) { 042 if (operands[0] instanceof IrCompare) { 043 IrBoolExpr antecedent = operands[1]; 044 IrCompare compare = (IrCompare) operands[0]; 045 if (compare.getOp().isEquality()) { 046 IrBoolExpr opt = optimizeLoneCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight()); 047 if (opt == null) { 048 opt = optimizeLoneCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft()); 049 } 050 if (opt != null) { 051 return opt; 052 } 053 } 054 } 055 if (operands[1] instanceof IrCompare) { 056 IrBoolExpr antecedent = operands[0]; 057 IrCompare compare = (IrCompare) operands[1]; 058 if (compare.getOp().isEquality()) { 059 IrBoolExpr opt = optimizeLoneCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight()); 060 if (opt == null) { 061 opt = optimizeLoneCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft()); 062 } 063 if (opt != null) { 064 return opt; 065 } 066 } 067 } 068 } 069 return changed(ir.getOperands(), operands) 070 ? lone(operands) 071 : ir; 072 } 073 074 @Override 075 public IrBoolExpr visit(IrOr ir, Void a) { 076 IrBoolExpr[] operands = rewrite(ir.getOperands(), a); 077 if (operands.length == 2) { 078 if (operands[0] instanceof IrCompare) { 079 IrBoolExpr antecedent = operands[1]; 080 IrCompare compare = (IrCompare) operands[0]; 081 if (compare.getOp().isEquality()) { 082 IrBoolExpr opt = optimizeOrCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight()); 083 if (opt == null) { 084 opt = optimizeOrCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft()); 085 } 086 if (opt != null) { 087 return opt; 088 } 089 } 090 } 091 if (operands[1] instanceof IrCompare) { 092 IrBoolExpr antecedent = operands[0]; 093 IrCompare compare = (IrCompare) operands[1]; 094 if (compare.getOp().isEquality()) { 095 IrBoolExpr opt = optimizeOrCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight()); 096 if (opt == null) { 097 opt = optimizeOrCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft()); 098 } 099 if (opt != null) { 100 return opt; 101 } 102 } 103 } 104 } 105 return changed(ir.getOperands(), operands) 106 ? or(operands) 107 : ir; 108 } 109 110 @Override 111 public IrBoolExpr visit(IrImplies ir, Void a) { 112 // Rewrite 113 // !a => !b 114 // to 115 // b => a 116 if (ir.getAntecedent().isNegative() && ir.getConsequent().isNegative()) { 117 return rewrite(implies(not(ir.getConsequent()), not(ir.getAntecedent())), a); 118 } 119 // Rewrite 120 // !a => b 121 // to 122 // a or b 123 if (ir.getAntecedent().isNegative()) { 124 return rewrite(or(not(ir.getAntecedent()), ir.getConsequent()), a); 125 } 126 // Rewrite 127 // a => !b 128 // to 129 // a + b <= 1 130 if (ir.getConsequent().isNegative()) { 131 return rewrite(lone(ir.getAntecedent(), not(ir.getConsequent())), a); 132 } 133 IrBoolExpr antecedent = rewrite(ir.getAntecedent(), a); 134 IrBoolExpr consequent = rewrite(ir.getConsequent(), a); 135 if (consequent instanceof IrCompare) { 136 IrCompare compare = (IrCompare) consequent; 137 if (compare.getOp().isEquality()) { 138 IrBoolExpr opt = optimizeImplicationCompare(antecedent, compare.getLeft(), compare.getOp(), compare.getRight()); 139 if (opt == null) { 140 opt = optimizeImplicationCompare(antecedent, compare.getRight(), compare.getOp(), compare.getLeft()); 141 } 142 if (opt != null) { 143 return opt; 144 } 145 } 146 } 147 return changed(ir.getAntecedent(), antecedent) 148 || changed(ir.getConsequent(), consequent) 149 ? implies(antecedent, consequent) 150 : ir; 151 } 152 153 @Override 154 public IrSetExpr visit(IrOffset ir, Void a) { 155 if (ir.getSet() instanceof IrOffset) { 156 // Rewrite 157 // offset(offset(set, a), b) 158 // to 159 // offset(set, a + b) 160 // This optimization is important for going multiple steps up the 161 // hierarchy. 162 IrOffset innerOffset = (IrOffset) ir.getSet(); 163 return rewrite(offset(innerOffset.getSet(), 164 ir.getOffset() + innerOffset.getOffset()), a); 165 } 166 return super.visit(ir, a); 167 } 168 }; 169 170 /** 171 * Optimize {@code lone(antecedent, left `op` right)} where `op` is = or !=. 172 */ 173 private static IrBoolExpr optimizeLoneCompare(IrBoolExpr antecedent, IrIntExpr left, IrCompare.Op op, IrIntExpr right) { 174 IrDomain domain = left.getDomain(); 175 Integer constant = IrUtil.getConstant(right); 176 if (domain.size() == 2 && constant != null) { 177 switch (op) { 178 case Equal: 179 // Rewrite 180 // lone(bool, int = 888) 181 // where dom(int) = {-3, 888} 182 // to 183 // asInt(bool) <= 888 - int 184 // asInt(bool) + int <= 888 185 if (domain.getHighBound() == constant.intValue()) { 186 return lessThanEqual(add(antecedent, left), 187 domain.getHighBound()); 188 } 189 // Rewrite 190 // lone(bool, int = -3) 191 // where dom(int) = {-3, 888} 192 // to 193 // asInt(bool) <= int - (-3) 194 if (domain.getLowBound() == constant.intValue()) { 195 return lessThanEqual(antecedent, 196 sub(left, domain.getLowBound())); 197 } 198 break; 199 case NotEqual: 200 // Rewrite 201 // lone(bool, int != 888) 202 // where dom(int) = {-3, 888} 203 // to 204 // asInt(bool) <= int - (-3) 205 if (domain.getHighBound() == constant.intValue()) { 206 return lessThanEqual(antecedent, 207 sub(left, domain.getLowBound())); 208 } 209 // Rewrite 210 // lone(bool, int != -3) 211 // where dom(int) = {-3, 888} 212 // to 213 // asInt(bool) <= 888 - int 214 // asInt(bool) + int <= 888 215 if (domain.getLowBound() == constant.intValue()) { 216 return lessThanEqual(add(antecedent, left), 217 domain.getHighBound()); 218 } 219 break; 220 } 221 } 222 return null; 223 } 224 225 /** 226 * Optimize {@code antecedent or (left `op` right)} where `op` is = or !=. 227 */ 228 private static IrBoolExpr optimizeOrCompare(IrBoolExpr antecedent, IrIntExpr left, IrCompare.Op op, IrIntExpr right) { 229 IrDomain domain = left.getDomain(); 230 Integer constant = IrUtil.getConstant(right); 231 if (domain.size() == 2 && constant != null) { 232 switch (op) { 233 case Equal: 234 // Rewrite 235 // bool or int = 888 236 // where dom(int) = {-3, 888} 237 // to 238 // asInt(bool) > (-3) - int 239 // asInt(bool) + int > (-3) 240 if (domain.getHighBound() == constant.intValue()) { 241 return greaterThan(add(antecedent, left), 242 domain.getLowBound()); 243 } 244 // Rewrite 245 // bool or int = -3 246 // where dom(int) = {-3, 888} 247 // to 248 // asInt(bool) > int - 888 249 if (domain.getLowBound() == constant.intValue()) { 250 return greaterThan(antecedent, 251 sub(left, domain.getHighBound())); 252 } 253 break; 254 case NotEqual: 255 // Rewrite 256 // bool or int != 888 257 // where dom(int) = {-3, 888} 258 // to 259 // asInt(bool) > int - 888 260 if (domain.getHighBound() == constant.intValue()) { 261 return greaterThan(antecedent, 262 sub(left, domain.getHighBound())); 263 } 264 // Rewrite 265 // bool or int != -3 266 // where dom(int) = {-3, 888} 267 // to 268 // asInt(bool) > (-3) - int 269 // asInt(bool) + int > (-3) 270 if (domain.getLowBound() == constant.intValue()) { 271 return greaterThan(add(antecedent, left), 272 domain.getLowBound()); 273 } 274 break; 275 } 276 } 277 return null; 278 } 279 280 /** 281 * Optimize {@code antecedent => (left `op` right)} where `op` is = or !=. 282 */ 283 private static IrBoolExpr optimizeImplicationCompare(IrBoolExpr antecedent, IrIntExpr left, IrCompare.Op op, IrIntExpr right) { 284 IrDomain domain = left.getDomain(); 285 Integer constant = IrUtil.getConstant(right); 286 if (domain.size() == 2 && constant != null) { 287 switch (op) { 288 case Equal: 289 // Rewrite 290 // bool => int = 888 291 // where dom(int) = {-3, 888} 292 // to 293 // asInt(bool) <= int - (-3) 294 if (domain.getHighBound() == constant.intValue()) { 295 return lessThanEqual(antecedent, 296 sub(left, domain.getLowBound())); 297 } 298 // Rewrite 299 // bool => int = -3 300 // where dom(int) = {-3, 888} 301 // to 302 // asInt(bool) <= 888 - int 303 // asInt(bool) + int <= 888 304 if (domain.getLowBound() == constant.intValue()) { 305 return lessThanEqual(add(antecedent, left), 306 domain.getHighBound()); 307 } 308 break; 309 case NotEqual: 310 // Rewrite 311 // bool => int != 888 312 // where dom(int) = {-3, 888} 313 // to 314 // asInt(bool) <= 888 - int 315 // asInt(bool) + int <= 888 316 if (domain.getHighBound() == constant.intValue()) { 317 return lessThanEqual(add(antecedent, left), 318 domain.getHighBound()); 319 } 320 // Rewrite 321 // bool => int != -3 322 // where dom(int) = {-3, 888} 323 // to 324 // asInt(bool) <= int - (-3) 325 if (domain.getLowBound() == constant.intValue()) { 326 return lessThanEqual(antecedent, 327 sub(left, domain.getLowBound())); 328 } 329 break; 330 } 331 } 332 return null; 333 } 334}