001package org.clafer.ast.analysis;
002
003import java.util.HashMap;
004import java.util.HashSet;
005import java.util.Map;
006import java.util.Map.Entry;
007import java.util.Set;
008import org.clafer.ast.AstAbstractClafer;
009import org.clafer.ast.AstArithm;
010import org.clafer.ast.AstBoolArithm;
011import org.clafer.ast.AstBoolExpr;
012import org.clafer.ast.AstCard;
013import org.clafer.ast.AstClafer;
014import org.clafer.ast.AstCompare;
015import org.clafer.ast.AstConcreteClafer;
016import org.clafer.ast.AstConstant;
017import org.clafer.ast.AstConstraint;
018import org.clafer.ast.AstDecl;
019import org.clafer.ast.AstDifference;
020import org.clafer.ast.AstDowncast;
021import org.clafer.ast.AstExpr;
022import org.clafer.ast.AstExprVisitor;
023import org.clafer.ast.AstGlobal;
024import org.clafer.ast.AstIfThenElse;
025import org.clafer.ast.AstIntClafer;
026import org.clafer.ast.AstIntersection;
027import org.clafer.ast.AstJoin;
028import org.clafer.ast.AstJoinParent;
029import org.clafer.ast.AstJoinRef;
030import org.clafer.ast.AstLocal;
031import org.clafer.ast.AstMembership;
032import org.clafer.ast.AstMinus;
033import org.clafer.ast.AstNot;
034import org.clafer.ast.AstPrimClafer;
035import org.clafer.ast.AstQuantify;
036import org.clafer.ast.AstRef;
037import org.clafer.ast.AstSetExpr;
038import org.clafer.ast.AstSetTest;
039import org.clafer.ast.AstSum;
040import org.clafer.ast.AstTernary;
041import org.clafer.ast.AstThis;
042import org.clafer.ast.AstUnion;
043import org.clafer.ast.AstUpcast;
044import org.clafer.ast.AstUtil;
045import static org.clafer.ast.Asts.*;
046import org.clafer.common.Check;
047import org.clafer.common.Util;
048import org.clafer.objective.Objective;
049
050/**
051 * <p>
052 * Type checks and creates explicit upcast nodes in the AST. When the
053 * expressions are rewritten, the types need to be reanalyzed.
054 * </p>
055 * <p>
056 * <pre>
057 * abstract A
058 *     a
059 * abstract B : A
060 *     b
061 * C : B
062 *     c
063 * D : B
064 *     d
065 * </pre>
066 * </p>
067 * <p>
068 * A lowest common supertype in this solver directly corresponds to how the
069 * expression is stored as a set. For example, suppose there is an expression
070 * that evaluates to {C0, D0}. Unfortunately, C0 is stored as 0 for the type C
071 * and D0 is also stored as 0 for the type D. The way the solver does it is to
072 * upcast both C0 and D0 to the type B, where C0 = B0 and D0 = B1 so the set
073 * {C0, D0} can be stored as {B0, B1}, ie. {0, 1} in Choco.
074 * {@link Type#getCommonSuperType()} of each expression is the type used for
075 * representing the set in Choco.
076 * </p>
077 * <p>
078 * Even though the expressions are stored as the common supertype, it is
079 * sometimes necessary to reconvert it back to the subtype. For example,
080 * consider the expression {@code (C ++ D).d}. {@code (C ++ D)} has the union
081 * type {C, D} but is stored as the supertype B. The join {@code .d} is allowed
082 * because D is in the union type of {@code (C ++ D)}, but since it is stored as
083 * a set of B's, the set first needs to be downcasted to a set of D's before the
084 * join can be performed. If the expression were {@code (C ++ D).a}, then the
085 * set of B's would need to be upcasted to a set of A's. If the expression were
086 * {@code (C ++ D).b} then no casting is required.
087 * </p>
088 *
089 * @author jimmy
090 */
091public class TypeAnalyzer implements Analyzer {
092
093    @Override
094    public Analysis analyze(Analysis analysis) {
095        Map<AstExpr, Type> typeMap = new HashMap<>();
096        Map<AstConstraint, AstBoolExpr> typedConstraints = new HashMap<>();
097        for (AstConstraint constraint : analysis.getConstraints()) {
098            AstClafer clafer = constraint.getContext();
099            TypeVisitor visitor = new TypeVisitor(Type.basicType(clafer), typeMap);
100            TypedExpr<AstBoolExpr> typedConstraint = visitor.typeCheck(analysis.getExpr(constraint));
101            typedConstraints.put(constraint, typedConstraint.getExpr());
102        }
103        Map<Objective, AstSetExpr> objectives = analysis.getObjectiveExprs();
104        Map<Objective, AstSetExpr> typedObjectives = new HashMap<>(objectives.size());
105        for (Entry<Objective, AstSetExpr> objective : objectives.entrySet()) {
106            TypeVisitor visitor = new TypeVisitor(Type.basicType(analysis.getModel()), typeMap);
107            TypedExpr<AstSetExpr> typedObjective = visitor.typeCheck(objective.getValue());
108            if (!(typedObjective.getCommonSupertype() instanceof AstIntClafer)) {
109                throw new TypeException("Cannot optimize on " + typedObjective.getType());
110            }
111            typedObjectives.put(objective.getKey(), typedObjective.getExpr());
112        }
113        return analysis.setTypeMap(typeMap)
114                .setConstraintExprs(typedConstraints)
115                .setObjectiveExprs(typedObjectives);
116    }
117
118    private static class TypeVisitor implements AstExprVisitor<Void, TypedExpr<?>> {
119
120        private final Type context;
121        private final Map<AstExpr, Type> typeMap;
122
123        TypeVisitor(Type context, Map<AstExpr, Type> typeMap) {
124            this.context = context;
125            this.typeMap = typeMap;
126        }
127
128        private <T extends AstExpr> TypedExpr<T> typeCheck(T expr) {
129            @SuppressWarnings("unchecked")
130            TypedExpr<T> typedExpr = (TypedExpr<T>) expr.accept(this, null);
131            return typedExpr;
132        }
133
134        private <T extends AstExpr> TypedExpr<T>[] typeCheck(T[] exprs) {
135            @SuppressWarnings("unchecked")
136            TypedExpr<T>[] typeChecked = new TypedExpr[exprs.length];
137            for (int i = 0; i < exprs.length; i++) {
138                typeChecked[i] = typeCheck(exprs[i]);
139            }
140            return typeChecked;
141        }
142
143        /**
144         * Multilevel cast.
145         *
146         * @param expr the expression
147         * @param target the target type
148         * @return the same expression but with the target type
149         */
150        private AstSetExpr castTo(TypedExpr<AstSetExpr> expr, AstClafer target) {
151            if (AstUtil.isAssignable(expr.getCommonSupertype(), target)) {
152                return upcastTo(expr, target);
153            } else if (isAnyAssignable(expr.getUnionType(), target)) {
154                return downcastTo(expr, target);
155            }
156            throw new TypeException("Cannot cast " + expr.getType() + " to " + target);
157        }
158
159        /**
160         * Multilevel downcast.
161         *
162         * @param expr the expression
163         * @param target the target type
164         * @return the same expression but with the target type
165         */
166        private AstSetExpr downcastTo(TypedExpr<AstSetExpr> expr, AstClafer target) {
167            if (expr.getType().getCommonSuperType().equals(target)) {
168                return expr.getExpr();
169            }
170            if (isAnyAssignable(expr.getUnionType(), target)) {
171                AstSetExpr subExpr = downcast(expr.getExpr(), target);
172                put(Type.basicType(target), subExpr);
173                return subExpr;
174            }
175            throw new TypeException("Cannot downcast " + expr.getType() + " to " + target);
176        }
177
178        /**
179         * Multilevel upcast.
180         *
181         * @param expr the expression
182         * @param target the target type
183         * @return the same expression but with the target type
184         */
185        private AstSetExpr upcastTo(TypedExpr<AstSetExpr> expr, AstClafer target) {
186            if (expr.getType().getCommonSuperType().equals(target)) {
187                return expr.getExpr();
188            }
189            if (AstUtil.isAssignable(expr.getCommonSupertype(), target)) {
190                AstSetExpr superExpr = upcast(expr.getExpr(), (AstAbstractClafer) target);
191                put(Type.basicType(target), superExpr);
192                return superExpr;
193            }
194            throw new TypeException("Cannot upcast " + expr.getType() + " to " + target);
195        }
196
197        private <T extends AstExpr> TypedExpr<T> put(AstClafer basicType, T expr) {
198            return put(Type.basicType(basicType), expr);
199        }
200
201        private <T extends AstExpr> TypedExpr<T> put(Type type, T expr) {
202            typeMap.put(expr, type);
203            return new TypedExpr<>(type, expr);
204        }
205
206        @Override
207        public TypedExpr<AstThis> visit(AstThis ast, Void a) {
208            return put(context, ast);
209        }
210
211        @Override
212        public TypedExpr<AstGlobal> visit(AstGlobal ast, Void a) {
213            return put(Type.basicType(ast.getType()), ast);
214        }
215
216        @Override
217        public TypedExpr<AstConstant> visit(AstConstant ast, Void a) {
218            return put(Type.basicType(ast.getType()), ast);
219        }
220
221        @Override
222        public TypedExpr<AstSetExpr> visit(AstJoin ast, Void a) {
223            TypedExpr<AstSetExpr> left = typeCheck(ast.getLeft());
224            AstConcreteClafer rightType = ast.getRight();
225            if (rightType.hasParent()) {
226                AstClafer joinType = rightType.getParent();
227                if (isAnyAssignable(left.getType().getUnionType(), joinType)) {
228                    return put(rightType, join(castTo(left, joinType), rightType));
229                }
230            }
231            throw new TypeException("Cannot join " + left.getType() + " . " + rightType);
232        }
233
234        @Override
235        public TypedExpr<AstSetExpr> visit(AstJoinParent ast, Void a) {
236            TypedExpr<AstSetExpr> children = typeCheck(ast.getChildren());
237            if (children.getType().isBasicType()) {
238                AstClafer childrenType = children.getType().getBasicType();
239                if (childrenType instanceof AstConcreteClafer) {
240                    AstConcreteClafer concreteChildrenType = (AstConcreteClafer) childrenType;
241                    if (concreteChildrenType.hasParent()) {
242                        return put(concreteChildrenType.getParent(),
243                                joinParent(children.getExpr()));
244                    }
245                }
246            }
247            throw new TypeException("Cannot join " + children.getType() + " . parent");
248        }
249
250        @Override
251        public TypedExpr<AstSetExpr> visit(AstJoinRef ast, Void a) {
252            TypedExpr<AstSetExpr> deref = typeCheck(ast.getDeref());
253
254            Set<AstRef> refs = new HashSet<>();
255            for (AstClafer type : deref.getUnionType()) {
256                AstRef ref = AstUtil.getInheritedRef(type);
257                if (ref != null) {
258                    refs.add(ref);
259                }
260            }
261            switch (refs.size()) {
262                case 0:
263                    throw new TypeException("Cannot join " + deref.getType() + " . ref");
264                case 1:
265                    AstRef ref = refs.iterator().next();
266                    return put(ref.getTargetType(), joinRef(castTo(deref, ref.getSourceType())));
267                default:
268                    throw new TypeException("Ambiguous join " + deref.getType() + " . ref");
269            }
270        }
271
272        @Override
273        public TypedExpr<AstBoolExpr> visit(AstNot ast, Void a) {
274            TypedExpr<AstBoolExpr> expr = typeCheck(ast.getExpr());
275            return put(BoolType, not(expr.getExpr()));
276        }
277
278        @Override
279        public TypedExpr<AstSetExpr> visit(AstMinus ast, Void a) {
280            TypedExpr<AstSetExpr> expr = typeCheck(ast.getExpr());
281            if (expr.getCommonSupertype() instanceof AstIntClafer) {
282                return put(IntType, minus(expr.getExpr()));
283            }
284            throw new TypeException("Cannot -" + expr.getType());
285        }
286
287        @Override
288        public TypedExpr<AstSetExpr> visit(AstCard ast, Void a) {
289            TypedExpr<AstSetExpr> set = typeCheck(ast.getSet());
290            if (set.getCommonSupertype() instanceof AstPrimClafer) {
291                throw new TypeException("Cannot |" + set.getType() + "|");
292            }
293            return put(IntType, card(set.getExpr()));
294        }
295
296        @Override
297        public TypedExpr<AstBoolExpr> visit(AstSetTest ast, Void a) {
298            TypedExpr<AstSetExpr> left = typeCheck(ast.getLeft());
299            TypedExpr<AstSetExpr> right = typeCheck(ast.getRight());
300
301            if (isDisjoint(left.getType(), right.getType())) {
302                throw new TypeException("Cannot " + left.getType() + " "
303                        + ast.getOp().getSyntax() + " " + right.getType());
304            }
305
306            AstClafer commonType = AstUtil.getLowestCommonSupertype(left.getCommonSupertype(), right.getCommonSupertype());
307            return put(BoolType, test(upcastTo(left, commonType), ast.getOp(), upcastTo(right, commonType)));
308        }
309
310        @Override
311        public TypedExpr<AstBoolExpr> visit(AstCompare ast, Void a) {
312            TypedExpr<AstSetExpr> left = typeCheck(ast.getLeft());
313            TypedExpr<AstSetExpr> right = typeCheck(ast.getRight());
314            if (left.getCommonSupertype() instanceof AstIntClafer
315                    && right.getCommonSupertype() instanceof AstIntClafer) {
316                return put(BoolType, compare(left.getExpr(), ast.getOp(), right.getExpr()));
317            }
318            throw new TypeException("Cannot " + left.getType() + " "
319                    + ast.getOp().getSyntax() + " " + right.getType());
320        }
321
322        @Override
323        public TypedExpr<AstSetExpr> visit(AstArithm ast, Void a) {
324            TypedExpr<AstSetExpr>[] operands = typeCheck(ast.getOperands());
325            for (TypedExpr<AstSetExpr> operand : operands) {
326                if (!(operand.getCommonSupertype() instanceof AstIntClafer)) {
327                    throw new TypeException("Cannot "
328                            + Util.intercalate(" " + ast.getOp().getSyntax() + " ",
329                            getTypes(operands)));
330                }
331            }
332            return put(IntType, arithm(ast.getOp(), getSetExprs(operands)));
333        }
334
335        @Override
336        public TypedExpr<AstSetExpr> visit(AstSum ast, Void a) {
337            TypedExpr<AstSetExpr> set = typeCheck(ast.getSet());
338
339            Set<AstRef> refs = new HashSet<>();
340            for (AstClafer type : set.getUnionType()) {
341                AstRef ref = AstUtil.getInheritedRef(type);
342                if (ref != null) {
343                    refs.add(ref);
344                }
345            }
346            switch (refs.size()) {
347                case 0:
348                    throw new TypeException("Cannot sum(" + set.getType() + ")");
349                case 1:
350                    AstRef ref = refs.iterator().next();
351                    return put(ref.getTargetType(), sum(castTo(set, ref.getSourceType())));
352                default:
353                    throw new TypeException("Ambiguous sum(" + set.getType() + ")");
354            }
355        }
356
357        @Override
358        public TypedExpr<?> visit(AstBoolArithm ast, Void a) {
359            TypedExpr<AstBoolExpr>[] operands = typeCheck(ast.getOperands());
360            return put(BoolType, arithm(ast.getOp(), getBoolExprs(operands)));
361        }
362
363        @Override
364        public TypedExpr<?> visit(AstDifference ast, Void a) {
365            TypedExpr<AstSetExpr> left = typeCheck(ast.getLeft());
366            TypedExpr<AstSetExpr> right = typeCheck(ast.getRight());
367
368            Set<AstClafer> unionType = new HashSet<>();
369            unionType.addAll(left.getUnionType());
370            unionType.addAll(right.getUnionType());
371
372            // TODO: check for primitives
373
374            Type type = new Type(left.getUnionType(), AstUtil.getLowestCommonSupertype(unionType));
375
376            return put(type, diff(
377                    upcastTo(left, type.getCommonSuperType()),
378                    upcastTo(right, type.getCommonSuperType())));
379        }
380
381        @Override
382        public TypedExpr<?> visit(AstIntersection ast, Void a) {
383            TypedExpr<AstSetExpr> left = typeCheck(ast.getLeft());
384            TypedExpr<AstSetExpr> right = typeCheck(ast.getRight());
385
386            Set<AstClafer> unionType = new HashSet<>();
387            unionType.addAll(left.getUnionType());
388            unionType.addAll(right.getUnionType());
389
390            Set<AstClafer> intersectionType = intersectionType(left.getType(), right.getType());
391
392            // TODO: check for primitives
393
394            Type type = new Type(intersectionType, AstUtil.getLowestCommonSupertype(unionType));
395
396            return put(type, inter(
397                    upcastTo(left, type.getCommonSuperType()),
398                    upcastTo(right, type.getCommonSuperType())));
399        }
400
401        @Override
402        public TypedExpr<?> visit(AstUnion ast, Void a) {
403            TypedExpr<AstSetExpr> left = typeCheck(ast.getLeft());
404            TypedExpr<AstSetExpr> right = typeCheck(ast.getRight());
405
406            Set<AstClafer> unionType = new HashSet<>();
407            unionType.addAll(left.getUnionType());
408            unionType.addAll(right.getUnionType());
409
410            // TODO: check for primitives
411
412            Type type = new Type(unionType);
413
414            if (type.getCommonSuperType() == null) {
415                throw new TypeException("Cannot " + left.getType() + " ++ " + right.getType());
416            }
417
418            return put(type, union(
419                    upcastTo(left, type.getCommonSuperType()),
420                    upcastTo(right, type.getCommonSuperType())));
421        }
422
423        @Override
424        public TypedExpr<AstBoolExpr> visit(AstMembership ast, Void a) {
425            TypedExpr<AstSetExpr> member = typeCheck(ast.getMember());
426            TypedExpr<AstSetExpr> set = typeCheck(ast.getSet());
427
428            if (isDisjoint(member.getType(), set.getType())) {
429                throw new TypeException("Cannot " + member.getType()
430                        + " " + ast.getOp().getSyntax() + " " + set.getType());
431            }
432
433            AstClafer commonType = AstUtil.getLowestCommonSupertype(member.getCommonSupertype(), set.getCommonSupertype());
434            return put(BoolType, membership(upcastTo(member, commonType), ast.getOp(), upcastTo(set, commonType)));
435        }
436
437        @Override
438        public TypedExpr<AstSetExpr> visit(AstTernary ast, Void a) {
439            TypedExpr<AstBoolExpr> antecedent = typeCheck(ast.getAntecedent());
440            TypedExpr<AstSetExpr> alternative = typeCheck(ast.getAlternative());
441            TypedExpr<AstSetExpr> consequent = typeCheck(ast.getConsequent());
442            AstClafer unionType = AstUtil.getLowestCommonSupertype(alternative.getCommonSupertype(), consequent.getCommonSupertype());
443            if (unionType == null) {
444                throw new TypeException("Cannot if " + antecedent.getType() + " then "
445                        + consequent.getType() + " else " + alternative.getType());
446            }
447            return put(unionType, ifThenElse(antecedent.getExpr(),
448                    upcastTo(consequent, unionType), upcastTo(alternative, unionType)));
449        }
450
451        @Override
452        public TypedExpr<AstBoolExpr> visit(AstIfThenElse ast, Void a) {
453            TypedExpr<AstBoolExpr> antecedent = typeCheck(ast.getAntecedent());
454            TypedExpr<AstBoolExpr> alternative = typeCheck(ast.getAlternative());
455            TypedExpr<AstBoolExpr> consequent = typeCheck(ast.getConsequent());
456            return put(BoolType, ifThenElse(antecedent.getExpr(), alternative.getExpr(), consequent.getExpr()));
457        }
458
459        @Override
460        public TypedExpr<?> visit(AstDowncast ast, Void a) {
461            TypedExpr<AstSetExpr> base = typeCheck(ast.getBase());
462            AstClafer to = ast.getTarget();
463            if (isAnyAssignable(base.getUnionType(), to)) {
464                return put(to, downcast(base.getExpr(), ast.getTarget()));
465            }
466            throw new TypeException("Cannot downcast from " + base.getType() + " to " + to);
467        }
468
469        @Override
470        public TypedExpr<AstSetExpr> visit(AstUpcast ast, Void a) {
471            TypedExpr<AstSetExpr> base = typeCheck(ast.getBase());
472            AstAbstractClafer to = ast.getTarget();
473            if (AstUtil.isAssignable(base.getCommonSupertype(), to)) {
474                return put(new Type(base.getUnionType(), to), upcast(base.getExpr(), ast.getTarget()));
475            }
476            throw new TypeException("Cannot upcast from " + base.getType() + " to " + to);
477        }
478
479        @Override
480        public TypedExpr<AstLocal> visit(AstLocal ast, Void a) {
481            Type localType = typeMap.get(ast);
482            if (localType == null) {
483                throw new AnalysisException(ast + " type not analyzed yet.");
484            }
485            return put(localType, ast);
486        }
487
488        @Override
489        public TypedExpr<AstBoolExpr> visit(AstQuantify ast, Void a) {
490            AstDecl[] decls = new AstDecl[ast.getDecls().length];
491            for (int i = 0; i < ast.getDecls().length; i++) {
492                AstDecl decl = ast.getDecls()[i];
493                TypedExpr<AstSetExpr> body = typeCheck(decl.getBody());
494                for (AstLocal local : decl.getLocals()) {
495                    put(body.getType(), local);
496                }
497                decls[i] = decl(decl.isDisjoint(), decl.getLocals(), body.getExpr());
498            }
499            TypedExpr<AstBoolExpr> body = typeCheck(ast.getBody());
500            return put(BoolType, quantify(ast.getQuantifier(), decls, body.getExpr()));
501        }
502    }
503
504    private static Type[] getTypes(TypedExpr<?>... exprs) {
505        Type[] types = new Type[exprs.length];
506        for (int i = 0; i < types.length; i++) {
507            types[i] = exprs[i].getType();
508        }
509        return types;
510    }
511
512    private static <T extends AstBoolExpr> AstBoolExpr[] getBoolExprs(TypedExpr<T>[] exprs) {
513        AstBoolExpr[] boolExprs = new AstBoolExpr[exprs.length];
514        for (int i = 0; i < boolExprs.length; i++) {
515            boolExprs[i] = exprs[i].getExpr();
516        }
517        return boolExprs;
518    }
519
520    private static <T extends AstSetExpr> AstSetExpr[] getSetExprs(TypedExpr<T>[] exprs) {
521        AstSetExpr[] setExprs = new AstSetExpr[exprs.length];
522        for (int i = 0; i < setExprs.length; i++) {
523            setExprs[i] = exprs[i].getExpr();
524        }
525        return setExprs;
526    }
527
528    private static boolean isAnyAssignable(Iterable<AstClafer> froms, AstClafer to) {
529        for (AstClafer from : froms) {
530            if (AstUtil.isAssignable(from, to)) {
531                return true;
532            }
533        }
534        return false;
535    }
536
537    private static boolean isAnyAssignable(AstClafer from, Iterable<AstClafer> tos) {
538        for (AstClafer to : tos) {
539            if (AstUtil.isAssignable(from, to)) {
540                return true;
541            }
542        }
543        return false;
544    }
545
546    private static boolean isDisjoint(Type t1, Type t2) {
547        for (AstClafer leftType : t1.getUnionType()) {
548            if (isAnyAssignable(leftType, t2.getUnionType())) {
549                return false;
550            }
551        }
552        for (AstClafer rightType : t2.getUnionType()) {
553            if (isAnyAssignable(rightType, t1.getUnionType())) {
554                return false;
555            }
556        }
557        return true;
558    }
559
560    private static Set<AstClafer> intersectionType(Type t1, Type t2) {
561        Set<AstClafer> interType = new HashSet<>();
562        for (AstClafer leftType : t1.getUnionType()) {
563            if (isAnyAssignable(leftType, t2.getUnionType())) {
564                interType.add(leftType);
565            }
566        }
567        for (AstClafer rightType : t2.getUnionType()) {
568            if (isAnyAssignable(rightType, t1.getUnionType())) {
569                interType.add(rightType);
570            }
571        }
572        return interType;
573    }
574
575    private static class TypedExpr<T extends AstExpr> {
576
577        private final Type type;
578        private final T expr;
579
580        TypedExpr(Type type, T expr) {
581            this.type = Check.notNull(type);
582            this.expr = Check.notNull(expr);
583        }
584
585        public Type getType() {
586            return type;
587        }
588
589        public Set<AstClafer> getUnionType() {
590            return type.getUnionType();
591        }
592
593        public AstClafer getCommonSupertype() {
594            return type.getCommonSuperType();
595        }
596
597        public boolean isBasicType() {
598            return type.isBasicType();
599        }
600
601        public AstClafer getBasicType() {
602            return type.getBasicType();
603        }
604
605        public T getExpr() {
606            return expr;
607        }
608    }
609}