-
Notifications
You must be signed in to change notification settings - Fork 1
/
prims.py
40 lines (29 loc) · 1.06 KB
/
prims.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import heapq
from typing import List
from prims_data_structures import Edge, Graph
PrimsHeap = List[Edge]
def create_heap(graph: Graph, start_vertex: int) -> PrimsHeap:
heap = [edge for edge in graph if edge.start == start_vertex]
heapq.heapify(heap)
return heap
def prims(graph: Graph) -> Graph:
graph = Graph(*graph.edges + graph.reverse().edges)
start_vertex = graph[0].start # arbitrary start vertex
visited_vertices = {start_vertex}
mst = Graph()
heap = create_heap(graph, start_vertex)
while visited_vertices != graph.vertices:
winner = heapq.heappop(heap)
visited_vertices.add(winner.end)
mst += winner
for i, unvisited_edge in enumerate(heap):
if unvisited_edge.end == winner.end:
heap[i].weight = float("inf")
heapq.heapify(heap)
for other_edge in graph:
if (
other_edge.start == winner.end
and other_edge.end not in visited_vertices
):
heapq.heappush(heap, other_edge)
return mst