Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rickenberg/b19044b1d1a35d1a24e66b0d77a6325d to your computer and use it in GitHub Desktop.
Save rickenberg/b19044b1d1a35d1a24e66b0d77a6325d to your computer and use it in GitHub Desktop.
EF6 extension to perform an UPSERT. This is a fork from another GIST. I have added a few useful comment and exposed the SQL as a private field for 'unit testing'.
using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Data.Entity.Core.Mapping;
using System.Data.Entity.Core.Metadata.Edm;
using System.Data.Entity.Infrastructure;
using System.Data.SqlTypes;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
namespace Extensions
{
/// <summary>
/// EF6 extension to perform an UPSERT - MERGE command.
/// </summary>
public static class DbContextExtensions
{
/// <summary>
/// Creates on T-SQL MERGE command for the insert/update of all entities in the enumerable passed in.
/// You need to call the Execute method on the returned object to actually run the command on DB.
/// NOTE-1: When using a single entity, null values from the entity will be skipped during update.
/// NOTE-2: Using multiple enities will always overwrite null values in database.
/// NOTE-3: MEREGE commands are using SQL parameters and there is a limit (2048) to how many parameters you can have!
/// So you might need to chunk your batches accordingly.
/// </summary>
/// <example>
/// <code>
/// var context = new TestDbContext();
/// var entity = new MyEntity { Id = 1, Name = "Entity1" };
/// var op = context.Upsert(new List<MyEntity> { entity });
/// op.Execute();
/// </code>
/// </example>
/// <typeparam name="TEntity">Type of the EF entity to merge</typeparam>
/// <param name="context">EF DB context object</param>
/// <param name="entity">Enumerable of EF entities to merge in DB</param>
/// <returns>The merge operation that can be executed</returns>
public static EntityOp<TEntity, int> Upsert<TEntity>(this DbContext context, IEnumerable<TEntity> entities) where TEntity : class
{
return new UpsertOp<TEntity>(context, entities);
}
}
public abstract class EntityOp<TEntity, TRet>
{
public readonly DbContext _context;
public readonly IEnumerable<TEntity> _entityList;
protected readonly string _tableName;
protected readonly string[] _entityPrimaryKeyNames;
protected readonly string[] _storeGeneratedPrimaryKeyNames;
protected readonly Dictionary<string, string> _propNames;
protected List<string> _matchPropertyNames;
public IEnumerable<string> MatchPropertyNames => (IEnumerable<string>)_matchPropertyNames ?? _entityPrimaryKeyNames;
//private readonly List<string> _excludeProperties = new List<string>();
private static string GetMemberName<T>(Expression<Func<TEntity, T>> selectMemberLambda)
{
var member = selectMemberLambda.Body as MemberExpression;
if (member == null)
{
throw new ArgumentException("The parameter selectMemberLambda must be a member accessing labda such as x => x.Id", "selectMemberLambda");
}
return member.Member.Name;
}
public EntityOp(DbContext context, IEnumerable<TEntity> entityList)
{
_context = context;
_entityList = entityList;
var mapping = GetEntitySetMapping(typeof(TEntity), context);
// Get the name of the primary key for the table as we wish to exclude this from the column mapping (we are assuming Identity insert is OFF)
//https://romiller.com/2015/08/05/ef6-1-get-mapping-between-properties-and-columns/
_propNames = mapping
.EntityTypeMappings.Single()
.Fragments.Single()
.PropertyMappings
.OfType<ScalarPropertyMapping>()
.ToDictionary(m => m.Property.Name, m => '[' + m.Column.Name + ']');
//_propNames = mapping.EntitySet.ElementType.DeclaredProperties
// .ToDictionary(p => p.ToString(), p=>'[' + p.Name + ']');
var keyNames = mapping.EntitySet.ElementType.KeyMembers
.ToLookup(k => k.IsStoreGeneratedIdentity, k => k.Name);
_entityPrimaryKeyNames = keyNames.SelectMany(k => k).ToArray();
_storeGeneratedPrimaryKeyNames = keyNames[true].ToArray();
// Find the storage entity set (table) that the entity is mapped
var table = mapping
.EntityTypeMappings.Single()
.Fragments.Single()
.StoreEntitySet;
// Return the table name from the storage entity set
_tableName = (string)table.MetadataProperties["Table"].Value ?? table.Name;
var schemaName = (string)table.MetadataProperties["Schema"].Value ?? table.Schema;
_tableName = $"[{schemaName}].[{_tableName}]";
}
public abstract TRet Execute();
public void Run()
{
Execute();
}
public EntityOp<TEntity, TRet> Key<TKey>(Expression<Func<TEntity, TKey>> selectKey)
{
(_matchPropertyNames ?? (_matchPropertyNames = new List<string>())).Add(GetMemberName(selectKey));
return this;
}
public EntityOp<TEntity, TRet> ExcludeField<TField>(Expression<Func<TEntity, TField>> selectField)
{
_propNames.Remove(GetMemberName(selectField));
return this;
}
private static EntitySetMapping GetEntitySetMapping(Type type, DbContext context)
{
var metadata = ((IObjectContextAdapter)context).ObjectContext.MetadataWorkspace;
// Get the part of the model that contains info about the actual CLR types
var objectItemCollection = ((ObjectItemCollection)metadata.GetItemCollection(DataSpace.OSpace));
// Get the entity type from the model that maps to the CLR type
var entityType = metadata
.GetItems<EntityType>(DataSpace.OSpace)
.Single(e => objectItemCollection.GetClrType(e) == type);
// Get the entity set that uses this entity type
var entitySet = metadata
.GetItems<EntityContainer>(DataSpace.CSpace)
.Single()
.EntitySets
.Single(s => s.ElementType.Name == entityType.Name);
// Find the mapping between conceptual and storage model for this entity set
return metadata.GetItems<EntityContainerMapping>(DataSpace.CSSpace)
.Single()
.EntitySetMappings
.Single(s => s.EntitySet == entitySet);
}
}
public class UpsertOp<TEntity> : EntityOp<TEntity, int>
{
private StringBuilder _sql;
public UpsertOp(DbContext context, IEnumerable<TEntity> entityList) : base(context, entityList)
{ }
public override int Execute()
{
_sql = new StringBuilder("merge into " + _tableName + " as T using (values ");
int nextIndex = 0;
var valueList = new List<object>(_propNames.Count * _entityList.Count());
var propInfos = _propNames.Keys.Select(k => typeof(TEntity).GetProperty(k)).ToList();
foreach (var entity in _entityList)
{
_sql.Append('(' + string.Join(",", Enumerable.Range(nextIndex, _propNames.Count)
.Select(r => "@p" + r.ToString())) + "),");
nextIndex += _propNames.Count;
var toAdd = new List<object>();
foreach (var info in propInfos)
{
var value = info.GetValue(entity);
if (value == null)
{
//Handle types that dbnull doesn't work for
var type = info.PropertyType;
if (type == typeof(byte[]))
{
toAdd.Add(SqlBinary.Null);
}
else
{
toAdd.Add(DBNull.Value);
}
}
else
{
toAdd.Add(value);
}
}
valueList.AddRange(toAdd);
}
_sql.Length -= 1;//remove last comma
_sql.Append(") as S (");
_sql.Append(string.Join(",", _propNames.Values));
_sql.Append(") ");
_sql.Append("on (");
_sql.Append(string.Join(" and ", MatchPropertyNames.Select(kn => "T." + kn + "=S." + kn)));
_sql.Append(") when matched then update set ");
// Handle null values
var propsWithoutNullValues = _propNames.ToList();
for (var idx = valueList.Count - 1; idx >= 0; idx--)
{
if (valueList[idx] == DBNull.Value)
{
propsWithoutNullValues.RemoveAt(idx);
}
if (valueList[idx] is SqlBinary && (SqlBinary)valueList[idx] == SqlBinary.Null)
{
propsWithoutNullValues.RemoveAt(idx);
}
}
_sql.Append(string.Join(",", propsWithoutNullValues.Where(p => !_entityPrimaryKeyNames.Contains(p.Key))
.Select(p => "T." + p.Value + "=S." + p.Value)));
var insertables = _propNames.Where(p => !_storeGeneratedPrimaryKeyNames.Contains(p.Key))
.Select(p => p.Value)
.ToList();
_sql.Append(" when not matched then insert (");
_sql.Append(string.Join(",", insertables));
_sql.Append(") values (S.");
_sql.Append(string.Join(",S.", insertables));
_sql.Append(");");
var command = _sql.ToString();
return _context.Database.ExecuteSqlCommand(command, valueList.ToArray());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment