comparison pipeline_astrobert.py @ 0:a35056104c2c draft default tip

planemo upload for repository https://github.com/esg-epfl-apc/tools-astro/tree/main/tools commit da42ae0d18f550dec7f6d7e29d297e7cf1909df2
author astroteam
date Fri, 13 Jun 2025 13:26:36 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:a35056104c2c
1 import re
2 import pandas as pd
3 import numpy as np
4 import tempfile
5 from transformers import AutoModelForTokenClassification, AutoTokenizer
6 from transformers import TokenClassificationPipeline
7
8
9 def split_text_in_phrases(text_id, text_):
10 list_proto_phrases = re.split(r"(\. [A-Z])", text_)
11 for i in range(1, len(list_proto_phrases) - 1, 2):
12 back_ = list_proto_phrases[i][0]
13 front_ = list_proto_phrases[i][-1]
14 list_proto_phrases[i+1] = front_ + list_proto_phrases[i+1]
15 list_proto_phrases[i-1] = list_proto_phrases[i-1] + back_
16
17 list_phrases = []
18 for i in range(0, len(list_proto_phrases), 2):
19 list_phrases.append(list_proto_phrases[i])
20
21 text_check = " ".join(list_phrases)
22 if text_check != text_:
23 print(text_id)
24 return list_phrases
25
26
27 def apply_astroBERT(text_id, body_text_0):
28 dict_out = {"TEXT_ID": [], "word": [], "start": [], "end": [], "score": [], "entity_group": [], "Phrase": []}
29
30 tmpdir_ = tempfile.TemporaryDirectory()
31
32 try:
33 # load astroBERT for NER-DEAL
34 remote_model_path = 'adsabs/astroBERT'
35 # you need to load the astroBERT trained for NER-DEAL, which is on a seperate branch
36 revision = 'NER-DEAL'
37
38 astroBERT_NER_DEAL = AutoModelForTokenClassification.from_pretrained(
39 pretrained_model_name_or_path=remote_model_path,
40 revision=revision,
41 cache_dir=tmpdir_.name
42 )
43
44 astroBERT_tokenizer = AutoTokenizer.from_pretrained(
45 pretrained_model_name_or_path=remote_model_path,
46 add_special_tokens=True,
47 do_lower_case=False,
48 model_max_length=512,
49 cache_dir=tmpdir_.name
50 )
51
52 # use the Hugginface Pipeline class
53 NER_pipeline = TokenClassificationPipeline(
54 model=astroBERT_NER_DEAL,
55 tokenizer=astroBERT_tokenizer,
56 task='astroBERT NER_DEAL',
57 aggregation_strategy='average',
58 ignore_labels=['O']
59 )
60
61 text = " ".join(body_text_0.split()).replace("°", "o").replace("º", "o").replace("−", "-").replace('°', "o")
62 list_phrases = split_text_in_phrases(text_id, text)
63
64 for phrase_ in list_phrases:
65 result = NER_pipeline(phrase_)
66
67 for u in result:
68 ent_ = u["entity_group"]
69 if ent_ in ["Instrument", "Telescope", "Wavelength", "CelestialObject", "CelestialRegion", "EntityOfFutureInterest", "Mission", "Observatory", "Survey"]:
70 dict_out["TEXT_ID"].append(text_id)
71 dict_out["Phrase"].append(phrase_)
72
73 dict_out["word"].append(u["word"])
74 dict_out["score"].append(u["score"])
75 dict_out["start"].append(u["start"])
76 dict_out["end"].append(u["end"])
77 dict_out["entity_group"].append(ent_)
78 except Exception as e:
79 print(f"An error occurred in apply_astroBERT: {e}")
80 finally:
81 tmpdir_.cleanup()
82
83 return pd.DataFrame(dict_out)
84
85
86 def get_astroBERT_cleaned_result(text_id, body_text_0):
87 list_entities = ["Instrument", "Telescope", "Wavelength", "CelestialObject", "CelestialRegion", "EntityOfFutureInterest", "Mission", "Observatory", "Survey"]
88
89 df_raw = apply_astroBERT(text_id, body_text_0)
90 dict_out = {"TEXT_ID": [], "word": [], "start": [], "end": [], "Score": [], "Phrase": [], "entity_group": []}
91
92 for entity_to_study in list_entities:
93 df_tmp0 = df_raw[df_raw["entity_group"] == entity_to_study]
94 phrases_ = np.unique(df_tmp0["Phrase"])
95
96 for phrase_ in phrases_:
97 df_tmp1 = df_tmp0[df_tmp0["Phrase"] == phrase_]
98 if len(df_tmp1) == 1:
99 dict_out["TEXT_ID"].append(text_id)
100 dict_out["Phrase"].append(df_tmp1.Phrase.values[0])
101 dict_out["word"].append(df_tmp1.word.values[0])
102 dict_out["start"].append(df_tmp1.start.values[0])
103 dict_out["end"].append(df_tmp1.end.values[0])
104 dict_out["Score"].append(df_tmp1.score.values[0])
105 dict_out["entity_group"].append(entity_to_study)
106
107 else:
108 df_tmp1.sort_values(by=['start'])
109 for s_i, (s_, e_, sc_) in enumerate(zip(df_tmp1.start.values, df_tmp1.end.values, df_tmp1.score.values)):
110 if s_i == 0:
111 s_o = s_
112 e_o = e_
113 sc_s = sc_
114 word_size = 1
115 else:
116
117 if s_ <= e_o + 1:
118 e_o = e_
119 sc_s += sc_
120 word_size += 1
121
122 else:
123 dict_out["TEXT_ID"].append(text_id)
124 dict_out["Phrase"].append(phrase_)
125 dict_out["word"].append(phrase_[s_o: e_o])
126 dict_out["start"].append(s_o)
127 dict_out["end"].append(e_o)
128 dict_out["Score"].append(sc_s / word_size)
129 dict_out["entity_group"].append(entity_to_study)
130
131 s_o = s_
132 e_o = e_
133 sc_s = sc_
134 word_size = 1
135
136 if s_i == len(df_tmp1) - 1:
137 dict_out["TEXT_ID"].append(text_id)
138 dict_out["Phrase"].append(phrase_)
139 dict_out["word"].append(phrase_[s_o: e_o])
140 dict_out["start"].append(s_o)
141 dict_out["end"].append(e_o)
142 dict_out["Score"].append(sc_s / word_size)
143 dict_out["entity_group"].append(entity_to_study)
144
145 return pd.DataFrame(dict_out)