Skip to content

Commit

Permalink
Update default hyper-parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
jjonescz committed Apr 30, 2022
1 parent 707b0e7 commit 12c8739
Showing 1 changed file with 42 additions and 30 deletions.
72 changes: 42 additions & 30 deletions awe/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ class AttentionNormalization(str, enum.Enum):
vector = 'vector'
softmax = 'softmax'

def _freeze(value):
"""
Creates field with a mutable default value.
Python doesn't like this being done directly. But `Params` are never
mutated anyway, so this workaround is easiest.
"""

return dataclasses.field(default_factory=lambda: value)

@dataclasses.dataclass
class Params:
"""
Expand All @@ -57,14 +67,14 @@ class Params:
Only considered when using the SWDE dataset for now.
"""

label_keys: list[str] = ('name', 'price', 'shortDescription', 'images')
label_keys: list[str] = ()
"""
Set of keys to select from the dataset.
Only considered when using the Apify dataset for now.
"""

train_website_indices: list[int] = (0, 3, 4, 5, 7)
train_website_indices: list[int] = (0, 1, 2, 3, 4)
"""
Indices of websites to put in the training set.
Expand All @@ -74,23 +84,23 @@ class Params:
exclude_websites: list[str] = ()
"""Website names to exclude from loading."""

train_subset: Optional[int] = 2000
train_subset: Optional[int] = 100
"""Number of pages per website to use for training."""

val_subset: Optional[int] = 50
val_subset: Optional[int] = 5
"""
Number of pages per website to use for validation (evaluation after each
training epoch).
"""

test_subset: Optional[int] = None
test_subset: Optional[int] = 250
"""
Number of pages per website to use for testing (evaluation of each
cross-validation run).
"""

# Trainer
epochs: int = 5
epochs: int = 3
"""
Number of epochs (passes over all training samples) to train the model for.
"""
Expand All @@ -103,13 +113,13 @@ class Params:
restore_num: Optional[int] = None
"""Existing version number to restore."""

batch_size: int = 16
batch_size: int = 32
"""Number of samples to have in a mini-batch during training/evaluation."""

save_every_n_epochs: Optional[int] = 1
save_every_n_epochs: Optional[int] = None
"""How often a checkpoint should be saved."""

save_better_val_loss_checkpoint: bool = True
save_better_val_loss_checkpoint: bool = False
"""
Save checkpoint after each epoch when better validation loss is achieved.
"""
Expand All @@ -124,12 +134,12 @@ class Params:
the storage space is not wasted by saving all checkpoints.
"""

log_every_n_steps: int = 10
log_every_n_steps: int = 100
"""
How often to evaluate and write metrics to TensorBoard during training.
"""

eval_every_n_steps: Optional[int] = 50
eval_every_n_steps: Optional[int] = None
"""
How often to execute full evaluation pass on the validation subset during
training.
Expand All @@ -150,13 +160,13 @@ class Params:
"""

# Sampling
load_visuals: bool = False
load_visuals: bool = True
"""
When loading HTML for pages, also load JSON visuals, parse visual attributes
and attach them to DOM nodes in memory.
"""

classify_only_text_nodes: bool = False
classify_only_text_nodes: bool = True
"""Sample only text fragments."""

classify_only_variable_nodes: bool = False
Expand All @@ -179,10 +189,10 @@ class Params:
validate_data: bool = True
"""Validate sampled page DOMs."""

ignore_invalid_pages: bool = False
ignore_invalid_pages: bool = True
"""Of sampled and validated pages, ignore those that are invalid."""

none_cutoff: Optional[int] = None
none_cutoff: Optional[int] = 30_000
"""
From 0 to 100,000. The higher, the more non-target nodes will be sampled.
"""
Expand All @@ -195,16 +205,16 @@ class Params:
"""Number of friends in the friend cycle."""

# Visual neighbors
visual_neighbors: bool = False
visual_neighbors: bool = True
"""Use visual neighbors as a feature when classifying nodes."""

n_neighbors: int = 4
n_neighbors: int = 10
"""Number of visual neighbors (the closes ones)."""

neighbor_distance: VisualNeighborDistance = VisualNeighborDistance.rect
"""How to determine the closes visual neighbors."""

neighbor_normalize: Optional[AttentionNormalization] = AttentionNormalization.softmax
neighbor_normalize: Optional[AttentionNormalization] = AttentionNormalization.vector
"""
How to normalize neighbor distances before feeding them to the attention
module.
Expand All @@ -217,16 +227,16 @@ class Params:
"""

# Ancestor chain
ancestor_chain: bool = False
ancestor_chain: bool = True
"""Use DOM ancestors as a feature when classifying nodes."""

n_ancestors: Optional[int] = 5
n_ancestors: Optional[int] = None
"""`None` to use all ancestors."""

ancestor_lstm_out_dim: int = 10
"""Output dimension of the LSTM aggregating ancestor features."""

ancestor_lstm_args: Optional[dict[str]] = None
ancestor_lstm_args: Optional[dict[str]] = _freeze({'bidirectional': True})
"""
Additional keyword arguments to the LSTM layer aggregating ancestor
features.
Expand All @@ -243,7 +253,7 @@ class Params:
"""

# Word vectors
tokenizer_family: TokenizerFamily = TokenizerFamily.custom
tokenizer_family: TokenizerFamily = TokenizerFamily.bert
"""Which tokenizer to use."""

tokenizer_id: str = ''
Expand All @@ -266,7 +276,7 @@ class Params:
"""

# HTML attributes
tokenize_node_attrs: list[str] = ()
tokenize_node_attrs: list[str] = ('itemprop',)
"""
DOM attributes to tokenize and use as a feature when classifying nodes and
also as a feature of each ancestor in the ancestor chain.
Expand All @@ -278,7 +288,7 @@ class Params:
"""Use the DOM attribute feature only for nodes of the ancestor chain."""

# LSTM
word_vector_function: Optional[str] = 'sum'
word_vector_function: Optional[str] = 'lstm'
"""
How to aggregate word vectors.
Expand All @@ -288,7 +298,7 @@ class Params:
lstm_dim: int = 100
"""Output dimension of the LSTM aggregating word vectors."""

lstm_args: Optional[dict[str]] = None
lstm_args: Optional[dict[str]] = _freeze({'bidirectional': True})
"""
Additional keyword arguments to the LSTM layer aggregating word vectors.
"""
Expand All @@ -313,7 +323,7 @@ class Params:
"""

# HTML DOM features
tag_name_embedding: bool = False
tag_name_embedding: bool = True
"""
Whether to use HTML tag name as a feature when classifying nodes.
Expand All @@ -323,15 +333,17 @@ class Params:
tag_name_embedding_dim: int = 30
"""Dimension of the output vector of HTML tag name embedding."""

position: bool = False
position: bool = True
"""
Whether to use visual position as a feature when classifying nodes.
See `awe.features.dom.Position`.
"""

# Visual features
enabled_visuals: Optional[list[str]] = None
enabled_visuals: Optional[list[str]] = (
"font_size", "font_style", "font_weight", "font_color"
)
"""
Filter visual attributes to only those in this list.
Expand All @@ -358,10 +370,10 @@ class Params:
layer_norm: bool = False
"""Use layer normalization in the classification head."""

head_dims: list[int] = (128, 64)
head_dims: list[int] = (100, 10)
"""Dimensions of feed-forward layers in the classification head."""

head_dropout: float = 0.5
head_dropout: float = 0.3
"""Dropout probability in the classification head."""

gradient_clipping: Optional[float] = None
Expand Down

0 comments on commit 12c8739

Please sign in to comment.