Mercurial > repos > shellac > sam_consensus_v3
comparison env/lib/python3.9/site-packages/networkx/algorithms/tests/test_d_separation.py @ 0:4f3585e2f14b draft default tip
"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author | shellac |
---|---|
date | Mon, 22 Mar 2021 18:12:50 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:4f3585e2f14b |
---|---|
1 from itertools import combinations | |
2 import pytest | |
3 import networkx as nx | |
4 | |
5 | |
6 def path_graph(): | |
7 """Return a path graph of length three.""" | |
8 G = nx.path_graph(3, create_using=nx.DiGraph) | |
9 G.graph["name"] = "path" | |
10 nx.freeze(G) | |
11 return G | |
12 | |
13 | |
14 def fork_graph(): | |
15 """Return a three node fork graph.""" | |
16 G = nx.DiGraph(name="fork") | |
17 G.add_edges_from([(0, 1), (0, 2)]) | |
18 nx.freeze(G) | |
19 return G | |
20 | |
21 | |
22 def collider_graph(): | |
23 """Return a collider/v-structure graph with three nodes.""" | |
24 G = nx.DiGraph(name="collider") | |
25 G.add_edges_from([(0, 2), (1, 2)]) | |
26 nx.freeze(G) | |
27 return G | |
28 | |
29 | |
30 def naive_bayes_graph(): | |
31 """Return a simply Naive Bayes PGM graph.""" | |
32 G = nx.DiGraph(name="naive_bayes") | |
33 G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)]) | |
34 nx.freeze(G) | |
35 return G | |
36 | |
37 | |
38 def asia_graph(): | |
39 """Return the 'Asia' PGM graph.""" | |
40 G = nx.DiGraph(name="asia") | |
41 G.add_edges_from( | |
42 [ | |
43 ("asia", "tuberculosis"), | |
44 ("smoking", "cancer"), | |
45 ("smoking", "bronchitis"), | |
46 ("tuberculosis", "either"), | |
47 ("cancer", "either"), | |
48 ("either", "xray"), | |
49 ("either", "dyspnea"), | |
50 ("bronchitis", "dyspnea"), | |
51 ] | |
52 ) | |
53 nx.freeze(G) | |
54 return G | |
55 | |
56 | |
57 @pytest.fixture(name="path_graph") | |
58 def path_graph_fixture(): | |
59 return path_graph() | |
60 | |
61 | |
62 @pytest.fixture(name="fork_graph") | |
63 def fork_graph_fixture(): | |
64 return fork_graph() | |
65 | |
66 | |
67 @pytest.fixture(name="collider_graph") | |
68 def collider_graph_fixture(): | |
69 return collider_graph() | |
70 | |
71 | |
72 @pytest.fixture(name="naive_bayes_graph") | |
73 def naive_bayes_graph_fixture(): | |
74 return naive_bayes_graph() | |
75 | |
76 | |
77 @pytest.fixture(name="asia_graph") | |
78 def asia_graph_fixture(): | |
79 return asia_graph() | |
80 | |
81 | |
82 @pytest.mark.parametrize( | |
83 "graph", | |
84 [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()], | |
85 ) | |
86 def test_markov_condition(graph): | |
87 """Test that the Markov condition holds for each PGM graph.""" | |
88 for node in graph.nodes: | |
89 parents = set(graph.predecessors(node)) | |
90 non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents | |
91 assert nx.d_separated(graph, {node}, non_descendants, parents) | |
92 | |
93 | |
94 def test_path_graph_dsep(path_graph): | |
95 """Example-based test of d-separation for path_graph.""" | |
96 assert nx.d_separated(path_graph, {0}, {2}, {1}) | |
97 assert not nx.d_separated(path_graph, {0}, {2}, {}) | |
98 | |
99 | |
100 def test_fork_graph_dsep(fork_graph): | |
101 """Example-based test of d-separation for fork_graph.""" | |
102 assert nx.d_separated(fork_graph, {1}, {2}, {0}) | |
103 assert not nx.d_separated(fork_graph, {1}, {2}, {}) | |
104 | |
105 | |
106 def test_collider_graph_dsep(collider_graph): | |
107 """Example-based test of d-separation for collider_graph.""" | |
108 assert nx.d_separated(collider_graph, {0}, {1}, {}) | |
109 assert not nx.d_separated(collider_graph, {0}, {1}, {2}) | |
110 | |
111 | |
112 def test_naive_bayes_dsep(naive_bayes_graph): | |
113 """Example-based test of d-separation for naive_bayes_graph.""" | |
114 for u, v in combinations(range(1, 5), 2): | |
115 assert nx.d_separated(naive_bayes_graph, {u}, {v}, {0}) | |
116 assert not nx.d_separated(naive_bayes_graph, {u}, {v}, {}) | |
117 | |
118 | |
119 def test_asia_graph_dsep(asia_graph): | |
120 """Example-based test of d-separation for asia_graph.""" | |
121 assert nx.d_separated( | |
122 asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"} | |
123 ) | |
124 assert nx.d_separated( | |
125 asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"} | |
126 ) | |
127 | |
128 | |
129 def test_undirected_graphs_are_not_supported(): | |
130 """ | |
131 Test that undirected graphs are not supported. | |
132 | |
133 d-separation does not apply in the case of undirected graphs. | |
134 """ | |
135 with pytest.raises(nx.NetworkXNotImplemented): | |
136 g = nx.path_graph(3, nx.Graph) | |
137 nx.d_separated(g, {0}, {1}, {2}) | |
138 | |
139 | |
140 def test_cyclic_graphs_raise_error(): | |
141 """ | |
142 Test that cycle graphs should cause erroring. | |
143 | |
144 This is because PGMs assume a directed acyclic graph. | |
145 """ | |
146 with pytest.raises(nx.NetworkXError): | |
147 g = nx.cycle_graph(3, nx.DiGraph) | |
148 nx.d_separated(g, {0}, {1}, {2}) | |
149 | |
150 | |
151 def test_invalid_nodes_raise_error(asia_graph): | |
152 """ | |
153 Test that graphs that have invalid nodes passed in raise errors. | |
154 """ | |
155 with pytest.raises(nx.NodeNotFound): | |
156 nx.d_separated(asia_graph, {0}, {1}, {2}) |