comparison env/lib/python3.9/site-packages/networkx/readwrite/gexf.py @ 0:4f3585e2f14b draft default tip

"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author shellac
date Mon, 22 Mar 2021 18:12:50 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:4f3585e2f14b
1 """Read and write graphs in GEXF format.
2
3 GEXF (Graph Exchange XML Format) is a language for describing complex
4 network structures, their associated data and dynamics.
5
6 This implementation does not support mixed graphs (directed and
7 undirected edges together).
8
9 Format
10 ------
11 GEXF is an XML format. See https://gephi.org/gexf/format/schema.html for the
12 specification and https://gephi.org/gexf/format/basic.html for examples.
13 """
14 import itertools
15 import time
16
17 import networkx as nx
18 from networkx.utils import open_file
19
20 from xml.etree.ElementTree import (
21 Element,
22 ElementTree,
23 SubElement,
24 tostring,
25 register_namespace,
26 )
27
28 __all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"]
29
30
31 @open_file(1, mode="wb")
32 def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"):
33 """Write G in GEXF format to path.
34
35 "GEXF (Graph Exchange XML Format) is a language for describing
36 complex networks structures, their associated data and dynamics" [1]_.
37
38 Node attributes are checked according to the version of the GEXF
39 schemas used for parameters which are not user defined,
40 e.g. visualization 'viz' [2]_. See example for usage.
41
42 Parameters
43 ----------
44 G : graph
45 A NetworkX graph
46 path : file or string
47 File or file name to write.
48 File names ending in .gz or .bz2 will be compressed.
49 encoding : string (optional, default: 'utf-8')
50 Encoding for text data.
51 prettyprint : bool (optional, default: True)
52 If True use line breaks and indenting in output XML.
53
54 Examples
55 --------
56 >>> G = nx.path_graph(4)
57 >>> nx.write_gexf(G, "test.gexf")
58
59 # visualization data
60 >>> G.nodes[0]["viz"] = {"size": 54}
61 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1}
62 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256}
63
64
65 Notes
66 -----
67 This implementation does not support mixed graphs (directed and undirected
68 edges together).
69
70 The node id attribute is set to be the string of the node label.
71 If you want to specify an id use set it as node data, e.g.
72 node['a']['id']=1 to set the id of node 'a' to 1.
73
74 References
75 ----------
76 .. [1] GEXF File Format, https://gephi.org/gexf/format/
77 .. [2] GEXF viz schema 1.1, https://gephi.org/gexf/1.1draft/viz
78 """
79 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
80 writer.add_graph(G)
81 writer.write(path)
82
83
84 def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"):
85 """Generate lines of GEXF format representation of G.
86
87 "GEXF (Graph Exchange XML Format) is a language for describing
88 complex networks structures, their associated data and dynamics" [1]_.
89
90 Parameters
91 ----------
92 G : graph
93 A NetworkX graph
94 encoding : string (optional, default: 'utf-8')
95 Encoding for text data.
96 prettyprint : bool (optional, default: True)
97 If True use line breaks and indenting in output XML.
98 version : string (default: 1.2draft)
99 Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html)
100 Supported values: "1.1draft", "1.2draft"
101
102
103 Examples
104 --------
105 >>> G = nx.path_graph(4)
106 >>> linefeed = chr(10) # linefeed=\n
107 >>> s = linefeed.join(nx.generate_gexf(G)) # doctest: +SKIP
108 >>> for line in nx.generate_gexf(G): # doctest: +SKIP
109 ... print(line)
110
111 Notes
112 -----
113 This implementation does not support mixed graphs (directed and undirected
114 edges together).
115
116 The node id attribute is set to be the string of the node label.
117 If you want to specify an id use set it as node data, e.g.
118 node['a']['id']=1 to set the id of node 'a' to 1.
119
120 References
121 ----------
122 .. [1] GEXF File Format, https://gephi.org/gexf/format/
123 """
124 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
125 writer.add_graph(G)
126 yield from str(writer).splitlines()
127
128
129 @open_file(0, mode="rb")
130 def read_gexf(path, node_type=None, relabel=False, version="1.2draft"):
131 """Read graph in GEXF format from path.
132
133 "GEXF (Graph Exchange XML Format) is a language for describing
134 complex networks structures, their associated data and dynamics" [1]_.
135
136 Parameters
137 ----------
138 path : file or string
139 File or file name to read.
140 File names ending in .gz or .bz2 will be decompressed.
141 node_type: Python type (default: None)
142 Convert node ids to this type if not None.
143 relabel : bool (default: False)
144 If True relabel the nodes to use the GEXF node "label" attribute
145 instead of the node "id" attribute as the NetworkX node label.
146 version : string (default: 1.2draft)
147 Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html)
148 Supported values: "1.1draft", "1.2draft"
149
150 Returns
151 -------
152 graph: NetworkX graph
153 If no parallel edges are found a Graph or DiGraph is returned.
154 Otherwise a MultiGraph or MultiDiGraph is returned.
155
156 Notes
157 -----
158 This implementation does not support mixed graphs (directed and undirected
159 edges together).
160
161 References
162 ----------
163 .. [1] GEXF File Format, https://gephi.org/gexf/format/
164 """
165 reader = GEXFReader(node_type=node_type, version=version)
166 if relabel:
167 G = relabel_gexf_graph(reader(path))
168 else:
169 G = reader(path)
170 return G
171
172
173 class GEXF:
174 versions = {}
175 d = {
176 "NS_GEXF": "http://www.gexf.net/1.1draft",
177 "NS_VIZ": "http://www.gexf.net/1.1draft/viz",
178 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
179 "SCHEMALOCATION": " ".join(
180 ["http://www.gexf.net/1.1draft", "http://www.gexf.net/1.1draft/gexf.xsd"]
181 ),
182 "VERSION": "1.1",
183 }
184 versions["1.1draft"] = d
185 d = {
186 "NS_GEXF": "http://www.gexf.net/1.2draft",
187 "NS_VIZ": "http://www.gexf.net/1.2draft/viz",
188 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
189 "SCHEMALOCATION": " ".join(
190 ["http://www.gexf.net/1.2draft", "http://www.gexf.net/1.2draft/gexf.xsd"]
191 ),
192 "VERSION": "1.2",
193 }
194 versions["1.2draft"] = d
195
196 types = [
197 (int, "integer"),
198 (float, "float"),
199 (float, "double"),
200 (bool, "boolean"),
201 (list, "string"),
202 (dict, "string"),
203 (int, "long"),
204 (str, "liststring"),
205 (str, "anyURI"),
206 (str, "string"),
207 ]
208
209 # These additions to types allow writing numpy types
210 try:
211 import numpy as np
212 except ImportError:
213 pass
214 else:
215 # prepend so that python types are created upon read (last entry wins)
216 types = [
217 (np.float64, "float"),
218 (np.float32, "float"),
219 (np.float16, "float"),
220 (np.float_, "float"),
221 (np.int_, "int"),
222 (np.int8, "int"),
223 (np.int16, "int"),
224 (np.int32, "int"),
225 (np.int64, "int"),
226 (np.uint8, "int"),
227 (np.uint16, "int"),
228 (np.uint32, "int"),
229 (np.uint64, "int"),
230 (np.int_, "int"),
231 (np.intc, "int"),
232 (np.intp, "int"),
233 ] + types
234
235 xml_type = dict(types)
236 python_type = dict(reversed(a) for a in types)
237
238 # http://www.w3.org/TR/xmlschema-2/#boolean
239 convert_bool = {
240 "true": True,
241 "false": False,
242 "True": True,
243 "False": False,
244 "0": False,
245 0: False,
246 "1": True,
247 1: True,
248 }
249
250 def set_version(self, version):
251 d = self.versions.get(version)
252 if d is None:
253 raise nx.NetworkXError(f"Unknown GEXF version {version}.")
254 self.NS_GEXF = d["NS_GEXF"]
255 self.NS_VIZ = d["NS_VIZ"]
256 self.NS_XSI = d["NS_XSI"]
257 self.SCHEMALOCATION = d["SCHEMALOCATION"]
258 self.VERSION = d["VERSION"]
259 self.version = version
260
261
262 class GEXFWriter(GEXF):
263 # class for writing GEXF format files
264 # use write_gexf() function
265 def __init__(
266 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft"
267 ):
268 self.prettyprint = prettyprint
269 self.encoding = encoding
270 self.set_version(version)
271 self.xml = Element(
272 "gexf",
273 {
274 "xmlns": self.NS_GEXF,
275 "xmlns:xsi": self.NS_XSI,
276 "xsi:schemaLocation": self.SCHEMALOCATION,
277 "version": self.VERSION,
278 },
279 )
280
281 # Make meta element a non-graph element
282 # Also add lastmodifieddate as attribute, not tag
283 meta_element = Element("meta")
284 subelement_text = f"NetworkX {nx.__version__}"
285 SubElement(meta_element, "creator").text = subelement_text
286 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d"))
287 self.xml.append(meta_element)
288
289 register_namespace("viz", self.NS_VIZ)
290
291 # counters for edge and attribute identifiers
292 self.edge_id = itertools.count()
293 self.attr_id = itertools.count()
294 self.all_edge_ids = set()
295 # default attributes are stored in dictionaries
296 self.attr = {}
297 self.attr["node"] = {}
298 self.attr["edge"] = {}
299 self.attr["node"]["dynamic"] = {}
300 self.attr["node"]["static"] = {}
301 self.attr["edge"]["dynamic"] = {}
302 self.attr["edge"]["static"] = {}
303
304 if graph is not None:
305 self.add_graph(graph)
306
307 def __str__(self):
308 if self.prettyprint:
309 self.indent(self.xml)
310 s = tostring(self.xml).decode(self.encoding)
311 return s
312
313 def add_graph(self, G):
314 # first pass through G collecting edge ids
315 for u, v, dd in G.edges(data=True):
316 eid = dd.get("id")
317 if eid is not None:
318 self.all_edge_ids.add(str(eid))
319 # set graph attributes
320 if G.graph.get("mode") == "dynamic":
321 mode = "dynamic"
322 else:
323 mode = "static"
324 # Add a graph element to the XML
325 if G.is_directed():
326 default = "directed"
327 else:
328 default = "undirected"
329 name = G.graph.get("name", "")
330 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name)
331 self.graph_element = graph_element
332 self.add_nodes(G, graph_element)
333 self.add_edges(G, graph_element)
334 self.xml.append(graph_element)
335
336 def add_nodes(self, G, graph_element):
337 nodes_element = Element("nodes")
338 for node, data in G.nodes(data=True):
339 node_data = data.copy()
340 node_id = str(node_data.pop("id", node))
341 kw = {"id": node_id}
342 label = str(node_data.pop("label", node))
343 kw["label"] = label
344 try:
345 pid = node_data.pop("pid")
346 kw["pid"] = str(pid)
347 except KeyError:
348 pass
349 try:
350 start = node_data.pop("start")
351 kw["start"] = str(start)
352 self.alter_graph_mode_timeformat(start)
353 except KeyError:
354 pass
355 try:
356 end = node_data.pop("end")
357 kw["end"] = str(end)
358 self.alter_graph_mode_timeformat(end)
359 except KeyError:
360 pass
361 # add node element with attributes
362 node_element = Element("node", **kw)
363 # add node element and attr subelements
364 default = G.graph.get("node_default", {})
365 node_data = self.add_parents(node_element, node_data)
366 if self.VERSION == "1.1":
367 node_data = self.add_slices(node_element, node_data)
368 else:
369 node_data = self.add_spells(node_element, node_data)
370 node_data = self.add_viz(node_element, node_data)
371 node_data = self.add_attributes("node", node_element, node_data, default)
372 nodes_element.append(node_element)
373 graph_element.append(nodes_element)
374
375 def add_edges(self, G, graph_element):
376 def edge_key_data(G):
377 # helper function to unify multigraph and graph edge iterator
378 if G.is_multigraph():
379 for u, v, key, data in G.edges(data=True, keys=True):
380 edge_data = data.copy()
381 edge_data.update(key=key)
382 edge_id = edge_data.pop("id", None)
383 if edge_id is None:
384 edge_id = next(self.edge_id)
385 while str(edge_id) in self.all_edge_ids:
386 edge_id = next(self.edge_id)
387 self.all_edge_ids.add(str(edge_id))
388 yield u, v, edge_id, edge_data
389 else:
390 for u, v, data in G.edges(data=True):
391 edge_data = data.copy()
392 edge_id = edge_data.pop("id", None)
393 if edge_id is None:
394 edge_id = next(self.edge_id)
395 while str(edge_id) in self.all_edge_ids:
396 edge_id = next(self.edge_id)
397 self.all_edge_ids.add(str(edge_id))
398 yield u, v, edge_id, edge_data
399
400 edges_element = Element("edges")
401 for u, v, key, edge_data in edge_key_data(G):
402 kw = {"id": str(key)}
403 try:
404 edge_label = edge_data.pop("label")
405 kw["label"] = str(edge_label)
406 except KeyError:
407 pass
408 try:
409 edge_weight = edge_data.pop("weight")
410 kw["weight"] = str(edge_weight)
411 except KeyError:
412 pass
413 try:
414 edge_type = edge_data.pop("type")
415 kw["type"] = str(edge_type)
416 except KeyError:
417 pass
418 try:
419 start = edge_data.pop("start")
420 kw["start"] = str(start)
421 self.alter_graph_mode_timeformat(start)
422 except KeyError:
423 pass
424 try:
425 end = edge_data.pop("end")
426 kw["end"] = str(end)
427 self.alter_graph_mode_timeformat(end)
428 except KeyError:
429 pass
430 source_id = str(G.nodes[u].get("id", u))
431 target_id = str(G.nodes[v].get("id", v))
432 edge_element = Element("edge", source=source_id, target=target_id, **kw)
433 default = G.graph.get("edge_default", {})
434 if self.VERSION == "1.1":
435 edge_data = self.add_slices(edge_element, edge_data)
436 else:
437 edge_data = self.add_spells(edge_element, edge_data)
438 edge_data = self.add_viz(edge_element, edge_data)
439 edge_data = self.add_attributes("edge", edge_element, edge_data, default)
440 edges_element.append(edge_element)
441 graph_element.append(edges_element)
442
443 def add_attributes(self, node_or_edge, xml_obj, data, default):
444 # Add attrvalues to node or edge
445 attvalues = Element("attvalues")
446 if len(data) == 0:
447 return data
448 mode = "static"
449 for k, v in data.items():
450 # rename generic multigraph key to avoid any name conflict
451 if k == "key":
452 k = "networkx_key"
453 val_type = type(v)
454 if val_type not in self.xml_type:
455 raise TypeError(f"attribute value type is not allowed: {val_type}")
456 if isinstance(v, list):
457 # dynamic data
458 for val, start, end in v:
459 val_type = type(val)
460 if start is not None or end is not None:
461 mode = "dynamic"
462 self.alter_graph_mode_timeformat(start)
463 self.alter_graph_mode_timeformat(end)
464 break
465 attr_id = self.get_attr_id(
466 str(k), self.xml_type[val_type], node_or_edge, default, mode
467 )
468 for val, start, end in v:
469 e = Element("attvalue")
470 e.attrib["for"] = attr_id
471 e.attrib["value"] = str(val)
472 # Handle nan, inf, -inf differently
473 if val_type == float:
474 if e.attrib["value"] == "inf":
475 e.attrib["value"] = "INF"
476 elif e.attrib["value"] == "nan":
477 e.attrib["value"] = "NaN"
478 elif e.attrib["value"] == "-inf":
479 e.attrib["value"] = "-INF"
480 if start is not None:
481 e.attrib["start"] = str(start)
482 if end is not None:
483 e.attrib["end"] = str(end)
484 attvalues.append(e)
485 else:
486 # static data
487 mode = "static"
488 attr_id = self.get_attr_id(
489 str(k), self.xml_type[val_type], node_or_edge, default, mode
490 )
491 e = Element("attvalue")
492 e.attrib["for"] = attr_id
493 if isinstance(v, bool):
494 e.attrib["value"] = str(v).lower()
495 else:
496 e.attrib["value"] = str(v)
497 # Handle float nan, inf, -inf differently
498 if val_type == float:
499 if e.attrib["value"] == "inf":
500 e.attrib["value"] = "INF"
501 elif e.attrib["value"] == "nan":
502 e.attrib["value"] = "NaN"
503 elif e.attrib["value"] == "-inf":
504 e.attrib["value"] = "-INF"
505 attvalues.append(e)
506 xml_obj.append(attvalues)
507 return data
508
509 def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
510 # find the id of the attribute or generate a new id
511 try:
512 return self.attr[edge_or_node][mode][title]
513 except KeyError:
514 # generate new id
515 new_id = str(next(self.attr_id))
516 self.attr[edge_or_node][mode][title] = new_id
517 attr_kwargs = {"id": new_id, "title": title, "type": attr_type}
518 attribute = Element("attribute", **attr_kwargs)
519 # add subelement for data default value if present
520 default_title = default.get(title)
521 if default_title is not None:
522 default_element = Element("default")
523 default_element.text = str(default_title)
524 attribute.append(default_element)
525 # new insert it into the XML
526 attributes_element = None
527 for a in self.graph_element.findall("attributes"):
528 # find existing attributes element by class and mode
529 a_class = a.get("class")
530 a_mode = a.get("mode", "static")
531 if a_class == edge_or_node and a_mode == mode:
532 attributes_element = a
533 if attributes_element is None:
534 # create new attributes element
535 attr_kwargs = {"mode": mode, "class": edge_or_node}
536 attributes_element = Element("attributes", **attr_kwargs)
537 self.graph_element.insert(0, attributes_element)
538 attributes_element.append(attribute)
539 return new_id
540
541 def add_viz(self, element, node_data):
542 viz = node_data.pop("viz", False)
543 if viz:
544 color = viz.get("color")
545 if color is not None:
546 if self.VERSION == "1.1":
547 e = Element(
548 f"{{{self.NS_VIZ}}}color",
549 r=str(color.get("r")),
550 g=str(color.get("g")),
551 b=str(color.get("b")),
552 )
553 else:
554 e = Element(
555 f"{{{self.NS_VIZ}}}color",
556 r=str(color.get("r")),
557 g=str(color.get("g")),
558 b=str(color.get("b")),
559 a=str(color.get("a")),
560 )
561 element.append(e)
562
563 size = viz.get("size")
564 if size is not None:
565 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size))
566 element.append(e)
567
568 thickness = viz.get("thickness")
569 if thickness is not None:
570 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness))
571 element.append(e)
572
573 shape = viz.get("shape")
574 if shape is not None:
575 if shape.startswith("http"):
576 e = Element(
577 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape)
578 )
579 else:
580 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape))
581 element.append(e)
582
583 position = viz.get("position")
584 if position is not None:
585 e = Element(
586 f"{{{self.NS_VIZ}}}position",
587 x=str(position.get("x")),
588 y=str(position.get("y")),
589 z=str(position.get("z")),
590 )
591 element.append(e)
592 return node_data
593
594 def add_parents(self, node_element, node_data):
595 parents = node_data.pop("parents", False)
596 if parents:
597 parents_element = Element("parents")
598 for p in parents:
599 e = Element("parent")
600 e.attrib["for"] = str(p)
601 parents_element.append(e)
602 node_element.append(parents_element)
603 return node_data
604
605 def add_slices(self, node_or_edge_element, node_or_edge_data):
606 slices = node_or_edge_data.pop("slices", False)
607 if slices:
608 slices_element = Element("slices")
609 for start, end in slices:
610 e = Element("slice", start=str(start), end=str(end))
611 slices_element.append(e)
612 node_or_edge_element.append(slices_element)
613 return node_or_edge_data
614
615 def add_spells(self, node_or_edge_element, node_or_edge_data):
616 spells = node_or_edge_data.pop("spells", False)
617 if spells:
618 spells_element = Element("spells")
619 for start, end in spells:
620 e = Element("spell")
621 if start is not None:
622 e.attrib["start"] = str(start)
623 self.alter_graph_mode_timeformat(start)
624 if end is not None:
625 e.attrib["end"] = str(end)
626 self.alter_graph_mode_timeformat(end)
627 spells_element.append(e)
628 node_or_edge_element.append(spells_element)
629 return node_or_edge_data
630
631 def alter_graph_mode_timeformat(self, start_or_end):
632 # If 'start' or 'end' appears, alter Graph mode to dynamic and
633 # set timeformat
634 if self.graph_element.get("mode") == "static":
635 if start_or_end is not None:
636 if isinstance(start_or_end, str):
637 timeformat = "date"
638 elif isinstance(start_or_end, float):
639 timeformat = "double"
640 elif isinstance(start_or_end, int):
641 timeformat = "long"
642 else:
643 raise nx.NetworkXError(
644 "timeformat should be of the type int, float or str"
645 )
646 self.graph_element.set("timeformat", timeformat)
647 self.graph_element.set("mode", "dynamic")
648
649 def write(self, fh):
650 # Serialize graph G in GEXF to the open fh
651 if self.prettyprint:
652 self.indent(self.xml)
653 document = ElementTree(self.xml)
654 document.write(fh, encoding=self.encoding, xml_declaration=True)
655
656 def indent(self, elem, level=0):
657 # in-place prettyprint formatter
658 i = "\n" + " " * level
659 if len(elem):
660 if not elem.text or not elem.text.strip():
661 elem.text = i + " "
662 if not elem.tail or not elem.tail.strip():
663 elem.tail = i
664 for elem in elem:
665 self.indent(elem, level + 1)
666 if not elem.tail or not elem.tail.strip():
667 elem.tail = i
668 else:
669 if level and (not elem.tail or not elem.tail.strip()):
670 elem.tail = i
671
672
673 class GEXFReader(GEXF):
674 # Class to read GEXF format files
675 # use read_gexf() function
676 def __init__(self, node_type=None, version="1.2draft"):
677 self.node_type = node_type
678 # assume simple graph and test for multigraph on read
679 self.simple_graph = True
680 self.set_version(version)
681
682 def __call__(self, stream):
683 self.xml = ElementTree(file=stream)
684 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
685 if g is not None:
686 return self.make_graph(g)
687 # try all the versions
688 for version in self.versions:
689 self.set_version(version)
690 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
691 if g is not None:
692 return self.make_graph(g)
693 raise nx.NetworkXError("No <graph> element in GEXF file.")
694
695 def make_graph(self, graph_xml):
696 # start with empty DiGraph or MultiDiGraph
697 edgedefault = graph_xml.get("defaultedgetype", None)
698 if edgedefault == "directed":
699 G = nx.MultiDiGraph()
700 else:
701 G = nx.MultiGraph()
702
703 # graph attributes
704 graph_name = graph_xml.get("name", "")
705 if graph_name != "":
706 G.graph["name"] = graph_name
707 graph_start = graph_xml.get("start")
708 if graph_start is not None:
709 G.graph["start"] = graph_start
710 graph_end = graph_xml.get("end")
711 if graph_end is not None:
712 G.graph["end"] = graph_end
713 graph_mode = graph_xml.get("mode", "")
714 if graph_mode == "dynamic":
715 G.graph["mode"] = "dynamic"
716 else:
717 G.graph["mode"] = "static"
718
719 # timeformat
720 self.timeformat = graph_xml.get("timeformat")
721 if self.timeformat == "date":
722 self.timeformat = "string"
723
724 # node and edge attributes
725 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes")
726 # dictionaries to hold attributes and attribute defaults
727 node_attr = {}
728 node_default = {}
729 edge_attr = {}
730 edge_default = {}
731 for a in attributes_elements:
732 attr_class = a.get("class")
733 if attr_class == "node":
734 na, nd = self.find_gexf_attributes(a)
735 node_attr.update(na)
736 node_default.update(nd)
737 G.graph["node_default"] = node_default
738 elif attr_class == "edge":
739 ea, ed = self.find_gexf_attributes(a)
740 edge_attr.update(ea)
741 edge_default.update(ed)
742 G.graph["edge_default"] = edge_default
743 else:
744 raise # unknown attribute class
745
746 # Hack to handle Gephi0.7beta bug
747 # add weight attribute
748 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}}
749 ed = {}
750 edge_attr.update(ea)
751 edge_default.update(ed)
752 G.graph["edge_default"] = edge_default
753
754 # add nodes
755 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes")
756 if nodes_element is not None:
757 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"):
758 self.add_node(G, node_xml, node_attr)
759
760 # add edges
761 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges")
762 if edges_element is not None:
763 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"):
764 self.add_edge(G, edge_xml, edge_attr)
765
766 # switch to Graph or DiGraph if no parallel edges were found.
767 if self.simple_graph:
768 if G.is_directed():
769 G = nx.DiGraph(G)
770 else:
771 G = nx.Graph(G)
772 return G
773
774 def add_node(self, G, node_xml, node_attr, node_pid=None):
775 # add a single node with attributes to the graph
776
777 # get attributes and subattributues for node
778 data = self.decode_attr_elements(node_attr, node_xml)
779 data = self.add_parents(data, node_xml) # add any parents
780 if self.VERSION == "1.1":
781 data = self.add_slices(data, node_xml) # add slices
782 else:
783 data = self.add_spells(data, node_xml) # add spells
784 data = self.add_viz(data, node_xml) # add viz
785 data = self.add_start_end(data, node_xml) # add start/end
786
787 # find the node id and cast it to the appropriate type
788 node_id = node_xml.get("id")
789 if self.node_type is not None:
790 node_id = self.node_type(node_id)
791
792 # every node should have a label
793 node_label = node_xml.get("label")
794 data["label"] = node_label
795
796 # parent node id
797 node_pid = node_xml.get("pid", node_pid)
798 if node_pid is not None:
799 data["pid"] = node_pid
800
801 # check for subnodes, recursive
802 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes")
803 if subnodes is not None:
804 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"):
805 self.add_node(G, node_xml, node_attr, node_pid=node_id)
806
807 G.add_node(node_id, **data)
808
809 def add_start_end(self, data, xml):
810 # start and end times
811 ttype = self.timeformat
812 node_start = xml.get("start")
813 if node_start is not None:
814 data["start"] = self.python_type[ttype](node_start)
815 node_end = xml.get("end")
816 if node_end is not None:
817 data["end"] = self.python_type[ttype](node_end)
818 return data
819
820 def add_viz(self, data, node_xml):
821 # add viz element for node
822 viz = {}
823 color = node_xml.find(f"{{{self.NS_VIZ}}}color")
824 if color is not None:
825 if self.VERSION == "1.1":
826 viz["color"] = {
827 "r": int(color.get("r")),
828 "g": int(color.get("g")),
829 "b": int(color.get("b")),
830 }
831 else:
832 viz["color"] = {
833 "r": int(color.get("r")),
834 "g": int(color.get("g")),
835 "b": int(color.get("b")),
836 "a": float(color.get("a", 1)),
837 }
838
839 size = node_xml.find(f"{{{self.NS_VIZ}}}size")
840 if size is not None:
841 viz["size"] = float(size.get("value"))
842
843 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness")
844 if thickness is not None:
845 viz["thickness"] = float(thickness.get("value"))
846
847 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape")
848 if shape is not None:
849 viz["shape"] = shape.get("shape")
850 if viz["shape"] == "image":
851 viz["shape"] = shape.get("uri")
852
853 position = node_xml.find(f"{{{self.NS_VIZ}}}position")
854 if position is not None:
855 viz["position"] = {
856 "x": float(position.get("x", 0)),
857 "y": float(position.get("y", 0)),
858 "z": float(position.get("z", 0)),
859 }
860
861 if len(viz) > 0:
862 data["viz"] = viz
863 return data
864
865 def add_parents(self, data, node_xml):
866 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents")
867 if parents_element is not None:
868 data["parents"] = []
869 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"):
870 parent = p.get("for")
871 data["parents"].append(parent)
872 return data
873
874 def add_slices(self, data, node_or_edge_xml):
875 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices")
876 if slices_element is not None:
877 data["slices"] = []
878 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"):
879 start = s.get("start")
880 end = s.get("end")
881 data["slices"].append((start, end))
882 return data
883
884 def add_spells(self, data, node_or_edge_xml):
885 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells")
886 if spells_element is not None:
887 data["spells"] = []
888 ttype = self.timeformat
889 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"):
890 start = self.python_type[ttype](s.get("start"))
891 end = self.python_type[ttype](s.get("end"))
892 data["spells"].append((start, end))
893 return data
894
895 def add_edge(self, G, edge_element, edge_attr):
896 # add an edge to the graph
897
898 # raise error if we find mixed directed and undirected edges
899 edge_direction = edge_element.get("type")
900 if G.is_directed() and edge_direction == "undirected":
901 raise nx.NetworkXError("Undirected edge found in directed graph.")
902 if (not G.is_directed()) and edge_direction == "directed":
903 raise nx.NetworkXError("Directed edge found in undirected graph.")
904
905 # Get source and target and recast type if required
906 source = edge_element.get("source")
907 target = edge_element.get("target")
908 if self.node_type is not None:
909 source = self.node_type(source)
910 target = self.node_type(target)
911
912 data = self.decode_attr_elements(edge_attr, edge_element)
913 data = self.add_start_end(data, edge_element)
914
915 if self.VERSION == "1.1":
916 data = self.add_slices(data, edge_element) # add slices
917 else:
918 data = self.add_spells(data, edge_element) # add spells
919
920 # GEXF stores edge ids as an attribute
921 # NetworkX uses them as keys in multigraphs
922 # if networkx_key is not specified as an attribute
923 edge_id = edge_element.get("id")
924 if edge_id is not None:
925 data["id"] = edge_id
926
927 # check if there is a 'multigraph_key' and use that as edge_id
928 multigraph_key = data.pop("networkx_key", None)
929 if multigraph_key is not None:
930 edge_id = multigraph_key
931
932 weight = edge_element.get("weight")
933 if weight is not None:
934 data["weight"] = float(weight)
935
936 edge_label = edge_element.get("label")
937 if edge_label is not None:
938 data["label"] = edge_label
939
940 if G.has_edge(source, target):
941 # seen this edge before - this is a multigraph
942 self.simple_graph = False
943 G.add_edge(source, target, key=edge_id, **data)
944 if edge_direction == "mutual":
945 G.add_edge(target, source, key=edge_id, **data)
946
947 def decode_attr_elements(self, gexf_keys, obj_xml):
948 # Use the key information to decode the attr XML
949 attr = {}
950 # look for outer '<attvalues>' element
951 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues")
952 if attr_element is not None:
953 # loop over <attvalue> elements
954 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"):
955 key = a.get("for") # for is required
956 try: # should be in our gexf_keys dictionary
957 title = gexf_keys[key]["title"]
958 except KeyError as e:
959 raise nx.NetworkXError(f"No attribute defined for={key}.") from e
960 atype = gexf_keys[key]["type"]
961 value = a.get("value")
962 if atype == "boolean":
963 value = self.convert_bool[value]
964 else:
965 value = self.python_type[atype](value)
966 if gexf_keys[key]["mode"] == "dynamic":
967 # for dynamic graphs use list of three-tuples
968 # [(value1,start1,end1), (value2,start2,end2), etc]
969 ttype = self.timeformat
970 start = self.python_type[ttype](a.get("start"))
971 end = self.python_type[ttype](a.get("end"))
972 if title in attr:
973 attr[title].append((value, start, end))
974 else:
975 attr[title] = [(value, start, end)]
976 else:
977 # for static graphs just assign the value
978 attr[title] = value
979 return attr
980
981 def find_gexf_attributes(self, attributes_element):
982 # Extract all the attributes and defaults
983 attrs = {}
984 defaults = {}
985 mode = attributes_element.get("mode")
986 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"):
987 attr_id = k.get("id")
988 title = k.get("title")
989 atype = k.get("type")
990 attrs[attr_id] = {"title": title, "type": atype, "mode": mode}
991 # check for the 'default' subelement of key element and add
992 default = k.find(f"{{{self.NS_GEXF}}}default")
993 if default is not None:
994 if atype == "boolean":
995 value = self.convert_bool[default.text]
996 else:
997 value = self.python_type[atype](default.text)
998 defaults[title] = value
999 return attrs, defaults
1000
1001
1002 def relabel_gexf_graph(G):
1003 """Relabel graph using "label" node keyword for node label.
1004
1005 Parameters
1006 ----------
1007 G : graph
1008 A NetworkX graph read from GEXF data
1009
1010 Returns
1011 -------
1012 H : graph
1013 A NetworkX graph with relabed nodes
1014
1015 Raises
1016 ------
1017 NetworkXError
1018 If node labels are missing or not unique while relabel=True.
1019
1020 Notes
1021 -----
1022 This function relabels the nodes in a NetworkX graph with the
1023 "label" attribute. It also handles relabeling the specific GEXF
1024 node attributes "parents", and "pid".
1025 """
1026 # build mapping of node labels, do some error checking
1027 try:
1028 mapping = [(u, G.nodes[u]["label"]) for u in G]
1029 except KeyError as e:
1030 raise nx.NetworkXError(
1031 "Failed to relabel nodes: missing node labels found. Use relabel=False."
1032 ) from e
1033 x, y = zip(*mapping)
1034 if len(set(y)) != len(G):
1035 raise nx.NetworkXError(
1036 "Failed to relabel nodes: "
1037 "duplicate node labels found. "
1038 "Use relabel=False."
1039 )
1040 mapping = dict(mapping)
1041 H = nx.relabel_nodes(G, mapping)
1042 # relabel attributes
1043 for n in G:
1044 m = mapping[n]
1045 H.nodes[m]["id"] = n
1046 H.nodes[m].pop("label")
1047 if "pid" in H.nodes[m]:
1048 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]]
1049 if "parents" in H.nodes[m]:
1050 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]]
1051 return H