Skip to content

Commit

Permalink
Add enhanced cte scheduling mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jaystarshot committed Dec 3, 2024
1 parent 170b377 commit a786c7a
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ public final class SystemSessionProperties
public static final String QUERY_RETRY_MAX_EXECUTION_TIME = "query_retry_max_execution_time";
public static final String PARTIAL_RESULTS_ENABLED = "partial_results_enabled";
public static final String PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD = "partial_results_completion_ratio_threshold";
public static final String ENHANCED_CTE_BLOCKING = "enhanced_cte_blocking";
public static final String PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER = "partial_results_max_execution_time_multiplier";
public static final String OFFSET_CLAUSE_ENABLED = "offset_clause_enabled";
public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled";
Expand Down Expand Up @@ -1278,6 +1279,11 @@ public SystemSessionProperties(
"Minimum query completion ratio threshold for partial results",
featuresConfig.getPartialResultsCompletionRatioThreshold(),
false),
booleanProperty(
ENHANCED_CTE_BLOCKING,
"Applicable for CTE Materialization. If enabled, only tablescans of the pending tablewriters are blocked and other stages can continue.",
featuresConfig.getEnhancedCTEBlocking(),
true),
booleanProperty(
OFFSET_CLAUSE_ENABLED,
"Enable support for OFFSET clause",
Expand Down Expand Up @@ -2663,6 +2669,11 @@ public static double getPartialResultsCompletionRatioThreshold(Session session)
return session.getSystemProperty(PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD, Double.class);
}

public static boolean isEnhancedCTEBlockingEnabled(Session session)
{
return isCteMaterializationApplicable(session) & session.getSystemProperty(ENHANCED_CTE_BLOCKING, Boolean.class);
}

public static double getPartialResultsMaxExecutionTimeMultiplier(Session session)
{
return session.getSystemProperty(PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER, Double.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@
import com.facebook.presto.server.remotetask.HttpRemoteTask;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TemporaryTableInfo;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -60,8 +65,10 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage;
import static com.facebook.presto.SystemSessionProperties.isEnhancedCTEBlockingEnabled;
import static com.facebook.presto.failureDetector.FailureDetector.State.GONE;
import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
Expand Down Expand Up @@ -557,7 +564,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M
// stage finished while we were scheduling this task
task.abort();
}

return task;
}

Expand Down Expand Up @@ -594,6 +600,59 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc
return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId));
}

private String getCteIdFromSource(PlanNode source)
{
// Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(source)
.where(planNode -> planNode instanceof TableFinishNode)
.findFirst()
.map(planNode -> ((TableFinishNode) planNode).getTemporaryTableInfo().orElseThrow(
() -> new IllegalStateException("TableFinishNode has no TemporaryTableInfo")))
.map(TemporaryTableInfo::getCteId)
.orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID"));
}

public boolean isCTETableFinishStage()
{
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> (planNode instanceof TableFinishNode))
.findAll().stream()
.anyMatch(planNode -> ((TableFinishNode) planNode).getTemporaryTableInfo().isPresent());
}

public String getCTEWriterId()
{
// Validate that this is a CTE TableFinish stage and return the associated CTE ID
if (!isCTETableFinishStage()) {
throw new IllegalStateException("This stage is not a CTE writer stage");
}
return getCteIdFromSource(planFragment.getRoot());
}

public boolean requiresMaterializedCTE()
{
if (!isEnhancedCTEBlockingEnabled(session)) {
return false;
}
// Search for TableScanNodes and check if they reference TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.anyMatch(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo().isPresent());
}

public List<String> getRequiredCTEList()
{
// Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.map(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo()
.orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo")))
.map(TemporaryTableInfo::getCteId)
.collect(Collectors.toList());
}

