forked from kennethreitz/records
-
Notifications
You must be signed in to change notification settings - Fork 0
/
records.py
137 lines (102 loc) · 4.08 KB
/
records.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
import os
from datetime import datetime
import tablib
import psycopg2
from psycopg2.extras import register_hstore, RealDictCursor
DATABASE_URL = os.environ.get('DATABASE_URL')
PG_TABLES_QUERY = "SELECT * FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
PG_INTERNAL_TABLES_QUERY = "SELECT * FROM pg_catalog.pg_tables"
class ResultSet(object):
"""A set of results from a query."""
def __init__(self, rows):
self._rows = rows
self._all_rows = []
self._completed = False
def __repr__(self):
return '<ResultSet {:o}>'.format(id(self))
def __iter__(self):
# Use cached results if available.
if self._completed:
for row in self._all_rows:
yield row
# Iterate over result cursor, cache rows.
for row in self._rows:
self._all_rows.append(row)
yield row
self._completed = True
def next(self):
try:
return self._rows.next()
except StopIteration:
raise StopIteration("ResultSet contains no more rows.")
@property
def dataset(self):
"""A Tablib Dataset representation of the ResultSet."""
# Create a new Tablib Dataset.
data = tablib.Dataset()
# Set the column names as headers on Tablib Dataset.
data.headers = self.all()[0].keys()
# Take each row, string-ify datetimes, insert into Tablib Dataset.
for row in self.all():
row = _reduce_datetimes([v for k, v in row.items()])
data.append(row)
return data
def all(self):
"""Returns a list of all rows for the ResultSet. If they haven't
been fetched yet, consume the iterator and cache the results."""
# If rows aren't cached, fetch them.
if not self._all_rows:
self._all_rows = list(self._rows)
return self._all_rows
class Database(object):
"""A Database connection."""
def __init__(self, db_url=None):
# If no db_url was provided, fallback to $DATABASE_URL.
self.db_url = db_url or DATABASE_URL
if not self.db_url:
raise ValueError('You must provide a db_url.')
# Connect to the database.
self.db = psycopg2.connect(self.db_url, cursor_factory=RealDictCursor)
# Enable hstore if it's available.
self._enable_hstore()
def _enable_hstore(self):
try:
register_hstore(self.db)
except psycopg2.ProgrammingError:
pass
def get_table_names(self, internal=False):
"""Returns a list of table names for the connected database."""
# Support listing internal table names as well.
query = PG_INTERNAL_TABLES_QUERY if internal else PG_TABLES_QUERY
# Return a list of tablenames.
return [r['tablename'] for r in self.query(query)]
def query(self, query, params=None, fetchall=False):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a ResultSet, which can be
iterated over to get result rows as dictionaries.
"""
# Execute the given query.
c = self.db.cursor()
c.execute(query, params)
# Row-by-row result generator.
row_gen = (r for r in c)
# Convert psycopg2 results to ResultSet
results = ResultSet(row_gen)
# Fetch all results if desired.
if fetchall:
results.all()
return results
def query_file(self, path, params=None, fetchall=False):
"""Like Database.query, but takes a filename to load a query from."""
# Read the given .sql file into memory.
with open(path) as f:
query = f.read()
# Defer processing to self.query method.
return self.query(query=query, params=params, fetchall=fetchall)
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""
for i in range(len(row)):
if isinstance(row[i], datetime):
row[i] = '{}'.format(row[i])
return row