对iris数据集合进行多酚类
library(tfestimators)
response <- function() "Species"
features <- function() setdiff(names(iris), response())
# split into train, test datasets
set.seed(123)
partitions <- modelr::resample_partition(iris, c(test = 0.2, train = 0.8))
iris_train <- as.data.frame(partitions$train)
iris_test <- as.data.frame(partitions$test)
# construct feature columns
feature_columns <- feature_columns(
column_numeric(features())
)
# construct classifier
classifier <- dnn_classifier(
feature_columns = feature_columns,
hidden_units = c(10, 20, 10),
n_classes = 3
)
# construct input function
iris_input_fn <- function(data) {
input_fn(data, features = features(), response = response())
}
# train classifier with training dataset
train(classifier, input_fn = iris_input_fn(iris_train))
The following factor levels of 'Species' have been encoded:
- 'setosa' => 0
- 'versicolor' => 1
- 'virginica' => 2
2018-02-18 16:13:55.750304: E tensorflow/core/util/events_writer.cc:162] The events file /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local has disappeared.
2018-02-18 16:13:55.750358: E tensorflow/core/util/events_writer.cc:131] Failed to flush 1 events to /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local
[-] Training -- loss: 143.24, step: 1
2018-02-18 16:13:56.963146: E tensorflow/core/util/events_writer.cc:162] The events file /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local has disappeared.
2018-02-18 16:13:56.963224: E tensorflow/core/util/events_writer.cc:131] Failed to flush 5 events to /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local
# valuate with test dataset
predictions <- predict(classifier, input_fn = iris_input_fn(iris_test))
predictions
# A tibble: 29 x 4
logits probabilities classes class_ids
<list> <list> <list> <list>
1 <-1.44, 0.525, -1.15> <0.105, 0.753, 0.142> <1> <1>
2 <-1.37, 0.516, -1.15> <0.113, 0.746, 0.141> <1> <1>
3 <-1.29, 0.475, -1.05> <0.123, 0.72, 0.157> <1> <1>
4 <-1.28, 0.469, -1.03> <0.125, 0.716, 0.159> <1> <1>
5 <-1.28, 0.464, -1.02> <0.125, 0.713, 0.162> <1> <1>
6 <-1.28, 0.485, -1.09> <0.124, 0.726, 0.15> <1> <1>
7 <-1.32, 0.486, -1.07> <0.119, 0.728, 0.153> <1> <1>
8 <-1.25, 0.462, -1.03> <0.128, 0.711, 0.161> <1> <1>
9 <-1.45, 0.526, -1.14> <0.104, 0.754, 0.142> <1> <1>
10 <-1.22, 0.44, -0.963> <0.132, 0.697, 0.171> <1> <1>
# ... with 19 more rows
evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_test))
The following factor levels of 'Species' have been encoded:
- 'setosa' => 0
- 'versicolor' => 1
- 'virginica' => 2
[-] Evaluating -- loss: 43.15, step: 1
> evaluation
# A tibble: 1 x 4
average_loss accuracy loss global_step
<dbl> <dbl> <dbl> <dbl>
1 1.49 0.345 43.2 2.00