Skip to content

Commit

Permalink
[Update lookup to latest version from fingermath
Browse files Browse the repository at this point in the history
  • Loading branch information
ApmeM committed Jul 19, 2024
1 parent 21aa4c7 commit b4dffcf
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 107 deletions.
61 changes: 0 additions & 61 deletions BrainAI.Tests/LookupTest.cs

This file was deleted.

123 changes: 77 additions & 46 deletions BrainAI/Pathfinding/Utils/Lookup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
using System;
using System.Collections;
using System.Linq;
using System.Diagnostics;

namespace BrainAI.Pathfinding
{
/// <summary>
/// Zero memory allocation lookup.
/// Copied from https://github.com/ApmeM/FingerMath/blob/master/FingerMath/Collections/Lookup.cs
/// </summary>
public class Lookup<TKey, TValue> : ILookup<TKey, TValue>
{
private Dictionary<TKey, LinkedListNode<(TKey, TValue)>> startReference = new Dictionary<TKey, LinkedListNode<(TKey, TValue)>>();
Expand Down Expand Up @@ -61,6 +64,7 @@ public void Add(TKey key, TValue value)
set.Add(tuple);
}
}

version++;
}

Expand All @@ -82,12 +86,16 @@ public void Remove(TKey key, TValue value)
var start = startReference[key];
while (
EqualityComparer<TKey>.Default.Equals(start.Value.Item1, key) &&
!EqualityComparer<TValue>.Default.Equals(start.Value.Item2, value))
!EqualityComparer<TValue>.Default.Equals(start.Value.Item2, value) &&
!EqualityComparer<LinkedListNode<(TKey, TValue)>>.Default.Equals(start, endReference[key]))
{
start = start.next;
}

if (!EqualityComparer<TKey>.Default.Equals(start.Value.Item1, key))
if (
!EqualityComparer<TKey>.Default.Equals(start.Value.Item1, key) ||
!EqualityComparer<TValue>.Default.Equals(start.Value.Item2, value)
)
{
return;
}
Expand Down Expand Up @@ -139,20 +147,10 @@ public bool Contains(TKey key)

public Enumerable this[TKey key] => this.Find(key);

public GroupingEnumerator GetEnumerator()
{
return new GroupingEnumerator(this);
}
public GroupingEnumerator GetEnumerator() => new GroupingEnumerator(this);

private Enumerable Find(TKey key)
{
if (!startReference.ContainsKey(key))
{
return new Enumerable(null, null, 0);
}
private Enumerable Find(TKey key) => new Enumerable(this, key);

return new Enumerable(startReference[key], endReference[key], counts[key]);
}
[Obsolete("Use Find instead. This method requires boxing and allocates memory.")]
IEnumerable<TValue> ILookup<TKey, TValue>.this[TKey key] => this.Find(key);

Expand All @@ -162,19 +160,19 @@ private Enumerable Find(TKey key)
[Obsolete("Use GetEnumerator instead. This method requires boxing and allocates memory.")]
IEnumerator<IGrouping<TKey, TValue>> IEnumerable<IGrouping<TKey, TValue>>.GetEnumerator() => this.GetEnumerator();

