This notebook shows an example of performing a grid search for the Amazon review helpfulness project. If you are unfamiliar with Keras/Tensorflow grid searches please review this notebook first.

library(tfruns)
library(dplyr)

Our grid search setup is located in this amazon-embeddings-grid-search.R script. The grid we’re going to search across is below and contains 5,832 hyperparameter combinations. This is way too large to evaluate all possible combinations so, rather, we’ll perform a random search which typically finds a near optimal solution. We execute a random search by including a sample argument less than 1.

Here, we evaluate only 10% (583 models) of the possible hyperparameter combinations. In fact, I even cut the grid search short which resulted in 137 models being evaluated.

Note: this will take several hours to run on a non-GPU

grid_search <- list(
  top_n_words = c(5000, 10000, 20000),
  max_len = c(75, 150, 200),
  output_dim = c(16, 32, 64),
  learning_rate = c(0.001, 0.0001),
  batch_size = c(32, 64, 128),
  layers = c(1, 2, 3),
  units = c(16, 32, 128),
  dropout = c(0, 0.5),
  weight_decay = c(0, 0.01)
)

tuning_run(
  "amazon-embeddings-grid-search.R",
  flags = grid_search,
  runs_dir = "amazon_runs",
  sample = 0.1,
  confirm = FALSE,
  echo = FALSE
  )

This grid search execution will create an “amazon_runs” subdirectory within your working directory. Our results indicate that the best performing models all have optimal MSEs in the low 0.05 range.

ls_runs(runs_dir = "amazon_runs", order = eval_best_loss, decreasing = FALSE)
Data frame: 137 x 38 
# ... with 127 more rows
# ... with 27 more columns:
#   flag_top_n_words, flag_max_len, flag_output_dim, flag_learning_rate, flag_batch_size, flag_layers, flag_units,
#   flag_dropout, flag_weight_decay, samples, batch_size, steps_completed, epochs, epochs_completed, metrics, model,
#   loss_function, optimizer, learning_rate, script, start, end, completed, output, source_code, context, type

If we look at our optimal model we see it has the following parameters:

best_run <- ls_runs(
  runs_dir = "amazon_runs",
  order = eval_best_loss,
  decreasing = FALSE
  ) %>%
  slice(1) %>%
  pull(run_dir)

view_run(best_run)

