Skip to content

Commit

Permalink
IGNITE-23308 SQL Calcite: Fix wrong numeric type coercion with set-op…
Browse files Browse the repository at this point in the history
… operations - Fixes #11557.

Signed-off-by: Aleksey Plekhanov <plehanov.alex@gmail.com>
  • Loading branch information
Vladsz83 authored and alex-plekhanov committed Oct 4, 2024
1 parent 6648f8d commit 5b6a433
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.ignite.internal.processors.query.calcite.rel.set.IgniteReduceIntersect;
import org.apache.ignite.internal.processors.query.calcite.rel.set.IgniteReduceMinus;
import org.apache.ignite.internal.processors.query.calcite.trait.IgniteDistributions;
import org.apache.ignite.internal.processors.query.calcite.util.Commons;

/**
* Set op (MINUS, INTERSECT) converter rule.
Expand Down Expand Up @@ -77,6 +78,8 @@ private abstract static class ColocatedSetOpConverterRule<T extends SetOp> exten
RelTraitSet outTrait = cluster.traitSetOf(IgniteConvention.INSTANCE).replace(IgniteDistributions.single());
List<RelNode> inputs = Util.transform(setOp.getInputs(), rel -> convert(rel, inTrait));

inputs = Commons.castToLeastRestrictiveIfRequired(inputs, cluster, inTrait);

return createNode(cluster, outTrait, inputs, setOp.all);
}
}
Expand Down Expand Up @@ -131,6 +134,8 @@ abstract PhysicalNode createReduceNode(RelOptCluster cluster, RelTraitSet traits
RelTraitSet outTrait = cluster.traitSetOf(IgniteConvention.INSTANCE);
List<RelNode> inputs = Util.transform(setOp.getInputs(), rel -> convert(rel, inTrait));

inputs = Commons.castToLeastRestrictiveIfRequired(inputs, cluster, inTrait);

RelNode map = createMapNode(cluster, outTrait, inputs, setOp.all);

return createReduceNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ public UnionConverterRule(Config cfg) {

/** {@inheritDoc} */
@Override public void onMatch(RelOptRuleCall call) {
final LogicalUnion union = call.rel(0);
LogicalUnion union = call.rel(0);

RelOptCluster cluster = union.getCluster();
RelTraitSet traits = cluster.traitSetOf(IgniteConvention.INSTANCE);
List<RelNode> inputs = Commons.transform(union.getInputs(), input -> convert(input, traits));

inputs = Commons.castToLeastRestrictiveIfRequired(inputs, cluster, traits);

RelNode res = new IgniteUnionAll(cluster, traits, inputs);

if (!union.all) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ else if (type instanceof BasicSqlType || type instanceof IntervalSqlType) {
return Enum.class;
case ANY:
case OTHER:
return Object.class;
case NULL:
return Void.class;
return Object.class;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@
import org.apache.calcite.plan.Context;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.SourceStringReader;
Expand All @@ -61,6 +67,7 @@
import org.apache.ignite.internal.processors.query.calcite.exec.exp.ExpressionFactoryImpl;
import org.apache.ignite.internal.processors.query.calcite.prepare.BaseQueryContext;
import org.apache.ignite.internal.processors.query.calcite.prepare.MappingQueryContext;
import org.apache.ignite.internal.processors.query.calcite.rel.IgniteProject;
import org.apache.ignite.internal.processors.query.calcite.type.IgniteTypeFactory;
import org.apache.ignite.internal.util.typedef.F;
import org.apache.ignite.internal.util.typedef.internal.A;
Expand Down Expand Up @@ -140,6 +147,61 @@ public static <T> List<T> intersect(Set<T> set, List<T> list) {
.collect(Collectors.toList());
}

/**
* Finds the least restrictive type of the inputs and adds a cast projection if required.
*
* @param inputs Inputs to try to cast.
* @param cluster Cluster.
* @param traits Traits.
* @return Converted inputs.
*/
public static List<RelNode> castToLeastRestrictiveIfRequired(List<RelNode> inputs, RelOptCluster cluster, RelTraitSet traits) {
List<RelDataType> inputRowTypes = inputs.stream().map(RelNode::getRowType).collect(Collectors.toList());

// Output type of a set operator is equal to leastRestrictive(inputTypes) (see SetOp::deriveRowType).
RelDataTypeFactory typeFactory = cluster.getTypeFactory();

RelDataType leastRestrictive = typeFactory.leastRestrictive(inputRowTypes);

if (leastRestrictive == null)
throw new IllegalStateException("Cannot find least restrictive type for arguments to set op: " + inputRowTypes);

// If input's type does not match the result type, then add a cast projection for non-matching fields.
RexBuilder rexBuilder = cluster.getRexBuilder();
List<RelNode> newInputs = new ArrayList<>(inputs.size());

for (RelNode input : inputs) {
RelDataType inputRowType = input.getRowType();

// It is always safe to convert from [T1 nullable, T2 not nullable] to [T1 nullable, T2 nullable] and
// the least restrictive type does exactly that.
if (SqlTypeUtil.equalAsStructSansNullability(typeFactory, leastRestrictive, inputRowType, null)) {
newInputs.add(input);

continue;
}

List<RexNode> expressions = new ArrayList<>(inputRowType.getFieldCount());

for (int i = 0; i < leastRestrictive.getFieldCount(); i++) {
RelDataType fieldType = inputRowType.getFieldList().get(i).getType();

RelDataType outFieldType = leastRestrictive.getFieldList().get(i).getType();

RexNode ref = rexBuilder.makeInputRef(input, i);

if (fieldType.equals(outFieldType))
expressions.add(ref);
else
expressions.add(rexBuilder.makeCast(outFieldType, ref, true, false));
}

newInputs.add(new IgniteProject(cluster, traits, input, expressions, leastRestrictive));
}

return newInputs;
}

/**
* Returns a given list as a typed list.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.ignite.internal.processors.query.calcite.integration;

import java.util.Arrays;
import java.util.List;

import javax.cache.Cache;
Expand Down Expand Up @@ -466,4 +467,72 @@ public void testSetOpRewindability() {
.returns(2)
.check();
}

/** */
@Test
public void testNumbersCastInUnion() throws Exception {
doTestNumbersCastInSetOp("UNION", 10, 20, 30, 33, 40, 44, 50, null);

doTestNumbersCastInSetOp("UNION ALL", 10, 20, 20, 30, 30, 33, 40, 44, 50, 50, 50, 50, null, null);
}

/** */
@Test
public void testNumbersCastInIntersect() throws Exception {
doTestNumbersCastInSetOp("INTERSECT", 20, 50, null);

doTestNumbersCastInSetOp("INTERSECT ALL", 20, 50, 50, null);
}

/** */
@Test
public void testNumbersCastInExcept() throws Exception {
doTestNumbersCastInSetOp("EXCEPT", 30, 40);

doTestNumbersCastInSetOp("EXCEPT ALL", 30, 30, 40);
}

/**
* Tests 'SELECT TBL1.val SetOp TBL2.val' where TBL1 has `INT val` and TBL2 has 'val' of different numeric type.
* TBL1: 30, 20, 30, 40, 50, 50, null
* TBL2: 10, 20, 33, 44, 50, 50, null
*
* @param op Operation like 'UNION' or 'INTERSECT'
* @param expected Expected result as integers.
*/
private void doTestNumbersCastInSetOp(String op, Integer... expected) throws InterruptedException {
List<String> types = F.asList("TINYINT", "SMALLINT", "INTEGER", "REAL", "FLOAT", "BIGINT", "DOUBLE", "DECIMAL");

sql(client, "CREATE TABLE t0(id INT PRIMARY KEY, val INTEGER) WITH \"affinity_key=id\"");

try {
sql(client, "INSERT INTO t0 VALUES (1, 30), (2, 20), (3, 30), (4, 40), (5, 50), (6, 50), (7, null)");

for (String tblOpts : Arrays.asList("", " WITH \"template=replicated\"", " WITH \"affinity_key=aff\"")) {
for (String t2 : types) {
sql(client, "CREATE TABLE t1(id INT, aff INT, val " + t2 + ", PRIMARY KEY(id, aff))" + tblOpts);

sql(client, "INSERT INTO t1 VALUES (1, 1, 10), (2, 1, 20), (3, 1, 33), (4, 2, 44), (5, 2, 50), " +
"(6, 3, 50), (7, 3, null)");

List<List<?>> res = sql(client, "SELECT val from t0 " + op + " select val from t1 ORDER BY 1 NULLS LAST");

sql(client, "DROP TABLE t1");

assertEquals(expected.length, res.size());

for (int i = 0; i < expected.length; ++i) {
assertEquals(1, res.get(i).size());

assertEquals(expected[i], res.get(i).get(0) == null ? null : ((Number)res.get(i).get(0)).intValue());
}
}
}
}
finally {
sql(client, "DROP TABLE t0");

awaitPartitionMapExchange();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.ignite.internal.processors.query.calcite.planner;

import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.ignite.internal.processors.query.calcite.rel.IgniteProject;
import org.apache.ignite.internal.processors.query.calcite.schema.IgniteSchema;
import org.apache.ignite.internal.processors.query.calcite.trait.IgniteDistribution;
import org.apache.ignite.internal.processors.query.calcite.trait.IgniteDistributions;
import org.apache.ignite.internal.processors.query.calcite.type.IgniteTypeFactory;
import org.apache.ignite.internal.processors.query.calcite.util.Commons;
import org.junit.Test;

/**
* Planner test various types, casts and coercions.
*/
public class DataTypesPlannerTest extends AbstractPlannerTest {
/** Tests casts of numeric types in SetOps (UNION, EXCEPT, INTERSECT, etc.). */
@Test
public void testSetOpNumbersCast() throws Exception {
List<IgniteDistribution> distrs = Arrays.asList(IgniteDistributions.single(), IgniteDistributions.random(),
IgniteDistributions.affinity(0, 1001, 0));

for (IgniteDistribution d1 : distrs) {
for (IgniteDistribution d2 : distrs) {
doTestSetOpNumbersCast(d1, d2, true, true);

doTestSetOpNumbersCast(d1, d2, false, true);

doTestSetOpNumbersCast(d1, d2, false, false);
}
}
}

/** */
private void doTestSetOpNumbersCast(
IgniteDistribution distr1,
IgniteDistribution distr2,
boolean nullable1,
boolean nullable2
) throws Exception {
IgniteSchema schema = new IgniteSchema(DEFAULT_SCHEMA);

IgniteTypeFactory f = Commons.typeFactory();

SqlTypeName[] numTypes = new SqlTypeName[] {SqlTypeName.TINYINT, SqlTypeName.SMALLINT, SqlTypeName.REAL,
SqlTypeName.FLOAT, SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.DOUBLE, SqlTypeName.DECIMAL};

boolean notNull = !nullable1 && !nullable2;

for (SqlTypeName t1 : numTypes) {
for (SqlTypeName t2 : numTypes) {
RelDataType type = new RelDataTypeFactory.Builder(f)
.add("C1", f.createTypeWithNullability(f.createSqlType(t1), nullable1))
.add("C2", f.createTypeWithNullability(f.createSqlType(SqlTypeName.VARCHAR), true))
.build();

createTable(schema, "TABLE1", type, distr1, null);

type = new RelDataTypeFactory.Builder(f)
.add("C1", f.createTypeWithNullability(f.createSqlType(t2), nullable2))
.add("C2", f.createTypeWithNullability(f.createSqlType(SqlTypeName.VARCHAR), true))
.build();

createTable(schema, "TABLE2", type, distr2, null);

for (String op : Arrays.asList("UNION", "INTERSECT", "EXCEPT")) {
String sql = "SELECT * FROM table1 " + op + " SELECT * FROM table2";

if (t1 == t2 && (!nullable1 || !nullable2))
assertPlan(sql, schema, nodeOrAnyChild(isInstanceOf(IgniteProject.class)).negate());
else {
RelDataType targetT = f.leastRestrictive(Arrays.asList(f.createSqlType(t1), f.createSqlType(t2)));

assertPlan(sql, schema, nodeOrAnyChild(isInstanceOf(SetOp.class)
.and(t1 == targetT.getSqlTypeName() ? input(0, nodeOrAnyChild(isInstanceOf(IgniteProject.class)).negate())
: input(0, projectFromTable("TABLE1", "CAST($0):" + targetT + (notNull ? " NOT NULL" : ""), "$1")))
.and(t2 == targetT.getSqlTypeName() ? input(1, nodeOrAnyChild(isInstanceOf(IgniteProject.class)).negate())
: input(1, projectFromTable("TABLE2", "CAST($0):" + targetT + (notNull ? " NOT NULL" : ""), "$1")))
));
}
}
}
}
}

/** */
protected Predicate<? extends RelNode> projectFromTable(String tableName, String... exprs) {
return nodeOrAnyChild(
isInstanceOf(IgniteProject.class)
.and(projection -> {
String actualProj = projection.getProjects().toString();

String expectedProj = Arrays.asList(exprs).toString();

return actualProj.equals(expectedProj);
})
.and(input(nodeOrAnyChild(isTableScan(tableName))))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.ignite.internal.processors.query.calcite.planner.AggregatePlannerTest;
import org.apache.ignite.internal.processors.query.calcite.planner.CorrelatedNestedLoopJoinPlannerTest;
import org.apache.ignite.internal.processors.query.calcite.planner.CorrelatedSubqueryPlannerTest;
import org.apache.ignite.internal.processors.query.calcite.planner.DataTypesPlannerTest;
import org.apache.ignite.internal.processors.query.calcite.planner.HashAggregatePlannerTest;
import org.apache.ignite.internal.processors.query.calcite.planner.HashIndexSpoolPlannerTest;
import org.apache.ignite.internal.processors.query.calcite.planner.IndexRebuildPlannerTest;
Expand Down Expand Up @@ -67,6 +68,7 @@
TableFunctionPlannerTest.class,
TableDmlPlannerTest.class,
UnionPlannerTest.class,
DataTypesPlannerTest.class,
JoinCommutePlannerTest.class,
LimitOffsetPlannerTest.class,
MergeJoinPlannerTest.class,
Expand Down

0 comments on commit 5b6a433

Please sign in to comment.