public struct Enumerable : IEnumerable<TValue>, IEnumerable, IGrouping<TKey, TValue>
public struct Enumerable : IEnumerable<TValue>, IEnumerable, IGrouping<TKey, TValue>, ICollection<TValue>
{
private readonly LinkedListNode<(TKey, TValue)> start;
private readonly LinkedListNode<(TKey, TValue)> end;
public readonly int Count;
private readonly Lookup<TKey, TValue> lookup;
public TKey Key { get; }

public int Count => this.lookup.counts.ContainsKey(this.Key) ? this.lookup.counts[this.Key] : 0;

public TKey Key => start.Value.Item1;
public bool IsReadOnly => false;

public Enumerable(LinkedListNode<(TKey, TValue)> start, LinkedListNode<(TKey, TValue)> end, int count)
public Enumerable(Lookup<TKey, TValue> lookup, TKey key)
{
this.start = start;
this.end = end;
this.Count = count;
this.lookup = lookup;
this.Key = key;
}

[Obsolete("Use GetEnumerator instead. This method requires boxing and allocates memory.")]
Expand All @@ -183,39 +181,65 @@ public Enumerable(LinkedListNode<(TKey, TValue)> start, LinkedListNode<(TKey, TV
[Obsolete("Use GetEnumerator instead. This method requires boxing and allocates memory.")]
IEnumerator<TValue> IEnumerable<TValue>.GetEnumerator() => GetEnumerator();

public Enumerator GetEnumerator()
public Enumerator GetEnumerator() => new Enumerator(this.lookup, this.Key);

public void Add(TValue item)
{
return new Enumerator(this.start, this.end);
this.lookup.Add(this.Key, item);
}

public void Clear()
{
this.lookup.Remove(this.Key);
}

public bool Contains(TValue item)
{
foreach (var el in this)
{
if (EqualityComparer<TValue>.Default.Equals(el, item))
{
return true;
}
}

return false;
}

public void CopyTo(TValue[] array, int arrayIndex)
{
foreach (var el in this)
{
array[arrayIndex] = el;
}
}

public bool Remove(TValue item)
{
this.lookup.Remove(this.Key, item);
return true;
}
}

public struct Enumerator : IEnumerator<TValue>, IEnumerator
{
private LinkedListNode<(TKey, TValue)> node;
private readonly int version;
private readonly LinkedListNode<(TKey, TValue)> start;
private LinkedListNode<(TKey, TValue)> end;
private readonly Lookup<TKey, TValue> lookup;
private readonly TKey key;
private LinkedListNode<(TKey, TValue)> node;
private TValue current;

public TValue Current => current;

object IEnumerator.Current => current;

internal Enumerator(LinkedListNode<(TKey, TValue)> start, LinkedListNode<(TKey, TValue)> end)
internal Enumerator(Lookup<TKey, TValue> lookup, TKey key)
{
this.start = start;
this.end = end;
this.node = start;
this.current = default;

if (this.node == null)
{
version = -1;
}
else
{
version = start.List.version;
}
this.version = lookup.version;
this.lookup = lookup;
this.key = key;
this.node = this.lookup.startReference.ContainsKey(key) ? this.lookup.startReference[key] : null;
}

public bool MoveNext()
Expand All @@ -225,14 +249,14 @@ public bool MoveNext()
return false;
}

if (version != node.List.version)
if (version != lookup.version)
{
throw new InvalidOperationException("EnumFailedVersion");
throw new InvalidOperationException("The underlying collection was changed.");
}

current = node.Value.Item2;

if (node == end)
if (node == this.lookup.endReference[key])
{
node = null;
}
Expand All @@ -256,12 +280,14 @@ public void Dispose()
public struct GroupingEnumerator : IEnumerator<IGrouping<TKey, TValue>>, IEnumerator
{
private Lookup<TKey, TValue> lookup;
private int version;
private Dictionary<TKey, LinkedListNode<(TKey, TValue)>>.Enumerator bucketEnumerator;

internal GroupingEnumerator(Lookup<TKey, TValue> lookup)
{
this.Current = default;
this.lookup = lookup;
this.version = this.lookup.version;
this.bucketEnumerator = lookup.startReference.GetEnumerator();
}

Expand All @@ -275,11 +301,16 @@ internal GroupingEnumerator(Lookup<TKey, TValue> lookup)

public bool MoveNext()
{
if (version != this.lookup.version)
{
throw new InvalidOperationException("The underlying collection was changed.");
}

var result = this.bucketEnumerator.MoveNext();
if (result)
{
var key = this.bucketEnumerator.Current.Key;
Current = new Enumerable(this.lookup.startReference[key], this.lookup.endReference[key], this.lookup.counts[key]);
Current = new Enumerable(this.lookup, key);
}
return result;
}
Expand Down

0 comments on commit b4dffcf

Please sign in to comment.