Skip to content

Commit

Permalink
Disable jitter noise during evaluation in SwitchTransformers (hugging…
Browse files Browse the repository at this point in the history
…face#28077)

* Disable jitter noise during evaluation

* Update outdated configuration information

* Formatting

* Add new line
  • Loading branch information
DaizeDong authored Dec 18, 2023
1 parent a0522de commit 7c5408d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[to
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)

if self.jitter_noise > 0:
if self.training and self.jitter_noise > 0:
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SwitchTransformersConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 32128):
Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`].
d_model (`int`, *optional*, defaults to 512):
d_model (`int`, *optional*, defaults to 768):
Size of the encoder layers and the pooler layer.
d_kv (`int`, *optional*, defaults to 64):
Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
Expand All @@ -50,21 +50,19 @@ class SwitchTransformersConfig(PretrainedConfig):
Transformer.
num_layers (`int`, *optional*, defaults to 12):
Number of dense hidden layers in the Transformer encoder layer.
num_sparse_encoder_layers (`int`, *optional*, defaults to 6):
num_sparse_encoder_layers (`int`, *optional*, defaults to 3):
Number of sparse (MoE) dense hidden layers in the Transformer encoder layer.
num_decoder_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
num_sparse_decoder_layers (`int`, *optional*, defaults to 12):
num_sparse_decoder_layers (`int`, *optional*, defaults to 3):
Number of sparse (MoE) dense hidden layers in the Transformer decoder layer.
num_heads (`int`, *optional*, defaults to 8):
num_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_experts (`int`, *optional*, defaults to 8):
Number of experts for each SwitchTransformer layer.
router_type (`str`, *optional*, defaults to `"tokens_masked"`):
Router type - choose between `"tokens_masked", `"tokens_scatter"` and `"experts_masked"`.
router_bias (`bool`, *optional*, defaults to `True`):
router_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the router.
router_jitter_noise (`float`, *optional*, defaults to 0.1):
router_jitter_noise (`float`, *optional*, defaults to 0.01):
Amount of noise to add to the router.
router_dtype (`str`, *optional*, default to `"float32"`):
The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
Expand All @@ -83,10 +81,10 @@ class SwitchTransformersConfig(PretrainedConfig):
The z loss factor for the total loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
initializer_factor (`float`, *optional*, defaults to 1):
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
dense_act_fn (`string`, *optional*, defaults to `"relu"`):
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. SwitchTransformersv1.1
uses the `"gated-gelu"` feed forward projection. Original SwitchTransformers uses `"relu"`.
add_router_probs (`bool`, *optional*, defaults to `False`):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[to
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)

if self.jitter_noise > 0:
if self.training and self.jitter_noise > 0:
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

Expand Down

0 comments on commit 7c5408d

Please sign in to comment.