Hyperparameter tuning for DNNs tends to be a bit more involved than other ML models due to the number of hyperparameters that can/should be assessed and the dependencies between these parameters. To automate the hyperparameter tuning for keras and tensorflow, we use the tfruns package.
This notebook shows an example of performing a grid search on a densley connected feedforward neural network for the IMDB movie review classifier introduced in this case study.
library(tfruns)
library(dplyr)
tfruns provides added flexibility for tracking, visualizing, and managing training runs. The most common way to use tfruns is to create an R script that contains the code to be executed. For this example, I created the imdb-grid-search.R script.
Within this script, we create “flags”, which identify the hyperparameters of interest. For this example, I assess:
- batch sizes of 128 and 512
- layers of 1, 2, and 3
- number of units per hidden layer of 16, 32, and 64
- learning rate of 0.001 and 0.0001
- dropout rates of 0, 0.3, and 0.5
- weight decay of 0, 0.01, and 0.001
This equates to 324 models. Since these models run relatively quickly, and since I ran this at the end of the day, I performed a full cartesian grid search. This means I ran and assessed every single one of the 324 models. If you are pressed for time you can run a stochastic hyperparameter grid search by using the sample
parameter in tuning_run
below.
Next, to execute this grid search, you first specify the grid search hyperparameter values for the flags you created in your .R script like below. You then run tuning_run()
in place of source()
to execute your .R script for the supplied hyperparameter grid.
Note: this takes over 3 hours to run on a non-GPU
grid_search <- list(
batch_size = c(128, 512),
layers = c(1, 2, 3),
units = c(16, 32, 64),
learning_rate = c(0.001, 0.0001),
dropout = c(0, 0.3, 0.5),
weight_decay = c(0, 0.01, 0.001)
)
tuning_run("imdb-grid-search.R", flags = grid_search, confirm = FALSE, echo = FALSE)
This grid search execution will create a “runs” subdirectory within your working directory. This fold contains information for every single training run executed during the grid search.
To list the results you can run ls_runs()
:
data.frame(ls_runs(runs_dir = "imdb_runs"))
You can even filter and order the results. The following illustrates that a few models tied with the lowest loss score of 0.262.
ls_runs(runs_dir = "imdb_runs", order = eval_best_loss, decreasing = FALSE)
Data frame: 324 x 31
# ... with 314 more rows
# ... with 23 more columns:
# flag_batch_size, flag_layers, flag_units, flag_learning_rate, flag_dropout,
# flag_weight_decay, samples, batch_size, epochs, epochs_completed, metrics, model,
# loss_function, optimizer, learning_rate, script, start, end, completed, output,
# source_code, context, type
To see details about any one of these models you can run view_run()
. In this example, I take the first optimal model from above. When you execute this, a pop up window will appear with that models summary information as illustrated below.
best_run <- ls_runs(
runs_dir = "imdb_runs",
order = eval_best_loss,
decreasing = FALSE
) %>%
slice(1) %>%
pull(run_dir)
view_run(best_run)
There are many other handy features to the tfruns package. I suggest you check it out at https://tensorflow.rstudio.com/tools/tfruns/overview/ and take it for a test drive.
LS0tCnRpdGxlOiAiSU1EQiBNb3ZpZSBDbGFzc2lmaWNhdGlvbiBHcmlkIFNlYXJjaCIKb3V0cHV0OiBodG1sX25vdGVib29rCi0tLQoKYGBge3Igc2V0dXAsIGluY2x1ZGU9RkFMU0V9CmtuaXRyOjpvcHRzX2NodW5rJHNldChlY2hvID0gVFJVRSkKZ2dwbG90Mjo6dGhlbWVfc2V0KGdncGxvdDI6OnRoZW1lX21pbmltYWwoKSkKCiMgY2xlYW4gdXAgaW4gY2FzZSB5b3UgcnVuIHRoaXMgbXVsdGlwbGUgdGltZXMKdGZydW5zOjpjbGVhbl9ydW5zKGNvbmZpcm0gPSBGQUxTRSkKYGBgCgpIeXBlcnBhcmFtZXRlciB0dW5pbmcgZm9yIEROTnMgdGVuZHMgdG8gYmUgYSBiaXQgbW9yZSBpbnZvbHZlZCB0aGFuIG90aGVyIE1MCm1vZGVscyBkdWUgdG8gdGhlIG51bWJlciBvZiBoeXBlcnBhcmFtZXRlcnMgdGhhdCBjYW4vc2hvdWxkIGJlIGFzc2Vzc2VkIGFuZCB0aGUKZGVwZW5kZW5jaWVzIGJldHdlZW4gdGhlc2UgcGFyYW1ldGVycy4gVG8gYXV0b21hdGUgdGhlIGh5cGVycGFyYW1ldGVyIHR1bmluZwpmb3Iga2VyYXMgYW5kIHRlbnNvcmZsb3csIHdlIHVzZSB0aGUgW19fdGZydW5zX19dKGh0dHBzOi8vZ2l0aHViLmNvbS9yc3R1ZGlvL3RmcnVucykKcGFja2FnZS4KClRoaXMgbm90ZWJvb2sgc2hvd3MgYW4gZXhhbXBsZSBvZiBwZXJmb3JtaW5nIGEgZ3JpZCBzZWFyY2ggb24gYSBkZW5zbGV5IGNvbm5lY3RlZApmZWVkZm9yd2FyZCBuZXVyYWwgbmV0d29yayBmb3IgdGhlIElNREIgbW92aWUgcmV2aWV3IGNsYXNzaWZpZXIgaW50cm9kdWNlZCBpbgp0aGlzIFtjYXNlIHN0dWR5XShodHRwczovL3JzdHVkaW8tY29uZi0yMDIwLmdpdGh1Yi5pby9kbC1rZXJhcy10Zi9ub3RlYm9va3MvMDItaW1kYi5uYi5odG1sKS4KCmBgYHtyLCBtZXNzYWdlPUZBTFNFLCB3YXJuaW5nPUZBTFNFfQpsaWJyYXJ5KHRmcnVucykKbGlicmFyeShkcGx5cikKYGBgCgp0ZnJ1bnMgcHJvdmlkZXMgYWRkZWQgZmxleGliaWxpdHkgZm9yIHRyYWNraW5nLCB2aXN1YWxpemluZywgYW5kIG1hbmFnaW5nIAp0cmFpbmluZyBydW5zLiBUaGUgbW9zdCBjb21tb24gd2F5IHRvIHVzZSB0ZnJ1bnMgaXMgdG8gY3JlYXRlIGFuIFIgc2NyaXB0IHRoYXQKY29udGFpbnMgdGhlIGNvZGUgdG8gYmUgZXhlY3V0ZWQuIEZvciB0aGlzIGV4YW1wbGUsIEkgY3JlYXRlZCB0aGUKW2ltZGItZ3JpZC1zZWFyY2guUl0oaHR0cHM6Ly9yc3R1ZGlvLWNvbmYtMjAyMC5naXRodWIuaW8vZGwta2VyYXMtdGYvbWF0ZXJpYWxzLzk5LWV4dHJhcy9pbWRiLWdyaWQtc2VhcmNoLlIpCnNjcmlwdC4KCldpdGhpbiB0aGlzIHNjcmlwdCwgd2UgY3JlYXRlICJmbGFncyIsIHdoaWNoIGlkZW50aWZ5IHRoZSBoeXBlcnBhcmFtZXRlcnMgb2YKaW50ZXJlc3QuIEZvciB0aGlzIGV4YW1wbGUsIEkgYXNzZXNzOgoKLSBiYXRjaCBzaXplcyBvZiAxMjggYW5kIDUxMgotIGxheWVycyBvZiAxLCAyLCBhbmQgMwotIG51bWJlciBvZiB1bml0cyBwZXIgaGlkZGVuIGxheWVyIG9mIDE2LCAzMiwgYW5kIDY0Ci0gbGVhcm5pbmcgcmF0ZSBvZiAwLjAwMSBhbmQgMC4wMDAxCi0gZHJvcG91dCByYXRlcyBvZiAwLCAwLjMsIGFuZCAwLjUKLSB3ZWlnaHQgZGVjYXkgb2YgMCwgMC4wMSwgYW5kIDAuMDAxCgpUaGlzIGVxdWF0ZXMgdG8gMzI0IG1vZGVscy4gU2luY2UgdGhlc2UgbW9kZWxzIHJ1biByZWxhdGl2ZWx5IHF1aWNrbHksIGFuZCBzaW5jZQpJIHJhbiB0aGlzIGF0IHRoZSBlbmQgb2YgdGhlIGRheSwgSSBwZXJmb3JtZWQgYSBmdWxsIGNhcnRlc2lhbiBncmlkIHNlYXJjaC4gVGhpcwptZWFucyBJIHJhbiBhbmQgYXNzZXNzZWQgZXZlcnkgc2luZ2xlIG9uZSBvZiB0aGUgMzI0IG1vZGVscy4gSWYgeW91IGFyZSBwcmVzc2VkCmZvciB0aW1lIHlvdSBjYW4gcnVuIGEgc3RvY2hhc3RpYyBoeXBlcnBhcmFtZXRlciBncmlkIHNlYXJjaCBieSB1c2luZyB0aGUKYHNhbXBsZWAgcGFyYW1ldGVyIGluIGB0dW5pbmdfcnVuYCBiZWxvdy4KCk5leHQsIHRvIGV4ZWN1dGUgdGhpcyBncmlkIHNlYXJjaCwgeW91IGZpcnN0IHNwZWNpZnkgdGhlIGdyaWQgc2VhcmNoIGh5cGVycGFyYW1ldGVyCnZhbHVlcyBmb3IgdGhlIGZsYWdzIHlvdSBjcmVhdGVkIGluIHlvdXIgLlIgc2NyaXB0IGxpa2UgYmVsb3cuIFlvdSB0aGVuIHJ1bgpgdHVuaW5nX3J1bigpYCBpbiBwbGFjZSBvZiBgc291cmNlKClgIHRvIGV4ZWN1dGUgeW91ciAuUiBzY3JpcHQgZm9yIHRoZSBzdXBwbGllZApoeXBlcnBhcmFtZXRlciBncmlkLgoKX19fTm90ZTogdGhpcyB0YWtlcyBvdmVyIDMgaG91cnMgdG8gcnVuIG9uIGEgbm9uLUdQVV9fXwoKYGBge3IsIG1lc3NhZ2U9RkFMU0UsIHdhcm5pbmc9RkFMU0V9CmdyaWRfc2VhcmNoIDwtIGxpc3QoCiAgYmF0Y2hfc2l6ZSA9IGMoMTI4LCA1MTIpLAogIGxheWVycyA9IGMoMSwgMiwgMyksCiAgdW5pdHMgPSBjKDE2LCAzMiwgNjQpLAogIGxlYXJuaW5nX3JhdGUgPSBjKDAuMDAxLCAwLjAwMDEpLAogIGRyb3BvdXQgPSBjKDAsIDAuMywgMC41KSwKICB3ZWlnaHRfZGVjYXkgPSBjKDAsIDAuMDEsIDAuMDAxKQopCgp0dW5pbmdfcnVuKCJpbWRiLWdyaWQtc2VhcmNoLlIiLCBmbGFncyA9IGdyaWRfc2VhcmNoLCBjb25maXJtID0gRkFMU0UsIGVjaG8gPSBGQUxTRSkKYGBgCgpUaGlzIGdyaWQgc2VhcmNoIGV4ZWN1dGlvbiB3aWxsIGNyZWF0ZSBhICJydW5zIiBzdWJkaXJlY3Rvcnkgd2l0aGluIHlvdXIgd29ya2luZwpkaXJlY3RvcnkuIFRoaXMgZm9sZCBjb250YWlucyBpbmZvcm1hdGlvbiBmb3IgZXZlcnkgc2luZ2xlIHRyYWluaW5nIHJ1biBleGVjdXRlZApkdXJpbmcgdGhlIGdyaWQgc2VhcmNoLgoKVG8gbGlzdCB0aGUgcmVzdWx0cyB5b3UgY2FuIHJ1biBgbHNfcnVucygpYDoKCmBgYHtyfQpkYXRhLmZyYW1lKGxzX3J1bnMocnVuc19kaXIgPSAiaW1kYl9ydW5zIikpCmBgYAoKWW91IGNhbiBldmVuIGZpbHRlciBhbmQgb3JkZXIgdGhlIHJlc3VsdHMuIFRoZSBmb2xsb3dpbmcgaWxsdXN0cmF0ZXMgdGhhdCBhIGZldwptb2RlbHMgdGllZCB3aXRoIHRoZSBsb3dlc3QgbG9zcyBzY29yZSBvZiAwLjI2Mi4KCmBgYHtyfQpsc19ydW5zKHJ1bnNfZGlyID0gImltZGJfcnVucyIsIG9yZGVyID0gZXZhbF9iZXN0X2xvc3MsIGRlY3JlYXNpbmcgPSBGQUxTRSkKYGBgCgpUbyBzZWUgZGV0YWlscyBhYm91dCBhbnkgb25lIG9mIHRoZXNlIG1vZGVscyB5b3UgY2FuIHJ1biBgdmlld19ydW4oKWAuIEluIHRoaXMKZXhhbXBsZSwgSSB0YWtlIHRoZSBmaXJzdCBvcHRpbWFsIG1vZGVsIGZyb20gYWJvdmUuIFdoZW4geW91IGV4ZWN1dGUgdGhpcywKYSBwb3AgdXAgd2luZG93IHdpbGwgYXBwZWFyIHdpdGggdGhhdCBtb2RlbHMgc3VtbWFyeSBpbmZvcm1hdGlvbiBhcyBpbGx1c3RyYXRlZApiZWxvdy4KCmBgYHtyLCBldmFsPUZBTFNFfQpiZXN0X3J1biA8LSBsc19ydW5zKAogIHJ1bnNfZGlyID0gImltZGJfcnVucyIsCiAgb3JkZXIgPSBldmFsX2Jlc3RfbG9zcywKICBkZWNyZWFzaW5nID0gRkFMU0UKICApICU+JQogIHNsaWNlKDEpICU+JQogIHB1bGwocnVuX2RpcikKCnZpZXdfcnVuKGJlc3RfcnVuKQpgYGAKCmBgYHtyLCBlY2hvPUZBTFNFfQprbml0cjo6aW5jbHVkZV9ncmFwaGljcygiLi4vLi4vZG9jcy9pbWFnZXMvaW1kYl9ncmlkX3NlYXJjaF9iZXN0X21vZGVsLnBuZyIpCmBgYAoKVGhlcmUgYXJlIG1hbnkgb3RoZXIgaGFuZHkgZmVhdHVyZXMgdG8gdGhlIHRmcnVucyBwYWNrYWdlLiBJIHN1Z2dlc3QgeW91IGNoZWNrCml0IG91dCBhdCBodHRwczovL3RlbnNvcmZsb3cucnN0dWRpby5jb20vdG9vbHMvdGZydW5zL292ZXJ2aWV3LyBhbmQgdGFrZSBpdCBmb3IKYSB0ZXN0IGRyaXZlLg==