머신러닝/Tensorflow

Tensorflow 2.0 - Text classification by TF Hub

aiemag 2021. 3. 8. 22:40
반응형

※ Tensorflow Hub : 일반화된 문제들에 대해서 모델의 재사용성을 극대화 하기 위해 구글에서 공개한 API, 미리 훈련된 모델을 FIne Tuning하여 쉽게 사용할 수 있음.

※ hub 설치 pip install tensorflow-hub

※ tfds 서치 pip install tensorflow_datasets

 

Tensorflow Hub를 이용해서 기존 text classficiation 을 해봅니다.

 

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
 
# load data
train_data, validation_data, test_data = tfds.load(
    name="imdb_reviews",
    split=('train[:60%]''train[60%:]''test'),
    as_supervised=True)
 
train_examples_batch, train_labels_batch = next(iter(train_data.batch(10)))
 
# make model
embedding = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embedding, input_shape=[],
                           dtype=tf.string, trainable=True)
 
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Dense(1))
 
model.summary()
 
model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
 
 
# train model
history = model.fit(train_data.shuffle(10000).batch(512),
                    epochs=20,
                    validation_data=validation_data.batch(512),
                    verbose=1)
 
# validation
results = model.evaluate(test_data.batch(512), verbose=2)
 
for name, value in zip(model.metrics_names, results):
    print("%s: %.3f" % (name, value))
cs

 

결과

반응형

'머신러닝 > Tensorflow' 카테고리의 다른 글

Tensorflow 2.0 - Basic text classification  (0) 2021.03.07
Tensorflow 2.0 - Basic image classification  (0) 2021.03.06
Tensorflow 2.0 - Tutorial  (0) 2021.03.05