import base64
import json
import os
import subprocess
import threading
from typing import Type, Union
import traceback
import requests
import streamlit as st
import streamlit.components.v1 as components
import validators
from streamlit_pyvista.helpers.cache import DEFAULT_CACHE_DIR
from streamlit_pyvista.helpers.streamlit_pyvista_logging import root_logger
from streamlit_pyvista.helpers.utils import is_localhost, is_web_link, replace_host, is_server_alive, is_notebook, \
wait_for_server_alive
from streamlit_pyvista.trame_viewers.trame_viewer import get_default_viewer_path
from .message_interface import ServerMessageInterface, EndpointsInterface
from .server_managers import ServerManagerBase, ServerManager
[docs]
class MeshViewerComponent:
"""Streamlit component to display 3d mesh using pyvista and it's Trame backend"""
[docs]
def __init__(self, mesh_path: Union[str, list[str]] = None,
server_manager_url: str = "http://127.0.0.1:9422",
setup_endpoint: str = EndpointsInterface.InitConnection,
server_manager_class: Type[ServerManagerBase] = ServerManager,
trame_viewer_class: str = None):
# If the user specified only the path of one file, transform it to an element in a list
if isinstance(mesh_path, str):
mesh_path = [mesh_path]
self.manager = server_manager_class
if trame_viewer_class is None:
self.viewer = get_default_viewer_path()
else:
self.viewer = trame_viewer_class
if not self._set_mesh_attributes(mesh_path):
return
self.width = 1200
self.height = 1000
self.default_port = 9422
self.server_url = server_manager_url
self.server_timeout = 12
self.error_during_setup = None
self.nbr_max_launch_attempt = 3
# Set all attribute related to the dynamic endpoints settings.
# Set the default required endpoints,
# select mesh is used to ask the server to show a specific mesh and host is the host of the data rendering
self.required_endpoints = [ServerMessageInterface.Keys.SelectMesh,
ServerMessageInterface.Keys.UploadMesh,
ServerMessageInterface.Keys.Host]
# Dict that will contained value received for our endpoints. Init connection is the default endpoint to
# request the server to give use all it's required endpoints
self.endpoints = {
ServerMessageInterface.Keys.InitConnection: setup_endpoint,
ServerMessageInterface.Keys.Host: self.server_url
}
# If the default server url is on localhost we launch the server manager locally
if is_localhost(self.server_url):
root_logger.debug(f"Server Manager url ({self.server_url}), a local instance is launched")
self._setup_server()
# Set up the endpoints
if not self._setup_endpoints():
root_logger.error("Couldn't setup the endpoints with the Trame server")
return
self.set_mesh()
root_logger.info("MeshViewer Created")
def _set_mesh_attributes(self, mesh_path: list[str]) -> bool:
if not self.check_valid_input_files(mesh_path):
self.endpoints = None
return False
self.mesh_path = mesh_path
self.sequence_size = len(mesh_path)
return True
def _setup_server(self):
"""
Launch a local server using python subprocess on another thread. If a Trame server isn't already running
"""
if is_server_alive(self.server_url, self.server_timeout):
return
trame_viewer_thread = threading.Thread(target=self._run_server_manager)
trame_viewer_thread.start()
root_logger.info("Local Server Manager is launched")
root_logger.debug(
f"Server Manager logs and trame server logs are available in a \
log file in {os.path.join(os.getcwd(), DEFAULT_CACHE_DIR)}")
def _setup_endpoints(self):
"""
Fill the endpoints dictionary with the info received from the server
Returns:
bool: True if the setup was successful, False otherwise.
"""
# If the server was launched locally, we need to wait for it to be up
wait_for_server_alive(self.endpoints[ServerMessageInterface.Keys.Host], self.server_timeout)
with open(self.viewer, "rb") as f:
c = f.read()
base64_bytes = base64.b64encode(c)
request_body = {ServerMessageInterface.Keys.Viewer: base64_bytes.decode('utf-8')}
del self.endpoints[ServerMessageInterface.Keys.Host]
try:
res = requests.get(self.server_url + self.endpoints[ServerMessageInterface.Keys.InitConnection],
json=json.dumps(request_body), timeout=self.server_timeout)
json_res = res.json()
except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
root_logger.error(f"Failed to establish a connection with the server {self.server_url}")
return False
except json.JSONDecodeError:
root_logger.error("The init_connection request should have a json object as response")
return False
if res.status_code != 200:
self.error_during_setup = json_res
return False
root_logger.debug(f"The endpoint received by the server are the following : {json_res}")
# Check that all necessary endpoints where given in the request and fill the endpoints Dict
for endpoint in self.required_endpoints:
if endpoint not in json_res:
root_logger.error(f"The endpoint {endpoint} was not specified by the server")
self.error_during_setup = f"The endpoint {endpoint} was not specified by the server"
return False
self.endpoints[endpoint] = json_res[endpoint]
return True
def _run_server_manager(self, attempt=0):
"""
Launch a Trame server using python subprocess
"""
try:
subprocess.run(
["python3", self.manager.get_launch_path(), "--port", str(self.default_port)],
capture_output=True,
text=True,
check=True
)
except subprocess.CalledProcessError as e:
error_message = f"""
Command '{e.cmd}' returned non-zero exit status {e.returncode}.
STDOUT:
{e.stdout}
STDERR:
{e.stderr}
Python Traceback:
{traceback.format_exc()}
"""
root_logger.error("Tried to launch the Server Manager but got an unexpected error")
root_logger.debug(f"Failed with the following error: {error_message}")
if attempt < self.nbr_max_launch_attempt:
self._run_server_manager(attempt+1)
[docs]
def set_mesh(self, meshes: Union[list[str], str] = None):
"""
Set the mesh viewed on the server by making a request.
Args:
meshes (list[str]|str): List of paths to the meshes to display.
"""
if meshes is None:
meshes = self.mesh_path
else:
if isinstance(meshes, str):
meshes = [meshes]
if not self._set_mesh_attributes(meshes):
return
if ServerMessageInterface.Keys.SelectMesh not in self.endpoints:
return
url = self.endpoints[ServerMessageInterface.Keys.Host] + self.endpoints[ServerMessageInterface.Keys.SelectMesh]
data = {
ServerMessageInterface.ReqSetMesh.MeshList: meshes,
ServerMessageInterface.ReqSetMesh.NbrFrames: len(meshes),
ServerMessageInterface.ReqSetMesh.Width: self.width,
ServerMessageInterface.ReqSetMesh.Height: self.height
}
try:
headers = {"Content-Type": "application/json"}
# Check in the response if any action is necessary such as make the iframe bigger or uploading files
response = requests.get(url, data=json.dumps(data), headers=headers, timeout=self.server_timeout)
resp_body = response.json()
if response.status_code == 400:
if ServerMessageInterface.RespSetMesh.Error in resp_body:
self.error_during_setup = resp_body
root_logger.error(resp_body[ServerMessageInterface.RespSetMesh.Error])
else:
if ServerMessageInterface.RespSetMesh.RequestSpace in resp_body:
self.height = resp_body[ServerMessageInterface.RespSetMesh.RequestSpace]
elif ServerMessageInterface.RespSetMesh.RequestFiles in resp_body:
missing_files = resp_body[ServerMessageInterface.RespSetMesh.RequestFiles]
root_logger.debug(
f"Trame server {self.endpoints[ServerMessageInterface.Keys.Host]} requested the following "
f"files : {missing_files}")
self._send_missing_files(missing_files)
except requests.exceptions.JSONDecodeError:
return
except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError):
self.error_during_setup = (f"Couldn't connect to the server with url {url}, either the server "
f"or the proxy crashed")
def _send_missing_files(self, missing_files: list[str]):
"""
If the trame server cannot access every file that he was asked to display then we need to upload them to him
Args:
missing_files(list[str]): The list of file paths that couldn't be access by the server and that needs to be\
sent to the server
"""
for file, index in missing_files:
request_body = {}
if file in self.mesh_path:
with open(file, 'rb') as f:
content = f.read()
base64_bytes = base64.b64encode(content)
request_body[file] = (base64_bytes.decode('utf-8'), index)
url = (self.endpoints[ServerMessageInterface.Keys.Host] +
self.endpoints[ServerMessageInterface.Keys.UploadMesh])
headers = {"Content-Type": "application/json"}
root_logger.debug(f"Sent file with path {file} to {self.endpoints['host']}")
requests.post(url, data=json.dumps(request_body), headers=headers, timeout=self.server_timeout)
[docs]
def show(self):
""" Render the streamlit component """
from . import REMOTE_HOST
# The only scenario that leads to endpoints = None is if one of the file is not valid
if self.endpoints is None:
return st.error("Some files passed as argument does not exists")
if self.error_during_setup is not None:
return st.error(self.error_during_setup)
headers = st.context.headers
host = headers.get("Host") if not REMOTE_HOST and headers else REMOTE_HOST
if host and not is_localhost(host) and is_localhost(self.server_url):
root_logger.debug(
f"The host that the iframe should have is {host} and it "
f"actually has {self.endpoints[ServerMessageInterface.Keys.Host]}")
iframe_host = replace_host(self.endpoints[ServerMessageInterface.Keys.Host], host)
else:
root_logger.debug(
f"The host that the iframe should have is {self.endpoints[ServerMessageInterface.Keys.Host]}")
iframe_host = self.endpoints[ServerMessageInterface.Keys.Host]
url = iframe_host + "/index.html"
if is_notebook():
from IPython.display import IFrame, display
iframe = IFrame(src=url, width='100%', height=self.height)
display(iframe)
else:
return components.iframe(url, height=self.height)