diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index f66aeaf89a2ee..2b20baea31d30 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -234,6 +234,7 @@ public final class SystemSessionProperties public static final String QUERY_RETRY_MAX_EXECUTION_TIME = "query_retry_max_execution_time"; public static final String PARTIAL_RESULTS_ENABLED = "partial_results_enabled"; public static final String PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD = "partial_results_completion_ratio_threshold"; + public static final String ENHANCED_CTE_BLOCKING = "enhanced_cte_blocking"; public static final String PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER = "partial_results_max_execution_time_multiplier"; public static final String OFFSET_CLAUSE_ENABLED = "offset_clause_enabled"; public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled"; @@ -1278,6 +1279,11 @@ public SystemSessionProperties( "Minimum query completion ratio threshold for partial results", featuresConfig.getPartialResultsCompletionRatioThreshold(), false), + booleanProperty( + ENHANCED_CTE_BLOCKING, + "Applicable for CTE Materialization. If enabled, only tablescans of the pending tablewriters are blocked and other stages can continue.", + featuresConfig.getEnhancedCTEBlocking(), + true), booleanProperty( OFFSET_CLAUSE_ENABLED, "Enable support for OFFSET clause", @@ -2663,6 +2669,11 @@ public static double getPartialResultsCompletionRatioThreshold(Session session) return session.getSystemProperty(PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD, Double.class); } + public static boolean isEnhancedCTEBlockingEnabled(Session session) + { + return isCteMaterializationApplicable(session) & session.getSystemProperty(ENHANCED_CTE_BLOCKING, Boolean.class); + } + public static double getPartialResultsMaxExecutionTimeMultiplier(Session session) { return session.getSystemProperty(PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER, Double.class); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java index f3d03ead61cb5..da6d0063f8809 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java @@ -27,9 +27,14 @@ import com.facebook.presto.server.remotetask.HttpRemoteTask; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.plan.PlanFragmentId; +import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.TemporaryTableInfo; import com.facebook.presto.split.RemoteSplit; import com.facebook.presto.sql.planner.PlanFragment; +import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableMap; @@ -60,8 +65,10 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage; +import static com.facebook.presto.SystemSessionProperties.isEnhancedCTEBlockingEnabled; import static com.facebook.presto.failureDetector.FailureDetector.State.GONE; import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -557,7 +564,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M // stage finished while we were scheduling this task task.abort(); } - return task; } @@ -594,6 +600,59 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId)); } + private String getCteIdFromSource(PlanNode source) + { + // Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo + return PlanNodeSearcher.searchFrom(source) + .where(planNode -> planNode instanceof TableFinishNode) + .findFirst() + .map(planNode -> ((TableFinishNode) planNode).getTemporaryTableInfo().orElseThrow( + () -> new IllegalStateException("TableFinishNode has no TemporaryTableInfo"))) + .map(TemporaryTableInfo::getCteId) + .orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID")); + } + + public boolean isCTETableFinishStage() + { + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> (planNode instanceof TableFinishNode)) + .findAll().stream() + .anyMatch(planNode -> ((TableFinishNode) planNode).getTemporaryTableInfo().isPresent()); + } + + public String getCTEWriterId() + { + // Validate that this is a CTE TableFinish stage and return the associated CTE ID + if (!isCTETableFinishStage()) { + throw new IllegalStateException("This stage is not a CTE writer stage"); + } + return getCteIdFromSource(planFragment.getRoot()); + } + + public boolean requiresMaterializedCTE() + { + if (!isEnhancedCTEBlockingEnabled(session)) { + return false; + } + // Search for TableScanNodes and check if they reference TemporaryTableInfo + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableScanNode) + .findAll().stream() + .anyMatch(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo().isPresent()); + } + + public List getRequiredCTEList() + { + // Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableScanNode) + .findAll().stream() + .map(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo() + .orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo"))) + .map(TemporaryTableInfo::getCteId) + .collect(Collectors.toList()); + } + private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus) { StageExecutionState stageExecutionState = getState(); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java new file mode 100644 index 0000000000000..3dc1263962328 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class CTEMaterializationTracker +{ + private final Map> materializationFutures = new ConcurrentHashMap<>(); + + private final Map materializedCtes = new ConcurrentHashMap<>(); + + public ListenableFuture getFutureForCTE(String cteName) + { + if (materializationFutures.containsKey(cteName)) { + if (!materializationFutures.get(cteName).isCancelled()) { + return materializationFutures.get(cteName); + } + } + materializationFutures.put(cteName, SettableFuture.create()); + return materializationFutures.get(cteName); + } + + public void markCTEAsMaterialized(String cteName) + { + materializedCtes.put(cteName, true); + SettableFuture future = materializationFutures.get(cteName); + if (future != null && !future.isCancelled()) { + future.set(null); // Notify all listeners + } + } + + public void markAllCTEsMaterialized() + { + materializationFutures.forEach((k, v) -> { + markCTEAsMaterialized(k); + }); + } + + public boolean hasBeenMaterialized(String cteName) + { + return materializedCtes.containsKey(cteName); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java index 8e965aab792c3..911e40fc49154 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler private final Queue tasksToRecover = new ConcurrentLinkedQueue<>(); + private final CTEMaterializationTracker cteMaterializationTracker; + @GuardedBy("this") private boolean closed; @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler( int splitBatchSize, OptionalInt concurrentLifespansPerTask, NodeSelector nodeSelector, - List partitionHandles) + List partitionHandles, + CTEMaterializationTracker cteMaterializationTracker) { requireNonNull(stage, "stage is null"); requireNonNull(splitSources, "splitSources is null"); requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty"); requireNonNull(partitionHandles, "partitionHandles is null"); + this.cteMaterializationTracker = cteMaterializationTracker; this.stage = stage; this.nodes = ImmutableList.copyOf(nodes); @@ -179,6 +183,29 @@ public ScheduleResult schedule() { // schedule a task on every node in the distribution List newTasks = ImmutableList.of(); + List> blocked = new ArrayList<>(); + + // CTE Materialization Check + if (stage.requiresMaterializedCTE()) { + List requiredCTEIds = stage.getRequiredCTEList(); // Ensure this method exists and returns a list of required CTE IDs as strings + for (String cteId : requiredCTEIds) { + if (!cteMaterializationTracker.hasBeenMaterialized(cteId)) { + // Add CTE materialization future to the blocked list + ListenableFuture materializationFuture = cteMaterializationTracker.getFutureForCTE(cteId); + blocked.add(materializationFuture); + } + } + // If any CTE is not materialized, return a blocked ScheduleResult + if (!blocked.isEmpty()) { + return ScheduleResult.blocked( + false, // true if all required CTEs are blocked + newTasks, + whenAnyComplete(blocked), // Wait for any CTE materialization to complete + BlockedReason.WAITING_FOR_CTE_MATERIALIZATION, + 0); + } + } + // schedule a task on every node in the distribution if (!scheduledTasks) { newTasks = Streams.mapWithIndex( nodes.stream(), @@ -193,7 +220,6 @@ public ScheduleResult schedule() } boolean allBlocked = true; - List> blocked = new ArrayList<>(); BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP; if (groupedLifespanScheduler.isPresent()) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java index ed85bfff8fd94..2af53f619880e 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java @@ -57,6 +57,8 @@ public enum BlockedReason * grouped execution where there are multiple lifespans per task). */ MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE, + + WAITING_FOR_CTE_MATERIALIZATION, /**/; public BlockedReason combineWith(BlockedReason other) @@ -64,6 +66,7 @@ public BlockedReason combineWith(BlockedReason other) switch (this) { case WRITER_SCALING: throw new IllegalArgumentException("cannot be combined"); + case WAITING_FOR_CTE_MATERIALIZATION: case NO_ACTIVE_DRIVER_GROUP: return other; case SPLIT_QUEUES_FULL: diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java index 3f885239d5d6d..96ad3537db651 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java @@ -168,7 +168,8 @@ public SectionExecution createSectionExecutions( boolean summarizeTaskInfo, RemoteTaskFactory remoteTaskFactory, SplitSourceFactory splitSourceFactory, - int attemptId) + int attemptId, + CTEMaterializationTracker cteMaterializationTracker) { // Only fetch a distribution once per section to ensure all stages see the same machine assignments Map partitioningCache = new HashMap<>(); @@ -184,7 +185,8 @@ public SectionExecution createSectionExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - attemptId); + attemptId, + cteMaterializationTracker); StageExecutionAndScheduler rootStage = getLast(sectionStages); rootStage.getStageExecution().setOutputBuffers(outputBuffers); return new SectionExecution(rootStage, sectionStages); @@ -203,7 +205,8 @@ private List createStreamingLinkedStageExecutions( boolean summarizeTaskInfo, RemoteTaskFactory remoteTaskFactory, SplitSourceFactory splitSourceFactory, - int attemptId) + int attemptId, + CTEMaterializationTracker cteMaterializationTracker) { ImmutableList.Builder stageExecutionAndSchedulers = ImmutableList.builder(); @@ -238,7 +241,8 @@ private List createStreamingLinkedStageExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - attemptId); + attemptId, + cteMaterializationTracker); stageExecutionAndSchedulers.addAll(subTree); childStagesBuilder.add(getLast(subTree).getStageExecution()); } @@ -260,7 +264,8 @@ private List createStreamingLinkedStageExecutions( stageExecution, partitioningHandle, tableWriteInfo, - childStageExecutions); + childStageExecutions, + cteMaterializationTracker); stageExecutionAndSchedulers.add(new StageExecutionAndScheduler( stageExecution, stageLinkage, @@ -279,7 +284,8 @@ private StageScheduler createStageScheduler( SqlStageExecution stageExecution, PartitioningHandle partitioningHandle, TableWriteInfo tableWriteInfo, - Set childStageExecutions) + Set childStageExecutions, + CTEMaterializationTracker cteMaterializationTracker) { Map splitSources = splitSourceFactory.createSplitSources(plan.getFragment(), session, tableWriteInfo); int maxTasksPerStage = getMaxTasksPerStage(session); @@ -383,7 +389,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { splitBatchSize, getConcurrentLifespansPerNode(session), nodeScheduler.createNodeSelector(session, connectorId, nodePredicate), - connectorPartitionHandles); + connectorPartitionHandles, + cteMaterializationTracker); if (plan.getFragment().getStageExecutionDescriptor().isRecoverableGroupedExecution()) { stageExecution.registerStageTaskRecoveryCallback(taskId -> { checkArgument(taskId.getStageExecutionId().getStageId().equals(stageId), "The task did not execute this stage"); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java index 198f23f1ba22e..e34a2ccc5953a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java @@ -76,9 +76,9 @@ import static com.facebook.airlift.concurrent.MoreFutures.tryGetFutureValue; import static com.facebook.airlift.concurrent.MoreFutures.whenAnyComplete; import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static com.facebook.presto.SystemSessionProperties.getMaxConcurrentMaterializations; import static com.facebook.presto.SystemSessionProperties.getPartialResultsCompletionRatioThreshold; import static com.facebook.presto.SystemSessionProperties.getPartialResultsMaxExecutionTimeMultiplier; +import static com.facebook.presto.SystemSessionProperties.isEnhancedCTEBlockingEnabled; import static com.facebook.presto.SystemSessionProperties.isPartialResultsEnabled; import static com.facebook.presto.SystemSessionProperties.isRuntimeOptimizerEnabled; import static com.facebook.presto.execution.BasicStageExecutionStats.aggregateBasicStageStats; @@ -149,6 +149,7 @@ public class SqlQueryScheduler private final AtomicBoolean scheduling = new AtomicBoolean(); private final PartialResultQueryTaskTracker partialResultQueryTaskTracker; + private final CTEMaterializationTracker cteMaterializationTracker = new CTEMaterializationTracker(); public static SqlQueryScheduler createSqlQueryScheduler( LocationFactory locationFactory, @@ -258,7 +259,8 @@ private SqlQueryScheduler( stageExecutions.stream() .forEach(execution -> this.stageExecutions.put(execution.getStageExecution().getStageExecutionId().getStageId(), execution)); - this.maxConcurrentMaterializations = getMaxConcurrentMaterializations(session); +// this.maxConcurrentMaterializations = getMaxConcurrentMaterializations(session); + this.maxConcurrentMaterializations = 100; this.partialResultQueryTaskTracker = new PartialResultQueryTaskTracker(partialResultQueryManager, getPartialResultsCompletionRatioThreshold(session), getPartialResultsMaxExecutionTimeMultiplier(session), warningCollector); } @@ -278,6 +280,17 @@ else if (state == CANCELED) { for (StageExecutionAndScheduler stageExecutionInfo : stageExecutions.values()) { SqlStageExecution stageExecution = stageExecutionInfo.getStageExecution(); + // Add a listener for state changes + if (stageExecution.isCTETableFinishStage()) { + stageExecution.addStateChangeListener(state -> { + if (state == StageExecutionState.FINISHED) { + String cteName = stageExecution.getCTEWriterId(); + log.info("CTE write completed for: " + cteName); + // Notify the materialization tracker + cteMaterializationTracker.markCTEAsMaterialized(cteName); + } + }); + } stageExecution.addStateChangeListener(state -> { if (queryStateMachine.isDone()) { return; @@ -363,7 +376,8 @@ private List createStageExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - 0).getSectionStages(); + 0, + cteMaterializationTracker).getSectionStages(); stages.addAll(sectionStages); return stages.build(); @@ -460,6 +474,7 @@ else if (!result.getBlocked().isDone()) { ScheduleResult.BlockedReason blockedReason = result.getBlockedReason().get(); switch (blockedReason) { case WRITER_SCALING: + case WAITING_FOR_CTE_MATERIALIZATION: // no-op break; case WAITING_FOR_SOURCE: @@ -678,7 +693,8 @@ private void updateStageExecutions(StreamingPlanSection section, Map updatedStageExecutions = sectionExecution.getSectionStages().stream() .collect(toImmutableMap(execution -> execution.getStageExecution().getStageExecutionId().getStageId(), identity())); @@ -774,10 +790,13 @@ private boolean isReadyForExecution(StreamingPlanSection section) // already scheduled return false; } - for (StreamingPlanSection child : section.getChildren()) { - SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId()); - if (rootStageExecution.getState() != FINISHED) { - return false; + if (!isEnhancedCTEBlockingEnabled(session)) { + // Block if child sections are not complete + for (StreamingPlanSection child : section.getChildren()) { + SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId()); + if (rootStageExecution.getState() != FINISHED) { + return false; + } } } return true; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 5c2040e6ade8a..93bc418ca1d13 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -138,6 +138,7 @@ public class FeaturesConfig private boolean ignoreStatsCalculatorFailures = true; private boolean printStatsForNonJoinQuery; private boolean defaultFilterFactorEnabled; + private boolean enhancedCTEBlocking = true; // Give a default 10% selectivity coefficient factor to avoid hitting unknown stats in join stats estimates // which could result in syntactic join order. Set it to 0 to disable this feature private double defaultJoinSelectivityCoefficient; @@ -1286,6 +1287,18 @@ public boolean isDefaultFilterFactorEnabled() return defaultFilterFactorEnabled; } + @Config("enhanced-cte-blocking") + public FeaturesConfig setEnhancedCTEBlocking(boolean enhancedCTEBlocking) + { + this.enhancedCTEBlocking = enhancedCTEBlocking; + return this; + } + + public boolean getEnhancedCTEBlocking() + { + return enhancedCTEBlocking; + } + @Config("optimizer.default-join-selectivity-coefficient") @ConfigDescription("Used when join selectivity estimation is unknown. Default 0 to disable the use of join selectivity, this will allow planner to fall back to FROM-clause join order when the join cardinality is unknown") public FeaturesConfig setDefaultJoinSelectivityCoefficient(double defaultJoinSelectivityCoefficient) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 4ec67782a1ab7..63839413230d4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -246,6 +246,7 @@ public void testDefaults() .setInlineProjectionsOnValues(false) .setEagerPlanValidationEnabled(false) .setEagerPlanValidationThreadPoolSize(20) + .setEnhancedCTEBlocking(true) .setPrestoSparkExecutionEnvironment(false)); } @@ -442,6 +443,7 @@ public void testExplicitPropertyMappings() .put("eager-plan-validation-enabled", "true") .put("eager-plan-validation-thread-pool-size", "2") .put("presto-spark-execution-environment", "true") + .put("enhanced_cte_blocking", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -634,6 +636,7 @@ public void testExplicitPropertyMappings() .setInlineProjectionsOnValues(true) .setEagerPlanValidationEnabled(true) .setEagerPlanValidationThreadPoolSize(2) + .setEnhancedCTEBlocking(false) .setPrestoSparkExecutionEnvironment(true); assertFullMapping(properties, expected); }