#!/usr/bin/env python3

import transformers
import torch
import sys
import os

device = sys.argv[1]
input_path = sys.argv[2]
output_path = sys.argv[3]

model_name = os.environ["MODEL_NAME"]
tokenizer = transformers.MarianTokenizer.from_pretrained(model_name)
model = transformers.MarianMTModel.from_pretrained(model_name)
model.eval()
if device == "GPU":
    model.cuda()

batch_size = 32
beam_size = 4


def to_tensor(batch):
    max_length = max(len(ids) for ids in batch)
    batch = [ids + [tokenizer.pad_token_id] * (max_length - len(ids)) for ids in batch]
    return torch.tensor(batch, device=model.device)


def batch_iter(lines):
    batch = []
    for line in lines:
        ids = list(map(int, line.strip().split()))
        batch.append(ids)
        if len(batch) == batch_size:
            yield to_tensor(batch)
            batch = []
    if batch:
        yield to_tensor(batch)


with open(input_path) as input_file, open(output_path, "w") as output_file:
    for batch in batch_iter(input_file):
        outputs = model.generate(
            input_ids=batch,
            num_beams=beam_size,
            length_penalty=0,
            early_stopping=True,
        )

        for tokens in outputs.tolist():
            tokens = filter(
                lambda token: (
                    token not in (tokenizer.pad_token_id, tokenizer.eos_token_id)
                ),
                tokens,
            )
            output_file.write(" ".join(map(str, tokens)))
            output_file.write("\n")
