API - Model Training¶
TensorLayerX provides two model training interfaces, which can satisfy the training of various deep learning tasks.
|
High-Level API for Training or Testing. |
|
High-Level API for Training or Testing. |
|
Module that returns the gradients. |
|
High-Level API for Training One Step. |
Model¶
-
tensorlayerx.model.
Model
(network, loss_fn=None, optimizer=None, metrics=None, **kwargs)[source]¶ High-Level API for Training or Testing.
Model groups layers into an object with training and inference features.
- Parameters
network (tensorlayer model) – The training or testing network.
loss_fn (function) – Objective function
optimizer (class) – Optimizer for updating the weights
metrics (class) – Dict or set of metrics to be evaluated by the model during
-
tensorlayerx.model.
trin
()¶ Model training.
-
tensorlayerx.model.
eval
()¶ Model prediction.
-
tensorlayerx.model.
save_weights
()¶ Input file_path, save model weights into a file of given format. Use load_weights() to restore.
-
tensorlayerx.model.
load_weights
()¶ Load model weights from a given file, which should be previously saved by save_weights().
Examples
>>> import tensorlayerx as tlx >>> class Net(Module): >>> def __init__(self): >>> super(Net, self).__init__() >>> self.conv = tlx.nn.Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), in_channels=5, name='conv2d') >>> self.bn = tlx.nn.BatchNorm2d(num_features=32, act=tlx.ReLU) >>> self.flatten = tlx.nn.Flatten() >>> self.fc = tlx.nn.Dense(n_units=12, in_channels=32*224*224) # padding=0 >>> >>> def construct(self, x): >>> x = self.conv(x) >>> x = self.bn(x) >>> x = self.flatten(x) >>> out = self.fc(x) >>> return out >>> >>> net = Net() >>> loss = tlx.losses.softmax_cross_entropy_with_logits >>> optim = tlx.optimizers.Momentum(params=net.trainable_weights, learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> dataset = get_dataset() >>> model.train(2, dataset)
WithLoss¶
-
tensorlayerx.model.
WithLoss
(backbone, loss_fn)[source]¶ High-Level API for Training or Testing.
Wraps the network with loss function. This Module accepts data and label as inputs and the computed loss will be returned.
- Parameters
backbone (tensorlayer model) – The tensorlayer network.
loss_fn (function) – Objective function
-
tensorlayerx.model.
forward
()¶ Model inference.
Examples
>>> import tensorlayerx as tlx >>> net = vgg16() >>> loss_fn = tlx.losses.softmax_cross_entropy_with_logits >>> net_with_loss = tlx.model.WithLoss(net, loss_fn)
WithGrad¶
-
tensorlayerx.model.
WithGrad
(network, loss_fn=None, optimizer=None)[source]¶ Module that returns the gradients.
- Parameters
network (tensorlayer model) – The tensorlayer network.
loss_fn (function) – Objective function
optimizer (class) – Optimizer for updating the weights
Examples
>>> import tensorlayerx as tlx >>> net = vgg16() >>> loss_fn = tlx.losses.softmax_cross_entropy_with_logits >>> optimizer = tlx.optimizers.Adam(learning_rate=1e-3) >>> net_with_grad = tlx.model.WithGrad(net, loss_fn, optimizer) >>> inputs, labels = tlx.nn.Input((128, 784), dtype=tlx.float32), tlx.nn.Input((128, 1), dtype=tlx.int32) >>> net_with_grad(inputs, labels)
TrainOneStep¶
-
tensorlayerx.model.
TrainOneStep
(net_with_loss, optimizer, train_weights)[source]¶ High-Level API for Training One Step.
Wraps the network with an optimizer. It can be trained in one step using the optimizer to get the loss.
- Parameters
net_with_loss (tensorlayer WithLoss) – The training or testing network.
optimizer (class) – Optimizer for updating the weights
train_weights (class) – Dict or set of metrics to be evaluated by the model during
Examples
>>> import tensorlayerx as tlx >>> net = vgg16() >>> train_weights = net.trainable_weights >>> loss_fn = tlx.losses.softmax_cross_entropy_with_logits >>> optimizer = tlx.optimizers.Adam(learning_rate=1e-3) >>> net_with_loss = tlx.model.WithLoss(net, loss_fn) >>> train_one_step = tlx.model.TrainOneStep(net_with_loss, optimizer, train_weights) >>> inputs, labels = tlx.nn.Input((128, 784), dtype=tlx.float32), tlx.nn.Input((128, 1), dtype=tlx.int32) >>> train_one_step(inputs, labels)