Source code for streamlit_pyvista.trame_viewers.trame_backend

import asyncio
import base64
import os
import signal
import threading
import time
from abc import ABC, abstractmethod
from types import FrameType
from typing import Optional

import numpy as np
import pyvista as pv
import validators
from aiohttp import web
from trame.app import Server
from trame.app import get_server
from trame.ui.vuetify3 import VAppLayout

from streamlit_pyvista.helpers.cache import (save_mesh_content, DEFAULT_CACHE_DIR)
from streamlit_pyvista.helpers.streamlit_pyvista_logging import root_logger
from streamlit_pyvista.helpers.utils import is_web_link
from streamlit_pyvista.lazy_mesh import LazyMesh, LazyMeshList
from streamlit_pyvista.message_interface import ServerMessageInterface, EndpointsInterface

SECOND = 1
ONE_MINUTE = 60 * SECOND


[docs] class TrameBackend(ABC): """ A Trame server class that manage the view of a 3d mesh and its controls """
[docs] def __init__(self, plotter: Optional[pv.Plotter] = None, server: Optional[Server] = None, port: int = 8080, host: str = "0.0.0.0"): """ Initialize the trame server backend. Args: plotter (pv.Plotter, optional): The plotter object for visualization. Defaults to None. server (Server, optional): The server object for handling client connections. Defaults to None. port (int, optional): The port number for the server. Defaults to 8080. host (str, optional): The host address for the server. Defaults to "0.0.0.0". """ self.shutdown_event = asyncio.Event() pv.OFF_SCREEN = True self.host = host self.port = port # Get a server if none was passed if server is None: self.server = get_server(port=self.port) else: self.server = server # Set mesh related attributes self.paths = None self.current_mesh = None self.warp_free_mesh = None self.cache_path = DEFAULT_CACHE_DIR # Set style attributes self.plotter_style = { "background-color": "black", "font-color": "white" } # Create a plotter and attributes related to it pl = self._setup_pl() self.pl = pl if plotter is None else plotter # Setup server lifecycle callback functions setattr(self, "on_server_bind", self.server.controller.add("on_server_bind")(self.on_server_bind)) setattr(self, "on_client_exited", self.server.controller.add("on_client_exited")(self.on_client_exited)) setattr(self, "on_client_connected", self.server.controller.add("on_client_connected")(self.on_client_connected)) setattr(self, "on_server_exited", self.server.controller.add("on_server_exited")(self.on_server_exited)) # Init the client counter to 1 at start to avoid the server to be concurrently defined as available. # The counter is decremented by 1 after 3 seconds self.client_counter = 1 threading.Timer(3 * SECOND, self._client_counter_cb).start() threading.Timer(5 * ONE_MINUTE, self.request_stop).start() # Setup api endpoints self.api_routes = [ web.get(EndpointsInterface.SelectMesh, self.change_mesh), web.get(EndpointsInterface.InitConnection, self.init_connection), web.post(EndpointsInterface.UploadMesh, self.upload_mesh), web.get(EndpointsInterface.ClientsNumber, self.client_number), web.get(EndpointsInterface.KillServer, self.kill_server), ] self.mesh_missing = None self.sequence_bounds = [0, 0] # Set state variables that need to exist before the ui is built self._setup_state() self.mesh_array = None self.width = 800 self.height = 900 self.controller_height = 450 self.ui = self._build_ui()
[docs] async def client_number(self, request): return web.json_response({ServerMessageInterface.Keys.NumberClients: self.client_counter}, status=200)
def _client_counter_cb(self): """ Decrements the client counter by 1. This method is called to update the client counter when a client disconnects. """ self.client_counter -= 1
[docs] async def kill_server(self, request): """ Stops the server and returns a JSON response indicating success. Args: request: The request object. Returns: A JSON response with a success message and a status code of 200. """ root_logger.error(f"Try to kill the server {self.server.port}") time.sleep(1) asyncio.get_running_loop().call_soon(asyncio.create_task, self.request_stop(force_stop=True)) root_logger.error(f"Successfully called the stop for server {self.server.port}") return web.json_response({ServerMessageInterface.Keys.Success: f"Trame Server {self.host}:{self.port} killed"}, status=200)
def _setup_pl(self) -> pv.Plotter: """ Set up the plotter with the specified styles and return it. Returns: pv.Plotter: The configured plotter object. """ # Create the plotter and add its styles pl = pv.Plotter() pl.background_color = self.plotter_style["background-color"] pl.theme.font.color = self.plotter_style["font-color"] self.bounds_scalar = None self.scalar_bar_mapper = None return pl @abstractmethod def _setup_state(self): """ Set up all the state variables to initial values """ pass @property def state(self): return self.server.state @property def ctrl(self): return self.server.controller def _update_mesh_displayed_from_index(self, idx: int): """ Update the mesh displayed in the plotter using its index in the sequence Args: idx (int): Index of the mesh to show """ if self.mesh_array is not None: if idx < self.sequence_bounds[1]: self.warp_free_mesh = self.mesh_array[idx] self._replace_mesh_displayed(self.mesh_array[idx]) def _handle_new_mesh_list(self, mesh_list: list[str]) -> list[tuple[str, int]]: """ This function handles the loading of new mesh in the server Args: mesh_list (List[str]): the paths of the mesh Returns: List[Tuple[str, int]]: a list of mesh that couldn't be loaded with only their path or link """ self.mesh_array = LazyMeshList() missing_mesh = [] # If the mesh is a sequence, then format its paths and load all element in the mesh array for i, path in enumerate(mesh_list): target_path = f"{self.cache_path}/{path.split('/')[-1].split('?')[0]}" # If the path is a link, call function to cache download and store the mesh if is_web_link(path): if not validators.url(path): root_logger.error( f"Trame server running on {self.host}:{self.server.port}: The link {path} is not valid") self.mesh_array.append(None) continue elif not os.path.exists(path): # If the file does not exist mark it as missing to notify it in the response missing_mesh.append((path, i)) self.mesh_array.append(None) continue self.paths[i] = path self.mesh_array.append(LazyMesh(path, target_path)) return missing_mesh
[docs] async def change_mesh(self, request) -> web.Response: """ This function is called when a request to '/select_mesh' is made Args: request: the request received Returns: web.Response: a http status 200 if there was no error, else a http status 400 Note: This function require the request received to have a json body with the following fields: - mesh_list: the paths (or the link) of the mesh to load - width: the width of the plotter - height: the height of the plotter - nbr_frames: the number of frames in the sequence """ request_body = await request.json() # Retrieve information from the request self.paths = request_body.get(ServerMessageInterface.ReqSetMesh.MeshList, None) self.width = request_body.get(ServerMessageInterface.ReqSetMesh.Width, self.width) self.height = request_body.get(ServerMessageInterface.ReqSetMesh.Height, self.height) self.sequence_bounds[1] = request_body.get(ServerMessageInterface.ReqSetMesh.NbrFrames, self.sequence_bounds[1]) if self.paths is None: root_logger.error( f"Trame server running on {self.host}:{self.server.port}: No filepath found in the change mesh request") return web.json_response({"error": "No filepath found in the change mesh request"}, status=400) # Reset the viewer to an empty state self._clear_viewer() # Get the mesh and prepare it to be displayed self.mesh_missing = self._handle_new_mesh_list(self.paths) if len(self.mesh_missing) > 0: root_logger.info(f"Missing mesh: {self.mesh_missing}, request made to client") return web.json_response({ServerMessageInterface.RespSetMesh.RequestFiles: self.mesh_missing}, status=200) self._update_viewer_for_new_meshes() # If the height allocated by the streamlit component, ask for more space in the response of the request response_body = {} return web.json_response(response_body, status=200)
def _fill_option_arrays(self): """ Fills the option arrays for the Trame backend. This method prepares UI elements that depend on the mesh by populating the option arrays. It filters out options that start with "vtk" and inserts "None" as the first option. Returns: None """ new_options = self.mesh_array[0].array_names.copy() new_options = list(filter(lambda x: not x.startswith("vtk"), new_options)) self.state.options = new_options self.state.options.insert(0, "None") self.state.options_warp = new_options def _update_viewer_for_new_meshes(self): """ Handles a new mesh request by replacing the current mesh with the first mesh in the mesh array. Updates UI elements that depend on the mesh and shows the new mesh in the viewers and its controls. """ self._update_mesh_displayed_from_index(0) self.pl.reset_camera() self._fill_option_arrays() # Show the new mesh in the viewers and its controls self._computes_bounds_scalar() self.ui = self._build_ui()
[docs] async def upload_mesh(self, request) -> web.Response: """ This function is called when a request to '/upload_mesh' is made Args: request: The request object containing the mesh data. Returns: web.Response:A JSON response indicating the success of the upload. """ request_body = await request.json() for key, (encoded_content, index) in request_body.items(): content = base64.b64decode(encoded_content) loc = save_mesh_content(content, f"{self.cache_path}/{key}") self.mesh_array[index] = LazyMesh(loc[0], loc[1]) self.mesh_missing.remove((key, index)) if self.mesh_missing is None or len(self.mesh_missing) == 0: self._update_viewer_for_new_meshes() return web.json_response({ServerMessageInterface.Keys.Success: "Mesh uploaded successfully"}, status=200)
def _compute_field_interval(self, field: str = None) -> tuple[float, float]: """ Compute the min and max of a field of vector over all it's frame ot get the all-time min and max to get the upper and lower bound of the scalar bar. Args: field (str): the field you want to compute the bounds Returns: Tuple[float, float]: it returns a tuple with the min and max """ # If the field is None get the default field on which to compute the min and max if field is None: field = self.state.mesh_representation if field is None or field == "None": field = self.state.options[1] # Loop over all the images and find the max of the array and the min max_bound = -np.inf min_bound = np.inf for i in range(self.sequence_bounds[1]): try: arr = self.mesh_array[i].get_array(field) except KeyError: root_logger.error( f"Trame server running on {self.host}:{self.server.port}: KeyError, field{field} does not exists") continue if len(arr) == 0 or isinstance(arr[0], str): continue l_max = arr.max() l_min = arr.min() if l_max > max_bound: max_bound = l_max if l_min < min_bound: min_bound = l_min return min_bound, max_bound def _computes_bounds_scalar(self): """ Compute the bounds of all the scalars of the mesh and store it in an attribute to avoid doing all the computation everytime a bar is shown """ if self.state.options is None: return # Store bounds and mapper for all the fields available except "None" which is the first one of the options array self.bounds_scalar = {} # We don't need to take the first option since we manually added it earlier with a `None` for field in self.state.options[1:]: self.bounds_scalar[field] = self._compute_field_interval(field) @abstractmethod def _replace_mesh_displayed(self, new_mesh: pv.DataSet): """ Change the mesh displayed in the plotter and its related data Args: new_mesh (pv.DataSet): the new mesh to display """ pass @abstractmethod def _clear_viewer(self): """ Reset the viewer and its related attribute to an empty viewer """ self.bounds_scalar = None self.state.mesh_representation = None @abstractmethod def _build_ui(self) -> VAppLayout: """ Build all the ui frontend with all different components Returns: VAppLayout: a VAppLayout for the server """ pass
[docs] def on_server_bind(self, wslink_server): """ When the server is bind, add api endpoint to it Args: wslink_server: the socket manager of the server """ wslink_server.app.add_routes(self.api_routes)
[docs] def on_client_exited(self): """ Handles the event when a client exits. Decreases the client counter and logs the event. If there are no more clients connected, it prints a message indicating that a client disconnected. """ self.client_counter -= 1 root_logger.debug( f"A client disconnected from Trame server {self.host}:{self.port}, there are {self.client_counter} " f"clients connected")
[docs] def on_client_connected(self): """ This method is called when a client connects to the Trame server. It increments the client counter and logs the connection details. """ self.client_counter += 1 root_logger.debug( f"A client connected to Trame server {self.host}:{self.port}, there are {self.client_counter} " f"clients connected")
[docs] def on_server_exited(self, **kwargs): """ Callback function called when the server has exited. """ root_logger.debug(f"Trame server {self.host}:{self.port} has exited successfully")
[docs] async def init_connection(self, request) -> web.Response: """ Base api endpoint on '/init_connection' to inform the client of all the endpoints available and their locations. Args: request: the request made to this endpoint Returns: web.Response: a json with all information about endpoints required and a success status 200 """ response_body = { ServerMessageInterface.Keys.SelectMesh: EndpointsInterface.SelectMesh, ServerMessageInterface.Keys.UploadMesh: EndpointsInterface.UploadMesh, ServerMessageInterface.Keys.Host: f"{EndpointsInterface.Localhost}:{self.server.port}" } root_logger.debug(f"Trame server {self.host}:{self.port} initialized connection with a client") return web.json_response(response_body, status=200)
[docs] async def start(self): """ Starts the Trame server and waits for it to finish. """ root_logger.info(f"Trame server running on {self.host}:{self.server.port}") await self.server.start(exec_mode="task")
# await self.shutdown_event.wait()
[docs] async def request_stop(self, force_stop: bool = False): """ Stops the server if there are no active clients, otherwise schedules a delayed call to stop. If there are no active clients connected to the server, the server is stopped immediately and the `shutdown_event` is set. Otherwise, a delayed call to `request_stop` is scheduled using `threading.Timer` and `asyncio.get_running_loop().call_soon(asyncio.create_task, self.request_stop())`. Args: force_stop (bool): Force the request to immediately stop the server """ if self.client_counter == 0 or force_stop: root_logger.debug(f"The Trame server {self.server.port} is about to stop") await self.server.stop() self.shutdown_event.set() else: threading.Timer(2 * ONE_MINUTE, lambda: self.request_stop()).start()
[docs] def signal_handler(self, sig: int, frame: FrameType): """ Handles the specified signal and initiates the shutdown process. Args: sig (int): The signal number. frame (FrameType): The current stack frame. """ root_logger.info(f"Received signal {sig}. Shutting down...") asyncio.create_task(self.request_stop())
[docs] async def run(self): """ Runs the Trame server. This method sets up signal handlers for interrupt and termination signals, and then starts the Trame server """ # Set up signal handlers for sig in (signal.SIGINT, signal.SIGTERM): signal.signal(sig, self.signal_handler) try: await self.start() finally: root_logger.info(f"Trame server on {self.host}:{self.server.port} stopped")