LS0tCnRpdGxlOiAiR3JpZCBTZWFyY2ggZm9yIEFtYXpvbiBSZXZpZXcgSGVscGZ1bG5lc3Mgd2l0aCBXb3JkIEVtYmVkZGluZ3MiCm91dHB1dDogaHRtbF9ub3RlYm9vawotLS0KCmBgYHtyIHNldHVwLCBpbmNsdWRlPUZBTFNFfQprbml0cjo6b3B0c19jaHVuayRzZXQoZWNobyA9IFRSVUUpCmdncGxvdDI6OnRoZW1lX3NldChnZ3Bsb3QyOjp0aGVtZV9taW5pbWFsKCkpCgojIGNsZWFuIHVwIGluIGNhc2UgeW91IHJ1biB0aGlzIG11bHRpcGxlIHRpbWVzCiN0ZnJ1bnM6OmNsZWFuX3J1bnMocnVuc19kaXIgPSAiYW1hem9uX3J1bnMiLCBjb25maXJtID0gRkFMU0UpCmBgYAoKVGhpcyBub3RlYm9vayBzaG93cyBhbiBleGFtcGxlIG9mIHBlcmZvcm1pbmcgYSBncmlkIHNlYXJjaCBmb3IgdGhlIFtBbWF6b24gcmV2aWV3CmhlbHBmdWxuZXNzIHByb2plY3RdKGh0dHBzOi8vcnN0dWRpby1jb25mLTIwMjAuZ2l0aHViLmlvL2RsLWtlcmFzLXRmL25vdGVib29rcy9wcm9qZWN0LWVtYmVkZGluZ3MtYW1hem9uLXJldmlld3MubmIuaHRtbCkuCklmIHlvdSBhcmUgdW5mYW1pbGlhciB3aXRoIEtlcmFzL1RlbnNvcmZsb3cgZ3JpZCBzZWFyY2hlcyBwbGVhc2UgcmV2aWV3IHRoaXMKW25vdGVib29rXShodHRwczovL3JzdHVkaW8tY29uZi0yMDIwLmdpdGh1Yi5pby9kbC1rZXJhcy10Zi9ub3RlYm9va3MvaW1kYi1ncmlkLXNlYXJjaC5uYi5odG1sKSBmaXJzdC4KCmBgYHtyLCBtZXNzYWdlPUZBTFNFLCB3YXJuaW5nPUZBTFNFfQpsaWJyYXJ5KHRmcnVucykKbGlicmFyeShkcGx5cikKYGBgCgpPdXIgZ3JpZCBzZWFyY2ggc2V0dXAgaXMgbG9jYXRlZCBpbiB0aGlzIFthbWF6b24tZW1iZWRkaW5ncy1ncmlkLXNlYXJjaC5SXShodHRwczovL3JzdHVkaW8tY29uZi0yMDIwLmdpdGh1Yi5pby9kbC1rZXJhcy10Zi9tYXRlcmlhbHMvOTktZXh0cmFzL2FtYXpvbi1lbWJlZGRpbmdzLWdyaWQtc2VhcmNoLlIpCnNjcmlwdC4gVGhlIGdyaWQgd2UncmUgZ29pbmcgdG8gc2VhcmNoIGFjcm9zcyBpcyBiZWxvdyBhbmQgY29udGFpbnMgX19fNSw4MzJfX18KaHlwZXJwYXJhbWV0ZXIgY29tYmluYXRpb25zLiBUaGlzIGlzIHdheSB0b28gbGFyZ2UgdG8gZXZhbHVhdGUgYWxsIHBvc3NpYmxlCmNvbWJpbmF0aW9ucyBzbywgcmF0aGVyLCB3ZSdsbCBwZXJmb3JtIGEgW3JhbmRvbSBzZWFyY2hdKGh0dHA6Ly9qbWxyLmNzYWlsLm1pdC5lZHUvcGFwZXJzL3ZvbHVtZTEzL2JlcmdzdHJhMTJhL2JlcmdzdHJhMTJhLnBkZikKd2hpY2ggdHlwaWNhbGx5IGZpbmRzIGEgbmVhciBvcHRpbWFsIHNvbHV0aW9uLiBXZSBleGVjdXRlIGEgcmFuZG9tIHNlYXJjaCBieQppbmNsdWRpbmcgYSBgc2FtcGxlYCBhcmd1bWVudCBsZXNzIHRoYW4gMS4gCgpIZXJlLCB3ZSBldmFsdWF0ZSBvbmx5IDEwJSAoNTgzIG1vZGVscykgb2YgdGhlIHBvc3NpYmxlIGh5cGVycGFyYW1ldGVyCmNvbWJpbmF0aW9ucy4gSW4gZmFjdCwgSSBldmVuIGN1dCB0aGUgZ3JpZCBzZWFyY2ggc2hvcnQgd2hpY2ggcmVzdWx0ZWQgaW4gMTM3Cm1vZGVscyBiZWluZyBldmFsdWF0ZWQuCgpfX19Ob3RlOiB0aGlzIHdpbGwgdGFrZSBzZXZlcmFsIGhvdXJzIHRvIHJ1biBvbiBhIG5vbi1HUFVfX18KCmBgYHtyLCBtZXNzYWdlPUZBTFNFLCB3YXJuaW5nPUZBTFNFfQpncmlkX3NlYXJjaCA8LSBsaXN0KAogIHRvcF9uX3dvcmRzID0gYyg1MDAwLCAxMDAwMCwgMjAwMDApLAogIG1heF9sZW4gPSBjKDc1LCAxNTAsIDIwMCksCiAgb3V0cHV0X2RpbSA9IGMoMTYsIDMyLCA2NCksCiAgbGVhcm5pbmdfcmF0ZSA9IGMoMC4wMDEsIDAuMDAwMSksCiAgYmF0Y2hfc2l6ZSA9IGMoMzIsIDY0LCAxMjgpLAogIGxheWVycyA9IGMoMSwgMiwgMyksCiAgdW5pdHMgPSBjKDE2LCAzMiwgMTI4KSwKICBkcm9wb3V0ID0gYygwLCAwLjUpLAogIHdlaWdodF9kZWNheSA9IGMoMCwgMC4wMSkKKQoKdHVuaW5nX3J1bigKICAiYW1hem9uLWVtYmVkZGluZ3MtZ3JpZC1zZWFyY2guUiIsCiAgZmxhZ3MgPSBncmlkX3NlYXJjaCwKICBydW5zX2RpciA9ICJhbWF6b25fcnVucyIsCiAgc2FtcGxlID0gMC4xLAogIGNvbmZpcm0gPSBGQUxTRSwKICBlY2hvID0gRkFMU0UKICApCmBgYAoKVGhpcyBncmlkIHNlYXJjaCBleGVjdXRpb24gd2lsbCBjcmVhdGUgYW4gImFtYXpvbl9ydW5zIiBzdWJkaXJlY3Rvcnkgd2l0aGluIHlvdXIKd29ya2luZyBkaXJlY3RvcnkuIE91ciByZXN1bHRzIGluZGljYXRlIHRoYXQgdGhlIGJlc3QgcGVyZm9ybWluZyBtb2RlbHMgYWxsIGhhdmUKb3B0aW1hbCBNU0VzIGluIHRoZSBsb3cgMC4wNSByYW5nZS4KCmBgYHtyLH0KbHNfcnVucyhydW5zX2RpciA9ICJhbWF6b25fcnVucyIsIG9yZGVyID0gZXZhbF9iZXN0X2xvc3MsIGRlY3JlYXNpbmcgPSBGQUxTRSkKYGBgCgpJZiB3ZSBsb29rIGF0IG91ciBvcHRpbWFsIG1vZGVsIHdlIHNlZSBpdCBoYXMgdGhlIGZvbGxvd2luZyBwYXJhbWV0ZXJzOgoKLSB0b3Bfbl93b3JkczogMjAsMDAwCi0gbWF4X2xlbjogMjAwCi0gb3V0cHV0X2RpbTogMzIKLSBsYXllcnMgKGhpZGRlbiBsYXllcnMgaW4gY2xhc3NpZmllcik6IDEKLSB1bml0cyAoaW4gaGlkZGVuIGxheWVycyk6IDMyCi0gbGVhcm5pbmdfcmF0ZTogMC4wMDAxCi0gYmF0Y2hfc2l6ZTogMTI4Ci0gZHJvcG91dDogMC41Ci0gd2VpZ2h0X2RlY2F5OiAwCgpgYGB7ciwgZXZhbD1GQUxTRX0KYmVzdF9ydW4gPC0gbHNfcnVucygKICBydW5zX2RpciA9ICJhbWF6b25fcnVucyIsCiAgb3JkZXIgPSBldmFsX2Jlc3RfbG9zcywKICBkZWNyZWFzaW5nID0gRkFMU0UKICApICU+JQogIHNsaWNlKDEpICU+JQogIHB1bGwocnVuX2RpcikKCnZpZXdfcnVuKGJlc3RfcnVuKQpgYGAKCmBgYHtyLCBlY2hvPUZBTFNFfQprbml0cjo6aW5jbHVkZV9ncmFwaGljcygiLi4vLi4vZG9jcy9pbWFnZXMvYW1hem9uX2dyaWRfc2VhcmNoX2Jlc3RfbW9kZWwucG5nIikKYGBg