001package org.clafer.ast.analysis;
002
003import java.util.HashMap;
004import java.util.Map;
005import org.clafer.ast.AstBoolExpr;
006import org.clafer.ast.AstClafer;
007import org.clafer.ast.AstConcreteClafer;
008import org.clafer.ast.AstConstant;
009import org.clafer.ast.AstConstraint;
010import org.clafer.ast.AstExpr;
011import org.clafer.ast.AstExprRewriter;
012import org.clafer.ast.AstGlobal;
013import org.clafer.ast.AstJoin;
014import org.clafer.ast.AstJoinParent;
015import org.clafer.ast.AstSetExpr;
016import org.clafer.ast.AstThis;
017import static org.clafer.ast.Asts.*;
018import org.clafer.ast.Card;
019import org.clafer.common.Util;
020
021/**
022 * Optimizes the expressions inside the constraints. Assumes type checking has
023 * already passed.
024 *
025 * @author jimmy
026 */
027public class OptimizerAnalyzer extends AstExprRewriter<Analysis> implements Analyzer {
028
029    @Override
030    public Analysis analyze(Analysis analysis) {
031        Map<AstConstraint, AstBoolExpr> constraintExprs = new HashMap<>();
032        for (AstConstraint constraint : analysis.getConstraints()) {
033            constraintExprs.put(constraint, rewrite(analysis.getExpr(constraint), analysis));
034        }
035        return analysis.setConstraintExprs(constraintExprs);
036    }
037
038    @Override
039    public AstExpr visit(AstJoin ast, Analysis a) {
040        AstSetExpr left = rewrite(ast.getLeft(), a);
041        if (left instanceof AstThis) {
042            if (a.getScope(a.getCommonSupertype(ast.getLeft())) == 1) {
043                Card childCard = a.getCard(ast.getRight());
044                if (childCard.isExact()) {
045                    return constant(ast.getRight(), Util.fromTo(0, childCard.getLow()));
046                }
047                return global(ast.getRight());
048            }
049        } else if (left instanceof AstGlobal) {
050            return global(ast.getRight());
051        } else if (left instanceof AstConstant) {
052            Card childCard = a.getCard(ast.getRight());
053            if (Format.ParentGroup.equals(a.getFormat(ast.getRight()))) {
054                AstConstant constant = (AstConstant) left;
055                assert childCard.isExact();
056                int[] childConstant = new int[constant.getValue().length * childCard.getLow()];
057                for (int i = 0; i < constant.getValue().length; i++) {
058                    for (int j = 0; j < childCard.getLow(); j++) {
059                        childConstant[i * childCard.getLow() + j] =
060                                i * constant.getValue()[i] + j;
061                    }
062                }
063                return constant(ast.getRight(), childConstant);
064            }
065            assert childCard.getLow() != a.getScope(ast.getRight()) :
066                    "Didn't run scope analysis before format analysis?";
067        }
068        return join(left, ast.getRight());
069    }
070
071    @Override
072    public AstExpr visit(AstJoinParent ast, Analysis a) {
073        AstSetExpr children = rewrite(ast.getChildren(), a);
074        if (children instanceof AstThis) {
075            AstClafer type = a.getCommonSupertype(ast);
076            if (a.getScope(type) == 1) {
077                return constant(type, 0);
078            }
079        } else if (children instanceof AstGlobal) {
080            AstClafer childType = a.getCommonSupertype(ast.getChildren());
081            if (childType instanceof AstConcreteClafer) {
082                AstConcreteClafer concreteChildType = (AstConcreteClafer) childType;
083                if (a.getCard(concreteChildType).hasLow()) {
084                    return global(a.getCommonSupertype(ast));
085                }
086            }
087        } else if (children instanceof AstConstant) {
088            AstConstant constant = (AstConstant) children;
089            AstClafer type = a.getCommonSupertype(ast);
090            if (constant.getValue().length > 0 && a.getScope(type) == 1) {
091                return constant(type, 0);
092            }
093        }
094        return joinParent(children);
095    }
096    // TDODO: rewrite for all (global)
097}