diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py index 67a32d80a..d395a1521 100644 --- a/fastchat/serve/cli.py +++ b/fastchat/serve/cli.py @@ -13,6 +13,7 @@ import os import re import sys +import ssl from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory @@ -22,29 +23,57 @@ from rich.console import Console from rich.live import Live from rich.markdown import Markdown +import asyncio +import websockets from fastchat.model.model_adapter import add_model_args from fastchat.modules.gptq import GptqConfig from fastchat.serve.inference import ChatIO, chat_loop + + class SimpleChatIO(ChatIO): - def __init__(self, multiline: bool = False): + def __init__(self, websocket, multiline: bool = False): self._multiline = multiline + self.websocket = websocket - def prompt_for_input(self, role) -> str: - if not self._multiline: - return input(f"{role}: ") + async def chat_websocket_client(self)-> str: + # URI to Point to the Proxy WebSocket Gin Server so that it can relay the message: + uri = "wss://35.209.170.184:8080/ws" + async with websockets.connect(uri, ssl=ssl.SSLContext()) as websocket: + # while True: + try: + user_input = await websocket.recv() + response = "User Input we got" + user_input + print(response) + print(user_input) + await websocket.send(response) + task = asyncio.create_task(self.prompt_for_input("USER")) + return user_input + except websockets.exceptions.ConnectionClosed: + print("Connection Closed") + + + + + + async def prompt_for_input(self, role) -> str: + # if not self._multiline: + # return input(f"{role}: ") prompt_data = [] - line = input(f"{role} [ctrl-d/z on empty line to end]: ") - while True: + # line = input(f"{role} [ctrl-d/z on empty line to end]: ") + # line = self.receive_input_from_websocket(role + f" [ctrl-d/z on empty line to end]: ") + line = await self.chat_websocket_client() + while line: prompt_data.append(line.strip()) try: - line = input() + line = self.chat_websocket_client() except EOFError as e: break return "\n".join(prompt_data) + def prompt_for_output(self, role: str): print(f"{role}: ", end="", flush=True) @@ -160,6 +189,8 @@ def stream_output(self, output_stream): def main(args): + # First of all we retrieve the Event Loop + # asyncio.get_event_loop().run_until_complete(SimpleChatIO.chat_websocket_client()) if args.gpus: if len(args.gpus.split(",")) < args.num_gpus: raise ValueError( @@ -169,7 +200,10 @@ def main(args): os.environ["XPU_VISIBLE_DEVICES"] = args.gpus if args.style == "simple": - chatio = SimpleChatIO(args.multiline) + websocket = websockets.connect("wss://35.209.170.184:8080/ws") + chatio = SimpleChatIO(websocket,args.multiline) + # asyncio.get_event_loop().run_until_complete(SimpleChatIO.chat_websocket_client()) + asyncio.run(chatio.chat_websocket_client()) elif args.style == "rich": chatio = RichChatIO(args.multiline, args.mouse) elif args.style == "programmatic": @@ -180,6 +214,7 @@ def main(args): chatio = FileInputChatIO(root_input_file_path) else: raise ValueError(f"Invalid style for console: {args.style}") + try: chat_loop( args.model_path,