Skip to content

Commit

Permalink
Use recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
hadrienk committed Nov 4, 2024
1 parent b4e7224 commit 2d17976
Showing 1 changed file with 46 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
import fr.insee.vtl.parser.VtlBaseVisitor;
import fr.insee.vtl.parser.VtlParser;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;

import static fr.insee.vtl.engine.VtlScriptEngine.fromContext;
Expand Down Expand Up @@ -149,92 +146,57 @@ public ResolvableExpression visitIfExpr(VtlParser.IfExprContext ctx) {
*/
@Override
public ResolvableExpression visitCaseExpr(VtlParser.CaseExprContext ctx) {
Positioned pos = fromContext(ctx);
List<VtlParser.ExprContext> exprs = ctx.expr();
List<VtlParser.ExprContext> whenExprs = new ArrayList<>();
List<VtlParser.ExprContext> thenExprs = new ArrayList<>();
for (int i = 0; i < exprs.size() - 1; i = i + 2) {
whenExprs.add(exprs.get(i));
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(e -> assertBoolean(exprVisitor.visit(e), e))
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
ResolvableExpression elseExpression = exprVisitor.visit(exprs.get(exprs.size() - 1));
List<ResolvableExpression> forTypeCheck = (new ArrayList<>(thenExpressions));
forTypeCheck.add(elseExpression);
// TODO: handle better the default element position
if (!hasSameTypeOrNull(forTypeCheck)) {
try {
throw new InvalidTypeException(
forTypeCheck.get(0).getClass(),
Boolean.class,
fromContext(ctx.expr(0))
);
} catch (InvalidTypeException e) {
throw new RuntimeException(e);
try {
Positioned pos = fromContext(ctx);
List<VtlParser.ExprContext> exprs = ctx.expr();
List<VtlParser.ExprContext> whenExprs = new ArrayList<>();
List<VtlParser.ExprContext> thenExprs = new ArrayList<>();
for (int i = 0; i < exprs.size() - 1; i = i + 2) {
whenExprs.add(exprs.get(i));
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(e -> assertBoolean(exprVisitor.visit(e), e))
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
ResolvableExpression elseExpression = exprVisitor.visit(exprs.get(exprs.size() - 1));
List<ResolvableExpression> forTypeCheck = (new ArrayList<>(thenExpressions));
forTypeCheck.add(elseExpression);
// TODO: handle better the default element position
if (!hasSameTypeOrNull(forTypeCheck)) {
try {
throw new InvalidTypeException(
forTypeCheck.get(0).getClass(),
Boolean.class,
fromContext(ctx.expr(0))
);
} catch (InvalidTypeException e) {
throw new RuntimeException(e);
}
}
}

Class<?> outputType = elseExpression.getType();

if (outputType.equals(String.class)) {
return ResolvableExpression.withType(String.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (String) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (String) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
}
if (outputType.equals(Double.class)) {
return ResolvableExpression.withType(Double.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Double) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Double) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
Class<?> outputType = elseExpression.getType();
return new CastExpression(pos, caseToIfIt(whenExpressions.listIterator(), thenExpressions.listIterator(), elseExpression), outputType);
} catch (VtlScriptException e) {
throw new VtlRuntimeException(e);
}
if (outputType.equals(Long.class)) {
return ResolvableExpression.withType(Long.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Long) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Long) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
}

private ResolvableExpression caseToIfIt(ListIterator<ResolvableExpression> whenExpr, ListIterator<ResolvableExpression> thenExpr, ResolvableExpression elseExpression) throws VtlScriptException {
if (!whenExpr.hasNext() || !thenExpr.hasNext()) {
return elseExpression;
}
if (outputType.equals(Boolean.class)) {
return ResolvableExpression.withType(Boolean.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Boolean) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Boolean) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
} else return null;

return genericFunctionsVisitor.invokeFunction("ifThenElse", Java8Helpers.listOf(
whenExpr.next(),
thenExpr.next(),
caseToIfIt(whenExpr, thenExpr, elseExpression)
), null);
}


/**
* Visits nvl expressions.
*
Expand Down

0 comments on commit 2d17976

Please sign in to comment.