// Copyright (c) 2011 Daniel Grunwald
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this
// software and associated documentation files (the "Software"), to deal in the Software
// without restriction, including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
// to whom the Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.Serialization;
namespace ICSharpCode.NRefactory.Utils
{
public class FastSerializer
{
#region Properties
///
/// Gets/Sets the serialization binder that is being used.
/// The default value is null, which will cause the FastSerializer to use the
/// full assembly and type names.
///
public SerializationBinder SerializationBinder { get; set; }
///
/// Can be used to set several 'fixed' instances.
/// When serializing, such instances will not be included; and any references to a fixed instance
/// will be stored as the index in this array.
/// When deserializing, the same (or equivalent) instances must be specified, and the deserializer
/// will use them in place of the fixed instances.
///
public object[] FixedInstances { get; set; }
#endregion
#region Constants
const int magic = 0x71D28A5E;
const byte Type_ReferenceType = 1;
const byte Type_ValueType = 2;
const byte Type_SZArray = 3;
const byte Type_ParameterizedType = 4;
#endregion
#region Serialization
sealed class SerializationType
{
public readonly int ID;
public readonly Type Type;
public SerializationType(int iD, Type type)
{
this.ID = iD;
this.Type = type;
}
public ObjectScanner Scanner;
public ObjectWriter Writer;
public string TypeName;
public int AssemblyNameID;
}
sealed class SerializationContext
{
readonly Dictionary objectToID = new Dictionary(ReferenceComparer.Instance);
readonly List instances = new List(); // index: object ID
readonly List objectTypes = new List(); // index: object ID
SerializationType stringType;
readonly Dictionary typeMap = new Dictionary();
readonly List types = new List();
readonly Dictionary assemblyNameToID = new Dictionary();
readonly List assemblyNames = new List();
readonly FastSerializer fastSerializer;
public readonly BinaryWriter writer;
int fixedInstanceCount;
internal SerializationContext(FastSerializer fastSerializer, BinaryWriter writer)
{
this.fastSerializer = fastSerializer;
this.writer = writer;
instances.Add(null); // use object ID 0 for null
objectTypes.Add(null);
}
#region Scanning
public void MarkFixedInstances(object[] fixedInstances)
{
if (fixedInstances == null)
return;
foreach (object obj in fixedInstances) {
if (!objectToID.ContainsKey(obj)) {
objectToID.Add(obj, instances.Count);
instances.Add(obj);
fixedInstanceCount++;
}
}
}
///
/// Marks an instance for future scanning.
///
public void Mark(object instance)
{
if (instance == null || objectToID.ContainsKey(instance))
return;
Log(" Mark {0}", instance.GetType().Name);
objectToID.Add(instance, instances.Count);
instances.Add(instance);
}
internal void Scan()
{
Log("Scanning...");
// starting from 1, because index 0 is null
// Also, do not scan any of the 'fixed instances'.
for (int i = 1 + fixedInstanceCount; i < instances.Count; i++) {
object instance = instances[i];
ISerializable serializable = instance as ISerializable;
Type type = instance.GetType();
Log("Scan #{0}: {1}", i, type.Name);
SerializationType sType = MarkType(type);
objectTypes.Add(sType);
if (serializable != null) {
SerializationInfo info = new SerializationInfo(type, fastSerializer.formatterConverter);
serializable.GetObjectData(info, fastSerializer.streamingContext);
instances[i] = info;
foreach (SerializationEntry entry in info) {
Mark(entry.Value);
}
sType.Writer = serializationInfoWriter;
} else {
ObjectScanner objectScanner = sType.Scanner;
if (objectScanner == null) {
objectScanner = fastSerializer.GetScanner(type);
sType.Scanner = objectScanner;
sType.Writer = fastSerializer.GetWriter(type);
}
objectScanner(this, instance);
}
}
}
#endregion
#region Scan Types
SerializationType MarkType(Type type)
{
SerializationType sType;
if (!typeMap.TryGetValue(type, out sType)) {
string assemblyName = null;
string typeName = null;
if (type.HasElementType) {
Debug.Assert(type.IsArray);
MarkType(type.GetElementType());
} else if (type.IsGenericType && !type.IsGenericTypeDefinition) {
MarkType(type.GetGenericTypeDefinition());
foreach (Type typeArg in type.GetGenericArguments())
MarkType(typeArg);
} else if (type.IsGenericParameter) {
throw new NotSupportedException();
} else {
var serializationBinder = fastSerializer.SerializationBinder;
if (serializationBinder != null) {
serializationBinder.BindToName(type, out assemblyName, out typeName);
} else {
assemblyName = type.Assembly.FullName;
typeName = type.FullName;
Debug.Assert(typeName != null);
}
}
sType = new SerializationType(typeMap.Count, type);
sType.TypeName = typeName;
if (assemblyName != null) {
if (!assemblyNameToID.TryGetValue(assemblyName, out sType.AssemblyNameID)) {
sType.AssemblyNameID = assemblyNames.Count;
assemblyNameToID.Add(assemblyName, sType.AssemblyNameID);
assemblyNames.Add(assemblyName);
Log("Registered assembly #{0}: {1}", sType.AssemblyNameID, assemblyName);
}
}
typeMap.Add(type, sType);
types.Add(sType);
Log("Registered type %{0}: {1}", sType.ID, type);
if (type == typeof(string)) {
stringType = sType;
}
}
return sType;
}
internal void ScanTypes()
{
for (int i = 0; i < types.Count; i++) {
Type type = types[i].Type;
if (type.IsGenericTypeDefinition || type.HasElementType)
continue;
if (typeof(ISerializable).IsAssignableFrom(type))
continue;
foreach (FieldInfo field in GetSerializableFields(type)) {
MarkType(field.FieldType);
}
}
}
#endregion
#region Writing
public void WriteObjectID(object instance)
{
int id = (instance == null) ? 0 : objectToID[instance];
if (instances.Count <= ushort.MaxValue)
writer.Write((ushort)id);
else
writer.Write(id);
}
void WriteTypeID(Type type)
{
Debug.Assert(typeMap.ContainsKey(type));
int typeID = typeMap[type].ID;
if (types.Count <= ushort.MaxValue)
writer.Write((ushort)typeID);
else
writer.Write(typeID);
}
internal void Write()
{
Log("Writing...");
writer.Write(magic);
// Write out type information
writer.Write(instances.Count);
writer.Write(types.Count);
writer.Write(assemblyNames.Count);
writer.Write(fixedInstanceCount);
foreach (string assemblyName in assemblyNames) {
writer.Write(assemblyName);
}
foreach (SerializationType sType in types) {
Type type = sType.Type;
if (type.HasElementType) {
if (type.IsArray) {
if (type.GetArrayRank() == 1)
writer.Write(Type_SZArray);
else
throw new NotSupportedException();
} else {
throw new NotSupportedException();
}
WriteTypeID(type.GetElementType());
} else if (type.IsGenericType && !type.IsGenericTypeDefinition) {
writer.Write(Type_ParameterizedType);
WriteTypeID(type.GetGenericTypeDefinition());
foreach (Type typeArg in type.GetGenericArguments()) {
WriteTypeID(typeArg);
}
} else {
if (type.IsValueType) {
writer.Write(Type_ValueType);
} else {
writer.Write(Type_ReferenceType);
}
if (assemblyNames.Count <= ushort.MaxValue)
writer.Write((ushort)sType.AssemblyNameID);
else
writer.Write(sType.AssemblyNameID);
writer.Write(sType.TypeName);
}
}
foreach (SerializationType sType in types) {
Type type = sType.Type;
if (type.IsGenericTypeDefinition || type.HasElementType)
continue;
writer.Write(FastSerializerVersionAttribute.GetVersionNumber(type));
if (type.IsPrimitive || typeof(ISerializable).IsAssignableFrom(type)) {
writer.Write(byte.MaxValue);
} else {
var fields = GetSerializableFields(type);
if (fields.Count >= byte.MaxValue)
throw new SerializationException("Too many fields.");
writer.Write((byte)fields.Count);
foreach (var field in fields) {
WriteTypeID(field.FieldType);
writer.Write(field.Name);
}
}
}
// Write out information necessary to create the instances
// starting from 1, because index 0 is null
for (int i = 1 + fixedInstanceCount; i < instances.Count; i++) {
SerializationType sType = objectTypes[i];
if (types.Count <= ushort.MaxValue)
writer.Write((ushort)sType.ID);
else
writer.Write(sType.ID);
if (sType == stringType) {
// Strings are written to the output immediately
// - we can't create an empty string and fill it later
writer.Write((string)instances[i]);
} else if (sType.Type.IsArray) {
// For arrays, write down the length, because we need that to create the array instance
writer.Write(((Array)instances[i]).Length);
}
}
// Write out information necessary to fill data into the instances
for (int i = 1 + fixedInstanceCount; i < instances.Count; i++) {
Log("0x{2:x6}, Write #{0}: {1}", i, objectTypes[i].Type.Name, writer.BaseStream.Position);
objectTypes[i].Writer(this, instances[i]);
}
Log("Serialization done.");
}
#endregion
}
#region Object Scanners
delegate void ObjectScanner(SerializationContext context, object instance);
static readonly MethodInfo mark = typeof(SerializationContext).GetMethod("Mark", new[] { typeof(object) });
static readonly FieldInfo writerField = typeof(SerializationContext).GetField("writer");
Dictionary scanners = new Dictionary();
ObjectScanner GetScanner(Type type)
{
ObjectScanner scanner;
if (!scanners.TryGetValue(type, out scanner)) {
scanner = CreateScanner(type);
scanners.Add(type, scanner);
}
return scanner;
}
ObjectScanner CreateScanner(Type type)
{
bool isArray = type.IsArray;
if (isArray) {
if (type.GetArrayRank() != 1)
throw new NotSupportedException();
type = type.GetElementType();
if (!type.IsValueType) {
return delegate (SerializationContext context, object array) {
foreach (object val in (object[])array) {
context.Mark(val);
}
};
}
}
for (Type baseType = type; baseType != null; baseType = baseType.BaseType) {
if (!baseType.IsSerializable)
throw new SerializationException("Type " + baseType + " is not [Serializable].");
}
List fields = GetSerializableFields(type);
fields.RemoveAll(f => !IsReferenceOrContainsReferences(f.FieldType));
if (fields.Count == 0) {
// The scanner has nothing to do for this object.
return delegate { };
}
DynamicMethod dynamicMethod = new DynamicMethod(
(isArray ? "ScanArray_" : "Scan_") + type.Name,
typeof(void), new [] { typeof(SerializationContext), typeof(object) },
true);
ILGenerator il = dynamicMethod.GetILGenerator();
if (isArray) {
var instance = il.DeclareLocal(type.MakeArrayType());
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Castclass, type.MakeArrayType());
il.Emit(OpCodes.Stloc, instance); // instance = (type[])arg_1;
// for (int i = 0; i < instance.Length; i++) scan instance[i];
var loopStart = il.DefineLabel();
var loopHead = il.DefineLabel();
var loopVariable = il.DeclareLocal(typeof(int));
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Stloc, loopVariable); // loopVariable = 0
il.Emit(OpCodes.Br, loopHead); // goto loopHead;
il.MarkLabel(loopStart);
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldloc, loopVariable); // instance, loopVariable
il.Emit(OpCodes.Ldelem, type); // &instance[loopVariable]
EmitScanValueType(il, type);
il.Emit(OpCodes.Ldloc, loopVariable); // loopVariable
il.Emit(OpCodes.Ldc_I4_1); // loopVariable, 1
il.Emit(OpCodes.Add); // loopVariable+1
il.Emit(OpCodes.Stloc, loopVariable); // loopVariable++;
il.MarkLabel(loopHead);
il.Emit(OpCodes.Ldloc, loopVariable); // loopVariable
il.Emit(OpCodes.Ldloc, instance); // loopVariable, instance
il.Emit(OpCodes.Ldlen); // loopVariable, instance.Length
il.Emit(OpCodes.Conv_I4);
il.Emit(OpCodes.Blt, loopStart); // if (loopVariable < instance.Length) goto loopStart;
} else if (type.IsValueType) {
// boxed value type
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Unbox_Any, type);
EmitScanValueType(il, type);
} else {
// reference type
var instance = il.DeclareLocal(type);
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Stloc, instance); // instance = (type)arg_1;
foreach (FieldInfo field in fields) {
EmitScanField(il, instance, field); // scan instance.Field
}
}
il.Emit(OpCodes.Ret);
return (ObjectScanner)dynamicMethod.CreateDelegate(typeof(ObjectScanner));
}
///
/// Emit 'scan instance.Field'.
/// Stack transition: ... => ...
///
void EmitScanField(ILGenerator il, LocalBuilder instance, FieldInfo field)
{
if (field.FieldType.IsValueType) {
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldfld, field); // instance.field
EmitScanValueType(il, field.FieldType);
} else {
il.Emit(OpCodes.Ldarg_0); // context
il.Emit(OpCodes.Ldloc, instance); // context, instance
il.Emit(OpCodes.Ldfld, field); // context, instance.field
il.Emit(OpCodes.Call, mark); // context.Mark(instance.field);
}
}
///
/// Stack transition: ..., value => ...
///
void EmitScanValueType(ILGenerator il, Type valType)
{
var fieldRef = il.DeclareLocal(valType);
il.Emit(OpCodes.Stloc, fieldRef);
foreach (FieldInfo field in GetSerializableFields(valType)) {
if (IsReferenceOrContainsReferences(field.FieldType)) {
EmitScanField(il, fieldRef, field);
}
}
}
static List GetSerializableFields(Type type)
{
List fields = new List();
for (Type baseType = type; baseType != null; baseType = baseType.BaseType) {
FieldInfo[] declFields = baseType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.DeclaredOnly);
Array.Sort(declFields, (a,b) => string.Compare(a.Name, b.Name, StringComparison.Ordinal));
fields.AddRange(declFields);
}
fields.RemoveAll(f => f.IsNotSerialized);
return fields;
}
static bool IsReferenceOrContainsReferences(Type type)
{
if (!type.IsValueType)
return true;
if (type.IsPrimitive)
return false;
foreach (FieldInfo field in GetSerializableFields(type)) {
if (IsReferenceOrContainsReferences(field.FieldType))
return true;
}
return false;
}
#endregion
#region Object Writers
delegate void ObjectWriter(SerializationContext context, object instance);
static readonly MethodInfo writeObjectID = typeof(SerializationContext).GetMethod("WriteObjectID", new[] { typeof(object) });
static readonly MethodInfo writeByte = typeof(BinaryWriter).GetMethod("Write", new[] { typeof(byte) });
static readonly MethodInfo writeShort = typeof(BinaryWriter).GetMethod("Write", new[] { typeof(short) });
static readonly MethodInfo writeInt = typeof(BinaryWriter).GetMethod("Write", new[] { typeof(int) });
static readonly MethodInfo writeLong = typeof(BinaryWriter).GetMethod("Write", new[] { typeof(long) });
static readonly MethodInfo writeFloat = typeof(BinaryWriter).GetMethod("Write", new[] { typeof(float) });
static readonly MethodInfo writeDouble = typeof(BinaryWriter).GetMethod("Write", new[] { typeof(double) });
OpCode callVirt = OpCodes.Callvirt;
static readonly ObjectWriter serializationInfoWriter = delegate(SerializationContext context, object instance) {
BinaryWriter writer = context.writer;
SerializationInfo info = (SerializationInfo)instance;
writer.Write(info.MemberCount);
foreach (SerializationEntry entry in info) {
writer.Write(entry.Name);
context.WriteObjectID(entry.Value);
}
};
Dictionary writers = new Dictionary();
ObjectWriter GetWriter(Type type)
{
ObjectWriter writer;
if (!writers.TryGetValue(type, out writer)) {
writer = CreateWriter(type);
writers.Add(type, writer);
}
return writer;
}
ObjectWriter CreateWriter(Type type)
{
if (type == typeof(string)) {
// String contents are written in the object creation section,
// not into the field value section.
return delegate {};
}
bool isArray = type.IsArray;
if (isArray) {
if (type.GetArrayRank() != 1)
throw new NotSupportedException();
type = type.GetElementType();
if (!type.IsValueType) {
return delegate (SerializationContext context, object array) {
foreach (object val in (object[])array) {
context.WriteObjectID(val);
}
};
} else if (type == typeof(byte)) {
return delegate (SerializationContext context, object array) {
context.writer.Write((byte[])array);
};
}
}
List fields = GetSerializableFields(type);
if (fields.Count == 0) {
// The writer has nothing to do for this object.
return delegate { };
}
DynamicMethod dynamicMethod = new DynamicMethod(
(isArray ? "WriteArray_" : "Write_") + type.Name,
typeof(void), new [] { typeof(SerializationContext), typeof(object) },
true);
ILGenerator il = dynamicMethod.GetILGenerator();
var writer = il.DeclareLocal(typeof(BinaryWriter));
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, writerField);
il.Emit(OpCodes.Stloc, writer); // writer = context.writer;
if (isArray) {
var instance = il.DeclareLocal(type.MakeArrayType());
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Castclass, type.MakeArrayType());
il.Emit(OpCodes.Stloc, instance); // instance = (type[])arg_1;
// for (int i = 0; i < instance.Length; i++) write instance[i];
var loopStart = il.DefineLabel();
var loopHead = il.DefineLabel();
var loopVariable = il.DeclareLocal(typeof(int));
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Stloc, loopVariable); // loopVariable = 0
il.Emit(OpCodes.Br, loopHead); // goto loopHead;
il.MarkLabel(loopStart);
if (type.IsEnum || type.IsPrimitive) {
if (type.IsEnum) {
type = type.GetEnumUnderlyingType();
}
Debug.Assert(type.IsPrimitive);
il.Emit(OpCodes.Ldloc, writer); // writer
il.Emit(OpCodes.Ldloc, instance); // writer, instance
il.Emit(OpCodes.Ldloc, loopVariable); // writer, instance, loopVariable
switch (Type.GetTypeCode(type)) {
case TypeCode.Boolean:
case TypeCode.SByte:
case TypeCode.Byte:
il.Emit(OpCodes.Ldelem_I1); // writer, instance[loopVariable]
il.Emit(callVirt, writeByte); // writer.Write(instance[loopVariable]);
break;
case TypeCode.Char:
case TypeCode.Int16:
case TypeCode.UInt16:
il.Emit(OpCodes.Ldelem_I2); // writer, instance[loopVariable]
il.Emit(callVirt, writeShort); // writer.Write(instance[loopVariable]);
break;
case TypeCode.Int32:
case TypeCode.UInt32:
il.Emit(OpCodes.Ldelem_I4); // writer, instance[loopVariable]
il.Emit(callVirt, writeInt); // writer.Write(instance[loopVariable]);
break;
case TypeCode.Int64:
case TypeCode.UInt64:
il.Emit(OpCodes.Ldelem_I8); // writer, instance[loopVariable]
il.Emit(callVirt, writeLong); // writer.Write(instance[loopVariable]);
break;
case TypeCode.Single:
il.Emit(OpCodes.Ldelem_R4); // writer, instance[loopVariable]
il.Emit(callVirt, writeFloat); // writer.Write(instance[loopVariable]);
break;
case TypeCode.Double:
il.Emit(OpCodes.Ldelem_R8); // writer, instance[loopVariable]
il.Emit(callVirt, writeDouble); // writer.Write(instance[loopVariable]);
break;
default:
throw new NotSupportedException("Unknown primitive type " + type);
}
} else {
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldloc, loopVariable); // instance, loopVariable
il.Emit(OpCodes.Ldelem, type); // instance[loopVariable]
EmitWriteValueType(il, writer, type);
}
il.Emit(OpCodes.Ldloc, loopVariable); // loopVariable
il.Emit(OpCodes.Ldc_I4_1); // loopVariable, 1
il.Emit(OpCodes.Add); // loopVariable+1
il.Emit(OpCodes.Stloc, loopVariable); // loopVariable++;
il.MarkLabel(loopHead);
il.Emit(OpCodes.Ldloc, loopVariable); // loopVariable
il.Emit(OpCodes.Ldloc, instance); // loopVariable, instance
il.Emit(OpCodes.Ldlen); // loopVariable, instance.Length
il.Emit(OpCodes.Conv_I4);
il.Emit(OpCodes.Blt, loopStart); // if (loopVariable < instance.Length) goto loopStart;
} else if (type.IsValueType) {
// boxed value type
if (type.IsEnum || type.IsPrimitive) {
il.Emit(OpCodes.Ldloc, writer);
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Unbox_Any, type);
WritePrimitiveValue(il, type);
} else {
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Unbox_Any, type);
EmitWriteValueType(il, writer, type);
}
} else {
// reference type
var instance = il.DeclareLocal(type);
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Stloc, instance); // instance = (type)arg_1;
foreach (FieldInfo field in fields) {
EmitWriteField(il, writer, instance, field); // write instance.Field
}
}
il.Emit(OpCodes.Ret);
return (ObjectWriter)dynamicMethod.CreateDelegate(typeof(ObjectWriter));
}
///
/// Emit 'write instance.Field'.
/// Stack transition: ... => ...
///
void EmitWriteField(ILGenerator il, LocalBuilder writer, LocalBuilder instance, FieldInfo field)
{
Type fieldType = field.FieldType;
if (fieldType.IsValueType) {
if (fieldType.IsPrimitive || fieldType.IsEnum) {
il.Emit(OpCodes.Ldloc, writer); // writer
il.Emit(OpCodes.Ldloc, instance); // writer, instance
il.Emit(OpCodes.Ldfld, field); // writer, instance.field
WritePrimitiveValue(il, fieldType);
} else {
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldfld, field); // instance.field
EmitWriteValueType(il, writer, fieldType);
}
} else {
il.Emit(OpCodes.Ldarg_0); // context
il.Emit(OpCodes.Ldloc, instance); // context, instance
il.Emit(OpCodes.Ldfld, field); // context, instance.field
il.Emit(OpCodes.Call, writeObjectID); // context.WriteObjectID(instance.field);
}
}
///
/// Writes a primitive value of the specified type.
/// Stack transition: ..., writer, value => ...
///
void WritePrimitiveValue(ILGenerator il, Type fieldType)
{
if (fieldType.IsEnum) {
fieldType = fieldType.GetEnumUnderlyingType();
Debug.Assert(fieldType.IsPrimitive);
}
switch (Type.GetTypeCode(fieldType)) {
case TypeCode.Boolean:
case TypeCode.SByte:
case TypeCode.Byte:
il.Emit(callVirt, writeByte); // writer.Write(value);
break;
case TypeCode.Char:
case TypeCode.Int16:
case TypeCode.UInt16:
il.Emit(callVirt, writeShort); // writer.Write(value);
break;
case TypeCode.Int32:
case TypeCode.UInt32:
il.Emit(callVirt, writeInt); // writer.Write(value);
break;
case TypeCode.Int64:
case TypeCode.UInt64:
il.Emit(callVirt, writeLong); // writer.Write(value);
break;
case TypeCode.Single:
il.Emit(callVirt, writeFloat); // writer.Write(value);
break;
case TypeCode.Double:
il.Emit(callVirt, writeDouble); // writer.Write(value);
break;
default:
throw new NotSupportedException("Unknown primitive type " + fieldType);
}
}
///
/// Stack transition: ..., value => ...
///
void EmitWriteValueType(ILGenerator il, LocalBuilder writer, Type valType)
{
Debug.Assert(valType.IsValueType);
Debug.Assert(!(valType.IsEnum || valType.IsPrimitive));
var fieldVal = il.DeclareLocal(valType);
il.Emit(OpCodes.Stloc, fieldVal);
foreach (FieldInfo field in GetSerializableFields(valType)) {
EmitWriteField(il, writer, fieldVal, field);
}
}
#endregion
StreamingContext streamingContext = new StreamingContext(StreamingContextStates.All);
FormatterConverter formatterConverter = new FormatterConverter();
public void Serialize(Stream stream, object instance)
{
Serialize(new BinaryWriterWith7BitEncodedInts(stream), instance);
}
public void Serialize(BinaryWriter writer, object instance)
{
SerializationContext context = new SerializationContext(this, writer);
context.MarkFixedInstances(this.FixedInstances);
context.Mark(instance);
context.Scan();
context.ScanTypes();
context.Write();
context.WriteObjectID(instance);
}
delegate void TypeSerializer(object instance, SerializationContext context);
#endregion
#region Deserialization
sealed class DeserializationContext
{
public Type[] Types; // index: type ID
public object[] Objects; // index: object ID
public BinaryReader Reader;
public object ReadObject()
{
if (this.Objects.Length <= ushort.MaxValue)
return this.Objects[Reader.ReadUInt16()];
else
return this.Objects[Reader.ReadInt32()];
}
#region DeserializeTypeDescriptions
internal int ReadTypeID()
{
if (this.Types.Length <= ushort.MaxValue)
return Reader.ReadUInt16();
else
return Reader.ReadInt32();
}
internal void DeserializeTypeDescriptions()
{
for (int i = 0; i < this.Types.Length; i++) {
Type type = this.Types[i];
if (type.IsGenericTypeDefinition || type.HasElementType)
continue;
int versionNumber = Reader.ReadInt32();
if (versionNumber != FastSerializerVersionAttribute.GetVersionNumber(type))
throw new SerializationException("Type '" + type.FullName + "' was serialized with version " + versionNumber + ", but is version " + FastSerializerVersionAttribute.GetVersionNumber(type));
bool isCustomSerialization = typeof(ISerializable).IsAssignableFrom(type);
bool typeIsSpecial = type.IsPrimitive || isCustomSerialization;
byte serializedFieldCount = Reader.ReadByte();
if (serializedFieldCount == byte.MaxValue) {
// special type
if (!typeIsSpecial)
throw new SerializationException("Type '" + type.FullName + "' was serialized as special type, but isn't special now.");
} else {
if (typeIsSpecial)
throw new SerializationException("Type '" + type.FullName + "' wasn't serialized as special type, but is special now.");
var availableFields = GetSerializableFields(this.Types[i]);
if (availableFields.Count != serializedFieldCount)
throw new SerializationException("Number of fields on " + type.FullName + " has changed.");
for (int j = 0; j < serializedFieldCount; j++) {
int fieldTypeID = ReadTypeID();
string fieldName = Reader.ReadString();
FieldInfo fieldInfo = availableFields[j];
if (fieldInfo.Name != fieldName)
throw new SerializationException("Field mismatch on type " + type.FullName);
if (fieldInfo.FieldType != this.Types[fieldTypeID])
throw new SerializationException(type.FullName + "." + fieldName + " was serialized as " + this.Types[fieldTypeID] + ", but now is " + fieldInfo.FieldType);
}
}
}
}
#endregion
}
delegate void ObjectReader(DeserializationContext context, object instance);
public object Deserialize(Stream stream)
{
return Deserialize(new BinaryReaderWith7BitEncodedInts(stream));
}
public object Deserialize(BinaryReader reader)
{
if (reader.ReadInt32() != magic)
throw new SerializationException("The data cannot be read by FastSerializer (unknown magic value)");
DeserializationContext context = new DeserializationContext();
context.Reader = reader;
context.Objects = new object[reader.ReadInt32()];
context.Types = new Type[reader.ReadInt32()];
string[] assemblyNames = new string[reader.ReadInt32()];
int fixedInstanceCount = reader.ReadInt32();
if (fixedInstanceCount != 0) {
if (this.FixedInstances == null || this.FixedInstances.Length != fixedInstanceCount)
throw new SerializationException("Number of fixed instances doesn't match");
for (int i = 0; i < fixedInstanceCount; i++) {
context.Objects[i + 1] = this.FixedInstances[i];
}
}
for (int i = 0; i < assemblyNames.Length; i++) {
assemblyNames[i] = reader.ReadString();
}
int stringTypeID = -1;
for (int i = 0; i < context.Types.Length; i++) {
byte typeKind = reader.ReadByte();
switch (typeKind) {
case Type_ReferenceType:
case Type_ValueType:
int assemblyID;
if (assemblyNames.Length <= ushort.MaxValue)
assemblyID = reader.ReadUInt16();
else
assemblyID = reader.ReadInt32();
string assemblyName = assemblyNames[assemblyID];
string typeName = reader.ReadString();
Type type;
if (SerializationBinder != null) {
type = SerializationBinder.BindToType(assemblyName, typeName);
} else {
type = Assembly.Load(assemblyName).GetType(typeName);
}
if (type == null)
throw new SerializationException("Could not find '" + typeName + "' in '" + assemblyName + "'");
if (typeKind == Type_ValueType && !type.IsValueType)
throw new SerializationException("Expected '" + typeName + "' to be a value type, but it is reference type");
if (typeKind == Type_ReferenceType && type.IsValueType)
throw new SerializationException("Expected '" + typeName + "' to be a reference type, but it is value type");
context.Types[i] = type;
if (type == typeof(string))
stringTypeID = i;
break;
case Type_SZArray:
context.Types[i] = context.Types[context.ReadTypeID()].MakeArrayType();
break;
case Type_ParameterizedType:
Type genericType = context.Types[context.ReadTypeID()];
int typeParameterCount = genericType.GetGenericArguments().Length;
Type[] typeArguments = new Type[typeParameterCount];
for (int j = 0; j < typeArguments.Length; j++) {
typeArguments[j] = context.Types[context.ReadTypeID()];
}
context.Types[i] = genericType.MakeGenericType(typeArguments);
break;
default:
throw new SerializationException("Unknown type kind");
}
}
context.DeserializeTypeDescriptions();
int[] typeIDByObjectID = new int[context.Objects.Length];
for (int i = 1 + fixedInstanceCount; i < context.Objects.Length; i++) {
int typeID = context.ReadTypeID();
object instance;
if (typeID == stringTypeID) {
instance = reader.ReadString();
} else {
Type type = context.Types[typeID];
if (type.IsArray) {
int length = reader.ReadInt32();
instance = Array.CreateInstance(type.GetElementType(), length);
} else {
instance = FormatterServices.GetUninitializedObject(type);
}
}
context.Objects[i] = instance;
typeIDByObjectID[i] = typeID;
}
List customDeserializatons = new List();
ObjectReader[] objectReaders = new ObjectReader[context.Types.Length]; // index: type ID
for (int i = 1 + fixedInstanceCount; i < context.Objects.Length; i++) {
object instance = context.Objects[i];
int typeID = typeIDByObjectID[i];
Log("0x{2:x6} Read #{0}: {1}", i, context.Types[typeID].Name, reader.BaseStream.Position);
ISerializable serializable = instance as ISerializable;
if (serializable != null) {
Type type = context.Types[typeID];
SerializationInfo info = new SerializationInfo(type, formatterConverter);
int count = reader.ReadInt32();
for (int j = 0; j < count; j++) {
string name = reader.ReadString();
object val = context.ReadObject();
info.AddValue(name, val);
}
CustomDeserializationAction action = GetCustomDeserializationAction(type);
customDeserializatons.Add(new CustomDeserialization(instance, info, action));
} else {
ObjectReader objectReader = objectReaders[typeID];
if (objectReader == null) {
objectReader = GetReader(context.Types[typeID]);
objectReaders[typeID] = objectReader;
}
objectReader(context, instance);
}
}
Log("File was read successfully, now running {0} custom deserializations...", customDeserializatons.Count);
foreach (CustomDeserialization customDeserializaton in customDeserializatons) {
customDeserializaton.Run(streamingContext);
}
for (int i = 1 + fixedInstanceCount; i < context.Objects.Length; i++) {
IDeserializationCallback dc = context.Objects[i] as IDeserializationCallback;
if (dc != null)
dc.OnDeserialization(null);
}
return context.ReadObject();
}
#region Object Reader
static readonly FieldInfo readerField = typeof(DeserializationContext).GetField("Reader");
static readonly MethodInfo readObject = typeof(DeserializationContext).GetMethod("ReadObject");
static readonly MethodInfo readByte = typeof(BinaryReader).GetMethod("ReadByte");
static readonly MethodInfo readShort = typeof(BinaryReader).GetMethod("ReadInt16");
static readonly MethodInfo readInt = typeof(BinaryReader).GetMethod("ReadInt32");
static readonly MethodInfo readLong = typeof(BinaryReader).GetMethod("ReadInt64");
static readonly MethodInfo readFloat = typeof(BinaryReader).GetMethod("ReadSingle");
static readonly MethodInfo readDouble = typeof(BinaryReader).GetMethod("ReadDouble");
Dictionary readers = new Dictionary();
ObjectReader GetReader(Type type)
{
ObjectReader reader;
if (!readers.TryGetValue(type, out reader)) {
reader = CreateReader(type);
readers.Add(type, reader);
}
return reader;
}
ObjectReader CreateReader(Type type)
{
if (type == typeof(string)) {
// String contents are written in the object creation section,
// not into the field value section; so there's nothing to read here.
return delegate {};
}
bool isArray = type.IsArray;
if (isArray) {
if (type.GetArrayRank() != 1)
throw new NotSupportedException();
type = type.GetElementType();
if (!type.IsValueType) {
return delegate (DeserializationContext context, object arrayInstance) {
object[] array = (object[])arrayInstance;
for (int i = 0; i < array.Length; i++) {
array[i] = context.ReadObject();
}
};
} else if (type == typeof(byte)) {
return delegate (DeserializationContext context, object arrayInstance) {
byte[] array = (byte[])arrayInstance;
BinaryReader binaryReader = context.Reader;
int pos = 0;
int bytesRead;
do {
bytesRead = binaryReader.Read(array, pos, array.Length - pos);
pos += bytesRead;
} while (bytesRead > 0);
if (pos != array.Length)
throw new EndOfStreamException();
};
}
}
var fields = GetSerializableFields(type);
if (fields.Count == 0) {
// The reader has nothing to do for this object.
return delegate { };
}
DynamicMethod dynamicMethod = new DynamicMethod(
(isArray ? "ReadArray_" : "Read_") + type.Name,
MethodAttributes.Public | MethodAttributes.Static,
CallingConventions.Standard,
typeof(void), new [] { typeof(DeserializationContext), typeof(object) },
type,
true);
ILGenerator il = dynamicMethod.GetILGenerator();
var reader = il.DeclareLocal(typeof(BinaryReader));
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, readerField);
il.Emit(OpCodes.Stloc, reader); // reader = context.reader;
if (isArray) {
var instance = il.DeclareLocal(type.MakeArrayType());
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Castclass, type.MakeArrayType());
il.Emit(OpCodes.Stloc, instance); // instance = (type[])arg_1;
// for (int i = 0; i < instance.Length; i++) read &instance[i];
var loopStart = il.DefineLabel();
var loopHead = il.DefineLabel();
var loopVariable = il.DeclareLocal(typeof(int));
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Stloc, loopVariable); // loopVariable = 0
il.Emit(OpCodes.Br, loopHead); // goto loopHead;
il.MarkLabel(loopStart);
if (type.IsEnum || type.IsPrimitive) {
if (type.IsEnum) {
type = type.GetEnumUnderlyingType();
}
Debug.Assert(type.IsPrimitive);
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldloc, loopVariable); // instance, loopVariable
ReadPrimitiveValue(il, reader, type); // instance, loopVariable, value
switch (Type.GetTypeCode(type)) {
case TypeCode.Boolean:
case TypeCode.SByte:
case TypeCode.Byte:
il.Emit(OpCodes.Stelem_I1); // instance[loopVariable] = value;
break;
case TypeCode.Char:
case TypeCode.Int16:
case TypeCode.UInt16:
il.Emit(OpCodes.Stelem_I2); // instance[loopVariable] = value;
break;
case TypeCode.Int32:
case TypeCode.UInt32:
il.Emit(OpCodes.Stelem_I4); // instance[loopVariable] = value;
break;
case TypeCode.Int64:
case TypeCode.UInt64:
il.Emit(OpCodes.Stelem_I8); // instance[loopVariable] = value;
break;
case TypeCode.Single:
il.Emit(OpCodes.Stelem_R4); // instance[loopVariable] = value;
break;
case TypeCode.Double:
il.Emit(OpCodes.Stelem_R8); // instance[loopVariable] = value;
break;
default:
throw new NotSupportedException("Unknown primitive type " + type);
}
} else {
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldloc, loopVariable); // instance, loopVariable
il.Emit(OpCodes.Ldelema, type); // instance[loopVariable]
EmitReadValueType(il, reader, type);
}
il.Emit(OpCodes.Ldloc, loopVariable); // loopVariable
il.Emit(OpCodes.Ldc_I4_1); // loopVariable, 1
il.Emit(OpCodes.Add); // loopVariable+1
il.Emit(OpCodes.Stloc, loopVariable); // loopVariable++;
il.MarkLabel(loopHead);
il.Emit(OpCodes.Ldloc, loopVariable); // loopVariable
il.Emit(OpCodes.Ldloc, instance); // loopVariable, instance
il.Emit(OpCodes.Ldlen); // loopVariable, instance.Length
il.Emit(OpCodes.Conv_I4);
il.Emit(OpCodes.Blt, loopStart); // if (loopVariable < instance.Length) goto loopStart;
} else if (type.IsValueType) {
// boxed value type
il.Emit(OpCodes.Ldarg_1); // instance
il.Emit(OpCodes.Unbox, type); // &(Type)instance
if (type.IsEnum || type.IsPrimitive) {
if (type.IsEnum) {
type = type.GetEnumUnderlyingType();
}
Debug.Assert(type.IsPrimitive);
ReadPrimitiveValue(il, reader, type); // &(Type)instance, value
switch (Type.GetTypeCode(type)) {
case TypeCode.Boolean:
case TypeCode.SByte:
case TypeCode.Byte:
il.Emit(OpCodes.Stind_I1);
break;
case TypeCode.Char:
case TypeCode.Int16:
case TypeCode.UInt16:
il.Emit(OpCodes.Stind_I2);
break;
case TypeCode.Int32:
case TypeCode.UInt32:
il.Emit(OpCodes.Stind_I4);
break;
case TypeCode.Int64:
case TypeCode.UInt64:
il.Emit(OpCodes.Stind_I8);
break;
case TypeCode.Single:
il.Emit(OpCodes.Stind_R4);
break;
case TypeCode.Double:
il.Emit(OpCodes.Stind_R8);
break;
default:
throw new NotSupportedException("Unknown primitive type " + type);
}
} else {
EmitReadValueType(il, reader, type);
}
} else {
// reference type
var instance = il.DeclareLocal(type);
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Stloc, instance); // instance = (type)arg_1;
foreach (FieldInfo field in fields) {
EmitReadField(il, reader, instance, field); // read instance.Field
}
}
il.Emit(OpCodes.Ret);
return (ObjectReader)dynamicMethod.CreateDelegate(typeof(ObjectReader));
}
void EmitReadField(ILGenerator il, LocalBuilder reader, LocalBuilder instance, FieldInfo field)
{
Type fieldType = field.FieldType;
if (fieldType.IsValueType) {
if (fieldType.IsPrimitive || fieldType.IsEnum) {
il.Emit(OpCodes.Ldloc, instance); // instance
ReadPrimitiveValue(il, reader, fieldType); // instance, value
il.Emit(OpCodes.Stfld, field); // instance.field = value;
} else {
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldflda, field); // &instance.field
EmitReadValueType(il, reader, fieldType);
}
} else {
il.Emit(OpCodes.Ldloc, instance); // instance
il.Emit(OpCodes.Ldarg_0); // instance, context
il.Emit(OpCodes.Call, readObject); // instance, context.ReadObject()
il.Emit(OpCodes.Castclass, fieldType);
il.Emit(OpCodes.Stfld, field); // instance.field = (fieldType) context.ReadObject();
}
}
///
/// Reads a primitive value of the specified type.
/// Stack transition: ... => ..., value
///
void ReadPrimitiveValue(ILGenerator il, LocalBuilder reader, Type fieldType)
{
if (fieldType.IsEnum) {
fieldType = fieldType.GetEnumUnderlyingType();
Debug.Assert(fieldType.IsPrimitive);
}
il.Emit(OpCodes.Ldloc, reader);
switch (Type.GetTypeCode(fieldType)) {
case TypeCode.Boolean:
case TypeCode.SByte:
case TypeCode.Byte:
il.Emit(callVirt, readByte);
break;
case TypeCode.Char:
case TypeCode.Int16:
case TypeCode.UInt16:
il.Emit(callVirt, readShort);
break;
case TypeCode.Int32:
case TypeCode.UInt32:
il.Emit(callVirt, readInt);
break;
case TypeCode.Int64:
case TypeCode.UInt64:
il.Emit(callVirt, readLong);
break;
case TypeCode.Single:
il.Emit(callVirt, readFloat);
break;
case TypeCode.Double:
il.Emit(callVirt, readDouble);
break;
default:
throw new NotSupportedException("Unknown primitive type " + fieldType);
}
}
///
/// Stack transition: ..., field-ref => ...
///
void EmitReadValueType(ILGenerator il, LocalBuilder reader, Type valType)
{
Debug.Assert(valType.IsValueType);
Debug.Assert(!(valType.IsEnum || valType.IsPrimitive));
var fieldRef = il.DeclareLocal(valType.MakeByRefType());
il.Emit(OpCodes.Stloc, fieldRef);
foreach (FieldInfo field in GetSerializableFields(valType)) {
EmitReadField(il, reader, fieldRef, field);
}
}
#endregion
#region Custom Deserialization
struct CustomDeserialization
{
readonly object instance;
readonly SerializationInfo serializationInfo;
readonly CustomDeserializationAction action;
public CustomDeserialization(object instance, SerializationInfo serializationInfo, CustomDeserializationAction action)
{
this.instance = instance;
this.serializationInfo = serializationInfo;
this.action = action;
}
public void Run(StreamingContext context)
{
action(instance, serializationInfo, context);
}
}
delegate void CustomDeserializationAction(object instance, SerializationInfo info, StreamingContext context);
Dictionary customDeserializationActions = new Dictionary();
CustomDeserializationAction GetCustomDeserializationAction(Type type)
{
CustomDeserializationAction action;
if (!customDeserializationActions.TryGetValue(type, out action)) {
action = CreateCustomDeserializationAction(type);
customDeserializationActions.Add(type, action);
}
return action;
}
static CustomDeserializationAction CreateCustomDeserializationAction(Type type)
{
ConstructorInfo ctor = type.GetConstructor(
BindingFlags.DeclaredOnly | BindingFlags.ExactBinding | BindingFlags.Instance
| BindingFlags.NonPublic | BindingFlags.Public,
null,
new Type [] { typeof(SerializationInfo), typeof(StreamingContext) },
null);
if (ctor == null)
throw new SerializationException("Could not find deserialization constructor for " + type.FullName);
DynamicMethod dynamicMethod = new DynamicMethod(
"CallCtor_" + type.Name,
MethodAttributes.Public | MethodAttributes.Static,
CallingConventions.Standard,
typeof(void), new [] { typeof(object), typeof(SerializationInfo), typeof(StreamingContext) },
type,
true);
ILGenerator il = dynamicMethod.GetILGenerator();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Ldarg_2);
il.Emit(OpCodes.Call, ctor);
il.Emit(OpCodes.Ret);
return (CustomDeserializationAction)dynamicMethod.CreateDelegate(typeof(CustomDeserializationAction));
}
#endregion
#endregion
[Conditional("DEBUG_SERIALIZER")]
static void Log(string format, params object[] args)
{
Debug.WriteLine(format, args);
}
}
///
/// Specifies the version of the class.
/// The will refuse to deserialize an instance that was stored by
/// a different version of the class than the current one.
///
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Enum)]
public class FastSerializerVersionAttribute : Attribute
{
readonly int versionNumber;
public FastSerializerVersionAttribute(int versionNumber)
{
this.versionNumber = versionNumber;
}
public int VersionNumber {
get {
return versionNumber;
}
}
internal static int GetVersionNumber(Type type)
{
var arr = type.GetCustomAttributes(typeof(FastSerializerVersionAttribute), false);
if (arr.Length == 0)
return 0;
else
return ((FastSerializerVersionAttribute)arr[0]).VersionNumber;
}
}
}