Skip to content

Commit

Permalink
retrying spill when oom in shared mode
Browse files Browse the repository at this point in the history
  • Loading branch information
kecookier committed Dec 16, 2024
1 parent 2346584 commit 8637931
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ public interface MemoryTargetVisitor<T> {
T visit(NoopMemoryTarget noopMemoryTarget);

T visit(DynamicOffHeapSizingMemoryTarget dynamicOffHeapSizingMemoryTarget);

T visit(RetryOnOomMemoryTarget retryOnOomMemoryTarget);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.gluten.memory.MemoryUsageStatsBuilder;
import org.apache.gluten.memory.memtarget.spark.TreeMemoryConsumers;

import org.apache.spark.SparkEnv;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.SparkResourceUtil;

import java.util.Map;

Expand All @@ -43,6 +45,14 @@ public static MemoryTarget overAcquire(
return new OverAcquire(target, overTarget, overAcquiredRatio);
}

public static TreeMemoryTarget retrySpillOnOom(TreeMemoryTarget target) {
SparkEnv env = SparkEnv.get();
if (env != null && env.conf() != null && SparkResourceUtil.getTaskSlots(env.conf()) > 1) {
return new RetryOnOomMemoryTarget(target);
}
return target;
}

@Experimental
public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget memoryTarget) {
if (GlutenConfig.getConf().dynamicOffHeapSizingEnabled()) {
Expand All @@ -59,11 +69,12 @@ public static TreeMemoryTarget newConsumer(
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryConsumers.Factory factory;
if (GlutenConfig.getConf().memoryIsolation()) {
factory = TreeMemoryConsumers.isolated();
return TreeMemoryConsumers.isolated().newConsumer(tmm, name, spiller, virtualChildren);
} else {
factory = TreeMemoryConsumers.shared();
// Retry of spilling is needed in shared mode because the maxMemoryPerTask of Vanilla Spark
// ExecutionMemoryPool is dynamic when with multi-slot config.
return MemoryTargets.retrySpillOnOom(
TreeMemoryConsumers.shared().newConsumer(tmm, name, spiller, virtualChildren));
}

return factory.newConsumer(tmm, name, spiller, virtualChildren);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.memory.memtarget;

import org.apache.gluten.memory.MemoryUsageStatsBuilder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

public class RetryOnOomMemoryTarget implements TreeMemoryTarget {
private static final Logger LOGGER = LoggerFactory.getLogger(RetryOnOomMemoryTarget.class);
private final TreeMemoryTarget target;

RetryOnOomMemoryTarget(TreeMemoryTarget target) {
this.target = target;
}

@Override
public long borrow(long size) {
long granted = target.borrow(size);
if (granted < size) {

LOGGER.info(
"Exceed Spark perTaskLimit with maxTaskSizeDynamic when "
+ "require:{} got:{}, try spill all.",
size,
granted);
final long spilled = TreeMemoryTargets.spillTree(target, Long.MAX_VALUE);
final long remaining = size - granted;
if (spilled >= remaining) {
granted += target.borrow(remaining);
}
}
return granted;
}

@Override
public long repay(long size) {
return target.repay(size);
}

@Override
public long usedBytes() {
return target.usedBytes();
}

@Override
public <T> T accept(MemoryTargetVisitor<T> visitor) {
return visitor.visit(this);
}

@Override
public String name() {
return target.name();
}

@Override
public MemoryUsageStats stats() {
return target.stats();
}

@Override
public TreeMemoryTarget newChild(
String name,
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
return target.newChild(name, capacity, spiller, virtualChildren);
}

@Override
public Map<String, TreeMemoryTarget> children() {
return target.children();
}

@Override
public TreeMemoryTarget parent() {
return target.parent();
}

@Override
public Spiller getNodeSpiller() {
return target.getNodeSpiller();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ object SparkMemoryUtil {
dynamicOffHeapSizingMemoryTarget: DynamicOffHeapSizingMemoryTarget): String = {
dynamicOffHeapSizingMemoryTarget.delegated().accept(this)
}

override def visit(retryOnOomMemoryTarget: RetryOnOomMemoryTarget): String = {
retryOnOomMemoryTarget.target().accept(this)
}
})
}

Expand Down

0 comments on commit 8637931

Please sign in to comment.