File size: 5,695 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from __future__ import annotations
import datetime
import os
import threading
from collections import OrderedDict
from collections.abc import Iterator
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.components import State
class StateHolder:
def __init__(self):
self.capacity = 10000
self.session_data: OrderedDict[str, SessionState] = OrderedDict()
self.time_last_used: dict[str, datetime.datetime] = {}
self.lock = threading.Lock()
def set_blocks(self, blocks: Blocks):
self.blocks = blocks
blocks.state_holder = self
self.capacity = blocks.state_session_capacity
def reset(self, blocks: Blocks):
"""Reset the state holder with new blocks. Used during reload mode."""
self.session_data = OrderedDict()
# Call set blocks again to set new ids
self.set_blocks(blocks)
def __getitem__(self, session_id: str) -> SessionState:
if session_id not in self.session_data:
self.session_data[session_id] = SessionState(self.blocks)
self.update(session_id)
self.time_last_used[session_id] = datetime.datetime.now()
return self.session_data[session_id]
def __contains__(self, session_id: str):
return session_id in self.session_data
def update(self, session_id: str):
with self.lock:
if session_id in self.session_data:
self.session_data.move_to_end(session_id)
if len(self.session_data) > self.capacity:
self.session_data.popitem(last=False)
def delete_all_expired_state(
self,
):
for session_id in self.session_data:
self.delete_state(session_id, expired_only=True)
def delete_state(self, session_id: str, expired_only: bool = False):
if session_id not in self.session_data:
return
to_delete = []
session_state = self.session_data[session_id]
for component, value, expired in session_state.state_components:
if not expired_only or expired:
component.delete_callback(value)
to_delete.append(component._id)
for component in to_delete:
del session_state.state_data[component]
class SessionState:
def __init__(self, blocks: Blocks):
self.blocks_config = copy(blocks.default_config)
# Keep a separate deep copy of the config so we can recreate
# the state for deep links
self.config_values = {
k: self.blocks_config.config_for_block(k, [], v)
for k, v in self.blocks_config.blocks.items()
if k in blocks.blocks
}
self.state_data: dict[int, Any] = {}
self._state_ttl = {}
self.is_closed = False
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
# During testing we set to a lower value to be able to test
self.STATE_TTL_WHEN_CLOSED = (
1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600
)
def __getitem__(self, key: int) -> Any:
block = self.blocks_config.blocks[key]
if block.stateful:
if key not in self.state_data:
self.state_data[key] = deepcopy(getattr(block, "value", None))
return self.state_data[key]
else:
return block
def __setitem__(self, key: int, value: Any):
from gradio.components import State
block = self.blocks_config.blocks.get(key)
if isinstance(block, State):
self._state_ttl[key] = (
block.time_to_live,
datetime.datetime.now(),
)
self.state_data[key] = value
else:
self.blocks_config.blocks[key] = value
if block:
self.config_values[key] = self.blocks_config.config_for_block(
key, [], block
)
def _update_config(self, key: int):
if self[key] is not None:
self.config_values[key] = self.blocks_config.config_for_block(
key, [], self[key]
)
def _update_value_in_config(self, key: int, value: Any):
if key not in self.config_values:
self.config_values[key] = self.blocks_config.config_for_block(
key, [], self.blocks_config.blocks[key]
)
if "props" in self.config_values[key]:
self.config_values[key]["props"]["value"] = value
def __contains__(self, key: int):
block = self.blocks_config.blocks.get(key)
if block is None:
return False
if block.stateful:
return key in self.state_data
else:
return key in self.blocks_config.blocks
@property
def components(self) -> Iterator[dict]:
for _, config in self.config_values.items():
if config:
yield config
@property
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
from gradio.components import State
for id in self.state_data:
block = self.blocks_config.blocks[id]
if isinstance(block, State) and id in self._state_ttl:
time_to_live, created_at = self._state_ttl[id]
if self.is_closed:
time_to_live = self.STATE_TTL_WHEN_CLOSED
value = self.state_data[id]
yield (
block,
value,
(datetime.datetime.now() - created_at).seconds > time_to_live,
)
|