Skip to content

Commit

Permalink
consistent hash
Browse files Browse the repository at this point in the history
  • Loading branch information
yikf committed Dec 16, 2024
1 parent 08bdb42 commit 0859e72
Show file tree
Hide file tree
Showing 6 changed files with 492 additions and 64 deletions.
277 changes: 277 additions & 0 deletions gluten-core/src/main/java/org/apache/gluten/hash/ConsistentHash.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.hash;

import com.google.common.base.Preconditions;
import org.apache.commons.codec.digest.MurmurHash3;

import javax.annotation.concurrent.ThreadSafe;

import java.util.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* A consistent hash ring implementation. the class is thread-safe.
*
* @param <T> the type of node to be used in the ring.
*/
@ThreadSafe
public class ConsistentHash<T extends ConsistentHash.Node> {

// read-write lock for the ring.
private final ReadWriteLock lock = new ReentrantReadWriteLock(true);

// keep the mapping of node to its partitions, each partition is a slot in the ring.
private final Map<T, Set<Partition<T>>> nodes = new HashMap<>();

// keep the mapping of slot to partition, the partition actually is a virtual node.
private final SortedMap<Long, Partition<T>> ring = new TreeMap<>();

// the number of virtual nodes for each physical node.
private final int replicate;

// the hasher function to hash the key.
private final Hasher hasher;

public ConsistentHash(int replicate) {
Preconditions.checkArgument(replicate > 0, "HashRing require positive replicate number.");
this.replicate = replicate;
this.hasher =
(key, seed) -> {
byte[] data = key.getBytes();
return MurmurHash3.hash32x86(data, 0, data.length, seed);
};
}

public ConsistentHash(int replicate, Hasher hasher) {
Preconditions.checkArgument(replicate > 0, "HashRing require positive replicate number.");
Preconditions.checkArgument(hasher != null, "HashRing require non-null hasher.");
this.replicate = replicate;
this.hasher = hasher;
}

/**
* Add a node to the ring, the node will be replicated to `replicate` virtual nodes.
*
* @param node the node to be added to the ring.
* @return true if the node is added successfully, false otherwise.
*/
public boolean addNode(T node) {
lock.writeLock().lock();
try {
return add(node);
} finally {
lock.writeLock().unlock();
}
}

/**
* Remove the node from the ring, all the virtual nodes will be removed.
*
* @param node the node to be removed.
* @return true if the node is removed successfully, false otherwise.
*/
public boolean removeNode(T node) {
lock.writeLock().lock();
boolean removed = false;
try {
if (nodes.containsKey(node)) {
Set<Partition<T>> partitions = nodes.remove(node);
partitions.forEach(p -> ring.remove(p.getSlot()));
removed = true;
}
} finally {
lock.writeLock().unlock();
}
return removed;
}

/**
* Allocate the node by the key, the number of nodes to be located is specified by the count.
*
* @param key the key to locate the node.
* @param count the number of nodes to be located.
* @return a set of nodes located by the key.
*/
public Set<T> allocateNodes(String key, int count) {
lock.readLock().lock();
try {
Set<T> res = new HashSet<>();
if (key != null && count > 0) {
if (count < nodes.size()) {
long slot = hasher.hash(key, 0);
Iterator<Partition<T>> it = new AllocateIterator(slot);
while (it.hasNext() && res.size() < count) {
Partition<T> part = it.next();
res.add(part.getNode());
}
} else {
res.addAll(nodes.keySet());
}
}
return res;
} finally {
lock.readLock().unlock();
}
}

/**
* Get all the nodes in the ring.
*
* @return a set of nodes in the ring.
*/
public Set<T> getNodes() {
lock.readLock().lock();
try {
return new HashSet<>(nodes.keySet());
} finally {
lock.readLock().unlock();
}
}

/**
* Get the partitions of the node. return null if the node is not exists.
*
* @param node the node to get the partitions.
* @return a set of partitions of the node.
*/
public Set<Partition<T>> getPartition(T node) {
lock.readLock().lock();
try {
return nodes.get(node);
} finally {
lock.readLock().unlock();
}
}

/**
* Check if the node is in the ring.
*
* @param node the node to be checked.
* @return true if the node is in the ring, false otherwise.
*/
public boolean contains(T node) {
lock.readLock().lock();
try {
return nodes.containsKey(node);
} finally {
lock.readLock().unlock();
}
}

/**
* Check if the ring contains the slot.
*
* @param slot the slot to be checked.
* @return true if the slot is in the ring, false otherwise.
*/
public boolean ringContain(long slot) {
lock.readLock().lock();
try {
return ring.containsKey(slot);
} finally {
lock.readLock().unlock();
}
}

private boolean add(T node) {
boolean added = false;
if (node != null && !nodes.containsKey(node)) {
Set<Partition<T>> partitions =
IntStream.range(0, replicate)
.mapToObj(idx -> new Partition<T>(node, idx))
.collect(Collectors.toSet());
nodes.put(node, partitions);

// allocate slot.
for (Partition<T> partition : partitions) {
long slot;
int seed = 0;
do {
slot = this.hasher.hash(partition.getPartitionKey(), seed++);
} while (ring.containsKey(slot));

partition.setSlot(slot);
ring.put(slot, partition);
}
added = true;
}
return added;
}

public static class Partition<T extends ConsistentHash.Node> {
private final T node;

private final int index;

private long slot;

public Partition(T node, int index) {
this.node = node;
this.index = index;
}

public String getPartitionKey() {
return String.format("%s:%d", node, index);
}

public T getNode() {
return node;
}

public void setSlot(long slot) {
this.slot = slot;
}

public long getSlot() {
return this.slot;
}
}

/** Base interface for the node in the ring. */
public interface Node {
String key();
}

/** Base interface for the hash function. */
public interface Hasher {
long hash(String key, int seed);
}

private class AllocateIterator implements Iterator<Partition<T>> {
private final Iterator<Partition<T>> head;
private final Iterator<Partition<T>> tail;

AllocateIterator(long slot) {
this.head = ring.headMap(slot).values().iterator();
this.tail = ring.tailMap(slot).values().iterator();
}

@Override
public boolean hasNext() {
return head.hasNext() || tail.hasNext();
}

@Override
public Partition<T> next() {
return tail.hasNext() ? tail.next() : head.next();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.gluten.softaffinity

import org.apache.gluten.GlutenConfig
import org.apache.gluten.hash.ConsistentHash
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.softaffinity.strategy.SoftAffinityStrategy
import org.apache.gluten.softaffinity.strategy.{ConsistentHashSoftAffinityStrategy, ExecutorNode}
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkEnv
Expand All @@ -38,7 +39,8 @@ abstract class AffinityManager extends LogLevelUtil with Logging {

private val resourceRWLock = new ReentrantReadWriteLock(true)

private val softAffinityAllocation = new SoftAffinityStrategy
lazy val softAffinityReplicationNum: Int =
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM_DEFAULT_VALUE

lazy val minOnTargetHosts: Int = GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE

Expand Down Expand Up @@ -88,6 +90,13 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
}
})

private var hashRing: Option[ConsistentHash[ExecutorNode]] = _

private val softAffinityStrategy = {
hashRing = Some(new ConsistentHash[ExecutorNode](softAffinityReplicationNum))
new ConsistentHashSoftAffinityStrategy(hashRing.get)
}

private val rand = new Random(System.currentTimeMillis)

def totalExecutors(): Int = totalRegisteredExecutors.intValue()
Expand All @@ -107,6 +116,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
executorsSet.add(execHostId._1)
idForExecutors += execHostId
sortedIdForExecutors = idForExecutors.sortBy(_._2)
hashRing.foreach(_.addNode(ExecutorNode(execHostId._1, execHostId._2)))
totalRegisteredExecutors.addAndGet(1)
}
logOnLevel(
Expand Down Expand Up @@ -139,6 +149,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
nodesExecutorsMap.remove(findedExecId._2)
}
sortedIdForExecutors = idForExecutors.sortBy(_._2)
hashRing.foreach(_.removeNode(ExecutorNode(execId, findedExecId._2)))
totalRegisteredExecutors.addAndGet(-1)
}
logOnLevel(
Expand Down Expand Up @@ -237,7 +248,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
if (nodesExecutorsMap.size < 1) {
Array.empty
} else {
softAffinityAllocation.allocateExecs(file, sortedIdForExecutors)
softAffinityStrategy.allocateExecs(file, softAffinityReplicationNum)
}
} finally {
resourceRWLock.readLock().unlock()
Expand Down Expand Up @@ -302,6 +313,10 @@ object SoftAffinityManager extends AffinityManager {
GlutenConfig.GLUTEN_SOFT_AFFINITY_ENABLED_DEFAULT_VALUE
)

override lazy val softAffinityReplicationNum: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM,
GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM_DEFAULT_VALUE)

override lazy val minOnTargetHosts: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS,
GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.softaffinity.strategy

import org.apache.gluten.hash.ConsistentHash

import java.util.Objects

import scala.collection.JavaConverters._

class ConsistentHashSoftAffinityStrategy(candidates: ConsistentHash[ExecutorNode])
extends SoftAffinityAllocationTrait {

/** Allocate the executors of count number from the candidates. */
override def allocateExecs(file: String, count: Int): Array[(String, String)] = {
candidates.allocateNodes(file, count).asScala.map(node => (node.exeId, node.host)).toArray
}
}

case class ExecutorNode(exeId: String, host: String) extends ConsistentHash.Node {
override def key(): String = s"$exeId-$host"

override def toString: String = s"$exeId-$host"

override def equals(o: Any): Boolean = {
o match {
case ExecutorNode(exeId, host) => this.exeId == exeId && this.host == host
case _ => false
}
}

override def hashCode(): Int = Objects.hashCode(exeId, host)
}
Loading

0 comments on commit 0859e72

Please sign in to comment.