반응형
※ Tensorflow Hub : 일반화된 문제들에 대해서 모델의 재사용성을 극대화 하기 위해 구글에서 공개한 API, 미리 훈련된 모델을 FIne Tuning하여 쉽게 사용할 수 있음.
※ hub 설치 pip install tensorflow-hub
※ tfds 서치 pip install tensorflow_datasets
Tensorflow Hub를 이용해서 기존 text classficiation 을 해봅니다.
Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
# load data
train_data, validation_data, test_data = tfds.load(
name="imdb_reviews",
split=('train[:60%]', 'train[60%:]', 'test'),
as_supervised=True)
train_examples_batch, train_labels_batch = next(iter(train_data.batch(10)))
# make model
embedding = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embedding, input_shape=[],
dtype=tf.string, trainable=True)
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Dense(1))
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# train model
history = model.fit(train_data.shuffle(10000).batch(512),
epochs=20,
validation_data=validation_data.batch(512),
verbose=1)
# validation
results = model.evaluate(test_data.batch(512), verbose=2)
for name, value in zip(model.metrics_names, results):
print("%s: %.3f" % (name, value))
|
cs |
결과
반응형
'머신러닝 > Tensorflow' 카테고리의 다른 글
Tensorflow 2.0 - Basic text classification (0) | 2021.03.07 |
---|---|
Tensorflow 2.0 - Basic image classification (0) | 2021.03.06 |
Tensorflow 2.0 - Tutorial (0) | 2021.03.05 |