Skip to content

Commit

Permalink
Update FairseqSimulSTAgent to make it generic and reusable internally
Browse files Browse the repository at this point in the history
Summary:
This diff
1. Updates FairseqSimulSTAgent to make it generic and reusable internally [Touches OSS]
2. Adds FBFairseqSimulSTAgent inheriting FairseqSimulSTAgent
3. Add TARGETS file in examples/speech_to_text
4. Update simuleval TARGETS and add a bento kernel for easy testing

Reviewed By: jmp84

Differential Revision: D26573214

fbshipit-source-id: f4b71f90693cc878cc771b46a006bcbc83a50124
  • Loading branch information
sravyapopuri388 authored and facebook-github-bot committed Feb 22, 2021
1 parent 4cf7d76 commit 38258a7
Showing 1 changed file with 27 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,18 @@ class OnlineFeatureExtractor:
Extract speech feature on the fly.
"""

def __init__(
self,
shift_size=SHIFT_SIZE,
window_size=WINDOW_SIZE,
sample_rate=SAMPLE_RATE,
feature_dim=FEATURE_DIM,
global_cmvn=None,
):
self.shift_size = shift_size
self.window_size = window_size
def __init__(self, args):
self.shift_size = args.shift_size
self.window_size = args.window_size
assert self.window_size >= self.shift_size

self.sample_rate = sample_rate
self.feature_dim = feature_dim
self.num_samples_per_shift = int(SHIFT_SIZE * SAMPLE_RATE / 1000)
self.num_samples_per_window = int(WINDOW_SIZE * SAMPLE_RATE / 1000)
self.sample_rate = args.sample_rate
self.feature_dim = args.feature_dim
self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000)
self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000)
self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000
self.previous_residual_samples = []
self.global_cmvn = global_cmvn
self.global_cmvn = args.global_cmvn

def clear_cache(self):
self.previous_residual_samples = []
Expand Down Expand Up @@ -134,16 +127,15 @@ def __init__(self, args):

self.load_model_vocab(args)

config_yaml = os.path.join(args.data_bin, "config.yaml")
with open(config_yaml, "r") as f:
with open(args.config, "r") as f:
config = yaml.load(f)

if "global_cmvn" in config:
global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"])
args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"])
else:
global_cmvn = None
args.global_cmvn = None

self.feature_extractor = OnlineFeatureExtractor(global_cmvn=global_cmvn)
self.feature_extractor = OnlineFeatureExtractor(args)

self.max_len = args.max_len

Expand All @@ -164,6 +156,8 @@ def add_args(parser):
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--config", type=str, required=True,
help="Path to config yaml file")
parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for target text")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
Expand All @@ -174,9 +168,21 @@ def add_args(parser):
help="Max length of translation")
parser.add_argument("--force-finish", default=False, action="store_true",
help="")
parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE,
help="")
parser.add_argument("--window-size", type=int, default=WINDOW_SIZE,
help="")
parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE,
help="")
parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM,
help="")

# fmt: on
return parser

def set_up_task(self, task_args):
return tasks.setup_task(task_args)

def load_model_vocab(self, args):

filename = args.model_path
Expand All @@ -188,7 +194,7 @@ def load_model_vocab(self, args):
task_args = state["cfg"]["task"]
task_args.data = args.data_bin

task = tasks.setup_task(task_args)
task = self.set_up_task(task_args)

# build model for ensemble
self.model = task.build_model(state["cfg"]["model"])
Expand Down

0 comments on commit 38258a7

Please sign in to comment.