Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dapper.sln
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
version.json = version.json
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Dapper.Contrib", "Dapper.Contrib\Dapper.Contrib.csproj", "{4E409F8F-CFBB-4332-8B0A-FD5A283051FD}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Dapper.Contrib", "src\Dapper.Contrib\Dapper.Contrib.csproj", "{4E409F8F-CFBB-4332-8B0A-FD5A283051FD}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Dapper.Tests.Contrib", "tests\Dapper.Tests.Contrib\Dapper.Tests.Contrib.csproj", "{DAB3C5B7-BCD1-4A5F-BB6B-50D2BB63DB4A}"
EndProject
Expand Down
108 changes: 108 additions & 0 deletions src/Dapper.Contrib/SqlMapperExtensions.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
Expand Down Expand Up @@ -267,6 +268,113 @@ public static async Task<bool> UpdateAsync<T>(this IDbConnection connection, T e
var updated = await connection.ExecuteAsync(sb.ToString(), entityToUpdate, commandTimeout: commandTimeout, transaction: transaction).ConfigureAwait(false);
return updated > 0;
}

/// <summary>
/// Updates entity in table "Ts" asynchronously using Task, checks if the entity is modified if the entity is tracked by the Get() extension.
/// </summary>
/// <typeparam name="T">Type to be updated</typeparam>
/// <param name="connection">Open SqlConnection</param>
/// <param name="entityToUpdate">Entity to be updated</param>
/// <param name="transaction">The transaction to run under, null (the default) if none</param>
/// <param name="commandTimeout">Number of seconds before command execution timeout</param>
/// <returns>true if updated, false if not found or not modified (tracked entities)</returns>
public static async Task<bool> UpdateOnlyAsync<T>(this IDbConnection connection, T entityToUpdate, IDbTransaction transaction = null, int? commandTimeout = null, params Expression<Func<T, object>>[] includedFields) where T : class
{
if ((entityToUpdate is IProxy proxy) && !proxy.IsDirty)
{
return false;
}

var type = typeof(T);

if (type.IsArray)
{
type = type.GetElementType();
}
else if (type.IsGenericType)
{
var typeInfo = type.GetTypeInfo();
bool implementsGenericIEnumerableOrIsGenericIEnumerable =
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>);

if (implementsGenericIEnumerableOrIsGenericIEnumerable)
{
type = type.GetGenericArguments()[0];
}
}

var keyProperties = KeyPropertiesCache(type).ToList();
var explicitKeyProperties = ExplicitKeyPropertiesCache(type);
if (keyProperties.Count == 0 && explicitKeyProperties.Count == 0)
throw new ArgumentException("Entity must have at least one [Key] or [ExplicitKey] property");

var name = GetTableName(type);

var sb = new StringBuilder();
sb.AppendFormat("update {0} set ", name);

var allProperties = includedFields.Select(ToPropertyInfo).ToList();
keyProperties.AddRange(explicitKeyProperties);
var computedProperties = ComputedPropertiesCache(type);
var nonIdProps = allProperties.Except(keyProperties.Union(computedProperties)).ToList();

var adapter = GetFormatter(connection);

for (var i = 0; i < nonIdProps.Count; i++)
{
var property = nonIdProps[i];
adapter.AppendColumnNameEqualsValue(sb, property.Name);
if (i < nonIdProps.Count - 1)
sb.Append(", ");
}
sb.Append(" where ");
for (var i = 0; i < keyProperties.Count; i++)
{
var property = keyProperties[i];
adapter.AppendColumnNameEqualsValue(sb, property.Name);
if (i < keyProperties.Count - 1)
sb.Append(" and ");
}
var updated = await connection.ExecuteAsync(sb.ToString(), entityToUpdate, commandTimeout: commandTimeout, transaction: transaction).ConfigureAwait(false);
return updated > 0;
}

private static PropertyInfo ToPropertyInfo<T>(Expression<Func<T, object>> propertyLambda) where T : class
{
var type = typeof(T);

MemberExpression member;

if (propertyLambda.Body is UnaryExpression unary)
{
member = unary.Operand as MemberExpression;
}
else
{
member = propertyLambda.Body as MemberExpression;
}

if (member == null)
throw new ArgumentException(string.Format(
"Expression '{0}' refers to a method, not a property.",
propertyLambda.ToString()));

PropertyInfo propInfo = member.Member as PropertyInfo;
if (propInfo == null)
throw new ArgumentException(string.Format(
"Expression '{0}' refers to a field, not a property.",
propertyLambda.ToString()));

if (type != propInfo.ReflectedType &&
!type.IsSubclassOf(propInfo.ReflectedType))
throw new ArgumentException(string.Format(
"Expression '{0}' refers to a property that is not from type {1}.",
propertyLambda.ToString(),
type));

return propInfo;
}

/// <summary>
/// Delete entity in table "Ts" asynchronously using Task.
Expand Down
43 changes: 43 additions & 0 deletions tests/Dapper.Tests.Contrib/TestSuite.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,50 @@ public async Task TestSimpleGetAsync()
await connection.DeleteAsync(user).ConfigureAwait(false);
}
}

[Fact]
public async Task InsertGetUpdateOnlyAsync()
{
using (var connection = GetOpenConnection())
{
Assert.Null(await connection.GetAsync<User>(30).ConfigureAwait(false));

var originalCount = (await connection.QueryAsync<int>("select Count(*) from Users").ConfigureAwait(false)).First();

var id = await connection.InsertAsync(new User { Name = "Adam", Age = 10 }).ConfigureAwait(false);

//get a user with "isdirty" tracking
var user = await connection.GetAsync<IUser>(id).ConfigureAwait(false);
Assert.Equal("Adam", user.Name);
Assert.False(await connection.UpdateAsync(user).ConfigureAwait(false)); //returns false if not updated, based on tracking
user.Name = "Bob";
Assert.True(await connection.UpdateAsync(user).ConfigureAwait(false)); //returns true if updated, based on tracking
user.Name = "Dave";
user.Age = 43;
Assert.True(await connection.UpdateOnlyAsync(user, includedFields: x => x.Age).ConfigureAwait(false)); //returns true if updated, based on tracking
user = await connection.GetAsync<IUser>(id).ConfigureAwait(false);
Assert.Equal("Bob", user.Name); // Name isn't updated
Assert.Equal(43, user.Age); // Age is updated

//get a user with no tracking
var notrackedUser = await connection.GetAsync<User>(id).ConfigureAwait(false);
Assert.Equal("Bob", notrackedUser.Name);
Assert.True(await connection.UpdateAsync(notrackedUser).ConfigureAwait(false));
//returns true, even though user was not changed
notrackedUser.Name = "Cecil";
Assert.True(await connection.UpdateAsync(notrackedUser).ConfigureAwait(false));
Assert.Equal("Cecil", (await connection.GetAsync<User>(id).ConfigureAwait(false)).Name);

Assert.Equal((await connection.QueryAsync<User>("select * from Users").ConfigureAwait(false)).Count(), originalCount + 1);
Assert.True(await connection.DeleteAsync(user).ConfigureAwait(false));
Assert.Equal((await connection.QueryAsync<User>("select * from Users").ConfigureAwait(false)).Count(), originalCount);

Assert.False(await connection.UpdateAsync(notrackedUser).ConfigureAwait(false)); //returns false, user not found

Assert.True(await connection.InsertAsync(new User { Name = "Adam", Age = 10 }).ConfigureAwait(false) > originalCount + 1);
}
}

[Fact]
public async Task InsertGetUpdateAsync()
{
Expand Down