Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hsd1503 committed Apr 4, 2020
1 parent cb8a341 commit b0f1663
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,55 @@ def read_data_physionet_4(window_size=3000, stride=500):

return X_train, X_test, Y_train, Y_test, pid_test

def read_data_physionet_4_with_val(window_size=3000, stride=500):

# read pkl
with open('../data/challenge2017/challenge2017.pkl', 'rb') as fin:
res = pickle.load(fin)
## scale data
all_data = res['data']
for i in range(len(all_data)):
tmp_data = all_data[i]
tmp_std = np.std(tmp_data)
tmp_mean = np.mean(tmp_data)
all_data[i] = (tmp_data - tmp_mean) / tmp_std
## encode label
all_label = []
for i in res['label']:
if i == 'N':
all_label.append(0)
elif i == 'A':
all_label.append(1)
elif i == 'O':
all_label.append(2)
elif i == '~':
all_label.append(3)
all_label = np.array(all_label)

# split train val test
X_train, X_test, Y_train, Y_test = train_test_split(all_data, all_label, test_size=0.2, random_state=0)
X_val, X_test, Y_val, Y_test = train_test_split(X_test, Y_test, test_size=0.5, random_state=0)

# slide and cut
print('before: ')
print(Counter(Y_train), Counter(Y_val), Counter(Y_test))
X_train, Y_train = slide_and_cut(X_train, Y_train, window_size=window_size, stride=stride)
X_val, Y_val, pid_val = slide_and_cut(X_val, Y_val, window_size=window_size, stride=stride, output_pid=True)
X_test, Y_test, pid_test = slide_and_cut(X_test, Y_test, window_size=window_size, stride=stride, output_pid=True)
print('after: ')
print(Counter(Y_train), Counter(Y_val), Counter(Y_test))

# shuffle train
shuffle_pid = np.random.permutation(Y_train.shape[0])
X_train = X_train[shuffle_pid]
Y_train = Y_train[shuffle_pid]

X_train = np.expand_dims(X_train, 1)
X_val = np.expand_dims(X_val, 1)
X_test = np.expand_dims(X_test, 1)

return X_train, X_val, X_test, Y_train, Y_val, Y_test, pid_val, pid_test


def read_data_generated(n_samples, n_length, n_channel, n_classes, verbose=False):
"""
Expand Down

0 comments on commit b0f1663

Please sign in to comment.