comparison env/lib/python3.9/site-packages/networkx/algorithms/tests/test_d_separation.py @ 0:4f3585e2f14b draft default tip

"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author shellac
date Mon, 22 Mar 2021 18:12:50 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:4f3585e2f14b
1 from itertools import combinations
2 import pytest
3 import networkx as nx
4
5
6 def path_graph():
7 """Return a path graph of length three."""
8 G = nx.path_graph(3, create_using=nx.DiGraph)
9 G.graph["name"] = "path"
10 nx.freeze(G)
11 return G
12
13
14 def fork_graph():
15 """Return a three node fork graph."""
16 G = nx.DiGraph(name="fork")
17 G.add_edges_from([(0, 1), (0, 2)])
18 nx.freeze(G)
19 return G
20
21
22 def collider_graph():
23 """Return a collider/v-structure graph with three nodes."""
24 G = nx.DiGraph(name="collider")
25 G.add_edges_from([(0, 2), (1, 2)])
26 nx.freeze(G)
27 return G
28
29
30 def naive_bayes_graph():
31 """Return a simply Naive Bayes PGM graph."""
32 G = nx.DiGraph(name="naive_bayes")
33 G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
34 nx.freeze(G)
35 return G
36
37
38 def asia_graph():
39 """Return the 'Asia' PGM graph."""
40 G = nx.DiGraph(name="asia")
41 G.add_edges_from(
42 [
43 ("asia", "tuberculosis"),
44 ("smoking", "cancer"),
45 ("smoking", "bronchitis"),
46 ("tuberculosis", "either"),
47 ("cancer", "either"),
48 ("either", "xray"),
49 ("either", "dyspnea"),
50 ("bronchitis", "dyspnea"),
51 ]
52 )
53 nx.freeze(G)
54 return G
55
56
57 @pytest.fixture(name="path_graph")
58 def path_graph_fixture():
59 return path_graph()
60
61
62 @pytest.fixture(name="fork_graph")
63 def fork_graph_fixture():
64 return fork_graph()
65
66
67 @pytest.fixture(name="collider_graph")
68 def collider_graph_fixture():
69 return collider_graph()
70
71
72 @pytest.fixture(name="naive_bayes_graph")
73 def naive_bayes_graph_fixture():
74 return naive_bayes_graph()
75
76
77 @pytest.fixture(name="asia_graph")
78 def asia_graph_fixture():
79 return asia_graph()
80
81
82 @pytest.mark.parametrize(
83 "graph",
84 [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()],
85 )
86 def test_markov_condition(graph):
87 """Test that the Markov condition holds for each PGM graph."""
88 for node in graph.nodes:
89 parents = set(graph.predecessors(node))
90 non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents
91 assert nx.d_separated(graph, {node}, non_descendants, parents)
92
93
94 def test_path_graph_dsep(path_graph):
95 """Example-based test of d-separation for path_graph."""
96 assert nx.d_separated(path_graph, {0}, {2}, {1})
97 assert not nx.d_separated(path_graph, {0}, {2}, {})
98
99
100 def test_fork_graph_dsep(fork_graph):
101 """Example-based test of d-separation for fork_graph."""
102 assert nx.d_separated(fork_graph, {1}, {2}, {0})
103 assert not nx.d_separated(fork_graph, {1}, {2}, {})
104
105
106 def test_collider_graph_dsep(collider_graph):
107 """Example-based test of d-separation for collider_graph."""
108 assert nx.d_separated(collider_graph, {0}, {1}, {})
109 assert not nx.d_separated(collider_graph, {0}, {1}, {2})
110
111
112 def test_naive_bayes_dsep(naive_bayes_graph):
113 """Example-based test of d-separation for naive_bayes_graph."""
114 for u, v in combinations(range(1, 5), 2):
115 assert nx.d_separated(naive_bayes_graph, {u}, {v}, {0})
116 assert not nx.d_separated(naive_bayes_graph, {u}, {v}, {})
117
118
119 def test_asia_graph_dsep(asia_graph):
120 """Example-based test of d-separation for asia_graph."""
121 assert nx.d_separated(
122 asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"}
123 )
124 assert nx.d_separated(
125 asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"}
126 )
127
128
129 def test_undirected_graphs_are_not_supported():
130 """
131 Test that undirected graphs are not supported.
132
133 d-separation does not apply in the case of undirected graphs.
134 """
135 with pytest.raises(nx.NetworkXNotImplemented):
136 g = nx.path_graph(3, nx.Graph)
137 nx.d_separated(g, {0}, {1}, {2})
138
139
140 def test_cyclic_graphs_raise_error():
141 """
142 Test that cycle graphs should cause erroring.
143
144 This is because PGMs assume a directed acyclic graph.
145 """
146 with pytest.raises(nx.NetworkXError):
147 g = nx.cycle_graph(3, nx.DiGraph)
148 nx.d_separated(g, {0}, {1}, {2})
149
150
151 def test_invalid_nodes_raise_error(asia_graph):
152 """
153 Test that graphs that have invalid nodes passed in raise errors.
154 """
155 with pytest.raises(nx.NodeNotFound):
156 nx.d_separated(asia_graph, {0}, {1}, {2})