Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor case #369

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ public class VtlNativeMethods {
Fun.<Boolean, Double, Double>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, String, String>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, Boolean, Boolean>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, Long>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, Double>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, String>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, Boolean>toMethod(ConditionalVisitor::caseFn),
Fun.<Long, Long>toMethod(ConditionalVisitor::nvl),
Fun.<Double, Double>toMethod(ConditionalVisitor::nvl),
Fun.<Double, Long>toMethod(ConditionalVisitor::nvl),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@
import fr.insee.vtl.parser.VtlBaseVisitor;
import fr.insee.vtl.parser.VtlParser;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;

import static fr.insee.vtl.engine.VtlScriptEngine.fromContext;
import static fr.insee.vtl.engine.utils.TypeChecking.assertBoolean;
import static fr.insee.vtl.engine.utils.TypeChecking.hasSameTypeOrNull;

/**
Expand Down Expand Up @@ -68,34 +64,6 @@ public static Boolean ifThenElse(Boolean condition, Boolean thenExpr, Boolean el
return condition ? thenExpr : elseExpr;
}

public static Long caseFn(Boolean condition, Long thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Double caseFn(Boolean condition, Double thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static String caseFn(Boolean condition, String thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Boolean caseFn(Boolean condition, Boolean thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Long nvl(Long value, Long defaultValue) {
return value == null ? defaultValue : value;
}
Expand Down Expand Up @@ -149,92 +117,59 @@ public ResolvableExpression visitIfExpr(VtlParser.IfExprContext ctx) {
*/
@Override
public ResolvableExpression visitCaseExpr(VtlParser.CaseExprContext ctx) {
Positioned pos = fromContext(ctx);
List<VtlParser.ExprContext> exprs = ctx.expr();
List<VtlParser.ExprContext> whenExprs = new ArrayList<>();
List<VtlParser.ExprContext> thenExprs = new ArrayList<>();
for (int i = 0; i < exprs.size() - 1; i = i + 2) {
whenExprs.add(exprs.get(i));
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(e -> assertBoolean(exprVisitor.visit(e), e))
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
ResolvableExpression elseExpression = exprVisitor.visit(exprs.get(exprs.size() - 1));
List<ResolvableExpression> forTypeCheck = (new ArrayList<>(thenExpressions));
forTypeCheck.add(elseExpression);
// TODO: handle better the default element position
if (!hasSameTypeOrNull(forTypeCheck)) {
try {
throw new InvalidTypeException(
forTypeCheck.get(0).getClass(),
Boolean.class,
fromContext(ctx.expr(0))
);
} catch (InvalidTypeException e) {
throw new RuntimeException(e);
try {
Positioned pos = fromContext(ctx);
List<VtlParser.ExprContext> exprs = ctx.expr();
List<VtlParser.ExprContext> whenExprs = new ArrayList<>();
List<VtlParser.ExprContext> thenExprs = new ArrayList<>();
for (int i = 0; i < exprs.size() - 1; i = i + 2) {
whenExprs.add(exprs.get(i));
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
ResolvableExpression elseExpression = exprVisitor.visit(exprs.get(exprs.size() - 1));
List<ResolvableExpression> forTypeCheck = (new ArrayList<>(thenExpressions));
forTypeCheck.add(elseExpression);
// TODO: handle better the default element position
if (!hasSameTypeOrNull(forTypeCheck)) {
try {
throw new InvalidTypeException(
forTypeCheck.get(0).getClass(),
Boolean.class,
fromContext(ctx.expr(0))
);
} catch (InvalidTypeException e) {
throw new RuntimeException(e);
}
}
}

Class<?> outputType = elseExpression.getType();

if (outputType.equals(String.class)) {
return ResolvableExpression.withType(String.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (String) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (String) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
}
if (outputType.equals(Double.class)) {
return ResolvableExpression.withType(Double.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Double) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Double) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
Class<?> outputType = elseExpression.getType();
return new CastExpression(pos, caseToIfIt(whenExpressions.listIterator(), thenExpressions.listIterator(), elseExpression), outputType);
} catch (VtlScriptException e) {
throw new VtlRuntimeException(e);
}
if (outputType.equals(Long.class)) {
return ResolvableExpression.withType(Long.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Long) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Long) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
}

private ResolvableExpression caseToIfIt(ListIterator<ResolvableExpression> whenExpr, ListIterator<ResolvableExpression> thenExpr, ResolvableExpression elseExpression) throws VtlScriptException {
if (!whenExpr.hasNext() || !thenExpr.hasNext()) {
return elseExpression;
}
if (outputType.equals(Boolean.class)) {
return ResolvableExpression.withType(Boolean.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Boolean) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Boolean) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
} else return null;

ResolvableExpression nextWhen = whenExpr.next();

return genericFunctionsVisitor.invokeFunction("ifThenElse", Java8Helpers.listOf(
nextWhen,
thenExpr.next(),
caseToIfIt(whenExpr, thenExpr, elseExpression)
), nextWhen);
}


/**
* Visits nvl expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import fr.insee.vtl.engine.expressions.CastExpression;
import fr.insee.vtl.engine.expressions.ComponentExpression;
import fr.insee.vtl.engine.expressions.FunctionExpression;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.engine.visitors.expression.ExpressionVisitor;
import fr.insee.vtl.model.*;
import fr.insee.vtl.model.exceptions.VtlScriptException;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.parser.VtlBaseVisitor;
import fr.insee.vtl.parser.VtlParser;
import org.antlr.v4.runtime.Token;
Expand Down Expand Up @@ -169,7 +169,8 @@ private DatasetExpression invokeFunctionOnDataset(String funcName, List<Resolvab
.collect(Collectors.toMap(e -> "arg" + e.hashCode(), e -> e));
if (measureNames.size() != 1) {
throw new VtlRuntimeException(
new InvalidArgumentException("mono-measure datasets don't contain same measures (number or names)", position)
new InvalidArgumentException("Variables in the mono-measure datasets are not named the same: " +
measureNames + " found", position)
);
}
DatasetExpression ds = proc.executeInnerJoin(dsExprs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import fr.insee.vtl.engine.exceptions.FunctionNotFoundException;
import fr.insee.vtl.engine.samples.DatasetSamples;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.utils.Java8Helpers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -92,6 +92,15 @@ public void testCaseExpr() throws ScriptException {
Java8Helpers.mapOf("id", "Franck", "c", 1L)
);
assertThat(((Dataset) res1).getDataStructure().get("c").getType()).isEqualTo(Long.class);
engine.eval("ds1 := ds_1[keep id, long1][rename long1 to bool_var];" +
"ds2 := ds_2[keep id, long1][rename long1 to bool_var]; " +
"res_ds <- case when ds1 < 30 then ds1 else ds2;");
Object res_ds = engine.getContext().getAttribute("res_ds");
assertThat(((Dataset) res_ds).getDataAsMap()).containsExactlyInAnyOrder(
Java8Helpers.mapOf("id", "Hadrien", "bool_var", 10L),
Java8Helpers.mapOf("id", "Nico", "bool_var", 20L),
Java8Helpers.mapOf("id", "Franck", "bool_var", 100L)
);
}

@Test
Expand Down