Source code for kyoto_reader.document

import _pickle as cPickle
import copy
import io
import logging
from collections import OrderedDict, ChainMap, defaultdict
from typing import List, Dict, Set, Optional, Iterator, TextIO, Collection

import jaconv
from pyknp import BList, Bunsetsu, Tag, Morpheme, Rel

from .base_phrase import BasePhrase
from .constants import ALL_CASES, ALL_EXOPHORS, ALL_COREFS, NE_CATEGORIES
from .coreference import Mention, Entity
from .ne import NamedEntity
from .pas import Pas, Predicate, BaseArgument, Argument, SpecialArgument
from .sentence import Sentence

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


[docs]class Document: """A class to represent a document of KWDLC, KyotoCorpus, or AnnotatedFKCCorpus. Args: knp_string (str): KNP format string of the document. doc_id (str): A document ID. cases (Collection[str]): Cases to extract. corefs (Collection[str]): Coreference relations to extract. relax_cases (bool): Whether to consider relations with "≒" as those without "≒" (e.g. ガ≒格 -> ガ格). extract_nes (bool): Whether to extract named entities. use_pas_tag (bool): Whether to read predicate-argument structures from <述語項構造: > tags, not <rel> tags. Attributes: knp_string (str): KNP format string of the document. doc_id (str): A document ID. cases (Collection[str]): Cases to extract. corefs (Collection[str]): Coreference relations to extract. extract_nes (bool): Whether to extract named entities. sid2sentence (Dict[str, Sentence]): A mapping from a sentence ID to the corresponding sentence. mentions (Dict[int, Mention]): A mapping from a document-wide tag ID to the corresponding mention. entities (Dict[int, Entity]): A mapping from a entity ID to the corresponding entity. named_entities (List[NamedEntity]): Extracted named entities. """
[docs] def __init__(self, knp_string: str, doc_id: str, cases: Collection[str], corefs: Collection[str], relax_cases: bool, extract_nes: bool, use_pas_tag: bool, ) -> None: self.knp_string: str = knp_string self.doc_id: str = doc_id self.cases: Collection[str] = cases self.corefs: Collection[str] = corefs self.relax_cases: bool = relax_cases self.extract_nes: bool = extract_nes self.use_pas_tag: bool = use_pas_tag self.sid2sentence: Dict[str, Sentence] = OrderedDict() dtid = dmid = 0 buff = '' for line in knp_string.strip().split('\n'): buff += line + '\n' if line.strip() == 'EOS': sentence = Sentence(buff, dtid, dmid, doc_id) if sentence.sid in self.sid2sentence: logger.warning(f'{sentence.sid}: duplicated sid found') self.sid2sentence[sentence.sid] = sentence dtid += len(sentence) dmid += len(sentence.mrph_list()) buff = '' self._mrph2dmid: Dict[Morpheme, int] = dict(ChainMap(*(sent.mrph2dmid for sent in self.sentences))) self._pas: Dict[int, Pas] = OrderedDict() self.mentions: Dict[int, Mention] = {} self.entities: Dict[int, Entity] = {} if use_pas_tag: self._analyze_pas() else: self._analyze_rel() if extract_nes: self.named_entities: List[NamedEntity] = [] self._extract_nes()
def _analyze_pas(self) -> None: """Extract predicate-argument structures represented in <述語項構造: > tags.""" sid2idx = {sid: idx for idx, sid in enumerate(self.sid2sentence.keys())} for bp in self.bp_list(): if bp.tag.pas is None: continue pas = Pas(bp) for case, arguments in bp.tag.pas.arguments.items(): if self.relax_cases: if case in ALL_CASES and case.endswith('≒'): case = case.rstrip('≒') # ガ≒ -> ガ for arg in arguments: arg.midasi = jaconv.h2z(arg.midasi, digit=True) # 不特定:人1 -> 不特定:人1 # exophor if arg.flag == 'E': entity = self._create_entity(exophor=arg.midasi, eid=arg.eid) pas.add_special_argument(case, arg.midasi, entity.eid, '') else: sid = self.sentences[sid2idx[arg.sid] - arg.sdist].sid arg_bp = self._get_bp(sid, arg.tid) _ = self._create_mention(arg_bp) pas.add_argument(case, arg_bp, '') if pas.arguments: self._pas[pas.dtid] = pas def _analyze_rel(self) -> None: """Extract predicate-argument structures and coreference relations represented in <rel> tags""" for bp in self.bp_list(): rels = [] for rel in self._extract_rel_tags(bp.tag): if self.relax_cases: if rel.atype in ALL_CASES and rel.atype.endswith('≒'): rel.atype = rel.atype.rstrip('≒') # ガ≒ -> ガ valid = True if rel.sid is not None and rel.sid not in self.sid2sentence: logger.warning(f'{bp.sid}: sentence: {rel.sid} not found in {self.doc_id}') valid = False if rel.atype in (ALL_CASES + ALL_COREFS): if not (rel.atype in self.cases or rel.atype in self.corefs): logger.info(f'{bp.sid}: relation type: {rel.atype} is ignored') valid = False else: logger.warning(f'{bp.sid}: unknown relation: {rel.atype}') if valid is True: rels.append(rel) # extract PAS pas = Pas(bp) for rel in rels: if rel.atype in self.cases: if rel.sid is not None: assert rel.tid is not None arg_bp = self._get_bp(rel.sid, rel.tid) if arg_bp is None: continue # create a mention and an entity when an argument is found _ = self._create_mention(arg_bp) pas.add_argument(rel.atype, arg_bp, rel.mode) # exophora else: if rel.target == 'なし': pas.set_arguments_optional(rel.atype) continue if rel.target not in ALL_EXOPHORS: logger.warning(f'{pas.sid}:unknown exophor: {rel.target}') continue entity = self._create_entity(rel.target) pas.add_special_argument(rel.atype, rel.target, entity.eid, rel.mode) if pas.arguments: self._pas[pas.dtid] = pas # extract coreference for rel in rels: if rel.atype in self.corefs: if rel.mode in ('', 'AND'): # ignore "OR" and "?" self._add_corefs(bp, rel) # to extract rels with mode: '?', rewrite initializer of pyknp Features class @staticmethod def _extract_rel_tags(tag: Tag) -> List[Rel]: """Parse tag.fstring to extract <rel> tags.""" splitter = "><" rels = [] spec = tag.fstring tag_start = 1 tag_end = None while tag_end != -1: tag_end = spec.find(splitter, tag_start) if spec[tag_start:].startswith('rel '): rel = Rel(spec[tag_start:tag_end]) if rel.target: rel.target = jaconv.h2z(rel.target, digit=True) # 不特定:人1 -> 不特定:人1 if rel.atype is not None: rels.append(rel) tag_start = tag_end + len(splitter) return rels def _add_corefs(self, source_bp: BasePhrase, rel: Rel, ) -> None: """Add a coreference relation.""" if rel.sid is not None: target_bp = self._get_bp(rel.sid, rel.tid) if target_bp is None: return if target_bp.dtid == source_bp.dtid: logger.warning(f'{source_bp.sid}: coreference with self found: {source_bp}') return else: target_bp = None if rel.target not in ALL_EXOPHORS: logger.warning(f'{source_bp.sid}: unknown exophor: {rel.target}') return uncertain: bool = rel.atype.endswith('≒') source_mention = self._create_mention(source_bp) for eid in source_mention.all_eids: # _merge_entities によって source_mention の eid が削除されているかもしれない if eid not in self.entities: continue source_entity = self.entities[eid] if rel.sid is not None: target_mention = self._create_mention(target_bp) for target_eid in target_mention.all_eids: target_entity = self.entities[target_eid] self._merge_entities(source_mention, target_mention, source_entity, target_entity, uncertain) else: target_entity = self._create_entity(exophor=rel.target) self._merge_entities(source_mention, None, source_entity, target_entity, uncertain) def _create_mention(self, bp: BasePhrase) -> Mention: """Create a mention from the corresponding base phrase. If the base phrase has not registered as a mention yet, create a new mention as well as an entity. Otherwise, return the registered mention. Args: bp (BasePhrase): A base phrase corresponding to the mention to be created. Returns: Mention: A mention. """ if bp.dtid not in self.mentions: # make a new coreference cluster mention = Mention(bp) self.mentions[bp.dtid] = mention entity = self._create_entity() entity.add_mention(mention, uncertain=False) else: mention = self.mentions[bp.dtid] return mention def _create_entity(self, exophor: Optional[str] = None, eid: Optional[int] = None, ) -> Entity: """Create an entity. exophor が singleton entity だった場合を除き、新しく Entity のインスタンスを作成して返す singleton entity とは、「著者」や「不特定:人1」などの必ず一つしか存在しないような entity 一方で、「不特定:人」や「不特定:物」は複数存在しうるので singleton entity ではない eid を指定しない場合、最後に作成した entity の次の eid を選択 Args: exophor (Optional[str]): 外界照応詞(optional) eid (Optional[int]): エンティティID(省略推奨) Returns: Entity: An entity to be created. """ if exophor: if exophor not in ('不特定:人', '不特定:物', '不特定:状況'): # exophor が singleton entity だった時 entities = [e for e in self.entities.values() if exophor == e.exophor] # すでに singleton entity が存在した場合、新しい entity は作らずにその entity を返す if entities: assert len(entities) == 1 # singleton entity が1つしかないことを保証 return entities[0] eids: Set[int] = {e.eid for e in self.entities.values()} if eid in eids: eid_ = eid eid: int = max(eids) + 1 logger.warning(f'{self.doc_id}:eid: {eid_} is already used. use eid: {eid} instead.') elif eid is None or eid < 0: eid: int = max(eids) + 1 if eids else 0 entity = Entity(eid, exophor=exophor) self.entities[eid] = entity return entity def _merge_entities(self, source_mention: Mention, target_mention: Optional[Mention], se: Entity, te: Entity, uncertain: bool, ) -> None: """Merge two entities. source_mention と se, target_mention と te の間には mention が張られているが、 source と target 間には張られていないので、add_mention する se と te が同一のエンティティであり、exophor も同じか片方が None ならば te の方を削除する Args: source_mention (Mention): A source mention. target_mention (Mention, optional): A target mention. se (Entity): A source entity. te (Entity): A target entity. uncertain (bool): Whether the relation between source and target mentions is uncertain (i.e., annotated \ with "≒"). """ uncertain_tgt = (target_mention is not None) and target_mention.is_uncertain_to(te) uncertain_src = source_mention.is_uncertain_to(se) if se is te: if not uncertain: # se(te), source_mention, target_mention の三角形のうち2辺が certain ならもう1辺も certain if (not uncertain_src) and uncertain_tgt: se.add_mention(target_mention, uncertain=False) if uncertain_src and (not uncertain_tgt): se.add_mention(source_mention, uncertain=False) return if target_mention is not None: se.add_mention(target_mention, uncertain=(uncertain or uncertain_src)) te.add_mention(source_mention, uncertain=(uncertain or uncertain_tgt)) # se と te が同一でない可能性が拭えない場合、te は削除しない if uncertain_src or uncertain or uncertain_tgt: return # se と te が同一でも exophor が異なれば te は削除しない if se.exophor is not None and te.exophor is not None and se.exophor != te.exophor: return # 以下 te を削除する準備 if se.exophor is None: se.exophor = te.exophor for tm in te.all_mentions: se.add_mention(tm, uncertain=tm.is_uncertain_to(te)) # argument も eid を持っているので eid が変わった場合はこちらも更新 for arg in [arg for pas in self._pas.values() for args in pas.arguments.values() for arg in args]: if isinstance(arg, SpecialArgument) and arg.eid == te.eid: arg.eid = se.eid self._delete_entity(te.eid, source_mention.sid) # delete target entity def _delete_entity(self, eid: int, sid: str ) -> None: """Delete an entity. Remove the target entity from all the mentions of the entity as well as from self.entities. Note that entity IDs can have a missing number. Args: eid (int): The entity ID of the entity to be deleted. sid (int): The sentence ID of the sentence being analyzed when the entity is deleted. """ if eid not in self.entities: return entity = self.entities[eid] logger.info(f'{sid}: delete entity: {eid} ({entity})') for mention in entity.all_mentions: entity.remove_mention(mention) self.entities.pop(eid) def _get_bp(self, sid: str, tid: int, ) -> Optional[BasePhrase]: """Get a base phrase from sentence ID and tag ID. Args: sid (str): A sentence ID. tid (int): A tag ID. Returns: Optional[BasePhrase]: The base phrase that has sentence ID of sid and tag ID of tid. """ sentence = self[sid] if not (0 <= tid < len(sentence.bps)): logger.warning(f'{sid}: tag id: {tid} out of range') return None return sentence.bps[tid] def _extract_nes(self) -> None: """Extract named entities referring tag objects.""" for sentence in self.sentences: tag_list = sentence.tag_list() # tag.features = {'NE': 'LOCATION:ダーマ神殿'} for tag in tag_list: if 'NE' not in tag.features: continue category, name = tag.features['NE'].split(':', maxsplit=1) if category not in NE_CATEGORIES: logger.warning(f'{sentence.sid}: unknown NE category: {category}') continue mrph_list = [m for t in tag_list[:tag.tag_id + 1] for m in t.mrph_list()] mrph_span = self._find_mrph_span(name, mrph_list, tag) if mrph_span is None: logger.warning(f'{sentence.sid}: mrph span of \'{name}\' not found') continue ne = NamedEntity(category, name, sentence, mrph_span, self._mrph2dmid) self.named_entities.append(ne) @staticmethod def _find_mrph_span(name: str, mrph_list: List[Morpheme], tag: Tag ) -> Optional[range]: """nameにマッチする形態素の範囲を返す""" for i in range(len(tag.mrph_list())): end_mid = len(mrph_list) - i mrph_span = '' for mrph in reversed(mrph_list[:end_mid]): mrph_span = mrph.midasi + mrph_span if mrph_span == name: return range(mrph.mrph_id, end_mid) return None @property def sentences(self) -> List['Sentence']: """List of sentences in this document. Returns: List[Sentence] """ return list(self.sid2sentence.values()) @property def mrph2dmid(self) -> Dict[Morpheme, int]: """A mapping from morpheme to its document-wide ID.""" return self._mrph2dmid @property def surf(self) -> str: """A surface expression of this document.""" return ''.join(sent.surf for sent in self.sentences)
[docs] def bnst_list(self) -> List[Bunsetsu]: """Return list of Bunsetsu object in pyknp.""" return [bnst for sentence in self.sentences for bnst in sentence.bnst_list()]
[docs] def bp_list(self) -> List[BasePhrase]: """Return list of base phrases.""" return [bp for sentence in self.sentences for bp in sentence.bps]
[docs] def tag_list(self) -> List[Tag]: """Return list of Tag object in pyknp.""" return [tag for sentence in self.sentences for tag in sentence.tag_list()]
[docs] def mrph_list(self) -> List[Morpheme]: """Return list of Morpheme object in pyknp.""" return [mrph for sentence in self.sentences for mrph in sentence.mrph_list()]
[docs] def get_entities(self, bp: BasePhrase, include_uncertain: bool = False) -> List[Entity]: """Return list of entities that the specified mention refers to. The mention is given as a type of BasePhrase. Args: bp (BasePhrase): A base phrase corresponds to the mention. include_uncertain (bool): Whether to return entities that has uncertain relation with the mention. """ if bp.dtid not in self.mentions: return [] mention = self.mentions[bp.dtid] eids = mention.all_eids if include_uncertain else mention.eids return [self.entities[eid] for eid in eids]
[docs] def pas_list(self) -> List[Pas]: """Return list of predicate-argument structures.""" return list(self._pas.values())
[docs] def get_predicates(self) -> List[Predicate]: """Return list of predicates.""" return [pas.predicate for pas in self._pas.values()]
[docs] def get_arguments(self, predicate: Predicate, relax: bool = False, include_optional: bool = False, ) -> Dict[str, List[BaseArgument]]: """Return all the arguments that the given predicate has. Args: predicate (Predicate): A predicate. relax (bool): If True, return arguments that have a coreference relation with the arguments the predicate \ has. include_optional (bool): If True, return adverbial arguments such as "すぐに" as well. Returns: Dict[str, List[BaseArgument]]: A mapping from a case to arguments. """ if predicate.dtid not in self._pas: return defaultdict(list) pas = copy.copy(self._pas[predicate.dtid]) pas.arguments = cPickle.loads(cPickle.dumps(pas.arguments, -1)) if include_optional is False: for case in self.cases: pas.arguments[case] = list(filter(lambda a: a.optional is False, pas.arguments[case])) if relax is True: for case, args in self._pas[predicate.dtid].arguments.items(): for arg in args: if isinstance(arg, SpecialArgument): entities = [self.entities[arg.eid]] else: assert isinstance(arg, Argument) entities = self.get_entities(arg, include_uncertain=True) for entity in entities: if entity.is_special and entity.exophor != str(arg): pas.add_special_argument(case, entity.exophor, entity.eid, 'AND') for mention in entity.all_mentions: if isinstance(arg, Argument) and mention.dtid == arg.dtid: continue pas.add_argument(case, mention, 'AND') return pas.arguments
[docs] def get_siblings(self, mention: Mention, relax: bool = False) -> Set[Mention]: """Return all the mentions that have coreference chains with the specified mention. Args: mention (Mention): A mention. relax (bool): If True, return coreferent mentions as well. Returns: Set[Mention]: A set of mentions. """ mentions = set() for eid in mention.eids: entity = self.entities[eid] mentions.update(entity.mentions) if relax is True: for eid in mention.eids_unc: entity = self.entities[eid] mentions.update(entity.all_mentions) if mention in mentions: mentions.remove(mention) return mentions
[docs] def draw_tree(self, sid: Optional[str] = None, coreference: bool = True, fh: Optional[TextIO] = None, ) -> None: """Write out the PAS and coreference relations in the specified sentence in a tree format. If sid is not specified, write out trees in all the sentences in this document. Args: sid (str, optional): A sentence ID of the target sentence. coreference (bool): If True, write out coreference relations as well. fh (TextIO, optional): The output stream. """ if sid is None: for _sid in self.sid2sentence.keys(): self._draw_sent_tree(_sid, coreference, fh) else: self._draw_sent_tree(sid, coreference, fh)
def _draw_sent_tree(self, sid: str, coreference: bool, fh: Optional[TextIO] = None, ) -> None: """Write out the PAS and coreference relations in the specified sentence in a tree format. Args: sid (str): A sentence ID of the target sentence. coreference (bool): If True, write out coreference relations as well. fh (Optional[TextIO]): The output stream. """ blist: BList = self[sid].blist with io.StringIO() as string: blist.draw_tag_tree(fh=string, show_pos=False) tree_strings = string.getvalue().rstrip('\n').split('\n') assert len(tree_strings) == len(blist.tag_list()) all_targets = [str(m) for m in self.mentions.values()] tid2mention = {mention.tid: mention for mention in self.mentions.values() if mention.sid == sid} for bp in self[sid].bps: tree_strings[bp.tid] += ' ' # predicate-argument structure arguments = self.get_arguments(bp) for case in self.cases: args = arguments[case] targets = set() for arg in args: target = str(arg) if all_targets.count(str(arg)) > 1 and isinstance(arg, Argument): target += str(arg.dtid) targets.add(target) if targets: tree_strings[bp.tid] += f'{case}:{",".join(targets)} ' # coreference if coreference and bp.tid in tid2mention: src_mention = tid2mention[bp.tid] tgt_mentions = [tgt for tgt in self.get_siblings(src_mention) if tgt.dtid < src_mention.dtid] targets = set() for tgt_mention in tgt_mentions: target = str(tgt_mention) if all_targets.count(target) > 1: target += str(tgt_mention.dtid) targets.add(target) for eid in src_mention.eids: entity = self.entities[eid] if entity.is_special: targets.add(entity.exophor) if targets: tree_strings[src_mention.tid] += f'=:{",".join(targets)}' print('\n'.join(tree_strings), file=fh)
[docs] def stat(self) -> dict: """Calculate various kinds of statistics of this document.""" ret = dict() ret['num_sents'] = len(self) ret['num_tags'] = len(self.tag_list()) ret['num_mrphs'] = len(self.mrph_list()) ret['num_taigen'] = sum(1 for tag in self.tag_list() if '体言' in tag.features) ret['num_yougen'] = sum(1 for tag in self.tag_list() if '用言' in tag.features) ret['num_entities'] = len(self.entities) ret['num_special_entities'] = sum(1 for ent in self.entities.values() if ent.is_special) num_mention = num_taigen = num_yougen = 0 for src_mention in self.mentions.values(): tgt_mentions: Set[Mention] = self.get_siblings(src_mention) if tgt_mentions: num_mention += 1 for tgt_mention in tgt_mentions: if '体言' in tgt_mention.tag.features: num_taigen += 1 if '用言' in tgt_mention.tag.features: num_yougen += 1 ret['num_mentions'] = num_mention ret['num_taigen_mentions'] = num_taigen ret['num_yougen_mentions'] = num_yougen return ret
def __len__(self): return len(self.sid2sentence) def __getitem__(self, sid: str) -> Optional[Sentence]: if sid in self.sid2sentence: return self.sid2sentence[sid] else: logger.error(f'sentence: {sid} is not in this document') return None def __iter__(self) -> Iterator[Sentence]: return iter(self.sid2sentence.values()) def __eq__(self, other: 'Document') -> bool: return self.doc_id == other.doc_id def __str__(self): return self.surf def __repr__(self) -> str: return f'Document([' + ', '.join(sent.surf for sent in self) + f'], did={self.doc_id})'