You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
834 B
Python
23 lines
834 B
Python
7 years ago
|
import numpy as np
|
||
|
|
||
|
def split_data(chars, batch_size, num_steps, split_frac=0.9):
|
||
|
slice_size = batch_size * num_steps
|
||
|
n_batches = int(len(chars) / slice_size)
|
||
|
x = chars[: n_batches*slice_size]
|
||
|
y = chars[1: n_batches*slice_size + 1]
|
||
|
split_idx = int(batch_size*split_frac)
|
||
|
|
||
|
x = np.stack(np.split(x, batch_size))
|
||
|
y = np.stack(np.split(y, batch_size))
|
||
|
|
||
|
split_idx = int(n_batches*split_frac)
|
||
|
train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
|
||
|
val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
|
||
|
return train_x, train_y, val_x, val_y
|
||
|
|
||
|
def get_batch(arrs, num_steps):
|
||
|
batch_size, slice_size = arrs[0].shape
|
||
|
|
||
|
n_batches = int(slice_size/num_steps)
|
||
|
for b in range(n_batches):
|
||
|
yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]
|