Skip to content

Commit

Permalink
Attempt to reuse the same machinery for CTB 9.0 as for 5.1 The file l…
Browse files Browse the repository at this point in the history
…ist is quite extensive...
  • Loading branch information
AngledLuffa committed Jul 7, 2023
1 parent 4ddacfd commit 323f446
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 15 deletions.
148 changes: 135 additions & 13 deletions stanza/utils/datasets/constituency/convert_ctb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
import glob
import os
import re
Expand All @@ -7,7 +8,11 @@
from stanza.models.constituency import tree_reader
from stanza.utils.datasets.constituency.utils import write_dataset

def filenum_to_shard(filenum):
class Version(Enum):
V51 = 1
V90 = 2

def filenum_to_shard_51(filenum):
if filenum >= 1 and filenum <= 815:
return 0
if filenum >= 1001 and filenum <= 1136:
Expand All @@ -25,18 +30,100 @@ def filenum_to_shard(filenum):

raise ValueError("Unhandled filenum %d" % filenum)

def collect_trees(root):
def filenum_to_shard_90(filenum):
if filenum >= 1 and filenum <= 40:
return 1
if filenum >= 900 and filenum <= 931:
return 1
if filenum in (1018, 1020, 1036, 1044, 1060, 1061, 1072, 1118, 1119, 1132, 1141, 1142, 1148):
return 1
if filenum >= 2165 and filenum <= 2180:
return 1
if filenum >= 2295 and filenum <= 2310:
return 1
if filenum >= 2570 and filenum <= 2602:
return 1
if filenum >= 2800 and filenum <= 2819:
return 1
if filenum >= 3110 and filenum <= 3145:
return 1


if filenum >= 41 and filenum <= 80:
return 2
if filenum >= 1120 and filenum <= 1129:
return 2
if filenum >= 2140 and filenum <= 2159:
return 2
if filenum >= 2280 and filenum <= 2294:
return 2
if filenum >= 2550 and filenum <= 2569:
return 2
if filenum >= 2775 and filenum <= 2799:
return 2
if filenum >= 3080 and filenum <= 3109:
return 2

if filenum >= 81 and filenum <= 900:
return 0
if filenum >= 1001 and filenum <= 1017:
return 0
if filenum in (1019, 1130, 1131):
return 0
if filenum >= 1021 and filenum <= 1035:
return 0
if filenum >= 1037 and filenum <= 1043:
return 0
if filenum >= 1045 and filenum <= 1059:
return 0
if filenum >= 1062 and filenum <= 1071:
return 0
if filenum >= 1073 and filenum <= 1117:
return 0
if filenum >= 1133 and filenum <= 1140:
return 0
if filenum >= 1143 and filenum <= 1147:
return 0
if filenum >= 1149 and filenum <= 2139:
return 0
if filenum >= 2160 and filenum <= 2164:
return 0
if filenum >= 2181 and filenum <= 2279:
return 0
if filenum >= 2311 and filenum <= 2549:
return 0
if filenum >= 2603 and filenum <= 2774:
return 0
if filenum >= 2820 and filenum <= 3079:
return 0
if filenum >= 4000 and filenum <= 7017:
return 0


def collect_trees_s(root):
if root.tag == 'S':
yield root.text, root.attrib['ID']

for child in root:
for tree in collect_trees(child):
for tree in collect_trees_s(child):
yield tree

def collect_trees_text(root):
if root.tag == 'TEXT' and len(root.text.strip()) > 0:
yield root.text, None

if root.tag == 'TURN' and len(root.text.strip()) > 0:
yield root.text, None

for child in root:
for tree in collect_trees_text(child):
yield tree


id_re = re.compile("<S ID=([0-9a-z]+)>")
amp_re = re.compile("[&]")
su_re = re.compile("<(su|msg) id=([0-9a-zA-Z_=]+)>")

def convert_ctb(input_dir, output_dir, dataset_name):
def convert_ctb(input_dir, output_dir, dataset_name, version):
input_files = glob.glob(os.path.join(input_dir, "*"))

# train, dev, test
Expand All @@ -50,25 +137,60 @@ def convert_ctb(input_dir, output_dir, dataset_name):
sorted_filenames.sort()

