1 //----------------------------------------------------------------
2 // <copyright company="Microsoft Corporation">
3 // Copyright (c) Microsoft Corporation. All rights reserved.
5 //----------------------------------------------------------------
7 namespace System.Activities.Presentation.Model
10 using System.Activities.Presentation;
11 using System.Collections.Generic;
12 using System.Diagnostics;
16 using Microsoft.Activities.Presentation.Xaml;
18 // The graph is completely defined by a collection of vertices and a collection of edges. The back pointers are not part of the graph
19 // definition, but as an auxiliary to quickly trace back to the root vertex if it is reachable from the root. A vertex should have
20 // no back pointers if it's not reachable from the root.
21 // This abstract base class is responsible for managing back pointers while the dervied class is responsible for managing vertices and edges.
22 internal abstract class GraphManager<TVertex, TEdge, TBackPointer> where TVertex : class
24 protected abstract TVertex Root { get; }
26 internal void VerifyBackPointers()
28 ICollection<TVertex> reachableVertices = this.CalculateReachableVertices(true);
30 foreach (TVertex vertex in this.GetVertices())
32 if (reachableVertices.Contains(vertex))
34 foreach (TBackPointer backPointer in this.GetBackPointers(vertex))
36 if (!reachableVertices.Contains(this.GetDestinationVertexFromBackPointer(backPointer)))
38 Fx.Assert(false, "a reachable vertex's back pointer should not point to a vertex that is not reachable");
41 if (!this.HasAssociatedEdge(backPointer))
43 Fx.Assert(false, "a reachable vertex doesn't have an outgoing edge to one of the vertex that have a back pointer to it");
49 if (this.GetBackPointers(vertex).Count() != 0)
51 Fx.Assert(false, "an unreachable vertex should not have any back pointer");
57 protected ICollection<TVertex> CalculateReachableVertices(bool verifyBackPointers = false)
59 HashSet<TVertex> reachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
61 if (this.Root == null)
63 return reachableVertices;
66 Queue<TVertex> queue = new Queue<TVertex>();
67 queue.Enqueue(this.Root);
68 reachableVertices.Add(this.Root);
70 while (queue.Count > 0)
72 TVertex vertex = queue.Dequeue();
74 foreach (TEdge edge in this.GetOutEdges(vertex))
76 if (verifyBackPointers && !this.HasBackPointer(edge))
78 Fx.Assert(false, "a reachable vertex doesn't have a back pointer to one of its incoming edges");
81 TVertex to = this.GetDestinationVertexFromEdge(edge);
82 if (!reachableVertices.Contains(to))
84 reachableVertices.Add(to);
90 return reachableVertices;
93 protected void OnRootChanged(TVertex oldRoot, TVertex newRoot)
97 this.RemoveBackPointers(oldRoot, true);
102 this.AddBackPointers(newRoot);
106 protected abstract IEnumerable<TVertex> GetVertices();
108 protected abstract IEnumerable<TEdge> GetOutEdges(TVertex vertex);
110 protected abstract IEnumerable<TBackPointer> GetBackPointers(TVertex vertex);
112 protected abstract TVertex GetDestinationVertexFromEdge(TEdge edge);
114 protected abstract TVertex GetSourceVertexFromEdge(TEdge edge);
116 protected abstract TVertex GetDestinationVertexFromBackPointer(TBackPointer backPointer);
118 // call this method when an edge is removed
119 protected void OnEdgeRemoved(TEdge edgeRemoved)
121 Fx.Assert(edgeRemoved != null, "edgeRemoved should not be null");
123 TVertex sourceVertex = this.GetSourceVertexFromEdge(edgeRemoved);
124 if (!this.CanReachRootViaBackPointer(sourceVertex))
129 this.RemoveAssociatedBackPointer(edgeRemoved);
130 TVertex destinationVertex = this.GetDestinationVertexFromEdge(edgeRemoved);
132 this.RemoveBackPointers(destinationVertex);
135 // call this method when an edge is added
136 protected void OnEdgeAdded(TEdge edgeAdded)
138 Fx.Assert(edgeAdded != null, "edgeAdded should not be null");
140 TVertex sourceVertex = this.GetSourceVertexFromEdge(edgeAdded);
141 if (!this.CanReachRootViaBackPointer(sourceVertex))
146 TVertex destinationVertex = this.GetDestinationVertexFromEdge(edgeAdded);
147 bool wasReachable = this.CanReachRootViaBackPointer(destinationVertex);
148 this.AddAssociatedBackPointer(edgeAdded);
155 this.AddBackPointers(destinationVertex);
158 protected abstract void RemoveAssociatedBackPointer(TEdge edge);
160 protected abstract void AddAssociatedBackPointer(TEdge edge);
162 protected abstract bool HasBackPointer(TEdge edge);
164 protected abstract bool HasAssociatedEdge(TBackPointer backPointer);
166 protected abstract void OnVerticesBecameReachable(IEnumerable<TVertex> reachableVertices);
168 protected abstract void OnVerticesBecameUnreachable(IEnumerable<TVertex> unreachableVertices);
170 private bool CanReachRootViaBackPointer(TVertex vertex)
172 Fx.Assert(vertex != null, "vertex should not be null");
174 if (vertex == this.Root)
179 HashSet<TVertex> visited = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
180 Queue<TVertex> queue = new Queue<TVertex>();
183 queue.Enqueue(vertex);
185 while (queue.Count > 0)
187 TVertex current = queue.Dequeue();
188 foreach (TBackPointer backPointer in this.GetBackPointers(current))
190 TVertex destinationVertex = this.GetDestinationVertexFromBackPointer(backPointer);
191 if (object.ReferenceEquals(destinationVertex, this.Root))
196 if (!visited.Contains(destinationVertex))
198 visited.Add(destinationVertex);
199 queue.Enqueue(destinationVertex);
207 // traverse the sub-graph starting from vertex and add back pointers
208 private void AddBackPointers(TVertex vertex)
210 HashSet<TVertex> verticesBecameReachable = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
211 Queue<TVertex> queue = new Queue<TVertex>();
213 verticesBecameReachable.Add(vertex);
214 queue.Enqueue(vertex);
216 while (queue.Count > 0)
218 TVertex currentVertex = queue.Dequeue();
220 foreach (TEdge edge in this.GetOutEdges(currentVertex))
222 TVertex destinationVertex = this.GetDestinationVertexFromEdge(edge);
223 bool wasReachable = this.GetBackPointers(destinationVertex).Count() > 0;
224 this.AddAssociatedBackPointer(edge);
225 if (!wasReachable && !verticesBecameReachable.Contains(destinationVertex))
227 verticesBecameReachable.Add(destinationVertex);
228 queue.Enqueue(destinationVertex);
233 this.OnVerticesBecameReachable(verticesBecameReachable);
236 // traverse the sub-graph starting from vertex, if a vertex is reachable then stop traversing its descendents,
237 // otherwise remove back pointers that pointer to it and continue traversing its descendents
238 private void RemoveBackPointers(TVertex vertex, bool isAllVerticesUnreachable = false)
240 ICollection<TVertex> reachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
242 if (!isAllVerticesUnreachable)
244 reachableVertices = this.CalculateReachableVertices();
247 if (reachableVertices.Contains(vertex))
252 Queue<TVertex> queue = new Queue<TVertex>();
253 HashSet<TVertex> unreachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
255 unreachableVertices.Add(vertex);
256 queue.Enqueue(vertex);
258 while (queue.Count > 0)
260 TVertex unreachableVertex = queue.Dequeue();
261 foreach (TEdge edge in this.GetOutEdges(unreachableVertex))
263 this.RemoveAssociatedBackPointer(edge);
264 TVertex to = this.GetDestinationVertexFromEdge(edge);
265 if (isAllVerticesUnreachable || !reachableVertices.Contains(to))
267 if (!unreachableVertices.Contains(to))
269 unreachableVertices.Add(to);
276 this.OnVerticesBecameUnreachable(unreachableVertices);