diff --git a/pytext/torchscript/batchutils.py b/pytext/torchscript/batchutils.py index fd2a99ba7..861221570 100644 --- a/pytext/torchscript/batchutils.py +++ b/pytext/torchscript/batchutils.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from typing import List, Tuple + +from typing import Dict, List, Tuple +from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer + +import torch def max_tokens(per_sentence_tokens: List[List[Tuple[str, int, int]]]) -> int: @@ -12,3 +16,223 @@ def max_tokens(per_sentence_tokens: List[List[Tuple[str, int, int]]]) -> int: sentence_lengths = [len(sentence) for sentence in per_sentence_tokens] return max(sentence_lengths) + + +######################################################################## +# +# utility functions to destructure flat result tensor combining +# cross-request batches and client side +# batches into a cross-request list of +# client-side batch tensors +# + + +def destructure_tensor( + client_batch: List[int], + result_tensor: torch.Tensor, +) -> List[torch.Tensor]: + start = 0 + res_list: List[torch.Tensor] = [] + + for elems in client_batch: + end = start + elems + res_list.append(result_tensor.narrow(0, start, elems)) + start = end + + return res_list + + +def destructure_tensor_list( + client_batch: List[int], + result_tensor_list: List[torch.Tensor], +) -> List[List[torch.Tensor]]: + res_list: List[List[torch.Tensor]] = [] + start = 0 + + for elems in client_batch: + end = start + elems + res_list.append(result_tensor_list[start:end]) + start = end + + return res_list + + +############################################################################ +# +# make_prediction_* () +# utility functions to collect inputs from multiple batch elements into a +# a single cross request batch +# +# make_batch_* () +# utility functions for batch optimizations +# + + +def make_prediction_texts( + batch: List[ + Tuple[ + List[str], # texts + ] + ], +) -> List[str]: + + batchsize = len(batch) + flat_texts: List[str] = [] + + for i in range(batchsize): + batch_element = batch[i][0] + flat_texts.extend(batch_element) + + if len(flat_texts) == 0: + raise RuntimeError("This is not good. Empty request batch.") + + return flat_texts + + +def make_batch_texts( + tensorizer: ScriptTensorizer, + mega_batch: List[ + Tuple[ + List[str], # texts + int, + ] + ], + goals: Dict[str, str], +) -> List[List[Tuple[List[str], int,]]]: # texts + + # The next lines sort all cross-request batch elements by the token length. + # Note that cross-request batch element can in turn be a client batch. + mega_batch_key_list = [ + (max_tokens(tensorizer.tokenize(x[0], None)), n) + for (n, x) in enumerate(mega_batch) + ] + sorted_mega_batch_key_list = sorted(mega_batch_key_list) + sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] + + # TBD: allow model server to specify batch size in goals dictionary + max_bs: int = 10 + len_mb = len(mega_batch) + num_batches = (len_mb + max_bs - 1) // max_bs + + batch_list: List[ + List[ + Tuple[ + List[str], # texts + int, # position + ] + ] + ] = [] + + start = 0 + + for _i in range(num_batches): + end = min(start + max_bs, len_mb) + batch_list.append(sorted_mega_batch[start:end]) + start = end + + return batch_list + + +# + + +def make_prediction_texts_wdense( + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], +) -> List[str]: + + batchsize = len(batch) + flat_texts: List[str] = [] + + for i in range(batchsize): + batch_element = batch[i][0] + flat_texts.extend(batch_element) + + if len(batch[i][0]) != len(batch[i][1]): + raise RuntimeError( + "This is not good. texts/dense client batch length mismatch" + ) + + if len(flat_texts) == 0: + raise RuntimeError("This is not good. Empty request batch.") + + return flat_texts + + +def make_prediction_wtexts_dense( + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], +) -> List[List[float]]: + + batchsize = len(batch) + flat_dense: List[List[float]] = [] + + for i in range(batchsize): + batch_element = batch[i][1] + flat_dense.extend(batch_element) + + if len(batch[i][0]) != len(batch[i][1]): + raise RuntimeError( + "This is not good. texts/dense client batch length mismatch" + ) + + if len(flat_dense) == 0: + raise RuntimeError("This is not good. Empty request batch.") + + return flat_dense + + +def make_batch_texts_dense( + tensorizer: ScriptTensorizer, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], +) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + # The next lines sort all cross-request batch elements by the token length. + # Note that cross-request batch element can in turn be a client batch. + mega_batch_key_list = [ + (max_tokens(tensorizer.tokenize(x[0], None)), n) + for (n, x) in enumerate(mega_batch) + ] + sorted_mega_batch_key_list = sorted(mega_batch_key_list) + sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] + + # TBD: allow model server to specify batch size in goals dictionary + max_bs: int = 10 + len_mb = len(mega_batch) + num_batches = (len_mb + max_bs - 1) // max_bs + + batch_list: List[ + List[ + Tuple[ + List[str], # texts + int, # position + ] + ] + ] = [] + + start = 0 + + for _i in range(num_batches): + end = min(start + max_bs, len_mb) + batch_list.append(sorted_mega_batch[start:end]) + start = end + + return batch_list + + +# diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index 40d30f80b..cef085135 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -3,7 +3,16 @@ from typing import Dict, List, Optional, Tuple import torch -from pytext.torchscript.batchutils import max_tokens +from pytext.torchscript.batchutils import ( + max_tokens, + make_prediction_texts, + make_batch_texts, + make_prediction_texts_wdense, + make_prediction_wtexts_dense, + make_batch_texts_dense, + destructure_tensor, + destructure_tensor_list, +) from pytext.torchscript.tensorizer.normalizer import VectorNormalizer from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer from pytext.torchscript.utils import ScriptBatchInput, squeeze_1d, squeeze_2d @@ -907,18 +916,16 @@ def forward( # * Sequence length and batch size padding for accelerators # +############################################ +# Pytext Classes: -class PyTextModule(ScriptModule): - def __init__( - self, - model: torch.jit.ScriptModule, - output_layer: torch.jit.ScriptModule, - tensorizer: ScriptTensorizer, - ): + +class PyTextEmbeddingModule(ScriptModule): + def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): super().__init__() self.model = model - self.output_layer = output_layer self.tensorizer = tensorizer + log_class_usage(self.__class__) @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): @@ -930,163 +937,66 @@ def set_padding_control(self, dimension: str, control: Optional[List[int]]): self.tensorizer.set_padding_control(dimension, control) @torch.jit.script_method - def forward(self, texts: List[str] = None): - inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(texts, None), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) + def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) - logits = self.model(input_tensors) - return self.output_layer(logits) - - -class PyTextModuleWithDense(PyTextModule): - def __init__( - self, - model: torch.jit.ScriptModule, - output_layer: torch.jit.ScriptModule, - tensorizer: ScriptTensorizer, - normalizer: VectorNormalizer, - ): - super().__init__(model, output_layer, tensorizer) - self.normalizer = normalizer - log_class_usage(self.__class__) + return self.model(input_tensors).cpu() @torch.jit.script_method def forward( self, texts: List[str], - dense_feat: List[List[float]], - ): + # the following arg was being ignored in the definition + # dense_feat: Optional[List[List[float]]] = None, + ) -> torch.Tensor: inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) - input_tensors = self.tensorizer(inputs) - dense_feat = self.normalizer.normalize(dense_feat) - - dense_tensor = torch.tensor(dense_feat, dtype=torch.float) - if self.tensorizer.device != "": - dense_tensor = dense_tensor.to(self.tensorizer.device) - logits = self.model(input_tensors, dense_tensor) - return self.output_layer(logits) - - -class PytextTwoTowerModule(torch.jit.ScriptModule): - @torch.jit.script_method - def set_device(self, device: str): - self.right_tensorizer.set_device(device) - self.left_tensorizer.set_device(device) + return self._forward(inputs) @torch.jit.script_method - def set_padding_control(self, dimension: str, control: Optional[List[int]]): - """ - This functions will be called to set a padding style. - None - No padding - List: first element 0, round seq length to the smallest list element larger than inputs - """ - self.right_tensorizer.set_padding_control(dimension, control) - self.left_tensorizer.set_padding_control(dimension, control) + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + ] + ], + ) -> List[torch.Tensor]: + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts(batch), + ) -class PyTextTwoTowerModule(PytextTwoTowerModule): - def __init__( - self, - model: torch.jit.ScriptModule, - output_layer: torch.jit.ScriptModule, - right_tensorizer: ScriptTensorizer, - left_tensorizer: ScriptTensorizer, - ): - super().__init__() - self.model = model - self.output_layer = output_layer - self.right_tensorizer = right_tensorizer - self.left_tensorizer = left_tensorizer + return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method - def forward( + def make_batch( self, - right_texts: List[str], - left_texts: List[str], - ): - right_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(right_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - right_input_tensors = self.right_tensorizer(right_inputs) - left_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(left_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - left_input_tensors = self.left_tensorizer(left_inputs) - logits = self.model(right_input_tensors, left_input_tensors) - return self.output_layer(logits) + mega_batch: List[ + Tuple[ + List[str], # texts + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], int,]]]: # texts + return make_batch_texts(self.tensorizer, mega_batch, goals) -class PyTextTwoTowerModuleWithDense(PyTextTwoTowerModule): + +class PyTextModule(ScriptModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, - right_tensorizer: ScriptTensorizer, - left_tensorizer: ScriptTensorizer, - right_normalizer: VectorNormalizer, - left_normalizer: VectorNormalizer, - ): - super().__init__(model, output_layer, right_tensorizer, left_tensorizer) - self.right_normalizer = right_normalizer - self.left_normalizer = left_normalizer - - @torch.jit.script_method - def forward( - self, - right_dense_feat: List[List[float]], - left_dense_feat: List[List[float]], - right_texts: List[str], - left_texts: List[str], + tensorizer: ScriptTensorizer, ): - right_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(right_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - right_input_tensors = self.right_tensorizer(right_inputs) - left_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(left_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - left_input_tensors = self.left_tensorizer(left_inputs) - - right_dense_feat = self.right_normalizer.normalize(right_dense_feat) - left_dense_feat = self.left_normalizer.normalize(left_dense_feat) - right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) - left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) - if self.right_tensorizer.device != "": - right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) - if self.left_tensorizer.device != "": - left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) - - logits = self.model( - right_input_tensors, - left_input_tensors, - right_dense_tensor, - left_dense_tensor, - ) - return self.output_layer(logits) - - -class PyTextEmbeddingModule(ScriptModule): - def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): super().__init__() self.model = model + self.output_layer = output_layer self.tensorizer = tensorizer - self.argno = -1 - log_class_usage(self.__class__) @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): @@ -1098,22 +1008,15 @@ def set_padding_control(self, dimension: str, control: Optional[List[int]]): self.tensorizer.set_padding_control(dimension, control) @torch.jit.script_method - def _forward(self, inputs: ScriptBatchInput): - input_tensors = self.tensorizer(inputs) - return self.model(input_tensors).cpu() - - @torch.jit.script_method - def forward( - self, - texts: List[str], - dense_feat: Optional[List[List[float]]] = None, - ) -> torch.Tensor: + def forward(self, texts: List[str]): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) - return self._forward(inputs) + input_tensors = self.tensorizer(inputs) + logits = self.model(input_tensors) + return self.output_layer(logits) @torch.jit.script_method def make_prediction( @@ -1121,46 +1024,15 @@ def make_prediction( batch: List[ Tuple[ List[str], # texts - Optional[List[List[float]]], # dense_feat ] ], ) -> List[torch.Tensor]: - batchsize = len(batch) - - client_batch: List[int] = [] - res_list: List[torch.Tensor] = [] - - flat_texts: List[str] = [] - - for i in range(batchsize): - batch_element = batch[i][0] - flat_texts.extend(batch_element) - client_batch.append(len(batch_element)) - - if batch[i][1] is not None: - # Cross-request batching not yet supported for requests with dense_feat - raise RuntimeError("Malformed request.") - - if len(flat_texts) == 0: - raise RuntimeError("This is not good. Empty request batch.") - flat_result: torch.Tensor = self.forward( - texts=flat_texts, - dense_feat=None, + texts=make_prediction_texts(batch), ) - # destructure flat result tensor combining - # cross-request batches and client side - # batches into a cross-request list of - # client-side batch tensors - start = 0 - for elems in client_batch: - end = start + elems - res_list.append(flat_result.narrow(0, start, elems)) - start = end - - return res_list + return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( @@ -1168,53 +1040,13 @@ def make_batch( mega_batch: List[ Tuple[ List[str], # texts - Optional[List[List[float]]], # dense_feat int, ] ], goals: Dict[str, str], - ) -> List[ - List[ - Tuple[ - List[str], # texts - Optional[List[List[float]]], # dense_feat - int, - ] - ] - ]: - - # The next lines sort all cross-request batch elements by the token length. - # Note that cross-request batch element can in turn be a client batch. - mega_batch_key_list = [ - (max_tokens(self.tensorizer.tokenize(x[0], None)), n) - for (n, x) in enumerate(mega_batch) - ] - sorted_mega_batch_key_list = sorted(mega_batch_key_list) - sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] - - # TBD: allow model server to specify batch size in goals dictionary - max_bs: int = 10 - len_mb = len(mega_batch) - num_batches = (len_mb + max_bs - 1) // max_bs - - batch_list: List[ - List[ - Tuple[ - List[str], # texts - Optional[List[List[float]]], # dense_feat - int, # position - ] - ] - ] = [] + ) -> List[List[Tuple[List[str], int,]]]: # texts - start = 0 - - for _i in range(num_batches): - end = min(start + max_bs, len_mb) - batch_list.append(sorted_mega_batch[start:end]) - start = end - - return batch_list + return make_batch_texts(self.tensorizer, mega_batch, goals) class PyTextEmbeddingModuleIndex(PyTextEmbeddingModule): @@ -1258,10 +1090,8 @@ def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): def forward( self, texts: List[str], - dense_feat: Optional[List[List[float]]] = None, + dense_feat: List[List[float]], ) -> torch.Tensor: - if dense_feat is None: - raise RuntimeError("Expect dense feature.") inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), @@ -1278,17 +1108,116 @@ def forward( else: return sentence_embedding - -class PyTextEmbeddingModuleWithDenseIndex(PyTextEmbeddingModuleWithDense): - def __init__( + @torch.jit.script_method + def make_prediction( self, - model: torch.jit.ScriptModule, - tensorizer: ScriptTensorizer, - normalizer: VectorNormalizer, - index: int = 0, - concat_dense: bool = True, - ): - super().__init__(model, tensorizer, normalizer, concat_dense) + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts_wdense(batch), + dense_feat=make_prediction_wtexts_dense(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + return make_batch_texts_dense(self.tensorizer, mega_batch, goals) + + +class PyTextModuleWithDense(PyTextModule): + def __init__( + self, + model: torch.jit.ScriptModule, + output_layer: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + normalizer: VectorNormalizer, + ): + super().__init__(model, output_layer, tensorizer) + self.normalizer = normalizer + log_class_usage(self.__class__) + + @torch.jit.script_method + def forward( + self, + texts: List[str], + dense_feat: List[List[float]], + ): + inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(texts, None), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + input_tensors = self.tensorizer(inputs) + dense_feat = self.normalizer.normalize(dense_feat) + + dense_tensor = torch.tensor(dense_feat, dtype=torch.float) + if self.tensorizer.device != "": + dense_tensor = dense_tensor.to(self.tensorizer.device) + logits = self.model(input_tensors, dense_tensor) + return self.output_layer(logits) + + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + ] + ], + ) -> List[torch.Tensor]: + + flat_result: torch.Tensor = self.forward( + texts=make_prediction_texts_wdense(batch), + dense_feat=make_prediction_wtexts_dense(batch), + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # texts + List[List[float]], # dense + int, + ] + ], + goals: Dict[str, str], + ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense + + return make_batch_texts_dense(self.tensorizer, mega_batch, goals) + + +class PyTextEmbeddingModuleWithDenseIndex(PyTextEmbeddingModuleWithDense): + def __init__( + self, + model: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + normalizer: VectorNormalizer, + index: int = 0, + concat_dense: bool = True, + ): + super().__init__(model, tensorizer, normalizer, concat_dense) self.index = torch.jit.Attribute(index, int) log_class_usage(self.__class__) @@ -1323,7 +1252,8 @@ def _forward(self, inputs: ScriptBatchInput): def forward( self, texts: List[str], - dense_feat: Optional[List[List[float]]] = None, + # the following is ignored by forward implementation. drop it + # dense_feat: Optional[List[List[float]]] = None, ) -> List[torch.Tensor]: inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), @@ -1338,64 +1268,26 @@ def make_prediction( batch: List[ Tuple[ List[str], # texts - Optional[List[List[float]]], # dense_feat ] ], ) -> List[List[torch.Tensor]]: - client_batch: List[int] = [] - res_list: List[List[torch.Tensor]] = [] - - flat_texts: List[str] = [] - - for be in batch: - batch_element = be[0] - if batch_element is not None: - flat_texts.extend(batch_element) - client_batch.append(len(batch_element)) - - if be[1] is not None: - raise RuntimeError( - "Malformed request: desnse_feat not supported for cross-request batching in VE Module" - ) - - if len(flat_texts) == 0: - raise RuntimeError("This is not good. Empty request batch.") - flat_result: List[torch.Tensor] = self.forward( - texts=flat_texts, - multi_texts=None, - tokens=None, - languages=None, - dense_feat=None, + texts=make_prediction_texts(batch), ) - # destructure flat result list combining - # cross-request batches and client side - # batches into a cross-request list of - # client-side batch result lists - start = 0 - for elems in client_batch: - end = start + elems - res_list.append(flat_result[start:end]) - start = end + return destructure_tensor_list([len(be[0]) for be in batch], flat_result) - return res_list +############################################ +# PytextTwoTower Classes: -class PyTextTwoTowerEmbeddingModule(PyTextTwoTowerModule): - def __init__( - self, - model: torch.jit.ScriptModule, - right_tensorizer: ScriptTensorizer, - left_tensorizer: ScriptTensorizer, - ): - super().__init__() - self.model = model - self.right_tensorizer = right_tensorizer - self.left_tensorizer = left_tensorizer - self.argno = -1 - log_class_usage(self.__class__) + +class PyTextTwoTowerBaseModule(torch.jit.ScriptModule): + @torch.jit.script_method + def set_device(self, device: str): + self.right_tensorizer.set_device(device) + self.left_tensorizer.set_device(device) @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): @@ -1407,31 +1299,6 @@ def set_padding_control(self, dimension: str, control: Optional[List[int]]): self.right_tensorizer.set_padding_control(dimension, control) self.left_tensorizer.set_padding_control(dimension, control) - @torch.jit.script_method - def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): - right_input_tensors = self.right_tensorizer(right_inputs) - left_input_tensors = self.left_tensorizer(left_inputs) - - return self.model(right_input_tensors, left_input_tensors).cpu() - - @torch.jit.script_method - def forward( - self, - right_texts: List[str], - left_texts: List[str], - ) -> torch.Tensor: - right_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(right_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - left_inputs: ScriptBatchInput = ScriptBatchInput( - texts=resolve_texts(left_texts), - tokens=squeeze_2d(None), - languages=squeeze_1d(None), - ) - return self._forward(right_inputs, left_inputs) - @torch.jit.script_method def make_prediction( self, @@ -1439,17 +1306,12 @@ def make_prediction( Tuple[ List[str], # right_texts List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat ] ], ) -> List[torch.Tensor]: batchsize = len(batch) - client_batch: List[int] = [] - res_list: List[torch.Tensor] = [] - flat_right_texts: List[str] = [] flat_left_texts: List[str] = [] @@ -1459,31 +1321,13 @@ def make_prediction( flat_right_texts.extend(batch_right_element) flat_left_texts.extend(batch_left_element) - client_batch.append(len(batch_right_element)) - - if batch[i][2] is not None or batch[i][3] is not None: - raise RuntimeError( - "Cross-request batching not supported for desne_feat" - ) - - flat_result: torch.Tensor = self.forward( - right_texts=flat_right_texts, - left_texts=flat_left_texts, - right_dense_feat=None, - left_dense_feat=None, - ) - # destructure flat result tensor combining - # cross-request batches and client side - # batches into a cross-request list of - # client-side batch tensors - start = 0 - for elems in client_batch: - end = start + elems - res_list.append(flat_result.narrow(0, start, elems)) - start = end + flat_result: torch.Tensor = self.forward( + right_texts=flat_right_texts, + left_texts=flat_left_texts, + ) - return res_list + return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( @@ -1492,23 +1336,11 @@ def make_batch( Tuple[ List[str], # right_texts List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat int, ] ], goals: Dict[str, str], - ) -> List[ - List[ - Tuple[ - List[str], # right_texts - List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat - int, - ] - ] - ]: + ) -> List[List[Tuple[List[str], List[str], int,]]]: # right_texts # left_texts # The next lines sort all cross-request batch elements by the token length of right_. # Note that cross-request batch element can in turn be a client batch. @@ -1529,8 +1361,6 @@ def make_batch( Tuple[ List[str], # right_texts List[str], # left_texts - Optional[List[List[float]]], # right_dense_feat - Optional[List[List[float]]], # left_dense_feat int, # position ] ] @@ -1546,6 +1376,81 @@ def make_batch( return batch_list +class PyTextTwoTowerEmbeddingModule(PyTextTwoTowerBaseModule): + def __init__( + self, + model: torch.jit.ScriptModule, + right_tensorizer: ScriptTensorizer, + left_tensorizer: ScriptTensorizer, + ): + super().__init__() + self.model = model + self.right_tensorizer = right_tensorizer + self.left_tensorizer = left_tensorizer + log_class_usage(self.__class__) + + @torch.jit.script_method + def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): + right_input_tensors = self.right_tensorizer(right_inputs) + left_input_tensors = self.left_tensorizer(left_inputs) + + return self.model(right_input_tensors, left_input_tensors).cpu() + + @torch.jit.script_method + def forward( + self, + right_texts: List[str], + left_texts: List[str], + ) -> torch.Tensor: + right_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(right_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + left_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(left_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + return self._forward(right_inputs, left_inputs) + + +class PyTextTwoTowerModule(PyTextTwoTowerBaseModule): + def __init__( + self, + model: torch.jit.ScriptModule, + output_layer: torch.jit.ScriptModule, + right_tensorizer: ScriptTensorizer, + left_tensorizer: ScriptTensorizer, + ): + super().__init__() + self.model = model + self.output_layer = output_layer + self.right_tensorizer = right_tensorizer + self.left_tensorizer = left_tensorizer + + @torch.jit.script_method + def forward( + self, + right_texts: List[str], + left_texts: List[str], + ): + right_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(right_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + right_input_tensors = self.right_tensorizer(right_inputs) + left_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(left_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + left_input_tensors = self.left_tensorizer(left_inputs) + logits = self.model(right_input_tensors, left_input_tensors) + return self.output_layer(logits) + + class PyTextTwoTowerEmbeddingModuleWithDense(PyTextTwoTowerEmbeddingModule): def __init__( self, @@ -1588,11 +1493,9 @@ def forward( self, right_texts: List[str], left_texts: List[str], - right_dense_feat: Optional[List[List[float]]] = None, - left_dense_feat: Optional[List[List[float]]] = None, + right_dense_feat: List[List[float]], + left_dense_feat: List[List[float]], ) -> torch.Tensor: - if right_dense_feat is None or left_dense_feat is None: - raise RuntimeError("Expect dense feature.") right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), @@ -1614,3 +1517,156 @@ def forward( right_inputs, left_inputs, right_dense_tensor, left_dense_tensor ) return sentence_embedding + + +class PyTextTwoTowerModuleWithDense(PyTextTwoTowerModule): + def __init__( + self, + model: torch.jit.ScriptModule, + output_layer: torch.jit.ScriptModule, + right_tensorizer: ScriptTensorizer, + left_tensorizer: ScriptTensorizer, + right_normalizer: VectorNormalizer, + left_normalizer: VectorNormalizer, + ): + super().__init__(model, output_layer, right_tensorizer, left_tensorizer) + self.right_normalizer = right_normalizer + self.left_normalizer = left_normalizer + + @torch.jit.script_method + def forward( + self, + right_texts: List[str], + left_texts: List[str], + right_dense_feat: List[List[float]], + left_dense_feat: List[List[float]], + ): + right_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(right_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + right_input_tensors = self.right_tensorizer(right_inputs) + left_inputs: ScriptBatchInput = ScriptBatchInput( + texts=resolve_texts(left_texts), + tokens=squeeze_2d(None), + languages=squeeze_1d(None), + ) + left_input_tensors = self.left_tensorizer(left_inputs) + + right_dense_feat = self.right_normalizer.normalize(right_dense_feat) + left_dense_feat = self.left_normalizer.normalize(left_dense_feat) + right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) + left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) + if self.right_tensorizer.device != "": + right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) + if self.left_tensorizer.device != "": + left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) + + logits = self.model( + right_input_tensors, + left_input_tensors, + right_dense_tensor, + left_dense_tensor, + ) + return self.output_layer(logits) + + @torch.jit.script_method + def make_prediction( + self, + batch: List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat + ] + ], + ) -> List[torch.Tensor]: + + batchsize = len(batch) + + flat_right_texts: List[str] = [] + flat_left_texts: List[str] = [] + flat_right_dense: List[List[float]] = [] + flat_left_dense: List[List[float]] = [] + + for i in range(batchsize): + batch_right_element = batch[i][0] + batch_left_element = batch[i][1] + batch_right_dense_element = batch[i][2] + batch_left_dense_element = batch[i][3] + + flat_right_texts.extend(batch_right_element) + flat_left_texts.extend(batch_left_element) + flat_right_dense.extend(batch_right_dense_element) + flat_left_dense.extend(batch_left_dense_element) + + flat_result: torch.Tensor = self.forward( + right_texts=flat_right_texts, + left_texts=flat_left_texts, + right_dense_feat=flat_right_dense, + left_dense_feat=flat_left_dense, + ) + + return destructure_tensor([len(be[0]) for be in batch], flat_result) + + @torch.jit.script_method + def make_batch( + self, + mega_batch: List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat + int, + ] + ], + goals: Dict[str, str], + ) -> List[ + List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat + int, + ] + ] + ]: + + # The next lines sort all cross-request batch elements by the token length of right_. + # Note that cross-request batch element can in turn be a client batch. + mega_batch_key_list = [ + (max_tokens(self.right_tensorizer.tokenize(x[0], None)), n) + for (n, x) in enumerate(mega_batch) + ] + sorted_mega_batch_key_list = sorted(mega_batch_key_list) + sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] + + # TBD: allow model server to specify batch size in goals dictionary + max_bs: int = 10 + len_mb = len(mega_batch) + num_batches = (len_mb + max_bs - 1) // max_bs + + batch_list: List[ + List[ + Tuple[ + List[str], # right_texts + List[str], # left_texts + List[List[float]], # right_dense_feat + List[List[float]], # left_dense_feat + int, # position + ] + ] + ] = [] + + start = 0 + + for _i in range(num_batches): + end = min(start + max_bs, len_mb) + batch_list.append(sorted_mega_batch[start:end]) + start = end + + return batch_list