4d4c2def1459460690dc00f3c10613ae77ee48c6
[mono.git] / mcs / class / referencesource / System.Activities.Presentation / System.Activities.Presentation / System / Activities / Presentation / Model / GraphManager.cs
1 //----------------------------------------------------------------
2 // <copyright company="Microsoft Corporation">
3 //     Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //----------------------------------------------------------------
6
7 namespace System.Activities.Presentation.Model
8 {
9     using System;
10     using System.Activities.Presentation;
11     using System.Collections.Generic;
12     using System.Diagnostics;
13     using System.Linq;
14     using System.Runtime;
15     using System.Text;
16     using Microsoft.Activities.Presentation.Xaml;
17
18     // The graph is completely defined by a collection of vertices and a collection of edges. The back pointers are not part of the graph
19     // definition, but as an auxiliary to quickly trace back to the root vertex if it is reachable from the root. A vertex should have
20     // no back pointers if it's not reachable from the root.
21     // This abstract base class is responsible for managing back pointers while the dervied class is responsible for managing vertices and edges.
22     internal abstract class GraphManager<TVertex, TEdge, TBackPointer> where TVertex : class
23     {
24         protected abstract TVertex Root { get; }
25
26         internal void VerifyBackPointers()
27         {
28             ICollection<TVertex> reachableVertices = this.CalculateReachableVertices(true);
29
30             foreach (TVertex vertex in this.GetVertices())
31             {
32                 if (reachableVertices.Contains(vertex))
33                 {
34                     foreach (TBackPointer backPointer in this.GetBackPointers(vertex))
35                     {
36                         if (!reachableVertices.Contains(this.GetDestinationVertexFromBackPointer(backPointer)))
37                         {
38                             Fx.Assert(false, "a reachable vertex's back pointer should not point to a vertex that is not reachable");
39                         }
40
41                         if (!this.HasAssociatedEdge(backPointer))
42                         {
43                             Fx.Assert(false, "a reachable vertex doesn't have an outgoing edge to one of the vertex that have a back pointer to it");
44                         }
45                     }
46                 }
47                 else
48                 {
49                     if (this.GetBackPointers(vertex).Count() != 0)
50                     {
51                         Fx.Assert(false, "an unreachable vertex should not have any back pointer");
52                     }
53                 }
54             }
55         }
56
57         protected ICollection<TVertex> CalculateReachableVertices(bool verifyBackPointers = false)
58         {
59             HashSet<TVertex> reachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
60
61             if (this.Root == null)
62             {
63                 return reachableVertices;
64             }
65
66             Queue<TVertex> queue = new Queue<TVertex>();
67             queue.Enqueue(this.Root);
68             reachableVertices.Add(this.Root);
69
70             while (queue.Count > 0)
71             {
72                 TVertex vertex = queue.Dequeue();
73
74                 foreach (TEdge edge in this.GetOutEdges(vertex))
75                 {
76                     if (verifyBackPointers && !this.HasBackPointer(edge))
77                     {
78                         Fx.Assert(false, "a reachable vertex doesn't have a back pointer to one of its incoming edges");
79                     }
80
81                     TVertex to = this.GetDestinationVertexFromEdge(edge);
82                     if (!reachableVertices.Contains(to))
83                     {
84                         reachableVertices.Add(to);
85                         queue.Enqueue(to);
86                     }
87                 }
88             }
89
90             return reachableVertices;
91         }
92
93         protected void OnRootChanged(TVertex oldRoot, TVertex newRoot)
94         {
95             if (oldRoot != null)
96             {
97                 this.RemoveBackPointers(oldRoot, true);
98             }
99
100             if (newRoot != null)
101             {
102                 this.AddBackPointers(newRoot);
103             }
104         }
105
106         protected abstract IEnumerable<TVertex> GetVertices();
107
108         protected abstract IEnumerable<TEdge> GetOutEdges(TVertex vertex);
109
110         protected abstract IEnumerable<TBackPointer> GetBackPointers(TVertex vertex);
111
112         protected abstract TVertex GetDestinationVertexFromEdge(TEdge edge);
113
114         protected abstract TVertex GetSourceVertexFromEdge(TEdge edge);
115
116         protected abstract TVertex GetDestinationVertexFromBackPointer(TBackPointer backPointer);
117
118         // call this method when an edge is removed
119         protected void OnEdgeRemoved(TEdge edgeRemoved)
120         {
121             Fx.Assert(edgeRemoved != null, "edgeRemoved should not be null");
122
123             TVertex sourceVertex = this.GetSourceVertexFromEdge(edgeRemoved);
124             if (!this.CanReachRootViaBackPointer(sourceVertex))
125             {
126                 return;
127             }
128
129             this.RemoveAssociatedBackPointer(edgeRemoved);
130             TVertex destinationVertex = this.GetDestinationVertexFromEdge(edgeRemoved);
131
132             this.RemoveBackPointers(destinationVertex);
133         }
134
135         // call this method when an edge is added
136         protected void OnEdgeAdded(TEdge edgeAdded)
137         {
138             Fx.Assert(edgeAdded != null, "edgeAdded should not be null");
139
140             TVertex sourceVertex = this.GetSourceVertexFromEdge(edgeAdded);
141             if (!this.CanReachRootViaBackPointer(sourceVertex))
142             {
143                 return;
144             }
145
146             TVertex destinationVertex = this.GetDestinationVertexFromEdge(edgeAdded);
147             bool wasReachable = this.CanReachRootViaBackPointer(destinationVertex);
148             this.AddAssociatedBackPointer(edgeAdded);
149
150             if (wasReachable)
151             {
152                 return;
153             }
154
155             this.AddBackPointers(destinationVertex);
156         }
157
158         protected abstract void RemoveAssociatedBackPointer(TEdge edge);
159
160         protected abstract void AddAssociatedBackPointer(TEdge edge);
161
162         protected abstract bool HasBackPointer(TEdge edge);
163
164         protected abstract bool HasAssociatedEdge(TBackPointer backPointer);
165
166         protected abstract void OnVerticesBecameReachable(IEnumerable<TVertex> reachableVertices);
167
168         protected abstract void OnVerticesBecameUnreachable(IEnumerable<TVertex> unreachableVertices);
169
170         private bool CanReachRootViaBackPointer(TVertex vertex)
171         {
172             Fx.Assert(vertex != null, "vertex should not be null");
173
174             if (vertex == this.Root)
175             {
176                 return true;
177             }
178
179             HashSet<TVertex> visited = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
180             Queue<TVertex> queue = new Queue<TVertex>();
181
182             visited.Add(vertex);
183             queue.Enqueue(vertex);
184
185             while (queue.Count > 0)
186             {
187                 TVertex current = queue.Dequeue();
188                 foreach (TBackPointer backPointer in this.GetBackPointers(current))
189                 {
190                     TVertex destinationVertex = this.GetDestinationVertexFromBackPointer(backPointer);
191                     if (object.ReferenceEquals(destinationVertex, this.Root))
192                     {
193                         return true;
194                     }
195
196                     if (!visited.Contains(destinationVertex))
197                     {
198                         visited.Add(destinationVertex);
199                         queue.Enqueue(destinationVertex);
200                     }
201                 }
202             }
203
204             return false;
205         }
206
207         // traverse the sub-graph starting from vertex and add back pointers
208         private void AddBackPointers(TVertex vertex)
209         {
210             HashSet<TVertex> verticesBecameReachable = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
211             Queue<TVertex> queue = new Queue<TVertex>();
212
213             verticesBecameReachable.Add(vertex);
214             queue.Enqueue(vertex);
215
216             while (queue.Count > 0)
217             {
218                 TVertex currentVertex = queue.Dequeue();
219
220                 foreach (TEdge edge in this.GetOutEdges(currentVertex))
221                 {
222                     TVertex destinationVertex = this.GetDestinationVertexFromEdge(edge);
223                     bool wasReachable = this.GetBackPointers(destinationVertex).Count() > 0;
224                     this.AddAssociatedBackPointer(edge);
225                     if (!wasReachable && !verticesBecameReachable.Contains(destinationVertex))
226                     {
227                         verticesBecameReachable.Add(destinationVertex);
228                         queue.Enqueue(destinationVertex);
229                     }
230                 }
231             }
232
233             this.OnVerticesBecameReachable(verticesBecameReachable);
234         }
235
236         // traverse the sub-graph starting from vertex, if a vertex is reachable then stop traversing its descendents,
237         // otherwise remove back pointers that pointer to it and continue traversing its descendents
238         private void RemoveBackPointers(TVertex vertex, bool isAllVerticesUnreachable = false)
239         {
240             ICollection<TVertex> reachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
241
242             if (!isAllVerticesUnreachable)
243             {
244                 reachableVertices = this.CalculateReachableVertices();
245             }
246
247             if (reachableVertices.Contains(vertex))
248             {
249                 return;
250             }
251
252             Queue<TVertex> queue = new Queue<TVertex>();
253             HashSet<TVertex> unreachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
254
255             unreachableVertices.Add(vertex);
256             queue.Enqueue(vertex);
257
258             while (queue.Count > 0)
259             {
260                 TVertex unreachableVertex = queue.Dequeue();
261                 foreach (TEdge edge in this.GetOutEdges(unreachableVertex))
262                 {
263                     this.RemoveAssociatedBackPointer(edge);
264                     TVertex to = this.GetDestinationVertexFromEdge(edge);
265                     if (isAllVerticesUnreachable || !reachableVertices.Contains(to))
266                     {
267                         if (!unreachableVertices.Contains(to))
268                         {
269                             unreachableVertices.Add(to);
270                             queue.Enqueue(to);
271                         }
272                     }
273                 }
274             }
275
276             this.OnVerticesBecameUnreachable(unreachableVertices);
277         }
278     }
279 }