From 846f22a03af975946a8796a6c1ff743668480925 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 10 Nov 2020 15:49:47 -0800 Subject: [PATCH 1/2] Refactor Pytext hierarchy Summary: Refactor Pytext hierarchy Differential Revision: D24816473 fbshipit-source-id: 0d87c879ffbf280713afcfe3c02d2daa7f39bdb1 --- pytext/torchscript/batchutils.py | 226 ++++++++++++- pytext/torchscript/module.py | 544 +++++++++++++++++-------------- 2 files changed, 522 insertions(+), 248 deletions(-) diff --git a/pytext/torchscript/batchutils.py b/pytext/torchscript/batchutils.py index fd2a99ba7..861221570 100644 --- a/pytext/torchscript/batchutils.py +++ b/pytext/torchscript/batchutils.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from typing import List, Tuple + +from typing import Dict, List, Tuple +from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer + +import torch def max_tokens(per_sentence_tokens: List[List[Tuple[str, int, int]]]) -> int: @@ -12,3 +16,223 @@ def max_tokens(per_sentence_tokens: List[List[Tuple[str, int, int]]]) -> int: sentence_lengths = [len(sentence) for sentence in per_sentence_tokens] return max(sentence_lengths) + + +######################################################################## +# +# utility functions to destructure flat result tensor combining +# cross-request batches and client side +# batches into a cross-request list of +# client-side batch tensors +# + + +def destructure_tensor( + client_batch: List[int], + result_tensor: torch.Tensor, +) -> List[torch.Tensor]: + start = 0 + res_list: List[torch.Tensor] = [] + + for elems in client_batch: + end = start + elems + res_list.append(result_tensor.narrow(0, start, elems)) + start = end + + return res_list + + +def destructure_tensor_list( + client_batch: List[int], + result_tensor_list: List[torch.Tensor], +) -> List[List[torch.Tensor]]: + res_list: List[List[torch.Tensor]] = [] + start = 0 + + for elems in client_batch: + end = start + elems + res_list.append(result_tensor_list[start:end]) + start = end + + return res_list + + +############################################################################ +# +# make_prediction_* () +# utility functions to collect inputs from multiple batch elements into a +# a single cross request batch +# +# make_batch_* () +# utility functions for batch optimizations +# + + +def make_prediction_texts( + batch: List[ + Tuple[ + List[str], # texts + ] + ], +) -> List[str]: + + batchsize = len(batch) + flat_texts: List[str] = [] + + for i in range(batchsize): + batch_element = batch[i][0] + flat_texts.extend(batch_element) + + if len(flat_texts) == 0: + raise RuntimeError("This is not good. Empty request batch.") + + return flat_texts + + +def make_batch_texts( + tensorizer: ScriptTensorizer, + mega_batch: List[ + Tuple[ + List[str], # texts + int, + ] + ], + goals: Dict[str, str], +) -> List[List[Tuple[List[str], int,]]]: # texts + + # The next lines sort all cross-request batch elements by the token length. + # Note that cross-request batch element can in turn be a client batch. + mega_batch_key_list = [ + (max_tokens(tensorizer.tokenize(x[0], None)), n) + for (n, x) in enumerate(mega_batch) + ] + sorted_mega_batch_key_list = sorted(mega_batch_key_list) + sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] + + # TBD: allow model server to specify batch size in goals dictionary + max_bs: int = 10 + len_mb = len(mega_batch) + num_batches = (len_mb + max_bs - 1) // max_bs + + batch_list: List[ + List[ + Tuple[ + List[str], # texts + int, # position + ] + ] + ] = [] + + start = 0 + + for _i in range(num_batches): + end = min(start + max_bs, len_mb) + batch_list.append(sorted_mega_batch[start:end]) + start = end + + return batch_list + + +# + + +def make_prediction_texts_wdense( + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], +) -> List[str]: + + batchsize = len(batch) + flat_texts: List[str] = [] + + for i in range(batchsize): + batch_element = batch[i][0] + flat_texts.extend(batch_element) + + if len(batch[i][0]) != len(batch[i][1]): + raise RuntimeError( + "This is not good. texts/dense client batch length mismatch" + ) + + if len(flat_texts) == 0: + raise RuntimeError("This is not good. Empty request batch.") + + return flat_texts + + +def make_prediction_wtexts_dense( + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], +) -> List[List[float]]: + + batchsize = len(batch) + flat_dense: List[List[float]] = [] + + for i in range(batchsize): + batch_element = batch[i][1] + flat_dense.extend(batch_element) + + if len(batch[i][0]) != len(batch[i][1]): + raise RuntimeError( + "This is not good. texts/dense client batch length mismatch" + ) + + if len(flat_dense) == 0: + raise RuntimeError("This is not good. Empty request batch.") + + return flat_dense + + +def make_batch_texts_dense( + tensorizer: ScriptTensorizer, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], +) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + # The next lines sort all cross-request batch elements by the token length. + # Note that cross-request batch element can in turn be a client batch. + mega_batch_key_list = [ + (max_tokens(tensorizer.tokenize(x[0], None)), n) + for (n, x) in enumerate(mega_batch) + ] + sorted_mega_batch_key_list = sorted(mega_batch_key_list) + sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] + + # TBD: allow model server to specify batch size in goals dictionary + max_bs: int = 10 + len_mb = len(mega_batch) + num_batches = (len_mb + max_bs - 1) // max_bs + + batch_list: List[ + List[ + Tuple[ + List[str], # texts + int, # position + ] + ] + ] = [] + + start = 0 + + for _i in range(num_batches): + end = min(start + max_bs, len_mb) + batch_list.append(sorted_mega_batch[start:end]) + start = end + + return batch_list + + +# diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index 40d30f80b..52e2a28a0 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -3,7 +3,16 @@ from typing import Dict, List, Optional, Tuple import torch -from pytext.torchscript.batchutils import max_tokens +from pytext.torchscript.batchutils import ( + max_tokens, + make_prediction_texts, + make_batch_texts, + make_prediction_texts_wdense, + make_prediction_wtexts_dense, + make_batch_texts_dense, + destructure_tensor, + destructure_tensor_list, +) from pytext.torchscript.tensorizer.normalizer import VectorNormalizer from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer from pytext.torchscript.utils import ScriptBatchInput, squeeze_1d, squeeze_2d @@ -930,7 +939,7 @@ def set_padding_control(self, dimension: str, control: Optional[List[int]]): self.tensorizer.set_padding_control(dimension, control) @torch.jit.script_method - def forward(self, texts: List[str] = None): + def forward(self, texts: List[str]): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), @@ -940,6 +949,36 @@ def forward(self, texts: List[str] = None): logits = self.model(input_tensors) return self.output_layer(logits) + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], int,]]]: # texts + + return make_batch_texts(self.tensorizer, mega_batch, goals) + class PyTextModuleWithDense(PyTextModule): def __init__( @@ -973,8 +1012,41 @@ def forward( logits = self.model(input_tensors, dense_tensor) return self.output_layer(logits) + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts_wdense(batch), + dense_feat=make_prediction_wtexts_dense(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) -class PytextTwoTowerModule(torch.jit.ScriptModule): + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + return make_batch_texts_dense(self.tensorizer, mega_batch, goals) + + +class PyTextTwoTowerBaseModule(torch.jit.ScriptModule): @torch.jit.script_method def set_device(self, device: str): self.right_tensorizer.set_device(device) @@ -990,8 +1062,84 @@ def set_padding_control(self, dimension: str, control: Optional[List[int]]): self.right_tensorizer.set_padding_control(dimension, control) self.left_tensorizer.set_padding_control(dimension, control) + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + ] + ], + ) -> List[torch.Tensor]: + + batchsize = len(batch) + + flat_right_texts: List[str] = [] + flat_left_texts: List[str] = [] + + for i in range(batchsize): + batch_right_element = batch[i][0] + batch_left_element = batch[i][1] + + flat_right_texts.extend(batch_right_element) + flat_left_texts.extend(batch_left_element) + + flat_result: torch.Tensor = self.forward( + right_texts=flat_right_texts, + left_texts=flat_left_texts, + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], List[str], int,]]]: # right_texts # left_texts + + # The next lines sort all cross-request batch elements by the token length of right_. + # Note that cross-request batch element can in turn be a client batch. + mega_batch_key_list = [ + (max_tokens(self.right_tensorizer.tokenize(x[0], None)), n) + for (n, x) in enumerate(mega_batch) + ] + sorted_mega_batch_key_list = sorted(mega_batch_key_list) + sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] + + # TBD: allow model server to specify batch size in goals dictionary + max_bs: int = 10 + len_mb = len(mega_batch) + num_batches = (len_mb + max_bs - 1) // max_bs + + batch_list: List[ + List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + int, # position + ] + ] + ] = [] + + start = 0 + + for _i in range(num_batches): + end = min(start + max_bs, len_mb) + batch_list.append(sorted_mega_batch[start:end]) + start = end + + return batch_list + -class PyTextTwoTowerModule(PytextTwoTowerModule): +class PyTextTwoTowerModule(PyTextTwoTowerBaseModule): def __init__( self, model: torch.jit.ScriptModule, @@ -1044,10 +1192,10 @@ def __init__( @torch.jit.script_method def forward( self, - right_dense_feat: List[List[float]], - left_dense_feat: List[List[float]], right_texts: List[str], left_texts: List[str], + right_dense_feat: List[List[float]], + left_dense_feat: List[List[float]], ): right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), @@ -1079,96 +1227,55 @@ def forward( ) return self.output_layer(logits) - -class PyTextEmbeddingModule(ScriptModule): - def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): - super().__init__() - self.model = model - self.tensorizer = tensorizer - self.argno = -1 - log_class_usage(self.__class__) - - @torch.jit.script_method - def set_padding_control(self, dimension: str, control: Optional[List[int]]): - """ - This functions will be called to set a padding style. - None - No padding - List: first element 0, round seq length to the smallest list element larger than inputs - """ - self.tensorizer.set_padding_control(dimension, control) - - @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput): - input_tensors = self.tensorizer(inputs) - return self.model(input_tensors).cpu() - - @torch.jit.script_method - def forward( - self, - texts: List[str], - dense_feat: Optional[List[List[float]]] = None, - ) -> torch.Tensor: - inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(texts, None), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - return self._forward(inputs) - @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ - List[str], # texts - Optional[List[List[float]]], # dense_feat + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat ] ], ) -> List[torch.Tensor]: batchsize = len(batch) - client_batch: List[int] = [] - res_list: List[torch.Tensor] = [] - - flat_texts: List[str] = [] + flat_right_texts: List[str] = [] + flat_left_texts: List[str] = [] + flat_right_dense: List[List[float]] = [] + flat_left_dense: List[List[float]] = [] for i in range(batchsize): - batch_element = batch[i][0] - flat_texts.extend(batch_element) - client_batch.append(len(batch_element)) - - if batch[i][1] is not None: - # Cross-request batching not yet supported for requests with dense_feat - raise RuntimeError("Malformed request.") + batch_right_element = batch[i][0] + batch_left_element = batch[i][1] + batch_right_dense_element = batch[i][2] + batch_left_dense_element = batch[i][3] - if len(flat_texts) == 0: - raise RuntimeError("This is not good. Empty request batch.") + flat_right_texts.extend(batch_right_element) + flat_left_texts.extend(batch_left_element) + flat_right_dense.extend(batch_right_dense_element) + flat_left_dense.extend(batch_left_dense_element) flat_result: torch.Tensor = self.forward( - texts=flat_texts, - dense_feat=None, + right_texts=flat_right_texts, + left_texts=flat_left_texts, + right_dense_feat=flat_right_dense, + left_dense_feat=flat_left_dense, ) - # destructure flat result tensor combining - # cross-request batches and client side - # batches into a cross-request list of - # client-side batch tensors - start = 0 - for elems in client_batch: - end = start + elems - res_list.append(flat_result.narrow(0, start, elems)) - start = end - - return res_list + return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ - List[str], # texts - Optional[List[List[float]]], # dense_feat + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat int, ] ], @@ -1176,17 +1283,19 @@ def make_batch( ) -> List[ List[ Tuple[ - List[str], # texts - Optional[List[List[float]]], # dense_feat + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat int, ] ] ]: - # The next lines sort all cross-request batch elements by the token length. + # The next lines sort all cross-request batch elements by the token length of right_. # Note that cross-request batch element can in turn be a client batch. mega_batch_key_list = [ - (max_tokens(self.tensorizer.tokenize(x[0], None)), n) + (max_tokens(self.right_tensorizer.tokenize(x[0], None)), n) for (n, x) in enumerate(mega_batch) ] sorted_mega_batch_key_list = sorted(mega_batch_key_list) @@ -1200,8 +1309,10 @@ def make_batch( batch_list: List[ List[ Tuple[ - List[str], # texts - Optional[List[List[float]]], # dense_feat + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat int, # position ] ] @@ -1217,6 +1328,73 @@ def make_batch( return batch_list +class PyTextEmbeddingModule(ScriptModule): + def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): + super().__init__() + self.model = model + self.tensorizer = tensorizer + self.argno = -1 + log_class_usage(self.__class__) + + @torch.jit.script_method + def set_padding_control(self, dimension: str, control: Optional[List[int]]): + """ + This functions will be called to set a padding style. + None - No padding + List: first element 0, round seq length to the smallest list element larger than inputs + """ + self.tensorizer.set_padding_control(dimension, control) + + @torch.jit.script_method + def _forward(self, inputs: ScriptBatchInput): + input_tensors = self.tensorizer(inputs) + return self.model(input_tensors).cpu() + + @torch.jit.script_method + def forward( + self, + texts: List[str], + # the following arg was being ignored in the definition + # dense_feat: Optional[List[List[float]]] = None, + ) -> torch.Tensor: + inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(texts, None), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + return self._forward(inputs) + + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], int,]]]: # texts + + return make_batch_texts(self.tensorizer, mega_batch, goals) + + class PyTextEmbeddingModuleIndex(PyTextEmbeddingModule): def __init__( self, @@ -1258,10 +1436,8 @@ def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): def forward( self, texts: List[str], - dense_feat: Optional[List[List[float]]] = None, + dense_feat: List[List[float]], ) -> torch.Tensor: - if dense_feat is None: - raise RuntimeError("Expect dense feature.") inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), @@ -1278,6 +1454,39 @@ def forward( else: return sentence_embedding + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts_wdense(batch), + dense_feat=make_prediction_wtexts_dense(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + return make_batch_texts_dense(self.tensorizer, mega_batch, goals) + class PyTextEmbeddingModuleWithDenseIndex(PyTextEmbeddingModuleWithDense): def __init__( @@ -1323,7 +1532,8 @@ def _forward(self, inputs: ScriptBatchInput): def forward( self, texts: List[str], - dense_feat: Optional[List[List[float]]] = None, + # the following is ignored by forward implementation. drop it + # dense_feat: Optional[List[List[float]]] = None, ) -> List[torch.Tensor]: inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), @@ -1338,52 +1548,18 @@ def make_prediction( batch: List[ Tuple[ List[str], # texts - Optional[List[List[float]]], # dense_feat ] ], ) -> List[List[torch.Tensor]]: - client_batch: List[int] = [] - res_list: List[List[torch.Tensor]] = [] - - flat_texts: List[str] = [] - - for be in batch: - batch_element = be[0] - if batch_element is not None: - flat_texts.extend(batch_element) - client_batch.append(len(batch_element)) - - if be[1] is not None: - raise RuntimeError( - "Malformed request: desnse_feat not supported for cross-request batching in VE Module" - ) - - if len(flat_texts) == 0: - raise RuntimeError("This is not good. Empty request batch.") - flat_result: List[torch.Tensor] = self.forward( - texts=flat_texts, - multi_texts=None, - tokens=None, - languages=None, - dense_feat=None, + texts=make_prediction_texts(batch), ) - # destructure flat result list combining - # cross-request batches and client side - # batches into a cross-request list of - # client-side batch result lists - start = 0 - for elems in client_batch: - end = start + elems - res_list.append(flat_result[start:end]) - start = end - - return res_list + return destructure_tensor_list([len(be[0]) for be in batch], flat_result) -class PyTextTwoTowerEmbeddingModule(PyTextTwoTowerModule): +class PyTextTwoTowerEmbeddingModule(PyTextTwoTowerBaseModule): def __init__( self, model: torch.jit.ScriptModule, @@ -1394,19 +1570,8 @@ def __init__( self.model = model self.right_tensorizer = right_tensorizer self.left_tensorizer = left_tensorizer - self.argno = -1 log_class_usage(self.__class__) - @torch.jit.script_method - def set_padding_control(self, dimension: str, control: Optional[List[int]]): - """ - This functions will be called to set a padding style. - None - No padding - List: first element 0, round seq length to the smallest list element larger than inputs - """ - self.right_tensorizer.set_padding_control(dimension, control) - self.left_tensorizer.set_padding_control(dimension, control) - @torch.jit.script_method def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): right_input_tensors = self.right_tensorizer(right_inputs) @@ -1432,119 +1597,6 @@ def forward( ) return self._forward(right_inputs, left_inputs) - @torch.jit.script_method - def make_prediction( - self, - batch: List[ - Tuple[ - List[str], # right_texts - List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat - ] - ], - ) -> List[torch.Tensor]: - - batchsize = len(batch) - - client_batch: List[int] = [] - res_list: List[torch.Tensor] = [] - - flat_right_texts: List[str] = [] - flat_left_texts: List[str] = [] - - for i in range(batchsize): - batch_right_element = batch[i][0] - batch_left_element = batch[i][1] - - flat_right_texts.extend(batch_right_element) - flat_left_texts.extend(batch_left_element) - client_batch.append(len(batch_right_element)) - - if batch[i][2] is not None or batch[i][3] is not None: - raise RuntimeError( - "Cross-request batching not supported for desne_feat" - ) - - flat_result: torch.Tensor = self.forward( - right_texts=flat_right_texts, - left_texts=flat_left_texts, - right_dense_feat=None, - left_dense_feat=None, - ) - - # destructure flat result tensor combining - # cross-request batches and client side - # batches into a cross-request list of - # client-side batch tensors - start = 0 - for elems in client_batch: - end = start + elems - res_list.append(flat_result.narrow(0, start, elems)) - start = end - - return res_list - - @torch.jit.script_method - def make_batch( - self, - mega_batch: List[ - Tuple[ - List[str], # right_texts - List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat - int, - ] - ], - goals: Dict[str, str], - ) -> List[ - List[ - Tuple[ - List[str], # right_texts - List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat - int, - ] - ] - ]: - - # The next lines sort all cross-request batch elements by the token length of right_. - # Note that cross-request batch element can in turn be a client batch. - mega_batch_key_list = [ - (max_tokens(self.right_tensorizer.tokenize(x[0], None)), n) - for (n, x) in enumerate(mega_batch) - ] - sorted_mega_batch_key_list = sorted(mega_batch_key_list) - sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] - - # TBD: allow model server to specify batch size in goals dictionary - max_bs: int = 10 - len_mb = len(mega_batch) - num_batches = (len_mb + max_bs - 1) // max_bs - - batch_list: List[ - List[ - Tuple[ - List[str], # right_texts - List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat - int, # position - ] - ] - ] = [] - - start = 0 - - for _i in range(num_batches): - end = min(start + max_bs, len_mb) - batch_list.append(sorted_mega_batch[start:end]) - start = end - - return batch_list - class PyTextTwoTowerEmbeddingModuleWithDense(PyTextTwoTowerEmbeddingModule): def __init__( @@ -1588,11 +1640,9 @@ def forward( self, right_texts: List[str], left_texts: List[str], - right_dense_feat: Optional[List[List[float]]] = None, - left_dense_feat: Optional[List[List[float]]] = None, + right_dense_feat: List[List[float]], + left_dense_feat: List[List[float]], ) -> torch.Tensor: - if right_dense_feat is None or left_dense_feat is None: - raise RuntimeError("Expect dense feature.") right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), From a22a4318259e60ca811cb44612012c6d1f0eeda1 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 10 Nov 2020 15:50:31 -0800 Subject: [PATCH 2/2] Reorder Pytext classes Summary: Changed order of class definition for pytext classes. No functional change, just moving text around Reviewed By: gunchu Differential Revision: D24868080 fbshipit-source-id: d6d5e33669952869b325708251fcdbbccb717fa9 --- pytext/torchscript/module.py | 704 ++++++++++++++++++----------------- 1 file changed, 355 insertions(+), 349 deletions(-) diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index 52e2a28a0..cef085135 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -916,6 +916,75 @@ def forward( # * Sequence length and batch size padding for accelerators # +############################################ +# Pytext Classes: + + +class PyTextEmbeddingModule(ScriptModule): + def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): + super().__init__() + self.model = model + self.tensorizer = tensorizer + log_class_usage(self.__class__) + + @torch.jit.script_method + def set_padding_control(self, dimension: str, control: Optional[List[int]]): + """ + This functions will be called to set a padding style. + None - No padding + List: first element 0, round seq length to the smallest list element larger than inputs + """ + self.tensorizer.set_padding_control(dimension, control) + + @torch.jit.script_method + def _forward(self, inputs: ScriptBatchInput): + input_tensors = self.tensorizer(inputs) + return self.model(input_tensors).cpu() + + @torch.jit.script_method + def forward( + self, + texts: List[str], + # the following arg was being ignored in the definition + # dense_feat: Optional[List[List[float]]] = None, + ) -> torch.Tensor: + inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(texts, None), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + return self._forward(inputs) + + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], int,]]]: # texts + + return make_batch_texts(self.tensorizer, mega_batch, goals) + class PyTextModule(ScriptModule): def __init__( @@ -980,6 +1049,99 @@ def make_batch( return make_batch_texts(self.tensorizer, mega_batch, goals) +class PyTextEmbeddingModuleIndex(PyTextEmbeddingModule): + def __init__( + self, + model: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + index: int = 0, + ): + super().__init__(model, tensorizer) + self.index = torch.jit.Attribute(index, int) + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward(self, inputs: ScriptBatchInput): + input_tensors = self.tensorizer(inputs) + return self.model(input_tensors)[self.index].cpu() + + +class PyTextEmbeddingModuleWithDense(PyTextEmbeddingModule): + def __init__( + self, + model: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + normalizer: VectorNormalizer, + concat_dense: bool = False, + ): + super().__init__(model, tensorizer) + self.normalizer = normalizer + self.concat_dense = torch.jit.Attribute(concat_dense, bool) + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): + input_tensors = self.tensorizer(inputs) + if self.tensorizer.device != "": + dense_tensor = dense_tensor.to(self.tensorizer.device) + return self.model(input_tensors, dense_tensor).cpu() + + @torch.jit.script_method + def forward( + self, + texts: List[str], + dense_feat: List[List[float]], + ) -> torch.Tensor: + + inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(texts, None), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + # call model + dense_feat = self.normalizer.normalize(dense_feat) + dense_tensor = torch.tensor(dense_feat, dtype=torch.float) + + sentence_embedding = self._forward(inputs, dense_tensor) + if self.concat_dense: + return torch.cat([sentence_embedding, dense_tensor], 1) + else: + return sentence_embedding + + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts_wdense(batch), + dense_feat=make_prediction_wtexts_dense(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + return make_batch_texts_dense(self.tensorizer, mega_batch, goals) + + class PyTextModuleWithDense(PyTextModule): def __init__( self, @@ -1046,6 +1208,81 @@ def make_batch( return make_batch_texts_dense(self.tensorizer, mega_batch, goals) +class PyTextEmbeddingModuleWithDenseIndex(PyTextEmbeddingModuleWithDense): + def __init__( + self, + model: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + normalizer: VectorNormalizer, + index: int = 0, + concat_dense: bool = True, + ): + super().__init__(model, tensorizer, normalizer, concat_dense) + self.index = torch.jit.Attribute(index, int) + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): + input_tensors = self.tensorizer(inputs) + if self.tensorizer.device != "": + dense_tensor = dense_tensor.to(self.tensorizer.device) + return self.model(input_tensors, dense_tensor)[self.index].cpu() + + +class PyTextVariableSizeEmbeddingModule(PyTextEmbeddingModule): + """ + Assumes model returns a tuple of representations and sequence lengths, then slices + each example's representation according to length. Returns a list of tensors. The + slicing is easier to do outside a traced model. + """ + + def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): + super().__init__(model, tensorizer) + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward(self, inputs: ScriptBatchInput): + input_tensors = self.tensorizer(inputs) + reps, seq_lens = self.model(input_tensors) + reps = reps.cpu() + seq_lens = seq_lens.cpu() + return [reps[i, : seq_lens[i]] for i in range(len(seq_lens))] + + @torch.jit.script_method + def forward( + self, + texts: List[str], + # the following is ignored by forward implementation. drop it + # dense_feat: Optional[List[List[float]]] = None, + ) -> List[torch.Tensor]: + inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(texts, None), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + return self._forward(inputs) + + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + ] + ], + ) -> List[List[torch.Tensor]]: + + flat_result: List[torch.Tensor] = self.forward( + texts=make_prediction_texts(batch), + ) + + return destructure_tensor_list([len(be[0]) for be in batch], flat_result) + + +############################################ +# PytextTwoTower Classes: + + class PyTextTwoTowerBaseModule(torch.jit.ScriptModule): @torch.jit.script_method def set_device(self, device: str): @@ -1139,51 +1376,158 @@ def make_batch( return batch_list -class PyTextTwoTowerModule(PyTextTwoTowerBaseModule): +class PyTextTwoTowerEmbeddingModule(PyTextTwoTowerBaseModule): def __init__( self, model: torch.jit.ScriptModule, - output_layer: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, ): super().__init__() self.model = model - self.output_layer = output_layer self.right_tensorizer = right_tensorizer self.left_tensorizer = left_tensorizer + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): + right_input_tensors = self.right_tensorizer(right_inputs) + left_input_tensors = self.left_tensorizer(left_inputs) + + return self.model(right_input_tensors, left_input_tensors).cpu() @torch.jit.script_method def forward( self, right_texts: List[str], left_texts: List[str], - ): + ) -> torch.Tensor: right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) - right_input_tensors = self.right_tensorizer(right_inputs) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) - left_input_tensors = self.left_tensorizer(left_inputs) - logits = self.model(right_input_tensors, left_input_tensors) - return self.output_layer(logits) + return self._forward(right_inputs, left_inputs) -class PyTextTwoTowerModuleWithDense(PyTextTwoTowerModule): +class PyTextTwoTowerModule(PyTextTwoTowerBaseModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, - right_normalizer: VectorNormalizer, - left_normalizer: VectorNormalizer, + ): + super().__init__() + self.model = model + self.output_layer = output_layer + self.right_tensorizer = right_tensorizer + self.left_tensorizer = left_tensorizer + + @torch.jit.script_method + def forward( + self, + right_texts: List[str], + left_texts: List[str], + ): + right_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(right_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + right_input_tensors = self.right_tensorizer(right_inputs) + left_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(left_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + left_input_tensors = self.left_tensorizer(left_inputs) + logits = self.model(right_input_tensors, left_input_tensors) + return self.output_layer(logits) + + +class PyTextTwoTowerEmbeddingModuleWithDense(PyTextTwoTowerEmbeddingModule): + def __init__( + self, + model: torch.jit.ScriptModule, + right_tensorizer: ScriptTensorizer, + left_tensorizer: ScriptTensorizer, + right_normalizer: VectorNormalizer, + left_normalizer: VectorNormalizer, + ): + super().__init__(model, right_tensorizer, left_tensorizer) + self.right_normalizer = right_normalizer + self.left_normalizer = left_normalizer + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward( + self, + right_inputs: ScriptBatchInput, + left_inputs: ScriptBatchInput, + right_dense_tensor: torch.Tensor, + left_dense_tensor: torch.Tensor, + ): + right_input_tensors = self.right_tensorizer(right_inputs) + left_input_tensors = self.left_tensorizer(left_inputs) + + if self.right_tensorizer.device != "": + right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) + if self.left_tensorizer.device != "": + left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) + + return self.model( + right_input_tensors, + left_input_tensors, + right_dense_tensor, + left_dense_tensor, + ).cpu() + + @torch.jit.script_method + def forward( + self, + right_texts: List[str], + left_texts: List[str], + right_dense_feat: List[List[float]], + left_dense_feat: List[List[float]], + ) -> torch.Tensor: + + right_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(right_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + left_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(left_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + + right_dense_feat = self.right_normalizer.normalize(right_dense_feat) + left_dense_feat = self.left_normalizer.normalize(left_dense_feat) + right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) + left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) + + sentence_embedding = self._forward( + right_inputs, left_inputs, right_dense_tensor, left_dense_tensor + ) + return sentence_embedding + + +class PyTextTwoTowerModuleWithDense(PyTextTwoTowerModule): + def __init__( + self, + model: torch.jit.ScriptModule, + output_layer: torch.jit.ScriptModule, + right_tensorizer: ScriptTensorizer, + left_tensorizer: ScriptTensorizer, + right_normalizer: VectorNormalizer, + left_normalizer: VectorNormalizer, ): super().__init__(model, output_layer, right_tensorizer, left_tensorizer) self.right_normalizer = right_normalizer @@ -1326,341 +1670,3 @@ def make_batch( start = end return batch_list - - -class PyTextEmbeddingModule(ScriptModule): - def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): - super().__init__() - self.model = model - self.tensorizer = tensorizer - self.argno = -1 - log_class_usage(self.__class__) - - @torch.jit.script_method - def set_padding_control(self, dimension: str, control: Optional[List[int]]): - """ - This functions will be called to set a padding style. - None - No padding - List: first element 0, round seq length to the smallest list element larger than inputs - """ - self.tensorizer.set_padding_control(dimension, control) - - @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput): - input_tensors = self.tensorizer(inputs) - return self.model(input_tensors).cpu() - - @torch.jit.script_method - def forward( - self, - texts: List[str], - # the following arg was being ignored in the definition - # dense_feat: Optional[List[List[float]]] = None, - ) -> torch.Tensor: - inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(texts, None), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - return self._forward(inputs) - - @torch.jit.script_method - def make_prediction( - self, - batch: List[ - Tuple[ - List[str], # texts - ] - ], - ) -> List[torch.Tensor]: - - flat_result: torch.Tensor = self.forward( - texts=make_prediction_texts(batch), - ) - - return destructure_tensor([len(be[0]) for be in batch], flat_result) - - @torch.jit.script_method - def make_batch( - self, - mega_batch: List[ - Tuple[ - List[str], # texts - int, - ] - ], - goals: Dict[str, str], - ) -> List[List[Tuple[List[str], int,]]]: # texts - - return make_batch_texts(self.tensorizer, mega_batch, goals) - - -class PyTextEmbeddingModuleIndex(PyTextEmbeddingModule): - def __init__( - self, - model: torch.jit.ScriptModule, - tensorizer: ScriptTensorizer, - index: int = 0, - ): - super().__init__(model, tensorizer) - self.index = torch.jit.Attribute(index, int) - log_class_usage(self.__class__) - - @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput): - input_tensors = self.tensorizer(inputs) - return self.model(input_tensors)[self.index].cpu() - - -class PyTextEmbeddingModuleWithDense(PyTextEmbeddingModule): - def __init__( - self, - model: torch.jit.ScriptModule, - tensorizer: ScriptTensorizer, - normalizer: VectorNormalizer, - concat_dense: bool = False, - ): - super().__init__(model, tensorizer) - self.normalizer = normalizer - self.concat_dense = torch.jit.Attribute(concat_dense, bool) - log_class_usage(self.__class__) - - @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): - input_tensors = self.tensorizer(inputs) - if self.tensorizer.device != "": - dense_tensor = dense_tensor.to(self.tensorizer.device) - return self.model(input_tensors, dense_tensor).cpu() - - @torch.jit.script_method - def forward( - self, - texts: List[str], - dense_feat: List[List[float]], - ) -> torch.Tensor: - - inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(texts, None), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - # call model - dense_feat = self.normalizer.normalize(dense_feat) - dense_tensor = torch.tensor(dense_feat, dtype=torch.float) - - sentence_embedding = self._forward(inputs, dense_tensor) - if self.concat_dense: - return torch.cat([sentence_embedding, dense_tensor], 1) - else: - return sentence_embedding - - @torch.jit.script_method - def make_prediction( - self, - batch: List[ - Tuple[ - List[str], # texts - List[List[float]], # dense - ] - ], - ) -> List[torch.Tensor]: - - flat_result: torch.Tensor = self.forward( - texts=make_prediction_texts_wdense(batch), - dense_feat=make_prediction_wtexts_dense(batch), - ) - - return destructure_tensor([len(be[0]) for be in batch], flat_result) - - @torch.jit.script_method - def make_batch( - self, - mega_batch: List[ - Tuple[ - List[str], # texts - List[List[float]], # dense - int, - ] - ], - goals: Dict[str, str], - ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense - - return make_batch_texts_dense(self.tensorizer, mega_batch, goals) - - -class PyTextEmbeddingModuleWithDenseIndex(PyTextEmbeddingModuleWithDense): - def __init__( - self, - model: torch.jit.ScriptModule, - tensorizer: ScriptTensorizer, - normalizer: VectorNormalizer, - index: int = 0, - concat_dense: bool = True, - ): - super().__init__(model, tensorizer, normalizer, concat_dense) - self.index = torch.jit.Attribute(index, int) - log_class_usage(self.__class__) - - @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): - input_tensors = self.tensorizer(inputs) - if self.tensorizer.device != "": - dense_tensor = dense_tensor.to(self.tensorizer.device) - return self.model(input_tensors, dense_tensor)[self.index].cpu() - - -class PyTextVariableSizeEmbeddingModule(PyTextEmbeddingModule): - """ - Assumes model returns a tuple of representations and sequence lengths, then slices - each example's representation according to length. Returns a list of tensors. The - slicing is easier to do outside a traced model. - """ - - def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): - super().__init__(model, tensorizer) - log_class_usage(self.__class__) - - @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput): - input_tensors = self.tensorizer(inputs) - reps, seq_lens = self.model(input_tensors) - reps = reps.cpu() - seq_lens = seq_lens.cpu() - return [reps[i, : seq_lens[i]] for i in range(len(seq_lens))] - - @torch.jit.script_method - def forward( - self, - texts: List[str], - # the following is ignored by forward implementation. drop it - # dense_feat: Optional[List[List[float]]] = None, - ) -> List[torch.Tensor]: - inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(texts, None), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - return self._forward(inputs) - - @torch.jit.script_method - def make_prediction( - self, - batch: List[ - Tuple[ - List[str], # texts - ] - ], - ) -> List[List[torch.Tensor]]: - - flat_result: List[torch.Tensor] = self.forward( - texts=make_prediction_texts(batch), - ) - - return destructure_tensor_list([len(be[0]) for be in batch], flat_result) - - -class PyTextTwoTowerEmbeddingModule(PyTextTwoTowerBaseModule): - def __init__( - self, - model: torch.jit.ScriptModule, - right_tensorizer: ScriptTensorizer, - left_tensorizer: ScriptTensorizer, - ): - super().__init__() - self.model = model - self.right_tensorizer = right_tensorizer - self.left_tensorizer = left_tensorizer - log_class_usage(self.__class__) - - @torch.jit.script_method - def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): - right_input_tensors = self.right_tensorizer(right_inputs) - left_input_tensors = self.left_tensorizer(left_inputs) - - return self.model(right_input_tensors, left_input_tensors).cpu() - - @torch.jit.script_method - def forward( - self, - right_texts: List[str], - left_texts: List[str], - ) -> torch.Tensor: - right_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(right_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - left_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(left_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - return self._forward(right_inputs, left_inputs) - - -class PyTextTwoTowerEmbeddingModuleWithDense(PyTextTwoTowerEmbeddingModule): - def __init__( - self, - model: torch.jit.ScriptModule, - right_tensorizer: ScriptTensorizer, - left_tensorizer: ScriptTensorizer, - right_normalizer: VectorNormalizer, - left_normalizer: VectorNormalizer, - ): - super().__init__(model, right_tensorizer, left_tensorizer) - self.right_normalizer = right_normalizer - self.left_normalizer = left_normalizer - log_class_usage(self.__class__) - - @torch.jit.script_method - def _forward( - self, - right_inputs: ScriptBatchInput, - left_inputs: ScriptBatchInput, - right_dense_tensor: torch.Tensor, - left_dense_tensor: torch.Tensor, - ): - right_input_tensors = self.right_tensorizer(right_inputs) - left_input_tensors = self.left_tensorizer(left_inputs) - - if self.right_tensorizer.device != "": - right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) - if self.left_tensorizer.device != "": - left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) - - return self.model( - right_input_tensors, - left_input_tensors, - right_dense_tensor, - left_dense_tensor, - ).cpu() - - @torch.jit.script_method - def forward( - self, - right_texts: List[str], - left_texts: List[str], - right_dense_feat: List[List[float]], - left_dense_feat: List[List[float]], - ) -> torch.Tensor: - - right_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(right_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - left_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(left_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - - right_dense_feat = self.right_normalizer.normalize(right_dense_feat) - left_dense_feat = self.left_normalizer.normalize(left_dense_feat) - right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) - left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) - - sentence_embedding = self._forward( - right_inputs, left_inputs, right_dense_tensor, left_dense_tensor - ) - return sentence_embedding