private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus)
{
StageExecutionState stageExecutionState = getState();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class CTEMaterializationTracker
{
private final Map<String, SettableFuture<Void>> materializationFutures = new ConcurrentHashMap<>();

private final Map<String, Boolean> materializedCtes = new ConcurrentHashMap<>();

public ListenableFuture<Void> getFutureForCTE(String cteName)
{
if (materializationFutures.containsKey(cteName)) {
if (!materializationFutures.get(cteName).isCancelled()) {
return materializationFutures.get(cteName);
}
}
materializationFutures.put(cteName, SettableFuture.create());
return materializationFutures.get(cteName);
}

public void markCTEAsMaterialized(String cteName)
{
materializedCtes.put(cteName, true);
SettableFuture<Void> future = materializationFutures.get(cteName);
if (future != null && !future.isCancelled()) {
future.set(null); // Notify all listeners
}
}

public void markAllCTEsMaterialized()
{
materializationFutures.forEach((k, v) -> {
markCTEAsMaterialized(k);
});
}

public boolean hasBeenMaterialized(String cteName)
{
return materializedCtes.containsKey(cteName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler

private final Queue<Integer> tasksToRecover = new ConcurrentLinkedQueue<>();

private final CTEMaterializationTracker cteMaterializationTracker;

@GuardedBy("this")
private boolean closed;

Expand All @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler(
int splitBatchSize,
OptionalInt concurrentLifespansPerTask,
NodeSelector nodeSelector,
List<ConnectorPartitionHandle> partitionHandles)
List<ConnectorPartitionHandle> partitionHandles,
CTEMaterializationTracker cteMaterializationTracker)
{
requireNonNull(stage, "stage is null");
requireNonNull(splitSources, "splitSources is null");
requireNonNull(bucketNodeMap, "bucketNodeMap is null");
checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty");
requireNonNull(partitionHandles, "partitionHandles is null");
this.cteMaterializationTracker = cteMaterializationTracker;

this.stage = stage;
this.nodes = ImmutableList.copyOf(nodes);
Expand Down Expand Up @@ -179,6 +183,29 @@ public ScheduleResult schedule()
{
// schedule a task on every node in the distribution
List<RemoteTask> newTasks = ImmutableList.of();
List<ListenableFuture<?>> blocked = new ArrayList<>();

// CTE Materialization Check
if (stage.requiresMaterializedCTE()) {
List<String> requiredCTEIds = stage.getRequiredCTEList(); // Ensure this method exists and returns a list of required CTE IDs as strings
for (String cteId : requiredCTEIds) {
if (!cteMaterializationTracker.hasBeenMaterialized(cteId)) {
// Add CTE materialization future to the blocked list
ListenableFuture<Void> materializationFuture = cteMaterializationTracker.getFutureForCTE(cteId);
blocked.add(materializationFuture);
}
}
// If any CTE is not materialized, return a blocked ScheduleResult
if (!blocked.isEmpty()) {
return ScheduleResult.blocked(
false, // true if all required CTEs are blocked
newTasks,
whenAnyComplete(blocked), // Wait for any CTE materialization to complete
BlockedReason.WAITING_FOR_CTE_MATERIALIZATION,
0);
}
}
// schedule a task on every node in the distribution
if (!scheduledTasks) {
newTasks = Streams.mapWithIndex(
nodes.stream(),
Expand All @@ -193,7 +220,6 @@ public ScheduleResult schedule()
}

boolean allBlocked = true;
List<ListenableFuture<?>> blocked = new ArrayList<>();
BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP;

if (groupedLifespanScheduler.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ public enum BlockedReason
* grouped execution where there are multiple lifespans per task).
*/
MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE,

WAITING_FOR_CTE_MATERIALIZATION,
/**/;

public BlockedReason combineWith(BlockedReason other)
{
switch (this) {
case WRITER_SCALING:
throw new IllegalArgumentException("cannot be combined");
case WAITING_FOR_CTE_MATERIALIZATION:
case NO_ACTIVE_DRIVER_GROUP:
return other;
case SPLIT_QUEUES_FULL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ public SectionExecution createSectionExecutions(
boolean summarizeTaskInfo,
RemoteTaskFactory remoteTaskFactory,
SplitSourceFactory splitSourceFactory,
int attemptId)
int attemptId,
CTEMaterializationTracker cteMaterializationTracker)
{
// Only fetch a distribution once per section to ensure all stages see the same machine assignments
Map<PartitioningHandle, NodePartitionMap> partitioningCache = new HashMap<>();
Expand All @@ -184,7 +185,8 @@ public SectionExecution createSectionExecutions(
summarizeTaskInfo,
remoteTaskFactory,
splitSourceFactory,
attemptId);
attemptId,
cteMaterializationTracker);
StageExecutionAndScheduler rootStage = getLast(sectionStages);
rootStage.getStageExecution().setOutputBuffers(outputBuffers);
return new SectionExecution(rootStage, sectionStages);
Expand All @@ -203,7 +205,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
boolean summarizeTaskInfo,
RemoteTaskFactory remoteTaskFactory,
SplitSourceFactory splitSourceFactory,
int attemptId)
int attemptId,
CTEMaterializationTracker cteMaterializationTracker)
{
ImmutableList.Builder<StageExecutionAndScheduler> stageExecutionAndSchedulers = ImmutableList.builder();

Expand Down Expand Up @@ -238,7 +241,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
summarizeTaskInfo,
remoteTaskFactory,
splitSourceFactory,
attemptId);
attemptId,
cteMaterializationTracker);
stageExecutionAndSchedulers.addAll(subTree);
childStagesBuilder.add(getLast(subTree).getStageExecution());
}
Expand All @@ -260,7 +264,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
stageExecution,
partitioningHandle,
tableWriteInfo,
childStageExecutions);
childStageExecutions,
cteMaterializationTracker);
stageExecutionAndSchedulers.add(new StageExecutionAndScheduler(
stageExecution,
stageLinkage,
Expand All @@ -279,7 +284,8 @@ private StageScheduler createStageScheduler(
SqlStageExecution stageExecution,
PartitioningHandle partitioningHandle,
TableWriteInfo tableWriteInfo,
Set<SqlStageExecution> childStageExecutions)
Set<SqlStageExecution> childStageExecutions,
CTEMaterializationTracker cteMaterializationTracker)
{
Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(plan.getFragment(), session, tableWriteInfo);
int maxTasksPerStage = getMaxTasksPerStage(session);
Expand Down Expand Up @@ -383,7 +389,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
splitBatchSize,
getConcurrentLifespansPerNode(session),
nodeScheduler.createNodeSelector(session, connectorId, nodePredicate),
connectorPartitionHandles);
connectorPartitionHandles,
cteMaterializationTracker);
if (plan.getFragment().getStageExecutionDescriptor().isRecoverableGroupedExecution()) {
stageExecution.registerStageTaskRecoveryCallback(taskId -> {
checkArgument(taskId.getStageExecutionId().getStageId().equals(stageId), "The task did not execute this stage");
Expand Down
Loading

0 comments on commit a786c7a

Please sign in to comment.