Skip to content

Commit

Permalink
Issue 2917: Merge the pickle and cPickle module.
Browse files Browse the repository at this point in the history
  • Loading branch information
avassalotti committed Jun 11, 2008
1 parent 1e637b7 commit cc31306
Show file tree
Hide file tree
Showing 8 changed files with 4,685 additions and 126 deletions.
157 changes: 74 additions & 83 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,37 +174,38 @@ def __init__(self, value):

# Pickling machinery

class Pickler:
class _Pickler:

def __init__(self, file, protocol=None):
"""This takes a binary file for writing a pickle data stream.
All protocols now read and write bytes.
The optional protocol argument tells the pickler to use the
given protocol; supported protocols are 0, 1, 2. The default
protocol is 2; it's been supported for many years now.
Protocol 1 is more efficient than protocol 0; protocol 2 is
more efficient than protocol 1.
given protocol; supported protocols are 0, 1, 2, 3. The default
protocol is 3; a backward-incompatible protocol designed for
Python 3.0.
Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the
more recent the version of Python needed to read the pickle
produced.
The file parameter must have a write() method that accepts a single
string argument. It can thus be an open file object, a StringIO
object, or any other custom object that meets this interface.
The file argument must have a write() method that accepts a single
bytes argument. It can thus be a file object opened for binary
writing, a io.BytesIO instance, or any other custom object that
meets this interface.
"""
if protocol is None:
protocol = DEFAULT_PROTOCOL
if protocol < 0:
protocol = HIGHEST_PROTOCOL
elif not 0 <= protocol <= HIGHEST_PROTOCOL:
raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
self.write = file.write
try:
self.write = file.write
except AttributeError:
raise TypeError("file must have a 'write' attribute")
self.memo = {}
self.proto = int(protocol)
self.bin = protocol >= 1
Expand Down Expand Up @@ -270,10 +271,10 @@ def get(self, i, pack=struct.pack):

return GET + repr(i).encode("ascii") + b'\n'

def save(self, obj):
def save(self, obj, save_persistent_id=True):
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid:
if pid is not None and save_persistent_id:
self.save_pers(pid)
return

Expand Down Expand Up @@ -341,7 +342,7 @@ def persistent_id(self, obj):
def save_pers(self, pid):
# Save a persistent id reference
if self.bin:
self.save(pid)
self.save(pid, save_persistent_id=False)
self.write(BINPERSID)
else:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
Expand All @@ -350,13 +351,13 @@ def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
# This API is called by some subclasses

# Assert that args is a tuple or None
# Assert that args is a tuple
if not isinstance(args, tuple):
raise PicklingError("args from reduce() should be a tuple")
raise PicklingError("args from save_reduce() should be a tuple")

# Assert that func is callable
if not hasattr(func, '__call__'):
raise PicklingError("func from reduce should be callable")
raise PicklingError("func from save_reduce() should be callable")

save = self.save
write = self.write
Expand Down Expand Up @@ -438,31 +439,6 @@ def save_bool(self, obj):
self.write(obj and TRUE or FALSE)
dispatch[bool] = save_bool

def save_int(self, obj, pack=struct.pack):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
# format, we can store it more efficiently than the general
# case.
# First one- and two-byte unsigned ints:
if obj >= 0:
if obj <= 0xff:
self.write(BININT1 + bytes([obj]))
return
if obj <= 0xffff:
self.write(BININT2 + bytes([obj&0xff, obj>>8]))
return
# Next check for 4-byte signed ints:
high_bits = obj >> 31 # note that Python shift sign-extends
if high_bits == 0 or high_bits == -1:
# All high bits are copies of bit 2**31, so the value
# fits in a 4-byte signed int.
self.write(BININT + pack("<i", obj))
return
# Text pickle, or int too big to fit in signed 4-byte format.
self.write(INT + repr(obj).encode("ascii") + b'\n')
# XXX save_int is merged into save_long
# dispatch[int] = save_int