for filenum, filename in sorted_filenames:
with open(filename, errors='ignore', encoding="gb2312") as fin:
text = fin.read()
text = id_re.sub(r'<S ID="\1">', text)
text = text.replace("&", "&amp;")
if version is Version.V51:
with open(filename, errors='ignore', encoding="gb2312") as fin:
text = fin.read()
else:
with open(filename, encoding="utf-8") as fin:
text = fin.read()
if text.find("<TURN>") >= 0 and text.find("</TURN>") < 0:
text = text.replace("<TURN>", "")
if filenum in (4205, 4208, 4289):
text = text.replace("<)", "&lt;)").replace(">)", "&gt;)")
if filenum >= 4000 and filenum <= 4411:
if text.find("<segment") >= 0:
text = text.replace("<segment id=", "<S ID=").replace("</segment>", "</S>")
elif text.find("<seg") < 0:
text = "<TEXT>\n%s</TEXT>\n" % text
else:
text = text.replace("<seg id=", "<S ID=").replace("</seg>", "</S>")
text = "<foo>\n%s</foo>\n" % text
if filenum >= 5000 and filenum <= 5558 or filenum >= 6000 and filenum <= 6700 or filenum >= 7000 and filenum <= 7017:
text = su_re.sub("", text)
if filenum in (6066, 6453):
text = text.replace("<", "&lt;").replace(">", "&gt;")
text = "<foo><TEXT>\n%s</TEXT></foo>\n" % text
text = id_re.sub(r'<S ID="\1">', text)
text = text.replace("&", "&amp;")

try:
xml_root = ET.fromstring(text)
except Exception as e:
print(text[:1000])
raise RuntimeError("Cannot xml process %s" % filename) from e
trees = [x for x in collect_trees(xml_root)]
trees = [x[0] for x in trees if filenum != 414 or x[1] != "4366"]
trees = [x for x in collect_trees_s(xml_root)]
if version is Version.V90 and len(trees) == 0:
trees = [x for x in collect_trees_text(xml_root)]

if version is Version.V51:
trees = [x[0] for x in trees if filenum != 414 or x[1] != "4366"]
else:
trees = [x[0] for x in trees]

trees = "\n".join(trees)
trees = tree_reader.read_trees(trees)
try:
trees = tree_reader.read_trees(trees)
except ValueError as e:
print(text[:300])
raise RuntimeError("Could not process the tree text in %s" % filename)
trees = [t.prune_none().simplify_labels() for t in trees]

assert len(trees) > 0, "No trees in %s" % filename

shard = filenum_to_shard(filenum)
if version is Version.V51:
shard = filenum_to_shard_51(filenum)
else:
shard = filenum_to_shard_90(filenum)
datasets[shard].extend(trees)


Expand Down
21 changes: 19 additions & 2 deletions stanza/utils/datasets/constituency/prepare_con_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@
year={2005},
pages={207–238}}
zh_ctb-90 is the 9.0 version of CTB
put LDC2016T13 in $CONSTITUENCY_BASE/chinese
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-90
the splits used are the ones from the file docs/ctb9.0-file-list.txt
included in the CTB 9.0 release
en_ptb3-revised is an updated version of PTB with NML and stuff
put LDC2015T13 in $CONSTITUENCY_BASE/english
the directory name may look like LDC2015T13_eng_news_txt_tbnk-ptb_revised
Expand Down Expand Up @@ -169,7 +176,7 @@
from stanza.utils.datasets.constituency.convert_alt import convert_alt
from stanza.utils.datasets.constituency.convert_arboretum import convert_tiger_treebank
from stanza.utils.datasets.constituency.convert_cintil import convert_cintil_treebank
from stanza.utils.datasets.constituency.convert_ctb import convert_ctb
import stanza.utils.datasets.constituency.convert_ctb as convert_ctb
from stanza.utils.datasets.constituency.convert_it_turin import convert_it_turin
from stanza.utils.datasets.constituency.convert_it_vit import convert_it_vit
from stanza.utils.datasets.constituency.convert_starlang import read_starlang
Expand Down Expand Up @@ -380,7 +387,16 @@ def process_ctb_51(paths, dataset_name, *args):

input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed")
output_dir = paths["CONSTITUENCY_DATA_DIR"]
convert_ctb(input_dir, output_dir, dataset_name)
convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51)

def process_ctb_90(paths, dataset_name, *args):
lang, source = dataset_name.split("_", 1)
assert lang == 'zh-hans'
assert source == 'ctb-90'

input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2016T13", "ctb9.0", "data", "bracketed")
output_dir = paths["CONSTITUENCY_DATA_DIR"]
convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V90)


def process_ptb3_revised(paths, dataset_name, *args):
Expand Down Expand Up @@ -425,6 +441,7 @@ def process_ptb3_revised(paths, dataset_name, *args):
'vi_vlsp22': process_vlsp22,

'zh-hans_ctb-51': process_ctb_51,
'zh-hans_ctb-90': process_ctb_90,
}

def main(dataset_name, *args):
Expand Down

0 comments on commit 323f446

Please sign in to comment.