count_cycles.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #!/usr/bin/env python3
  2. # Copyright 2020 The Chromium Authors. All rights reserved.
  3. # Use of this source code is governed by a BSD-style license that can be
  4. # found in the LICENSE file.
  5. """Command-line tool to enumerate cycles in Graph structures."""
  6. import argparse
  7. import collections
  8. from typing import Dict, List, Tuple
  9. import serialization
  10. import graph
  11. Cycle = Tuple[graph.Node, ...]
  12. def find_cycles_from_node(
  13. start_node: graph.Node,
  14. max_cycle_length: int,
  15. node_to_id: Dict[graph.Node, int],
  16. ) -> List[List[List[graph.Node]]]:
  17. """Finds all cycles starting at |start_node| in a subset of nodes.
  18. Only nodes with ID >= |start_node|'s ID will be considered. This ensures
  19. uniquely counting all cycles since this function is called on all nodes of
  20. the graph, one at a time in increasing order. Some justification: Consider
  21. cycle C with smallest node n. When this function is called on node n, C will
  22. be found since all nodes of C are >= n. After that call, C will never be
  23. found again since further calls are on nodes > n (n is removed from the
  24. search space).
  25. Cycles are found by recursively scanning all outbound nodes starting from
  26. |start_node|, up to a certain depth. Note this is the same idea, but is
  27. different from DFS since nodes can be visited more than once (to avoid
  28. missing cycles). An example of normal DFS (where nodes can only be visited
  29. once) missing cycles is in the following graph, starting at a:
  30. a <-> b <-> c
  31. ^ ^
  32. | |
  33. +-----------+
  34. DFS(a)
  35. DFS(b)
  36. DFS(a) (cycle aba, return)
  37. DFS(c)
  38. DFS(b) (already seen, return)
  39. DFS(a) (cycle abca, return)
  40. DFS(c) (already seen, return)
  41. Since DFS(c) cannot proceed, we miss the cycles aca and acba.
  42. Args:
  43. start_node: The node to start the cycle search from. Only nodes with ID
  44. >= |start_node|'s ID will be considered.
  45. max_cycle_length: The maximum length of cycles to be found.
  46. node_to_id: A map from a Node to a generated ID.
  47. Returns:
  48. A list |cycles| of length |max_cycle_length| + 1, where cycles[i]
  49. contains all relevant cycles of length i.
  50. """
  51. start_node_id = node_to_id[start_node]
  52. cycles = [[] for _ in range(max_cycle_length + 1)]
  53. def edge_is_interesting(start: graph.Node, end: graph.Node) -> bool:
  54. if start == end:
  55. # Ignore self-loops.
  56. return False
  57. if node_to_id[end] < start_node_id:
  58. # Ignore edges ending at nodes with ID lower than the start.
  59. return False
  60. return True
  61. dfs_stack = collections.deque()
  62. on_stack: Dict[graph.Node, bool] = collections.defaultdict(bool)
  63. def find_cycles_dfs(cur_node: graph.Node, cur_length: int):
  64. for other_node in cur_node.outbound:
  65. if edge_is_interesting(cur_node, other_node):
  66. if other_node == start_node:
  67. # We have found a valid cycle, add it to the list.
  68. new_cycle = list(dfs_stack) + [cur_node, start_node]
  69. cycles[cur_length + 1].append(new_cycle)
  70. elif (not on_stack[other_node]
  71. and cur_length + 1 < max_cycle_length):
  72. # We are only allowed to recurse into the next node if:
  73. # 1) It hasn't been visited in the current cycle. This is
  74. # because if the next node n _has_ been visited in the
  75. # current cycle (i.e., it's on the stack), then we have
  76. # found a cycle starting and ending at n. Since this
  77. # function only returns cycles starting at |start_node|, we
  78. # only care if |n = start_node| (which we already detect
  79. # above).
  80. # 2) It would not exceed the maximum depth allowed.
  81. dfs_stack.append(cur_node)
  82. on_stack[cur_node] = True
  83. find_cycles_dfs(other_node, cur_length + 1)
  84. dfs_stack.pop()
  85. on_stack[cur_node] = False
  86. find_cycles_dfs(start_node, 0)
  87. return cycles
  88. def find_cycles(base_graph: graph.Graph,
  89. max_cycle_length: int) -> List[List[Cycle]]:
  90. """Finds all cycles in the graph within a certain length.
  91. The algorithm is as such: Number the nodes arbitrarily. For i from 0 to
  92. the number of nodes, find all cycles starting and ending at node i using
  93. only nodes with numbers >= i (see find_cycles_from_node). Taking the union
  94. of the results will give all relevant cycles in the graph.
  95. Returns:
  96. A list |cycles| of length |max_cycle_length| + 1, where cycles[i]
  97. contains all cycles of length i.
  98. """
  99. sorted_base_graph_nodes = sorted(base_graph.nodes)
  100. # Some preliminary setup: map between the graph nodes' unique keys and a
  101. # unique number, since the algorithm needs some way to decide when a node is
  102. # 'bigger'. Nodes with a lower number will be processed first, which
  103. # influences the output cycles. For example, the cycle abca is also valid as
  104. # the cycle bcab or cabc. By numbering node a lower than b and c, it is
  105. # guaranteed that the cycle will be output as abca.
  106. node_to_id = {}
  107. for generated_node_id, node in enumerate(sorted_base_graph_nodes):
  108. node_to_id[node] = generated_node_id
  109. num_nodes = base_graph.num_nodes
  110. cycles = [[] for _ in range(max_cycle_length + 1)]
  111. for start_node in sorted_base_graph_nodes:
  112. start_node_cycles = find_cycles_from_node(start_node, max_cycle_length,
  113. node_to_id)
  114. for cycle_length, cycle_list in enumerate(start_node_cycles):
  115. cycles[cycle_length].extend(cycle_list)
  116. # Convert cycles to be tuples of nodes, so the cycles are hashable and
  117. # immutable.
  118. immutable_cycles = []
  119. for cycle_list in cycles:
  120. immutable_cycles.append([tuple(cycle) for cycle in cycle_list])
  121. return immutable_cycles
  122. def main():
  123. """Enumerates the cycles within a certain length in a graph."""
  124. arg_parser = argparse.ArgumentParser(
  125. description='Given a JSON dependency graph, count the number of cycles '
  126. 'in the package graph.')
  127. required_arg_group = arg_parser.add_argument_group('required arguments')
  128. required_arg_group.add_argument(
  129. '-f',
  130. '--file',
  131. required=True,
  132. help='Path to the JSON file containing the dependency graph. '
  133. 'See the README on how to generate this file.')
  134. required_arg_group.add_argument(
  135. '-l',
  136. '--cycle-length',
  137. type=int,
  138. required=True,
  139. help='The maximum length of cycles to find, at most 5 or 6 to keep the '
  140. 'script runtime low.')
  141. arg_parser.add_argument(
  142. '-o',
  143. '--output',
  144. type=argparse.FileType('w'),
  145. help='Path to the file to write the list of cycles to.')
  146. args = arg_parser.parse_args()
  147. _, package_graph, _ = serialization.load_class_and_package_graphs_from_file(
  148. args.file)
  149. all_cycles = find_cycles(package_graph, args.cycle_length)
  150. # There are no cycles of length 0 or 1 (since self-loops are disallowed).
  151. nonzero_cycles = all_cycles[2:]
  152. print(f'Found {sum(len(cycles) for cycles in nonzero_cycles)} cycles.')
  153. for cycle_length, cycles in enumerate(nonzero_cycles, 2):
  154. print(f'Found {len(cycles)} cycles of length {cycle_length}.')
  155. if args.output is not None:
  156. print(f'Dumping cycles to {args.output.name}.')
  157. with args.output as output_file:
  158. for cycle_length, cycles in enumerate(nonzero_cycles, 2):
  159. output_file.write(f'Cycles of length {cycle_length}:\n')
  160. cycle_texts = []
  161. for cycle in cycles:
  162. cycle_texts.append(' > '.join(cycle_node.name
  163. for cycle_node in cycle))
  164. output_file.write('\n'.join(sorted(cycle_texts)))
  165. output_file.write('\n')
  166. if __name__ == '__main__':
  167. main()