def save_long(self, obj, pack=struct.pack):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
Expand Down Expand Up @@ -503,7 +479,7 @@ def save_float(self, obj, pack=struct.pack):

def save_bytes(self, obj, pack=struct.pack):
if self.proto < 3:
self.save_reduce(bytes, (list(obj),))
self.save_reduce(bytes, (list(obj),), obj=obj)
return
n = len(obj)
if n < 256:
Expand Down Expand Up @@ -579,12 +555,6 @@ def save_tuple(self, obj):

dispatch[tuple] = save_tuple

# save_empty_tuple() isn't used by anything in Python 2.3. However, I
# found a Pickler subclass in Zope3 that calls it, so it's not harmless
# to remove it.
def save_empty_tuple(self, obj):
self.write(EMPTY_TUPLE)

def save_list(self, obj):
write = self.write

Expand Down Expand Up @@ -696,7 +666,7 @@ def save_global(self, obj, name=None, pack=struct.pack):
module = whichmodule(obj, name)

try:
__import__(module)
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
except (ImportError, KeyError, AttributeError):
Expand All @@ -720,9 +690,19 @@ def save_global(self, obj, name=None, pack=struct.pack):
else:
write(EXT4 + pack("<i", code))
return
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 3:
write(GLOBAL + bytes(module, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
try:
write(GLOBAL + bytes(module, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto))

write(GLOBAL + bytes(module, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
self.memoize(obj)

dispatch[FunctionType] = save_global
Expand Down Expand Up @@ -781,7 +761,7 @@ def whichmodule(func, funcname):

# Unpickling machinery

class Unpickler:
class _Unpickler:

def __init__(self, file, *, encoding="ASCII", errors="strict"):
"""This takes a binary file for reading a pickle data stream.
Expand Down Expand Up @@ -841,6 +821,9 @@ def marker(self):
while stack[k] is not mark: k = k-1
return k

def persistent_load(self, pid):
raise UnpickingError("unsupported persistent id encountered")

dispatch = {}

def load_proto(self):
Expand All @@ -850,7 +833,7 @@ def load_proto(self):
dispatch[PROTO[0]] = load_proto

def load_persid(self):
pid = self.readline()[:-1]
pid = self.readline()[:-1].decode("ascii")
self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid

Expand Down Expand Up @@ -879,9 +862,9 @@ def load_int(self):
val = True
else:
try:
val = int(data)
val = int(data, 0)
except ValueError:
val = int(data)
val = int(data, 0)
self.append(val)
dispatch[INT[0]] = load_int

Expand Down Expand Up @@ -933,7 +916,8 @@ def load_string(self):
break
else:
raise ValueError("insecure string pickle: %r" % orig)
self.append(codecs.escape_decode(rep)[0])
self.append(codecs.escape_decode(rep)[0]
.decode(self.encoding, self.errors))
dispatch[STRING[0]] = load_string

def load_binstring(self):
Expand Down Expand Up @@ -975,7 +959,7 @@ def load_tuple(self):
dispatch[TUPLE[0]] = load_tuple

def load_empty_tuple(self):
self.stack.append(())
self.append(())
dispatch[EMPTY_TUPLE[0]] = load_empty_tuple

def load_tuple1(self):
Expand All @@ -991,11 +975,11 @@ def load_tuple3(self):
dispatch[TUPLE3[0]] = load_tuple3

def load_empty_list(self):
self.stack.append([])
self.append([])
dispatch[EMPTY_LIST[0]] = load_empty_list

def load_empty_dictionary(self):
self.stack.append({})
self.append({})
dispatch[EMPTY_DICT[0]] = load_empty_dictionary

def load_list(self):
Expand All @@ -1022,13 +1006,13 @@ def load_dict(self):
def _instantiate(self, klass, k):
args = tuple(self.stack[k+1:])
del self.stack[k:]
instantiated = 0
instantiated = False
if (not args and
isinstance(klass, type) and
not hasattr(klass, "__getinitargs__")):
value = _EmptyClass()
value.__class__ = klass
instantiated = 1
instantiated = True
if not instantiated:
try:
value = klass(*args)
Expand All @@ -1038,8 +1022,8 @@ def _instantiate(self, klass, k):
self.append(value)

def load_inst(self):
module = self.readline()[:-1]
name = self.readline()[:-1]
module = self.readline()[:-1].decode("ascii")
name = self.readline()[:-1].decode("ascii")
klass = self.find_class(module, name)
self._instantiate(klass, self.marker())
dispatch[INST[0]] = load_inst
Expand All @@ -1059,8 +1043,8 @@ def load_newobj(self):
dispatch[NEWOBJ[0]] = load_newobj

def load_global(self):
module = self.readline()[:-1]
name = self.readline()[:-1]
module = self.readline()[:-1].decode("utf-8")
name = self.readline()[:-1].decode("utf-8")
klass = self.find_class(module, name)
self.append(klass)
dispatch[GLOBAL[0]] = load_global
Expand Down Expand Up @@ -1095,11 +1079,7 @@ def get_extension(self, code):

def find_class(self, module, name):
# Subclasses may override this
if isinstance(module, bytes_types):
module = module.decode("utf-8")
if isinstance(name, bytes_types):
name = name.decode("utf-8")
__import__(module)
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
return klass
Expand Down Expand Up @@ -1131,31 +1111,33 @@ def load_dup(self):
dispatch[DUP[0]] = load_dup

def load_get(self):
self.append(self.memo[self.readline()[:-1].decode("ascii")])
i = int(self.readline()[:-1])
self.append(self.memo[i])
dispatch[GET[0]] = load_get

def load_binget(self):
i = ord(self.read(1))
self.append(self.memo[repr(i)])
i = self.read(1)[0]
self.append(self.memo[i])
dispatch[BINGET[0]] = load_binget

def load_long_binget(self):
i = mloads(b'i' + self.read(4))
self.append(self.memo[repr(i)])
self.append(self.memo[i])
dispatch[LONG_BINGET[0]] = load_long_binget

def load_put(self):
self.memo[self.readline()[:-1].decode("ascii")] = self.stack[-1]
i = int(self.readline()[:-1])
self.memo[i] = self.stack[-1]
dispatch[PUT[0]] = load_put

def load_binput(self):
i = ord(self.read(1))
self.memo[repr(i)] = self.stack[-1]
i = self.read(1)[0]
self.memo[i] = self.stack[-1]
dispatch[BINPUT[0]] = load_binput

def load_long_binput(self):
i = mloads(b'i' + self.read(4))
self.memo[repr(i)] = self.stack[-1]
self.memo[i] = self.stack[-1]
dispatch[LONG_BINPUT[0]] = load_long_binput

def load_append(self):
Expand Down Expand Up @@ -1321,6 +1303,15 @@ def decode_long(data):
n -= 1 << (nbytes * 8)
return n

# Use the faster _pickle if possible
try:
from _pickle import *
except ImportError:
Pickler, Unpickler = _Pickler, _Unpickler
PickleError = _PickleError
PicklingError = _PicklingError
UnpicklingError = _UnpicklingError

# Shorthands

def dump(obj, file, protocol=None):
Expand All @@ -1333,14 +1324,14 @@ def dumps(obj, protocol=None):
assert isinstance(res, bytes_types)
return res

def load(file):
return Unpickler(file).load()
def load(file, *, encoding="ASCII", errors="strict"):
return Unpickler(file, encoding=encoding, errors=errors).load()

def loads(s):
def loads(s, *, encoding="ASCII", errors="strict"):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file).load()
return Unpickler(file, encoding=encoding, errors=errors).load()

# Doctest

Expand Down
Loading

0 comments on commit cc31306

Please sign in to comment.