Mercurial > repos > shellac > sam_consensus_v3
comparison env/lib/python3.9/site-packages/networkx/readwrite/gexf.py @ 0:4f3585e2f14b draft default tip
"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author | shellac |
---|---|
date | Mon, 22 Mar 2021 18:12:50 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:4f3585e2f14b |
---|---|
1 """Read and write graphs in GEXF format. | |
2 | |
3 GEXF (Graph Exchange XML Format) is a language for describing complex | |
4 network structures, their associated data and dynamics. | |
5 | |
6 This implementation does not support mixed graphs (directed and | |
7 undirected edges together). | |
8 | |
9 Format | |
10 ------ | |
11 GEXF is an XML format. See https://gephi.org/gexf/format/schema.html for the | |
12 specification and https://gephi.org/gexf/format/basic.html for examples. | |
13 """ | |
14 import itertools | |
15 import time | |
16 | |
17 import networkx as nx | |
18 from networkx.utils import open_file | |
19 | |
20 from xml.etree.ElementTree import ( | |
21 Element, | |
22 ElementTree, | |
23 SubElement, | |
24 tostring, | |
25 register_namespace, | |
26 ) | |
27 | |
28 __all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"] | |
29 | |
30 | |
31 @open_file(1, mode="wb") | |
32 def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"): | |
33 """Write G in GEXF format to path. | |
34 | |
35 "GEXF (Graph Exchange XML Format) is a language for describing | |
36 complex networks structures, their associated data and dynamics" [1]_. | |
37 | |
38 Node attributes are checked according to the version of the GEXF | |
39 schemas used for parameters which are not user defined, | |
40 e.g. visualization 'viz' [2]_. See example for usage. | |
41 | |
42 Parameters | |
43 ---------- | |
44 G : graph | |
45 A NetworkX graph | |
46 path : file or string | |
47 File or file name to write. | |
48 File names ending in .gz or .bz2 will be compressed. | |
49 encoding : string (optional, default: 'utf-8') | |
50 Encoding for text data. | |
51 prettyprint : bool (optional, default: True) | |
52 If True use line breaks and indenting in output XML. | |
53 | |
54 Examples | |
55 -------- | |
56 >>> G = nx.path_graph(4) | |
57 >>> nx.write_gexf(G, "test.gexf") | |
58 | |
59 # visualization data | |
60 >>> G.nodes[0]["viz"] = {"size": 54} | |
61 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1} | |
62 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256} | |
63 | |
64 | |
65 Notes | |
66 ----- | |
67 This implementation does not support mixed graphs (directed and undirected | |
68 edges together). | |
69 | |
70 The node id attribute is set to be the string of the node label. | |
71 If you want to specify an id use set it as node data, e.g. | |
72 node['a']['id']=1 to set the id of node 'a' to 1. | |
73 | |
74 References | |
75 ---------- | |
76 .. [1] GEXF File Format, https://gephi.org/gexf/format/ | |
77 .. [2] GEXF viz schema 1.1, https://gephi.org/gexf/1.1draft/viz | |
78 """ | |
79 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version) | |
80 writer.add_graph(G) | |
81 writer.write(path) | |
82 | |
83 | |
84 def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"): | |
85 """Generate lines of GEXF format representation of G. | |
86 | |
87 "GEXF (Graph Exchange XML Format) is a language for describing | |
88 complex networks structures, their associated data and dynamics" [1]_. | |
89 | |
90 Parameters | |
91 ---------- | |
92 G : graph | |
93 A NetworkX graph | |
94 encoding : string (optional, default: 'utf-8') | |
95 Encoding for text data. | |
96 prettyprint : bool (optional, default: True) | |
97 If True use line breaks and indenting in output XML. | |
98 version : string (default: 1.2draft) | |
99 Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html) | |
100 Supported values: "1.1draft", "1.2draft" | |
101 | |
102 | |
103 Examples | |
104 -------- | |
105 >>> G = nx.path_graph(4) | |
106 >>> linefeed = chr(10) # linefeed=\n | |
107 >>> s = linefeed.join(nx.generate_gexf(G)) # doctest: +SKIP | |
108 >>> for line in nx.generate_gexf(G): # doctest: +SKIP | |
109 ... print(line) | |
110 | |
111 Notes | |
112 ----- | |
113 This implementation does not support mixed graphs (directed and undirected | |
114 edges together). | |
115 | |
116 The node id attribute is set to be the string of the node label. | |
117 If you want to specify an id use set it as node data, e.g. | |
118 node['a']['id']=1 to set the id of node 'a' to 1. | |
119 | |
120 References | |
121 ---------- | |
122 .. [1] GEXF File Format, https://gephi.org/gexf/format/ | |
123 """ | |
124 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version) | |
125 writer.add_graph(G) | |
126 yield from str(writer).splitlines() | |
127 | |
128 | |
129 @open_file(0, mode="rb") | |
130 def read_gexf(path, node_type=None, relabel=False, version="1.2draft"): | |
131 """Read graph in GEXF format from path. | |
132 | |
133 "GEXF (Graph Exchange XML Format) is a language for describing | |
134 complex networks structures, their associated data and dynamics" [1]_. | |
135 | |
136 Parameters | |
137 ---------- | |
138 path : file or string | |
139 File or file name to read. | |
140 File names ending in .gz or .bz2 will be decompressed. | |
141 node_type: Python type (default: None) | |
142 Convert node ids to this type if not None. | |
143 relabel : bool (default: False) | |
144 If True relabel the nodes to use the GEXF node "label" attribute | |
145 instead of the node "id" attribute as the NetworkX node label. | |
146 version : string (default: 1.2draft) | |
147 Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html) | |
148 Supported values: "1.1draft", "1.2draft" | |
149 | |
150 Returns | |
151 ------- | |
152 graph: NetworkX graph | |
153 If no parallel edges are found a Graph or DiGraph is returned. | |
154 Otherwise a MultiGraph or MultiDiGraph is returned. | |
155 | |
156 Notes | |
157 ----- | |
158 This implementation does not support mixed graphs (directed and undirected | |
159 edges together). | |
160 | |
161 References | |
162 ---------- | |
163 .. [1] GEXF File Format, https://gephi.org/gexf/format/ | |
164 """ | |
165 reader = GEXFReader(node_type=node_type, version=version) | |
166 if relabel: | |
167 G = relabel_gexf_graph(reader(path)) | |
168 else: | |
169 G = reader(path) | |
170 return G | |
171 | |
172 | |
173 class GEXF: | |
174 versions = {} | |
175 d = { | |
176 "NS_GEXF": "http://www.gexf.net/1.1draft", | |
177 "NS_VIZ": "http://www.gexf.net/1.1draft/viz", | |
178 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance", | |
179 "SCHEMALOCATION": " ".join( | |
180 ["http://www.gexf.net/1.1draft", "http://www.gexf.net/1.1draft/gexf.xsd"] | |
181 ), | |
182 "VERSION": "1.1", | |
183 } | |
184 versions["1.1draft"] = d | |
185 d = { | |
186 "NS_GEXF": "http://www.gexf.net/1.2draft", | |
187 "NS_VIZ": "http://www.gexf.net/1.2draft/viz", | |
188 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance", | |
189 "SCHEMALOCATION": " ".join( | |
190 ["http://www.gexf.net/1.2draft", "http://www.gexf.net/1.2draft/gexf.xsd"] | |
191 ), | |
192 "VERSION": "1.2", | |
193 } | |
194 versions["1.2draft"] = d | |
195 | |
196 types = [ | |
197 (int, "integer"), | |
198 (float, "float"), | |
199 (float, "double"), | |
200 (bool, "boolean"), | |
201 (list, "string"), | |
202 (dict, "string"), | |
203 (int, "long"), | |
204 (str, "liststring"), | |
205 (str, "anyURI"), | |
206 (str, "string"), | |
207 ] | |
208 | |
209 # These additions to types allow writing numpy types | |
210 try: | |
211 import numpy as np | |
212 except ImportError: | |
213 pass | |
214 else: | |
215 # prepend so that python types are created upon read (last entry wins) | |
216 types = [ | |
217 (np.float64, "float"), | |
218 (np.float32, "float"), | |
219 (np.float16, "float"), | |
220 (np.float_, "float"), | |
221 (np.int_, "int"), | |
222 (np.int8, "int"), | |
223 (np.int16, "int"), | |
224 (np.int32, "int"), | |
225 (np.int64, "int"), | |
226 (np.uint8, "int"), | |
227 (np.uint16, "int"), | |
228 (np.uint32, "int"), | |
229 (np.uint64, "int"), | |
230 (np.int_, "int"), | |
231 (np.intc, "int"), | |
232 (np.intp, "int"), | |
233 ] + types | |
234 | |
235 xml_type = dict(types) | |
236 python_type = dict(reversed(a) for a in types) | |
237 | |
238 # http://www.w3.org/TR/xmlschema-2/#boolean | |
239 convert_bool = { | |
240 "true": True, | |
241 "false": False, | |
242 "True": True, | |
243 "False": False, | |
244 "0": False, | |
245 0: False, | |
246 "1": True, | |
247 1: True, | |
248 } | |
249 | |
250 def set_version(self, version): | |
251 d = self.versions.get(version) | |
252 if d is None: | |
253 raise nx.NetworkXError(f"Unknown GEXF version {version}.") | |
254 self.NS_GEXF = d["NS_GEXF"] | |
255 self.NS_VIZ = d["NS_VIZ"] | |
256 self.NS_XSI = d["NS_XSI"] | |
257 self.SCHEMALOCATION = d["SCHEMALOCATION"] | |
258 self.VERSION = d["VERSION"] | |
259 self.version = version | |
260 | |
261 | |
262 class GEXFWriter(GEXF): | |
263 # class for writing GEXF format files | |
264 # use write_gexf() function | |
265 def __init__( | |
266 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft" | |
267 ): | |
268 self.prettyprint = prettyprint | |
269 self.encoding = encoding | |
270 self.set_version(version) | |
271 self.xml = Element( | |
272 "gexf", | |
273 { | |
274 "xmlns": self.NS_GEXF, | |
275 "xmlns:xsi": self.NS_XSI, | |
276 "xsi:schemaLocation": self.SCHEMALOCATION, | |
277 "version": self.VERSION, | |
278 }, | |
279 ) | |
280 | |
281 # Make meta element a non-graph element | |
282 # Also add lastmodifieddate as attribute, not tag | |
283 meta_element = Element("meta") | |
284 subelement_text = f"NetworkX {nx.__version__}" | |
285 SubElement(meta_element, "creator").text = subelement_text | |
286 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d")) | |
287 self.xml.append(meta_element) | |
288 | |
289 register_namespace("viz", self.NS_VIZ) | |
290 | |
291 # counters for edge and attribute identifiers | |
292 self.edge_id = itertools.count() | |
293 self.attr_id = itertools.count() | |
294 self.all_edge_ids = set() | |
295 # default attributes are stored in dictionaries | |
296 self.attr = {} | |
297 self.attr["node"] = {} | |
298 self.attr["edge"] = {} | |
299 self.attr["node"]["dynamic"] = {} | |
300 self.attr["node"]["static"] = {} | |
301 self.attr["edge"]["dynamic"] = {} | |
302 self.attr["edge"]["static"] = {} | |
303 | |
304 if graph is not None: | |
305 self.add_graph(graph) | |
306 | |
307 def __str__(self): | |
308 if self.prettyprint: | |
309 self.indent(self.xml) | |
310 s = tostring(self.xml).decode(self.encoding) | |
311 return s | |
312 | |
313 def add_graph(self, G): | |
314 # first pass through G collecting edge ids | |
315 for u, v, dd in G.edges(data=True): | |
316 eid = dd.get("id") | |
317 if eid is not None: | |
318 self.all_edge_ids.add(str(eid)) | |
319 # set graph attributes | |
320 if G.graph.get("mode") == "dynamic": | |
321 mode = "dynamic" | |
322 else: | |
323 mode = "static" | |
324 # Add a graph element to the XML | |
325 if G.is_directed(): | |
326 default = "directed" | |
327 else: | |
328 default = "undirected" | |
329 name = G.graph.get("name", "") | |
330 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name) | |
331 self.graph_element = graph_element | |
332 self.add_nodes(G, graph_element) | |
333 self.add_edges(G, graph_element) | |
334 self.xml.append(graph_element) | |
335 | |
336 def add_nodes(self, G, graph_element): | |
337 nodes_element = Element("nodes") | |
338 for node, data in G.nodes(data=True): | |
339 node_data = data.copy() | |
340 node_id = str(node_data.pop("id", node)) | |
341 kw = {"id": node_id} | |
342 label = str(node_data.pop("label", node)) | |
343 kw["label"] = label | |
344 try: | |
345 pid = node_data.pop("pid") | |
346 kw["pid"] = str(pid) | |
347 except KeyError: | |
348 pass | |
349 try: | |
350 start = node_data.pop("start") | |
351 kw["start"] = str(start) | |
352 self.alter_graph_mode_timeformat(start) | |
353 except KeyError: | |
354 pass | |
355 try: | |
356 end = node_data.pop("end") | |
357 kw["end"] = str(end) | |
358 self.alter_graph_mode_timeformat(end) | |
359 except KeyError: | |
360 pass | |
361 # add node element with attributes | |
362 node_element = Element("node", **kw) | |
363 # add node element and attr subelements | |
364 default = G.graph.get("node_default", {}) | |
365 node_data = self.add_parents(node_element, node_data) | |
366 if self.VERSION == "1.1": | |
367 node_data = self.add_slices(node_element, node_data) | |
368 else: | |
369 node_data = self.add_spells(node_element, node_data) | |
370 node_data = self.add_viz(node_element, node_data) | |
371 node_data = self.add_attributes("node", node_element, node_data, default) | |
372 nodes_element.append(node_element) | |
373 graph_element.append(nodes_element) | |
374 | |
375 def add_edges(self, G, graph_element): | |
376 def edge_key_data(G): | |
377 # helper function to unify multigraph and graph edge iterator | |
378 if G.is_multigraph(): | |
379 for u, v, key, data in G.edges(data=True, keys=True): | |
380 edge_data = data.copy() | |
381 edge_data.update(key=key) | |
382 edge_id = edge_data.pop("id", None) | |
383 if edge_id is None: | |
384 edge_id = next(self.edge_id) | |
385 while str(edge_id) in self.all_edge_ids: | |
386 edge_id = next(self.edge_id) | |
387 self.all_edge_ids.add(str(edge_id)) | |
388 yield u, v, edge_id, edge_data | |
389 else: | |
390 for u, v, data in G.edges(data=True): | |
391 edge_data = data.copy() | |
392 edge_id = edge_data.pop("id", None) | |
393 if edge_id is None: | |
394 edge_id = next(self.edge_id) | |
395 while str(edge_id) in self.all_edge_ids: | |
396 edge_id = next(self.edge_id) | |
397 self.all_edge_ids.add(str(edge_id)) | |
398 yield u, v, edge_id, edge_data | |
399 | |
400 edges_element = Element("edges") | |
401 for u, v, key, edge_data in edge_key_data(G): | |
402 kw = {"id": str(key)} | |
403 try: | |
404 edge_label = edge_data.pop("label") | |
405 kw["label"] = str(edge_label) | |
406 except KeyError: | |
407 pass | |
408 try: | |
409 edge_weight = edge_data.pop("weight") | |
410 kw["weight"] = str(edge_weight) | |
411 except KeyError: | |
412 pass | |
413 try: | |
414 edge_type = edge_data.pop("type") | |
415 kw["type"] = str(edge_type) | |
416 except KeyError: | |
417 pass | |
418 try: | |
419 start = edge_data.pop("start") | |
420 kw["start"] = str(start) | |
421 self.alter_graph_mode_timeformat(start) | |
422 except KeyError: | |
423 pass | |
424 try: | |
425 end = edge_data.pop("end") | |
426 kw["end"] = str(end) | |
427 self.alter_graph_mode_timeformat(end) | |
428 except KeyError: | |
429 pass | |
430 source_id = str(G.nodes[u].get("id", u)) | |
431 target_id = str(G.nodes[v].get("id", v)) | |
432 edge_element = Element("edge", source=source_id, target=target_id, **kw) | |
433 default = G.graph.get("edge_default", {}) | |
434 if self.VERSION == "1.1": | |
435 edge_data = self.add_slices(edge_element, edge_data) | |
436 else: | |
437 edge_data = self.add_spells(edge_element, edge_data) | |
438 edge_data = self.add_viz(edge_element, edge_data) | |
439 edge_data = self.add_attributes("edge", edge_element, edge_data, default) | |
440 edges_element.append(edge_element) | |
441 graph_element.append(edges_element) | |
442 | |
443 def add_attributes(self, node_or_edge, xml_obj, data, default): | |
444 # Add attrvalues to node or edge | |
445 attvalues = Element("attvalues") | |
446 if len(data) == 0: | |
447 return data | |
448 mode = "static" | |
449 for k, v in data.items(): | |
450 # rename generic multigraph key to avoid any name conflict | |
451 if k == "key": | |
452 k = "networkx_key" | |
453 val_type = type(v) | |
454 if val_type not in self.xml_type: | |
455 raise TypeError(f"attribute value type is not allowed: {val_type}") | |
456 if isinstance(v, list): | |
457 # dynamic data | |
458 for val, start, end in v: | |
459 val_type = type(val) | |
460 if start is not None or end is not None: | |
461 mode = "dynamic" | |
462 self.alter_graph_mode_timeformat(start) | |
463 self.alter_graph_mode_timeformat(end) | |
464 break | |
465 attr_id = self.get_attr_id( | |
466 str(k), self.xml_type[val_type], node_or_edge, default, mode | |
467 ) | |
468 for val, start, end in v: | |
469 e = Element("attvalue") | |
470 e.attrib["for"] = attr_id | |
471 e.attrib["value"] = str(val) | |
472 # Handle nan, inf, -inf differently | |
473 if val_type == float: | |
474 if e.attrib["value"] == "inf": | |
475 e.attrib["value"] = "INF" | |
476 elif e.attrib["value"] == "nan": | |
477 e.attrib["value"] = "NaN" | |
478 elif e.attrib["value"] == "-inf": | |
479 e.attrib["value"] = "-INF" | |
480 if start is not None: | |
481 e.attrib["start"] = str(start) | |
482 if end is not None: | |
483 e.attrib["end"] = str(end) | |
484 attvalues.append(e) | |
485 else: | |
486 # static data | |
487 mode = "static" | |
488 attr_id = self.get_attr_id( | |
489 str(k), self.xml_type[val_type], node_or_edge, default, mode | |
490 ) | |
491 e = Element("attvalue") | |
492 e.attrib["for"] = attr_id | |
493 if isinstance(v, bool): | |
494 e.attrib["value"] = str(v).lower() | |
495 else: | |
496 e.attrib["value"] = str(v) | |
497 # Handle float nan, inf, -inf differently | |
498 if val_type == float: | |
499 if e.attrib["value"] == "inf": | |
500 e.attrib["value"] = "INF" | |
501 elif e.attrib["value"] == "nan": | |
502 e.attrib["value"] = "NaN" | |
503 elif e.attrib["value"] == "-inf": | |
504 e.attrib["value"] = "-INF" | |
505 attvalues.append(e) | |
506 xml_obj.append(attvalues) | |
507 return data | |
508 | |
509 def get_attr_id(self, title, attr_type, edge_or_node, default, mode): | |
510 # find the id of the attribute or generate a new id | |
511 try: | |
512 return self.attr[edge_or_node][mode][title] | |
513 except KeyError: | |
514 # generate new id | |
515 new_id = str(next(self.attr_id)) | |
516 self.attr[edge_or_node][mode][title] = new_id | |
517 attr_kwargs = {"id": new_id, "title": title, "type": attr_type} | |
518 attribute = Element("attribute", **attr_kwargs) | |
519 # add subelement for data default value if present | |
520 default_title = default.get(title) | |
521 if default_title is not None: | |
522 default_element = Element("default") | |
523 default_element.text = str(default_title) | |
524 attribute.append(default_element) | |
525 # new insert it into the XML | |
526 attributes_element = None | |
527 for a in self.graph_element.findall("attributes"): | |
528 # find existing attributes element by class and mode | |
529 a_class = a.get("class") | |
530 a_mode = a.get("mode", "static") | |
531 if a_class == edge_or_node and a_mode == mode: | |
532 attributes_element = a | |
533 if attributes_element is None: | |
534 # create new attributes element | |
535 attr_kwargs = {"mode": mode, "class": edge_or_node} | |
536 attributes_element = Element("attributes", **attr_kwargs) | |
537 self.graph_element.insert(0, attributes_element) | |
538 attributes_element.append(attribute) | |
539 return new_id | |
540 | |
541 def add_viz(self, element, node_data): | |
542 viz = node_data.pop("viz", False) | |
543 if viz: | |
544 color = viz.get("color") | |
545 if color is not None: | |
546 if self.VERSION == "1.1": | |
547 e = Element( | |
548 f"{{{self.NS_VIZ}}}color", | |
549 r=str(color.get("r")), | |
550 g=str(color.get("g")), | |
551 b=str(color.get("b")), | |
552 ) | |
553 else: | |
554 e = Element( | |
555 f"{{{self.NS_VIZ}}}color", | |
556 r=str(color.get("r")), | |
557 g=str(color.get("g")), | |
558 b=str(color.get("b")), | |
559 a=str(color.get("a")), | |
560 ) | |
561 element.append(e) | |
562 | |
563 size = viz.get("size") | |
564 if size is not None: | |
565 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size)) | |
566 element.append(e) | |
567 | |
568 thickness = viz.get("thickness") | |
569 if thickness is not None: | |
570 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness)) | |
571 element.append(e) | |
572 | |
573 shape = viz.get("shape") | |
574 if shape is not None: | |
575 if shape.startswith("http"): | |
576 e = Element( | |
577 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape) | |
578 ) | |
579 else: | |
580 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape)) | |
581 element.append(e) | |
582 | |
583 position = viz.get("position") | |
584 if position is not None: | |
585 e = Element( | |
586 f"{{{self.NS_VIZ}}}position", | |
587 x=str(position.get("x")), | |
588 y=str(position.get("y")), | |
589 z=str(position.get("z")), | |
590 ) | |
591 element.append(e) | |
592 return node_data | |
593 | |
594 def add_parents(self, node_element, node_data): | |
595 parents = node_data.pop("parents", False) | |
596 if parents: | |
597 parents_element = Element("parents") | |
598 for p in parents: | |
599 e = Element("parent") | |
600 e.attrib["for"] = str(p) | |
601 parents_element.append(e) | |
602 node_element.append(parents_element) | |
603 return node_data | |
604 | |
605 def add_slices(self, node_or_edge_element, node_or_edge_data): | |
606 slices = node_or_edge_data.pop("slices", False) | |
607 if slices: | |
608 slices_element = Element("slices") | |
609 for start, end in slices: | |
610 e = Element("slice", start=str(start), end=str(end)) | |
611 slices_element.append(e) | |
612 node_or_edge_element.append(slices_element) | |
613 return node_or_edge_data | |
614 | |
615 def add_spells(self, node_or_edge_element, node_or_edge_data): | |
616 spells = node_or_edge_data.pop("spells", False) | |
617 if spells: | |
618 spells_element = Element("spells") | |
619 for start, end in spells: | |
620 e = Element("spell") | |
621 if start is not None: | |
622 e.attrib["start"] = str(start) | |
623 self.alter_graph_mode_timeformat(start) | |
624 if end is not None: | |
625 e.attrib["end"] = str(end) | |
626 self.alter_graph_mode_timeformat(end) | |
627 spells_element.append(e) | |
628 node_or_edge_element.append(spells_element) | |
629 return node_or_edge_data | |
630 | |
631 def alter_graph_mode_timeformat(self, start_or_end): | |
632 # If 'start' or 'end' appears, alter Graph mode to dynamic and | |
633 # set timeformat | |
634 if self.graph_element.get("mode") == "static": | |
635 if start_or_end is not None: | |
636 if isinstance(start_or_end, str): | |
637 timeformat = "date" | |
638 elif isinstance(start_or_end, float): | |
639 timeformat = "double" | |
640 elif isinstance(start_or_end, int): | |
641 timeformat = "long" | |
642 else: | |
643 raise nx.NetworkXError( | |
644 "timeformat should be of the type int, float or str" | |
645 ) | |
646 self.graph_element.set("timeformat", timeformat) | |
647 self.graph_element.set("mode", "dynamic") | |
648 | |
649 def write(self, fh): | |
650 # Serialize graph G in GEXF to the open fh | |
651 if self.prettyprint: | |
652 self.indent(self.xml) | |
653 document = ElementTree(self.xml) | |
654 document.write(fh, encoding=self.encoding, xml_declaration=True) | |
655 | |
656 def indent(self, elem, level=0): | |
657 # in-place prettyprint formatter | |
658 i = "\n" + " " * level | |
659 if len(elem): | |
660 if not elem.text or not elem.text.strip(): | |
661 elem.text = i + " " | |
662 if not elem.tail or not elem.tail.strip(): | |
663 elem.tail = i | |
664 for elem in elem: | |
665 self.indent(elem, level + 1) | |
666 if not elem.tail or not elem.tail.strip(): | |
667 elem.tail = i | |
668 else: | |
669 if level and (not elem.tail or not elem.tail.strip()): | |
670 elem.tail = i | |
671 | |
672 | |
673 class GEXFReader(GEXF): | |
674 # Class to read GEXF format files | |
675 # use read_gexf() function | |
676 def __init__(self, node_type=None, version="1.2draft"): | |
677 self.node_type = node_type | |
678 # assume simple graph and test for multigraph on read | |
679 self.simple_graph = True | |
680 self.set_version(version) | |
681 | |
682 def __call__(self, stream): | |
683 self.xml = ElementTree(file=stream) | |
684 g = self.xml.find(f"{{{self.NS_GEXF}}}graph") | |
685 if g is not None: | |
686 return self.make_graph(g) | |
687 # try all the versions | |
688 for version in self.versions: | |
689 self.set_version(version) | |
690 g = self.xml.find(f"{{{self.NS_GEXF}}}graph") | |
691 if g is not None: | |
692 return self.make_graph(g) | |
693 raise nx.NetworkXError("No <graph> element in GEXF file.") | |
694 | |
695 def make_graph(self, graph_xml): | |
696 # start with empty DiGraph or MultiDiGraph | |
697 edgedefault = graph_xml.get("defaultedgetype", None) | |
698 if edgedefault == "directed": | |
699 G = nx.MultiDiGraph() | |
700 else: | |
701 G = nx.MultiGraph() | |
702 | |
703 # graph attributes | |
704 graph_name = graph_xml.get("name", "") | |
705 if graph_name != "": | |
706 G.graph["name"] = graph_name | |
707 graph_start = graph_xml.get("start") | |
708 if graph_start is not None: | |
709 G.graph["start"] = graph_start | |
710 graph_end = graph_xml.get("end") | |
711 if graph_end is not None: | |
712 G.graph["end"] = graph_end | |
713 graph_mode = graph_xml.get("mode", "") | |
714 if graph_mode == "dynamic": | |
715 G.graph["mode"] = "dynamic" | |
716 else: | |
717 G.graph["mode"] = "static" | |
718 | |
719 # timeformat | |
720 self.timeformat = graph_xml.get("timeformat") | |
721 if self.timeformat == "date": | |
722 self.timeformat = "string" | |
723 | |
724 # node and edge attributes | |
725 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes") | |
726 # dictionaries to hold attributes and attribute defaults | |
727 node_attr = {} | |
728 node_default = {} | |
729 edge_attr = {} | |
730 edge_default = {} | |
731 for a in attributes_elements: | |
732 attr_class = a.get("class") | |
733 if attr_class == "node": | |
734 na, nd = self.find_gexf_attributes(a) | |
735 node_attr.update(na) | |
736 node_default.update(nd) | |
737 G.graph["node_default"] = node_default | |
738 elif attr_class == "edge": | |
739 ea, ed = self.find_gexf_attributes(a) | |
740 edge_attr.update(ea) | |
741 edge_default.update(ed) | |
742 G.graph["edge_default"] = edge_default | |
743 else: | |
744 raise # unknown attribute class | |
745 | |
746 # Hack to handle Gephi0.7beta bug | |
747 # add weight attribute | |
748 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}} | |
749 ed = {} | |
750 edge_attr.update(ea) | |
751 edge_default.update(ed) | |
752 G.graph["edge_default"] = edge_default | |
753 | |
754 # add nodes | |
755 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes") | |
756 if nodes_element is not None: | |
757 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"): | |
758 self.add_node(G, node_xml, node_attr) | |
759 | |
760 # add edges | |
761 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges") | |
762 if edges_element is not None: | |
763 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"): | |
764 self.add_edge(G, edge_xml, edge_attr) | |
765 | |
766 # switch to Graph or DiGraph if no parallel edges were found. | |
767 if self.simple_graph: | |
768 if G.is_directed(): | |
769 G = nx.DiGraph(G) | |
770 else: | |
771 G = nx.Graph(G) | |
772 return G | |
773 | |
774 def add_node(self, G, node_xml, node_attr, node_pid=None): | |
775 # add a single node with attributes to the graph | |
776 | |
777 # get attributes and subattributues for node | |
778 data = self.decode_attr_elements(node_attr, node_xml) | |
779 data = self.add_parents(data, node_xml) # add any parents | |
780 if self.VERSION == "1.1": | |
781 data = self.add_slices(data, node_xml) # add slices | |
782 else: | |
783 data = self.add_spells(data, node_xml) # add spells | |
784 data = self.add_viz(data, node_xml) # add viz | |
785 data = self.add_start_end(data, node_xml) # add start/end | |
786 | |
787 # find the node id and cast it to the appropriate type | |
788 node_id = node_xml.get("id") | |
789 if self.node_type is not None: | |
790 node_id = self.node_type(node_id) | |
791 | |
792 # every node should have a label | |
793 node_label = node_xml.get("label") | |
794 data["label"] = node_label | |
795 | |
796 # parent node id | |
797 node_pid = node_xml.get("pid", node_pid) | |
798 if node_pid is not None: | |
799 data["pid"] = node_pid | |
800 | |
801 # check for subnodes, recursive | |
802 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes") | |
803 if subnodes is not None: | |
804 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"): | |
805 self.add_node(G, node_xml, node_attr, node_pid=node_id) | |
806 | |
807 G.add_node(node_id, **data) | |
808 | |
809 def add_start_end(self, data, xml): | |
810 # start and end times | |
811 ttype = self.timeformat | |
812 node_start = xml.get("start") | |
813 if node_start is not None: | |
814 data["start"] = self.python_type[ttype](node_start) | |
815 node_end = xml.get("end") | |
816 if node_end is not None: | |
817 data["end"] = self.python_type[ttype](node_end) | |
818 return data | |
819 | |
820 def add_viz(self, data, node_xml): | |
821 # add viz element for node | |
822 viz = {} | |
823 color = node_xml.find(f"{{{self.NS_VIZ}}}color") | |
824 if color is not None: | |
825 if self.VERSION == "1.1": | |
826 viz["color"] = { | |
827 "r": int(color.get("r")), | |
828 "g": int(color.get("g")), | |
829 "b": int(color.get("b")), | |
830 } | |
831 else: | |
832 viz["color"] = { | |
833 "r": int(color.get("r")), | |
834 "g": int(color.get("g")), | |
835 "b": int(color.get("b")), | |
836 "a": float(color.get("a", 1)), | |
837 } | |
838 | |
839 size = node_xml.find(f"{{{self.NS_VIZ}}}size") | |
840 if size is not None: | |
841 viz["size"] = float(size.get("value")) | |
842 | |
843 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness") | |
844 if thickness is not None: | |
845 viz["thickness"] = float(thickness.get("value")) | |
846 | |
847 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape") | |
848 if shape is not None: | |
849 viz["shape"] = shape.get("shape") | |
850 if viz["shape"] == "image": | |
851 viz["shape"] = shape.get("uri") | |
852 | |
853 position = node_xml.find(f"{{{self.NS_VIZ}}}position") | |
854 if position is not None: | |
855 viz["position"] = { | |
856 "x": float(position.get("x", 0)), | |
857 "y": float(position.get("y", 0)), | |
858 "z": float(position.get("z", 0)), | |
859 } | |
860 | |
861 if len(viz) > 0: | |
862 data["viz"] = viz | |
863 return data | |
864 | |
865 def add_parents(self, data, node_xml): | |
866 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents") | |
867 if parents_element is not None: | |
868 data["parents"] = [] | |
869 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"): | |
870 parent = p.get("for") | |
871 data["parents"].append(parent) | |
872 return data | |
873 | |
874 def add_slices(self, data, node_or_edge_xml): | |
875 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices") | |
876 if slices_element is not None: | |
877 data["slices"] = [] | |
878 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"): | |
879 start = s.get("start") | |
880 end = s.get("end") | |
881 data["slices"].append((start, end)) | |
882 return data | |
883 | |
884 def add_spells(self, data, node_or_edge_xml): | |
885 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells") | |
886 if spells_element is not None: | |
887 data["spells"] = [] | |
888 ttype = self.timeformat | |
889 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"): | |
890 start = self.python_type[ttype](s.get("start")) | |
891 end = self.python_type[ttype](s.get("end")) | |
892 data["spells"].append((start, end)) | |
893 return data | |
894 | |
895 def add_edge(self, G, edge_element, edge_attr): | |
896 # add an edge to the graph | |
897 | |
898 # raise error if we find mixed directed and undirected edges | |
899 edge_direction = edge_element.get("type") | |
900 if G.is_directed() and edge_direction == "undirected": | |
901 raise nx.NetworkXError("Undirected edge found in directed graph.") | |
902 if (not G.is_directed()) and edge_direction == "directed": | |
903 raise nx.NetworkXError("Directed edge found in undirected graph.") | |
904 | |
905 # Get source and target and recast type if required | |
906 source = edge_element.get("source") | |
907 target = edge_element.get("target") | |
908 if self.node_type is not None: | |
909 source = self.node_type(source) | |
910 target = self.node_type(target) | |
911 | |
912 data = self.decode_attr_elements(edge_attr, edge_element) | |
913 data = self.add_start_end(data, edge_element) | |
914 | |
915 if self.VERSION == "1.1": | |
916 data = self.add_slices(data, edge_element) # add slices | |
917 else: | |
918 data = self.add_spells(data, edge_element) # add spells | |
919 | |
920 # GEXF stores edge ids as an attribute | |
921 # NetworkX uses them as keys in multigraphs | |
922 # if networkx_key is not specified as an attribute | |
923 edge_id = edge_element.get("id") | |
924 if edge_id is not None: | |
925 data["id"] = edge_id | |
926 | |
927 # check if there is a 'multigraph_key' and use that as edge_id | |
928 multigraph_key = data.pop("networkx_key", None) | |
929 if multigraph_key is not None: | |
930 edge_id = multigraph_key | |
931 | |
932 weight = edge_element.get("weight") | |
933 if weight is not None: | |
934 data["weight"] = float(weight) | |
935 | |
936 edge_label = edge_element.get("label") | |
937 if edge_label is not None: | |
938 data["label"] = edge_label | |
939 | |
940 if G.has_edge(source, target): | |
941 # seen this edge before - this is a multigraph | |
942 self.simple_graph = False | |
943 G.add_edge(source, target, key=edge_id, **data) | |
944 if edge_direction == "mutual": | |
945 G.add_edge(target, source, key=edge_id, **data) | |
946 | |
947 def decode_attr_elements(self, gexf_keys, obj_xml): | |
948 # Use the key information to decode the attr XML | |
949 attr = {} | |
950 # look for outer '<attvalues>' element | |
951 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues") | |
952 if attr_element is not None: | |
953 # loop over <attvalue> elements | |
954 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"): | |
955 key = a.get("for") # for is required | |
956 try: # should be in our gexf_keys dictionary | |
957 title = gexf_keys[key]["title"] | |
958 except KeyError as e: | |
959 raise nx.NetworkXError(f"No attribute defined for={key}.") from e | |
960 atype = gexf_keys[key]["type"] | |
961 value = a.get("value") | |
962 if atype == "boolean": | |
963 value = self.convert_bool[value] | |
964 else: | |
965 value = self.python_type[atype](value) | |
966 if gexf_keys[key]["mode"] == "dynamic": | |
967 # for dynamic graphs use list of three-tuples | |
968 # [(value1,start1,end1), (value2,start2,end2), etc] | |
969 ttype = self.timeformat | |
970 start = self.python_type[ttype](a.get("start")) | |
971 end = self.python_type[ttype](a.get("end")) | |
972 if title in attr: | |
973 attr[title].append((value, start, end)) | |
974 else: | |
975 attr[title] = [(value, start, end)] | |
976 else: | |
977 # for static graphs just assign the value | |
978 attr[title] = value | |
979 return attr | |
980 | |
981 def find_gexf_attributes(self, attributes_element): | |
982 # Extract all the attributes and defaults | |
983 attrs = {} | |
984 defaults = {} | |
985 mode = attributes_element.get("mode") | |
986 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"): | |
987 attr_id = k.get("id") | |
988 title = k.get("title") | |
989 atype = k.get("type") | |
990 attrs[attr_id] = {"title": title, "type": atype, "mode": mode} | |
991 # check for the 'default' subelement of key element and add | |
992 default = k.find(f"{{{self.NS_GEXF}}}default") | |
993 if default is not None: | |
994 if atype == "boolean": | |
995 value = self.convert_bool[default.text] | |
996 else: | |
997 value = self.python_type[atype](default.text) | |
998 defaults[title] = value | |
999 return attrs, defaults | |
1000 | |
1001 | |
1002 def relabel_gexf_graph(G): | |
1003 """Relabel graph using "label" node keyword for node label. | |
1004 | |
1005 Parameters | |
1006 ---------- | |
1007 G : graph | |
1008 A NetworkX graph read from GEXF data | |
1009 | |
1010 Returns | |
1011 ------- | |
1012 H : graph | |
1013 A NetworkX graph with relabed nodes | |
1014 | |
1015 Raises | |
1016 ------ | |
1017 NetworkXError | |
1018 If node labels are missing or not unique while relabel=True. | |
1019 | |
1020 Notes | |
1021 ----- | |
1022 This function relabels the nodes in a NetworkX graph with the | |
1023 "label" attribute. It also handles relabeling the specific GEXF | |
1024 node attributes "parents", and "pid". | |
1025 """ | |
1026 # build mapping of node labels, do some error checking | |
1027 try: | |
1028 mapping = [(u, G.nodes[u]["label"]) for u in G] | |
1029 except KeyError as e: | |
1030 raise nx.NetworkXError( | |
1031 "Failed to relabel nodes: missing node labels found. Use relabel=False." | |
1032 ) from e | |
1033 x, y = zip(*mapping) | |
1034 if len(set(y)) != len(G): | |
1035 raise nx.NetworkXError( | |
1036 "Failed to relabel nodes: " | |
1037 "duplicate node labels found. " | |
1038 "Use relabel=False." | |
1039 ) | |
1040 mapping = dict(mapping) | |
1041 H = nx.relabel_nodes(G, mapping) | |
1042 # relabel attributes | |
1043 for n in G: | |
1044 m = mapping[n] | |
1045 H.nodes[m]["id"] = n | |
1046 H.nodes[m].pop("label") | |
1047 if "pid" in H.nodes[m]: | |
1048 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]] | |
1049 if "parents" in H.nodes[m]: | |
1050 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]] | |
1051 return H |