딥러닝에 쓰이는 어려운 라이브러리 정리
Python 라이브러리 정리
이 게시글의 계기는 TensorFlow에 대한 github의 코드를 이해하지 못해서 입니다.
전체 코드를 보고 싶으시다면 char-rnn-tensorflow에 있습니다.물론 아래 코드는 해당 라이브러리 함수의 빙산의 일각에 불과합니다.
굉장히 주관적인 어려움에 의한 게시글입니다.
다행히 여기 쓰일 함수를 구글링해서 들어오신 분들이 많을거라 예상되기 때문에 저가 이미 알고 있는 함수들에 대한 언급은 하지 않겠습니다.
딥러닝에서 아래 라이브러리들이 다 꼭 쓰여야만 하는 것도 아닙니다.
단순히 나중을 위해 공부하는 것입니다.
import tensorflow as tf
import argparse
import time
import os
from six.moves import cPickle
from utils import TextLoader
from model import Model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
help='data directory containing input.txt')
'''(중략)'''
parser.add_argument('--init_from', type=str, default=None,
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
'config.pkl' : configuration;
'chars_vocab.pkl' : vocabulary definitions;
'checkpoint' : paths to model file(s) (created by tf).
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)""")
args = parser.parse_args()
train(args)
def train(args):
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
args.vocab_size = data_loader.vocab_size
# check compatibility if training is continued from previously saved model
if args.init_from is not None:
# check if all necessary files exist
assert os.path.isdir(args.init_from), " %s must be a a path" % args.init_from
'''(중략)'''
ckpt = tf.train.get_checkpoint_state(args.init_from)
assert ckpt, "No checkpoint found"
assert ckpt.model_checkpoint_path, "No model path found in checkpoint"
# open old config and check if models are compatible
with open(os.path.join(args.init_from, 'config.pkl')) as f:
saved_model_args = cPickle.load(f)
need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
for checkme in need_be_same:
assert vars(saved_model_args)[checkme] == vars(args)[
checkme], "Command line argument and saved model disagree on '%s' " % checkme
# open saved vocab/dict and check if vocabs/dicts are compatible
with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
saved_chars, saved_vocab = cPickle.load(f)
assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
cPickle.dump(args, f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
cPickle.dump((data_loader.chars, data_loader.vocab), f)
model = Model(args)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
if args.init_from is not None:
saver.restore(sess, ckpt.model_checkpoint_path)
for e in range(args.num_epochs):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
state = model.initial_state.eval()
for b in range(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
feed = {model.input_data: x, model.targets: y, model.initial_state: state}
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
end = time.time()
print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
.format(e * data_loader.num_batches + b, args.num_epochs * data_loader.num_batches, e, train_loss, end - start))
if (e * data_loader.num_batches + b) % args.save_every == 0 or (e == args.num_epochs - 1 and b == data_loader.num_batches - 1):
checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
print("model saved to {}".format(checkpoint_path))
위 코드는 RNN을 통해 셰익스피어의 책을 학습하고 컴퓨터가 스스로 그와 비슷한 작품을 만드는 코드의 일부분입니다.
깃헙을 보시면 아시겠지만
마지막 두 라이브러리인 TextLoader와 Model은 같은 폴더에 있는 클래스입니다.
1. tensorflow
2. argparse
3. Pickle
이제 하나씩 라이브러리에 대해 알아보죠.
1. tensorflow
1.1 tf.train.get_checkpoint_state(checkpoint_dir, latest_filename=None)
모델의 CheckpointState proto를 반환하는 함수입니다.
ckpt = tf.train.get_checkpoint_state(args.init_from)
assert ckpt, "No checkpoint found"
assert ckpt.model_checkpoint_path, "No model path found in checkpoint"
위처럼 예제에서 쓰였습니다.
assert()는 쉼표 앞에는 condition을 나타내고 뒤에는 condition이 False였을 시에 나타나는 AssertionError문입니다. 예외 처리의 일종이라고 합니다.
ckpt에 객체를 넣어놓고 assert를 통해 예외 처리하는 과정입니다.
1.2 tf.train.Saver(tf.all_variables())
이 함수가 제일 궁금했습니다. Saver 때문에 위 코드를 실행시키면 모델 체크 포인트가 디렉토리 안에 저장되었습니다. 딥러닝이 워낙 학습하는 데에 긴 시간이 걸리기 때문에 만들어진 함수가 아닌가 싶습니다.
저가 궁금한 것은 모델의 마지막 체크포인트를 불러들여 나의 예측하고자 하는 값을 넣으면 바로 예측값이 나오느냐입니다. 왜냐하면 만약 위와 같은 과정이 되고 13시간동안 99%의 정확도를 가진 모델이 내가 만든 프로그램에 들어간다면, 딥러닝이 생활에서도 쉽게 쓰일 수 있기 때문입니다.
모든 변수들을 Saver클래스의 매개변수로 넘겨 저장합니다.
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
위와 같이 global_step의 수에 맞게 파일 이름도 지정되어 저장됩니다.
# Create a saver.
saver = tf.train.Saver(tf.all_variables())
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
# Append the step number to the checkpoint name:
saver.save(sess, 'my-model', global_step=step)
모든 변수들을 Saver클래스의 매개변수로 넘겨 저장합니다.
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
위와 같이 global_step의 수에 맞게 파일 이름도 지정되어 저장됩니다.
1.3 tf.train.Saver.restore(sess, save_path)
saver.restore(sess, ckpt.model_checkpoint_path)는 자동으로 step에 맞춰 저장되기 전에 ckpt에 있던 args의 값들을 저장 saver.save가 아닌 restore이라는 함수를 써서 저장하는 것입니다.
1.4 ckpt = tf.train.get_checkpoint_state()
args에서 체크포인트를 받아 상태를 ckpt에 저장합니다. 이제 ckpt에서 . 연산자를 이용해 더 많은 일들을 할 준비가 됐습니다.
1.5 ckpt.model_checkpoint_path
위에서 ckpt에 상태를 모두 저장한 상태입니다. 또한 . 연산자를 이용해 체크 포인트의 path를 반환할 수 있습니다.
1.6 flags = tf.app.flags
이 함수는 아래의 argparse 모듈과 같은 역할을 합니다. argparse보다 나은 점은 저가 볼 때는 굳이 argparser를 import 할 필요 없이 아래와 같이 추가만 하면 됩니다.
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]")
FLAGS = flags.FLAGS
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
flags에 tf.app.flags 를 통해 객체를 저장하고 DEFINE_integer나 DEFINE_string 을 통해 아래 argparse와 같은 역할을 할 수 있습니다.
2. argparse
2.1 parser = argparse.ArgumentParser()
parser 객체를 생성하는 코드입니다. 이 코드를 실행하고 parser 뒤에 . 을 붙여 함수들을 사용합니다.
2.2 parser.add_argument()
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
help='data directory containing input.txt')
이 함수는 위와 같이 쓰입니다. 위 함수를 토대로 설명하자면, $python a.py처럼 실행할 때, 위 옵션을 사용하려면 $python a.py --data_dir data/tiny 처럼 옵션을 설정할 수 있습니다.
위에는 default 옵션을 사용해서 default 디렉토리와 다른 디렉토리에 데이터가 있지 않는 이상 안써도 됩니다.
help 옵션은 $python a.pu -h 를 실행했을 때 help 문구가 출력되면서 옵션에 대한 설명을 나타냅니다.
2.3 parser.parse_args()
이제 parser에 add한 arguments를 통합해 변수에 저장할 때 쓰입니다.
변수에 저장하면, 아래와 같이 . 연산자를 이용해 옵션에 접근할 수 있습니다.
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare', help='data directory containing input.txt')
args = parser.parse_args()
print(args.data_dir)
3. pickle
우선 Pickle에 대해 간략히 설명하자면 객체를 저장하는 모듈입니다. serializing 하는 모듈이라고도 합니다.
3.1 pickle.save() & pickle.load()
객체를 위 함수를 이용해 저장 및 불러올 수 있습니다.
간단한 예제를 이용해 보죠.
import pickle
import os
tup_ob = {'a' : 3, 'b' : 5}
list_ob = ['string', 1023, 103.4]
with open(os.path.join(os.getcwd(), 'pickle.pkl'),'wb') as f:
save_pickle = pickle.dump((tup_ob,list_ob),f)
with open(os.path.join(os.getcwd(), 'pickle.pkl'),'rb') as f:
tup,li = pickle.load(f)
print(tup,li)
위 두 변수를 출력해보면, 아래와 같은 결과를 얻을 수 있습니다.
{'a': 3, 'b': 5} ['string', 1023, 103.4]
Reference
* https://github.com/sherjilozair/char-rnn-tensorflow/
댓글
댓글 쓰기