This notebook shows an example of performing a grid search for the Amazon review helpfulness project. If you are unfamiliar with Keras/Tensorflow grid searches please review this notebook first.
library(tfruns)
library(dplyr)
Our grid search setup is located in this amazon-embeddings-grid-search.R script. The grid we’re going to search across is below and contains 5,832 hyperparameter combinations. This is way too large to evaluate all possible combinations so, rather, we’ll perform a random search which typically finds a near optimal solution. We execute a random search by including a sample
argument less than 1.
Here, we evaluate only 10% (583 models) of the possible hyperparameter combinations. In fact, I even cut the grid search short which resulted in 137 models being evaluated.
Note: this will take several hours to run on a non-GPU
grid_search <- list(
top_n_words = c(5000, 10000, 20000),
max_len = c(75, 150, 200),
output_dim = c(16, 32, 64),
learning_rate = c(0.001, 0.0001),
batch_size = c(32, 64, 128),
layers = c(1, 2, 3),
units = c(16, 32, 128),
dropout = c(0, 0.5),
weight_decay = c(0, 0.01)
)
tuning_run(
"amazon-embeddings-grid-search.R",
flags = grid_search,
runs_dir = "amazon_runs",
sample = 0.1,
confirm = FALSE,
echo = FALSE
)
This grid search execution will create an “amazon_runs” subdirectory within your working directory. Our results indicate that the best performing models all have optimal MSEs in the low 0.05 range.
ls_runs(runs_dir = "amazon_runs", order = eval_best_loss, decreasing = FALSE)
Data frame: 137 x 38
# ... with 127 more rows
# ... with 27 more columns:
# flag_top_n_words, flag_max_len, flag_output_dim, flag_learning_rate, flag_batch_size, flag_layers, flag_units,
# flag_dropout, flag_weight_decay, samples, batch_size, steps_completed, epochs, epochs_completed, metrics, model,
# loss_function, optimizer, learning_rate, script, start, end, completed, output, source_code, context, type
If we look at our optimal model we see it has the following parameters:
- top_n_words: 20,000
- max_len: 200
- output_dim: 32
- layers (hidden layers in classifier): 1
- units (in hidden layers): 32
- learning_rate: 0.0001
- batch_size: 128
- dropout: 0.5
- weight_decay: 0
best_run <- ls_runs(
runs_dir = "amazon_runs",
order = eval_best_loss,
decreasing = FALSE
) %>%
slice(1) %>%
pull(run_dir)
view_run(best_run)
LS0tCnRpdGxlOiAiR3JpZCBTZWFyY2ggZm9yIEFtYXpvbiBSZXZpZXcgSGVscGZ1bG5lc3Mgd2l0aCBXb3JkIEVtYmVkZGluZ3MiCm91dHB1dDogaHRtbF9ub3RlYm9vawotLS0KCmBgYHtyIHNldHVwLCBpbmNsdWRlPUZBTFNFfQprbml0cjo6b3B0c19jaHVuayRzZXQoZWNobyA9IFRSVUUpCmdncGxvdDI6OnRoZW1lX3NldChnZ3Bsb3QyOjp0aGVtZV9taW5pbWFsKCkpCgojIGNsZWFuIHVwIGluIGNhc2UgeW91IHJ1biB0aGlzIG11bHRpcGxlIHRpbWVzCiN0ZnJ1bnM6OmNsZWFuX3J1bnMocnVuc19kaXIgPSAiYW1hem9uX3J1bnMiLCBjb25maXJtID0gRkFMU0UpCmBgYAoKVGhpcyBub3RlYm9vayBzaG93cyBhbiBleGFtcGxlIG9mIHBlcmZvcm1pbmcgYSBncmlkIHNlYXJjaCBmb3IgdGhlIFtBbWF6b24gcmV2aWV3CmhlbHBmdWxuZXNzIHByb2plY3RdKGh0dHBzOi8vcnN0dWRpby1jb25mLTIwMjAuZ2l0aHViLmlvL2RsLWtlcmFzLXRmL25vdGVib29rcy9wcm9qZWN0LWVtYmVkZGluZ3MtYW1hem9uLXJldmlld3MubmIuaHRtbCkuCklmIHlvdSBhcmUgdW5mYW1pbGlhciB3aXRoIEtlcmFzL1RlbnNvcmZsb3cgZ3JpZCBzZWFyY2hlcyBwbGVhc2UgcmV2aWV3IHRoaXMKW25vdGVib29rXShodHRwczovL3JzdHVkaW8tY29uZi0yMDIwLmdpdGh1Yi5pby9kbC1rZXJhcy10Zi9ub3RlYm9va3MvaW1kYi1ncmlkLXNlYXJjaC5uYi5odG1sKSBmaXJzdC4KCmBgYHtyLCBtZXNzYWdlPUZBTFNFLCB3YXJuaW5nPUZBTFNFfQpsaWJyYXJ5KHRmcnVucykKbGlicmFyeShkcGx5cikKYGBgCgpPdXIgZ3JpZCBzZWFyY2ggc2V0dXAgaXMgbG9jYXRlZCBpbiB0aGlzIFthbWF6b24tZW1iZWRkaW5ncy1ncmlkLXNlYXJjaC5SXShodHRwczovL3JzdHVkaW8tY29uZi0yMDIwLmdpdGh1Yi5pby9kbC1rZXJhcy10Zi9tYXRlcmlhbHMvOTktZXh0cmFzL2FtYXpvbi1lbWJlZGRpbmdzLWdyaWQtc2VhcmNoLlIpCnNjcmlwdC4gVGhlIGdyaWQgd2UncmUgZ29pbmcgdG8gc2VhcmNoIGFjcm9zcyBpcyBiZWxvdyBhbmQgY29udGFpbnMgX19fNSw4MzJfX18KaHlwZXJwYXJhbWV0ZXIgY29tYmluYXRpb25zLiBUaGlzIGlzIHdheSB0b28gbGFyZ2UgdG8gZXZhbHVhdGUgYWxsIHBvc3NpYmxlCmNvbWJpbmF0aW9ucyBzbywgcmF0aGVyLCB3ZSdsbCBwZXJmb3JtIGEgW3JhbmRvbSBzZWFyY2hdKGh0dHA6Ly9qbWxyLmNzYWlsLm1pdC5lZHUvcGFwZXJzL3ZvbHVtZTEzL2JlcmdzdHJhMTJhL2JlcmdzdHJhMTJhLnBkZikKd2hpY2ggdHlwaWNhbGx5IGZpbmRzIGEgbmVhciBvcHRpbWFsIHNvbHV0aW9uLiBXZSBleGVjdXRlIGEgcmFuZG9tIHNlYXJjaCBieQppbmNsdWRpbmcgYSBgc2FtcGxlYCBhcmd1bWVudCBsZXNzIHRoYW4gMS4gCgpIZXJlLCB3ZSBldmFsdWF0ZSBvbmx5IDEwJSAoNTgzIG1vZGVscykgb2YgdGhlIHBvc3NpYmxlIGh5cGVycGFyYW1ldGVyCmNvbWJpbmF0aW9ucy4gSW4gZmFjdCwgSSBldmVuIGN1dCB0aGUgZ3JpZCBzZWFyY2ggc2hvcnQgd2hpY2ggcmVzdWx0ZWQgaW4gMTM3Cm1vZGVscyBiZWluZyBldmFsdWF0ZWQuCgpfX19Ob3RlOiB0aGlzIHdpbGwgdGFrZSBzZXZlcmFsIGhvdXJzIHRvIHJ1biBvbiBhIG5vbi1HUFVfX18KCmBgYHtyLCBtZXNzYWdlPUZBTFNFLCB3YXJuaW5nPUZBTFNFfQpncmlkX3NlYXJjaCA8LSBsaXN0KAogIHRvcF9uX3dvcmRzID0gYyg1MDAwLCAxMDAwMCwgMjAwMDApLAogIG1heF9sZW4gPSBjKDc1LCAxNTAsIDIwMCksCiAgb3V0cHV0X2RpbSA9IGMoMTYsIDMyLCA2NCksCiAgbGVhcm5pbmdfcmF0ZSA9IGMoMC4wMDEsIDAuMDAwMSksCiAgYmF0Y2hfc2l6ZSA9IGMoMzIsIDY0LCAxMjgpLAogIGxheWVycyA9IGMoMSwgMiwgMyksCiAgdW5pdHMgPSBjKDE2LCAzMiwgMTI4KSwKICBkcm9wb3V0ID0gYygwLCAwLjUpLAogIHdlaWdodF9kZWNheSA9IGMoMCwgMC4wMSkKKQoKdHVuaW5nX3J1bigKICAiYW1hem9uLWVtYmVkZGluZ3MtZ3JpZC1zZWFyY2guUiIsCiAgZmxhZ3MgPSBncmlkX3NlYXJjaCwKICBydW5zX2RpciA9ICJhbWF6b25fcnVucyIsCiAgc2FtcGxlID0gMC4xLAogIGNvbmZpcm0gPSBGQUxTRSwKICBlY2hvID0gRkFMU0UKICApCmBgYAoKVGhpcyBncmlkIHNlYXJjaCBleGVjdXRpb24gd2lsbCBjcmVhdGUgYW4gImFtYXpvbl9ydW5zIiBzdWJkaXJlY3Rvcnkgd2l0aGluIHlvdXIKd29ya2luZyBkaXJlY3RvcnkuIE91ciByZXN1bHRzIGluZGljYXRlIHRoYXQgdGhlIGJlc3QgcGVyZm9ybWluZyBtb2RlbHMgYWxsIGhhdmUKb3B0aW1hbCBNU0VzIGluIHRoZSBsb3cgMC4wNSByYW5nZS4KCmBgYHtyLH0KbHNfcnVucyhydW5zX2RpciA9ICJhbWF6b25fcnVucyIsIG9yZGVyID0gZXZhbF9iZXN0X2xvc3MsIGRlY3JlYXNpbmcgPSBGQUxTRSkKYGBgCgpJZiB3ZSBsb29rIGF0IG91ciBvcHRpbWFsIG1vZGVsIHdlIHNlZSBpdCBoYXMgdGhlIGZvbGxvd2luZyBwYXJhbWV0ZXJzOgoKLSB0b3Bfbl93b3JkczogMjAsMDAwCi0gbWF4X2xlbjogMjAwCi0gb3V0cHV0X2RpbTogMzIKLSBsYXllcnMgKGhpZGRlbiBsYXllcnMgaW4gY2xhc3NpZmllcik6IDEKLSB1bml0cyAoaW4gaGlkZGVuIGxheWVycyk6IDMyCi0gbGVhcm5pbmdfcmF0ZTogMC4wMDAxCi0gYmF0Y2hfc2l6ZTogMTI4Ci0gZHJvcG91dDogMC41Ci0gd2VpZ2h0X2RlY2F5OiAwCgpgYGB7ciwgZXZhbD1GQUxTRX0KYmVzdF9ydW4gPC0gbHNfcnVucygKICBydW5zX2RpciA9ICJhbWF6b25fcnVucyIsCiAgb3JkZXIgPSBldmFsX2Jlc3RfbG9zcywKICBkZWNyZWFzaW5nID0gRkFMU0UKICApICU+JQogIHNsaWNlKDEpICU+JQogIHB1bGwocnVuX2RpcikKCnZpZXdfcnVuKGJlc3RfcnVuKQpgYGAKCmBgYHtyLCBlY2hvPUZBTFNFfQprbml0cjo6aW5jbHVkZV9ncmFwaGljcygiLi4vLi4vZG9jcy9pbWFnZXMvYW1hem9uX2dyaWRfc2VhcmNoX2Jlc3RfbW9kZWwucG5nIikKYGBg