Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-8244][CORE] Softaffinity use consistent hash schedule #8245

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ object MergeTreePartsPartitionsUtil extends Logging {
partNameWithLocation: util.HashMap[String, String],
locationDistinct: util.HashSet[String]): Unit = {

logWarning(
s"[DEBUG] PR-8245 " +
s"partNameWithLocation = $partNameWithLocation, locationDistinct = $locationDistinct")
val currentSizeByLocation = new util.HashMap[String, Long]
val currentFilesByLocation = new util.HashMap[String, ArrayBuffer[MergeTreePartSplit]]

Expand Down Expand Up @@ -434,6 +437,7 @@ object MergeTreePartsPartitionsUtil extends Logging {
}

locationDistinct.forEach(closePartition)
logWarning(s"[DEBUG] PR-8245, partitions = $partitions")
}

/** Generate bucket partition */
Expand Down
277 changes: 277 additions & 0 deletions gluten-core/src/main/java/org/apache/gluten/hash/ConsistentHash.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.hash;

import com.google.common.base.Preconditions;
import org.apache.commons.codec.digest.MurmurHash3;

import javax.annotation.concurrent.ThreadSafe;

import java.util.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* A consistent hash ring implementation. the class is thread-safe.
*
* @param <T> the type of node to be used in the ring.
*/
@ThreadSafe
public class ConsistentHash<T extends ConsistentHash.Node> {

// read-write lock for the ring.
private final ReadWriteLock lock = new ReentrantReadWriteLock(true);

// keep the mapping of node to its partitions, each partition is a slot in the ring.
private final Map<T, Set<Partition<T>>> nodes = new HashMap<>();

// keep the mapping of slot to partition, the partition actually is a virtual node.
private final SortedMap<Long, Partition<T>> ring = new TreeMap<>();

// the number of virtual nodes for each physical node.
private final int replicate;

// the hasher function to hash the key.
private final Hasher hasher;

public ConsistentHash(int replicate) {
Preconditions.checkArgument(replicate > 0, "HashRing require positive replicate number.");
this.replicate = replicate;
this.hasher =
(key, seed) -> {
byte[] data = key.getBytes();
return MurmurHash3.hash32x86(data, 0, data.length, seed);
};
}

public ConsistentHash(int replicate, Hasher hasher) {
Preconditions.checkArgument(replicate > 0, "HashRing require positive replicate number.");
Preconditions.checkArgument(hasher != null, "HashRing require non-null hasher.");
this.replicate = replicate;
this.hasher = hasher;
}

/**
* Add a node to the ring, the node will be replicated to `replicate` virtual nodes.
*
* @param node the node to be added to the ring.
* @return true if the node is added successfully, false otherwise.
*/
public boolean addNode(T node) {
lock.writeLock().lock();
try {
return add(node);
} finally {
lock.writeLock().unlock();
}
}

/**
* Remove the node from the ring, all the virtual nodes will be removed.
*
* @param node the node to be removed.
* @return true if the node is removed successfully, false otherwise.
*/
public boolean removeNode(T node) {
lock.writeLock().lock();
boolean removed = false;
try {
if (nodes.containsKey(node)) {
Set<Partition<T>> partitions = nodes.remove(node);
partitions.forEach(p -> ring.remove(p.getSlot()));
removed = true;
}
} finally {
lock.writeLock().unlock();
}
return removed;
}

/**
* Allocate the node by the key, the number of nodes to be located is specified by the count.
*
* @param key the key to locate the node.
* @param count the number of nodes to be located.
* @return a set of nodes located by the key.
*/
public Set<T> allocateNodes(String key, int count) {
lock.readLock().lock();
try {
Set<T> res = new HashSet<>();
if (key != null && count > 0) {
if (count < nodes.size()) {
long slot = hasher.hash(key, 0);
Iterator<Partition<T>> it = new AllocateIterator(slot);
while (it.hasNext() && res.size() < count) {
Partition<T> part = it.next();
res.add(part.getNode());
}
} else {
res.addAll(nodes.keySet());
}
}
return res;
} finally {
lock.readLock().unlock();
}
}

/**
* Get all the nodes in the ring.
*
* @return a set of nodes in the ring.
*/
public Set<T> getNodes() {
lock.readLock().lock();
try {
return new HashSet<>(nodes.keySet());
} finally {
lock.readLock().unlock();
}
}

/**
* Get the partitions of the node. return null if the node is not exists.
*
* @param node the node to get the partitions.
* @return a set of partitions of the node.
*/
public Set<Partition<T>> getPartition(T node) {
lock.readLock().lock();
try {
return nodes.get(node);
} finally {
lock.readLock().unlock();
}
}

/**
* Check if the node is in the ring.
*
* @param node the node to be checked.
* @return true if the node is in the ring, false otherwise.
*/
public boolean contains(T node) {
lock.readLock().lock();
try {
return nodes.containsKey(node);
} finally {
lock.readLock().unlock();
}
}

/**
* Check if the ring contains the slot.
*
* @param slot the slot to be checked.
* @return true if the slot is in the ring, false otherwise.
*/
public boolean ringContain(long slot) {
lock.readLock().lock();
try {
return ring.containsKey(slot);
} finally {
lock.readLock().unlock();
}
}

private boolean add(T node) {
boolean added = false;
if (node != null && !nodes.containsKey(node)) {
Set<Partition<T>> partitions =
IntStream.range(0, replicate)
.mapToObj(idx -> new Partition<T>(node, idx))
.collect(Collectors.toSet());
nodes.put(node, partitions);

// allocate slot.
for (Partition<T> partition : partitions) {
long slot;
int seed = 0;
do {
slot = this.hasher.hash(partition.getPartitionKey(), seed++);
} while (ring.containsKey(slot));

partition.setSlot(slot);
ring.put(slot, partition);
}
added = true;
}
return added;
}

public static class Partition<T extends ConsistentHash.Node> {
private final T node;

private final int index;

private long slot;

public Partition(T node, int index) {
this.node = node;
this.index = index;
}

public String getPartitionKey() {
return String.format("%s:%d", node, index);
}

public T getNode() {
return node;
}

public void setSlot(long slot) {
this.slot = slot;
}

public long getSlot() {
return this.slot;
}
}

/** Base interface for the node in the ring. */
public interface Node {
String key();
}

/** Base interface for the hash function. */
public interface Hasher {
long hash(String key, int seed);
}

private class AllocateIterator implements Iterator<Partition<T>> {
private final Iterator<Partition<T>> head;
private final Iterator<Partition<T>> tail;

AllocateIterator(long slot) {
this.head = ring.headMap(slot).values().iterator();
this.tail = ring.tailMap(slot).values().iterator();
}

@Override
public boolean hasNext() {
return head.hasNext() || tail.hasNext();
}

@Override
public Partition<T> next() {
return tail.hasNext() ? tail.next() : head.next();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.gluten.softaffinity

import org.apache.gluten.GlutenConfig
import org.apache.gluten.hash.ConsistentHash
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.softaffinity.strategy.SoftAffinityStrategy
import org.apache.gluten.softaffinity.strategy.{ConsistentHashSoftAffinityStrategy, ExecutorNode}
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkEnv
Expand All @@ -38,7 +39,8 @@ abstract class AffinityManager extends LogLevelUtil with Logging {

private val resourceRWLock = new ReentrantReadWriteLock(true)

private val softAffinityAllocation = new SoftAffinityStrategy
lazy val softAffinityReplicationNum: Int =
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM_DEFAULT_VALUE

lazy val minOnTargetHosts: Int = GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE

Expand Down Expand Up @@ -88,6 +90,13 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
}
})

private var hashRing: Option[ConsistentHash[ExecutorNode]] = _

private val softAffinityStrategy = {
hashRing = Some(new ConsistentHash[ExecutorNode](softAffinityReplicationNum))
new ConsistentHashSoftAffinityStrategy(hashRing.get)
}

private val rand = new Random(System.currentTimeMillis)

def totalExecutors(): Int = totalRegisteredExecutors.intValue()
Expand All @@ -107,6 +116,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
executorsSet.add(execHostId._1)
idForExecutors += execHostId
sortedIdForExecutors = idForExecutors.sortBy(_._2)
hashRing.foreach(_.addNode(ExecutorNode(execHostId._1, execHostId._2)))
totalRegisteredExecutors.addAndGet(1)
}
logOnLevel(
Expand Down Expand Up @@ -139,6 +149,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
nodesExecutorsMap.remove(findedExecId._2)
}
sortedIdForExecutors = idForExecutors.sortBy(_._2)
hashRing.foreach(_.removeNode(ExecutorNode(execId, findedExecId._2)))
totalRegisteredExecutors.addAndGet(-1)
}
logOnLevel(
Expand Down Expand Up @@ -237,7 +248,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
if (nodesExecutorsMap.size < 1) {
Array.empty
} else {
softAffinityAllocation.allocateExecs(file, sortedIdForExecutors)
softAffinityStrategy.allocateExecs(file, softAffinityReplicationNum)
}
} finally {
resourceRWLock.readLock().unlock()
Expand Down Expand Up @@ -302,6 +313,10 @@ object SoftAffinityManager extends AffinityManager {
GlutenConfig.GLUTEN_SOFT_AFFINITY_ENABLED_DEFAULT_VALUE
)

override lazy val softAffinityReplicationNum: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM,
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM_DEFAULT_VALUE)

override lazy val minOnTargetHosts: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS,
GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE
Expand Down
Loading
Loading