import json from peewee import * from peewee import Expression from peewee import Node from peewee import NodeList from playhouse.postgres_ext import ArrayField from playhouse.postgres_ext import DateTimeTZField from playhouse.postgres_ext import IndexedFieldMixin from playhouse.postgres_ext import IntervalField from playhouse.postgres_ext import Match from playhouse.postgres_ext import TSVectorField # Helpers needed for psycopg3-specific overrides. from playhouse.postgres_ext import _JsonLookupBase try: import psycopg from psycopg.types.json import Jsonb from psycopg.pq import TransactionStatus except ImportError: psycopg = Jsonb = None JSONB_CONTAINS = '@>' JSONB_CONTAINED_BY = '<@' JSONB_CONTAINS_KEY = '?' JSONB_CONTAINS_ANY_KEY = '?|' JSONB_CONTAINS_ALL_KEYS = '?&' JSONB_EXISTS = '?' JSONB_REMOVE = '-' class _Psycopg3JsonLookupBase(_JsonLookupBase): def concat(self, rhs): if not isinstance(rhs, Node): rhs = Jsonb(rhs) # Note: uses psycopg3's Jsonb. return Expression(self.as_json(True), OP.CONCAT, rhs) def contains(self, other): clone = self.as_json(True) if isinstance(other, (list, dict)): return Expression(clone, JSONB_CONTAINS, Jsonb(other)) # Same. return Expression(clone, JSONB_EXISTS, other) class JsonLookup(_Psycopg3JsonLookupBase): def __getitem__(self, value): return JsonLookup(self.node, self.parts + [value], self._as_json) def __sql__(self, ctx): ctx.sql(self.node) for part in self.parts[:-1]: ctx.literal('->').sql(part) if self.parts: (ctx .literal('->' if self._as_json else '->>') .sql(self.parts[-1])) return ctx class JsonPath(_Psycopg3JsonLookupBase): def __sql__(self, ctx): return (ctx .sql(self.node) .literal('#>' if self._as_json else '#>>') .sql(Value('{%s}' % ','.join(map(str, self.parts))))) def cast_jsonb(node): return NodeList((node, SQL('::jsonb')), glue='') class BinaryJSONField(IndexedFieldMixin, Field): field_type = 'JSONB' _json_datatype = 'jsonb' __hash__ = Field.__hash__ def __init__(self, dumps=None, *args, **kwargs): self.dumps = dumps or json.dumps super(BinaryJSONField, self).__init__(*args, **kwargs) def db_value(self, value): if value is None: return value if not isinstance(value, Jsonb): return Cast(self.dumps(value), self._json_datatype) return value def __getitem__(self, value): return JsonLookup(self, [value]) def path(self, *keys): return JsonPath(self, keys) def concat(self, value): if not isinstance(value, Node): value = Jsonb(value) return super(BinaryJSONField, self).concat(value) def contains(self, other): if isinstance(other, (list, dict)): return Expression(self, JSONB_CONTAINS, Jsonb(other)) elif isinstance(other, BinaryJSONField): return Expression(self, JSONB_CONTAINS, other) return Expression(cast_jsonb(self), JSONB_EXISTS, other) def contained_by(self, other): return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Jsonb(other)) def contains_any(self, *items): return Expression( cast_jsonb(self), JSONB_CONTAINS_ANY_KEY, Value(list(items), unpack=False)) def contains_all(self, *items): return Expression( cast_jsonb(self), JSONB_CONTAINS_ALL_KEYS, Value(list(items), unpack=False)) def has_key(self, key): return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) def remove(self, *items): return Expression( cast_jsonb(self), JSONB_REMOVE, # Hack: psycopg3 parameterizes this as an array, e.g. '{k1,k2}', # but that doesn't seem to be working, so we explicitly cast. # Perhaps postgres is interpreting it as a string. Using the more # explicit ARRAY['k1','k2'] also works just fine -- but we'll make # the cast explicit to get it working. Cast(Value(list(items), unpack=False), 'text[]')) class Psycopg3Database(PostgresqlDatabase): def _connect(self): if psycopg is None: raise ImproperlyConfigured('psycopg3 is not installed!') conn = psycopg.connect(dbname=self.database, **self.connect_params) if self._isolation_level is not None: conn.isolation_level = self._isolation_level conn.autocommit = True return conn def get_binary_type(self): return psycopg.Binary def _set_server_version(self, conn): self.server_version = conn.pgconn.server_version if self.server_version >= 90600: self.safe_create_index = True def is_connection_usable(self): if self._state.closed: return False # Returns True if we are idle, running a command, or in an active # connection. If the connection is in an error state or the connection # is otherwise unusable, return False. conn = self._state.conn return conn.pgconn.transaction_status < TransactionStatus.INERROR def extract_date(self, date_part, date_field): return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field)))