base.py 2.92 KB
from typing import NamedTuple
from urllib.parse import parse_qsl, unquote, urlparse
from redis import StrictRedis, ConnectionPool

try:
    from collections.abc import Mapping
except ImportError:
    from collections import Mapping


url_parts = NamedTuple('url_parts', [
    ('scheme', str),
    ('hostname', str),
    ('port', int),
    ('username', str),
    ('password', str),
    ('path', str),
    ('query', Mapping),
])


def url_to_parts(url):
    # type: (str) -> urlparts
    """Parse URL into :class:`urlparts` tuple of components."""
    scheme = urlparse(url).scheme
    schemeless = url[len(scheme) + 3:]
    # parse with HTTP URL semantics
    parts = urlparse('http://' + schemeless)
    path = parts.path or ''
    path = path[1:] if path and path[0] == '/' else path
    return url_parts(
        scheme,
        unquote(parts.hostname or '') or None,
        parts.port,
        unquote(parts.username or '') or None,
        unquote(parts.password or '') or None,
        unquote(path or '') or None,
        dict(parse_qsl(parts.query)),
    )


class Redis:

    def __init__(self, url, connection_pool=None, max_connections=20, socket_timeout=120,
                 retry_on_timeout=None, socket_connect_timeout=None):
        self._ConnectionPool = connection_pool
        scheme, host, port, _, password, path, query = url_to_parts(url)
        self.conn_params = {
            'host': host,
            'port': port,
            'db': int(path),
            'password': password,
            'max_connections': max_connections,
            'socket_timeout': socket_timeout and float(socket_timeout),
            'retry_on_timeout': retry_on_timeout or False,
            'socket_connect_timeout':
                socket_connect_timeout and float(socket_connect_timeout),
            'decode_responses': True
        }

        self.client = StrictRedis(
            connection_pool=self._get_pool(**self.conn_params),
        )

    @property
    def ConnectionPool(self):
        if self._ConnectionPool is None:
            self._ConnectionPool = ConnectionPool
        return self._ConnectionPool

    def _get_pool(self, **params):
        return self.ConnectionPool(**params)

    def get(self, key):
        return self.client.get(key)

    def mget(self, keys):
        return self.client.mget(keys)

    def set(self, key, value, expires=None):
        if expires:
            return self.client.setex(key, expires, value)
        else:
            return self.client.set(key, value)

    def delete(self, key):
        self.client.delete(key)

    def incr(self, key):
        return self.client.incr(key)

    def expire(self, key, value):
        return self.client.expire(key, value)

    def lpush(self, key, values):
        return self.client.lpush(key, *values)  # int

    def lrange(self, key, start, end):
        return self.client.lrange(key, start, end)  # list

    def rpop(self, key):
        return self.client.rpop(key)  # str or None