Just spent half an hour "optimizing" my PyTorch code so it would need to do less array formatting per batch. Time per iteration is now the same it was before the optimization. Woohoo!