diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index 37d91f253..2bc7dc0fa 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -212,6 +212,9 @@ def __init__(self, base_tokenizer: BaseTokenizer, def _init_data_paths(self): doc_cls = self.base_tokenizer.get_doc_class() doc_cls.register_addon_path('relations', def_val=[], force=True) + entity_cls = self.base_tokenizer.get_entity_class() + entity_cls.register_addon_path('start', def_val=None, force=True) + entity_cls.register_addon_path('end', def_val=None, force=True) def save(self, save_path: str = "./") -> None: self.component.save(save_path=save_path) @@ -833,27 +836,23 @@ def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs relations: list = doc.get_addon_data( # type: ignore "relations") + out_rels = predict_rel_dataset.dataset[ + "output_relations"][rel_idx] relations.append( { "relation": rc_cnf.general.idx2labels[ predicted_label_id], "label_id": predicted_label_id, - "ent1_text": predict_rel_dataset.dataset[ - "output_relations"][rel_idx][ - 2], - "ent2_text": predict_rel_dataset.dataset[ - "output_relations"][rel_idx][ - 3], + "ent1_text": out_rels[2], + "ent2_text": out_rels[3], "confidence": float("{:.3f}".format( confidence[0])), - "start_ent_pos": "", - "end_ent_pos": "", - "start_entity_id": - predict_rel_dataset.dataset[ - "output_relations"][rel_idx][8], - "end_entity_id": - predict_rel_dataset.dataset[ - "output_relations"][rel_idx][9] + "start_ent1_char_pos": out_rels[18], + "end_ent1_char_pos": out_rels[19], + "start_ent2_char_pos": out_rels[20], + "end_ent2_char_pos": out_rels[21], + "start_entity_id": out_rels[8], + "end_entity_id": out_rels[9], }) pbar.update(len(token_ids)) pbar.close() @@ -901,6 +900,8 @@ def predict_text_with_anns(self, text: str, annotations: list[dict] entity = base_tokenizer.create_entity( doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) entity.cui = ann["cui"] + entity.set_addon_data('start', ann['strat']) + entity.set_addon_data('end', ann['end']) doc.ner_ents.append(entity) doc = self(doc) diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_dataset.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_dataset.py index 149410538..6130f34f3 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_dataset.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_dataset.py @@ -246,6 +246,9 @@ def _create_relation_validation(self, ent2_token: Union[str, MutableEntity] = tmp_doc_text[ ent2_start_char_pos: ent2_end_char_pos] + annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens( + self.config.general.annotation_schema_tag_ids) + if (abs(ent2_start_char_pos - ent1_start_char_pos ) <= self.config.general.window_size and ent1_token != ent2_token): @@ -281,24 +284,20 @@ def _create_relation_validation(self, if is_spacy_doc or is_mct_export: tmp_doc_text = text - _pre_e1 = tmp_doc_text[0: (ent1_start_char_pos)] - _e1_s2 = ( - tmp_doc_text[ent1_end_char_pos: ent2_start_char_pos - 1]) - _e2_end = tmp_doc_text[ent2_end_char_pos + 1: text_length] - ent2_token_end_pos = (ent2_token_end_pos + 2) - - annotation_token_text = ( - self.tokenizer.hf_tokenizers.convert_ids_to_tokens( - self.config.general.annotation_schema_tag_ids)) + s1, e1, s2, e2 = annotation_token_text tmp_doc_text = ( - str(_pre_e1) + " " + - annotation_token_text[0] + " " + - str(ent1_token) + " " + - annotation_token_text[1] + " " + str(_e1_s2) + " " + - annotation_token_text[2] + " " + str(ent2_token) + " " + - annotation_token_text[3] + " " + str(_e2_end) - ) + tmp_doc_text[:ent2_end_char_pos] + + e2 + tmp_doc_text[ent2_end_char_pos:]) + tmp_doc_text = ( + tmp_doc_text[:ent2_start_char_pos] + + s2 + tmp_doc_text[ent2_start_char_pos:]) + tmp_doc_text = ( + tmp_doc_text[:ent1_end_char_pos] + + e1 + tmp_doc_text[ent1_end_char_pos:]) + tmp_doc_text = ( + tmp_doc_text[:ent1_start_char_pos] + + s1 + tmp_doc_text[ent1_start_char_pos:]) ann_tag_token_len = len(annotation_token_text[0]) @@ -309,11 +308,10 @@ def _create_relation_validation(self, _right_context_start_end_pos = ( # 8 for spces right_context_end_char_pos + (ann_tag_token_len * 4) + 8) - right_context_end_char_pos = ( - len(tmp_doc_text) + 1 if - right_context_end_char_pos >= len(tmp_doc_text) or - _right_context_start_end_pos >= len(tmp_doc_text) - else _right_context_start_end_pos) + right_context_end_char_pos = len(tmp_doc_text) if ( + right_context_end_char_pos >= len(tmp_doc_text) + or _right_context_start_end_pos >= len(tmp_doc_text) + ) else _right_context_start_end_pos # reassign the new text with added tags text_length = len(tmp_doc_text) @@ -363,16 +361,20 @@ def _create_relation_validation(self, ent2_token_start_pos += ent1_token_start_pos ent1_ent2_new_start = (ent1_token_start_pos, ent2_token_start_pos) - en1_start, en1_end = window_tokenizer_data[ - "offset_mapping"][ent1_token_start_pos] - en2_start, en2_end = window_tokenizer_data[ - "offset_mapping"][ent2_token_start_pos] + os_map = window_tokenizer_data["offset_mapping"] + s1_start, s1_end = os_map[ent1_token_start_pos] + e1_start, e1_end = os_map[_ent1_token_end_pos] + + s2_start, s2_end = os_map[ent2_token_start_pos] + e2_start, e2_end = os_map[_ent2_token_end_pos] return [window_tokenizer_data["input_ids"], ent1_ent2_new_start, ent1_token, ent2_token, "UNK", self.config.model.padding_idx, None, None, None, None, None, None, doc_id, "", - en1_start, en1_end, en2_start, en2_end] + s1_start, e1_end, s2_start, e2_end, + ent1_start_char_pos, ent1_end_char_pos, + ent2_start_char_pos, ent2_end_char_pos] return [] def _get_token_type_and_start_end(