Spaces:
Running
Running
File size: 4,369 Bytes
80fda88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import ast
from abc import ABC, abstractmethod
from app.config import config
from app.models import const
# Base class for state management
class BaseState(ABC):
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
@abstractmethod
def get_task(self, task_id: str):
pass
@abstractmethod
def get_all_tasks(self, page: int, page_size: int):
pass
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def get_all_tasks(self, page: int, page_size: int):
start = (page - 1) * page_size
end = start + page_size
tasks = list(self._tasks.values())
total = len(tasks)
return tasks[start:end], total
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
self._tasks[task_id] = {
"task_id": task_id,
"state": state,
"progress": progress,
**kwargs,
}
def get_task(self, task_id: str):
return self._tasks.get(task_id, None)
def delete_task(self, task_id: str):
if task_id in self._tasks:
del self._tasks[task_id]
# Redis state management
class RedisState(BaseState):
def __init__(self, host="localhost", port=6379, db=0, password=None):
import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def get_all_tasks(self, page: int, page_size: int):
start = (page - 1) * page_size
end = start + page_size
tasks = []
cursor = 0
total = 0
while True:
cursor, keys = self._redis.scan(cursor, count=page_size)
total += len(keys)
if total > start:
for key in keys[max(0, start - total):end - total]:
task_data = self._redis.hgetall(key)
task = {
k.decode("utf-8"): self._convert_to_original_type(v) for k, v in task_data.items()
}
tasks.append(task)
if len(tasks) >= page_size:
break
if cursor == 0 or len(tasks) >= page_size:
break
return tasks, total
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
fields = {
"task_id": task_id,
"state": state,
"progress": progress,
**kwargs,
}
for field, value in fields.items():
self._redis.hset(task_id, field, str(value))
def get_task(self, task_id: str):
task_data = self._redis.hgetall(task_id)
if not task_data:
return None
task = {
key.decode("utf-8"): self._convert_to_original_type(value)
for key, value in task_data.items()
}
return task
def delete_task(self, task_id: str):
self._redis.delete(task_id)
@staticmethod
def _convert_to_original_type(value):
"""
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode("utf-8")
try:
# try to convert byte string array to list
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
pass
if value_str.isdigit():
return int(value_str)
# Add more conversions here if needed
return value_str
# Global state
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
state = (
RedisState(
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
)
if _enable_redis
else MemoryState()
)
|