-
Notifications
You must be signed in to change notification settings - Fork 67
/
t5_precompute.py
112 lines (97 loc) · 5.31 KB
/
t5_precompute.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0
"""Script to stream text from a dataset, compute CLIP and T5 latents, and write the latents to streaming dataset."""
import json
import os
from argparse import ArgumentParser
import torch
from streaming import MDSWriter, StreamingDataset
from tqdm import trange
from transformers import AutoModel, AutoTokenizer, CLIPTextModel
# TODO: Implement batching? 10% faster (when using t5-only), but a lot more complicated code
arg_parser = ArgumentParser()
arg_parser.add_argument('--remote_src_base',
type=str,
required=True,
help='Remote base to download MDS-formatted shards.')
arg_parser.add_argument('--remote_dst_base', type=str, required=True, help='Remote base to write MDS-formatted shards.')
arg_parser.add_argument('--subdir_paths',
nargs='+',
type=str,
required=True,
help='Path to the subdirectory to process.')
args = arg_parser.parse_args()
cache_dir = '/tmp/hf_files'
# Instantiate tokenizers
print('Building tokenizers')
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='tokenizer',
cache_dir=cache_dir,
local_files_only=True)
print('Building models')
t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl',
torch_dtype=torch.float16,
cache_dir=cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='text_encoder',
torch_dtype=torch.float16,
cache_dir=cache_dir,
local_files_only=True).cuda().eval()
columns = None
for subdir_path in args.subdir_paths:
remote_src = os.path.join(args.remote_src_base, subdir_path)
remote_dst = os.path.join(args.remote_dst_base, subdir_path)
# Dataset
print('Building dataset')
dataset = StreamingDataset(remote=remote_src,
local=os.path.join('/tmp', subdir_path),
download_timeout=300,
shuffle=False)
# Get columns
if columns is None:
with open(os.path.join('/tmp/', subdir_path, 'index.json')) as f:
index_json = json.load(f)
columns = dict(zip(index_json['shards'][0]['column_names'], index_json['shards'][0]['column_encodings']))
columns['T5_ATTENTION_MASK'] = 'bytes'
columns['T5_LATENTS'] = 'bytes'
columns['CLIP_ATTENTION_MASK'] = 'bytes'
columns['CLIP_LATENTS'] = 'bytes'
columns['CLIP_POOLED_TEXT'] = 'bytes'
print(columns)
# Make writer
writer = MDSWriter(out=remote_dst, columns=columns, compression='zstd', hashes=[], size_limit='1GB')
print('Loading batch')
with torch.no_grad():
for i in trange(len(dataset)):
sample = dataset[i]
captions = sample['DESCRIPTION']
# Pre-compute T5
t5_tokenizer_out = t5_tokenizer(captions,
padding='max_length',
max_length=t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
tokenized_captions = t5_tokenizer_out['input_ids'].cuda()
attention_masks = t5_tokenizer_out['attention_mask'].to(torch.bool).cuda()
sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes()
t5_out = t5_model(input_ids=tokenized_captions, attention_mask=attention_masks)
sample['T5_LATENTS'] = t5_out[0].squeeze(0).cpu().numpy().tobytes()
# Pre-compute CLIP
clip_tokenizer_out = clip_tokenizer(captions,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
tokenized_captions = clip_tokenizer_out['input_ids'].cuda()
attention_masks = clip_tokenizer_out['attention_mask'].cuda()
sample['CLIP_ATTENTION_MASK'] = clip_tokenizer_out['attention_mask'].squeeze(0).to(
torch.bool).numpy().tobytes()
clip_out = clip_model(input_ids=tokenized_captions,
attention_mask=attention_masks,
output_hidden_states=True)
sample['CLIP_LATENTS'] = clip_out.hidden_states[-2].squeeze(0).cpu().numpy().tobytes()
sample['CLIP_POOLED_TEXT'] = clip_out[1].squeeze(0).cpu().numpy().tobytes()
writer.write(sample)
writer.finish()