Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change priority to user type readers and implement user entity readers #1487

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,24 +275,23 @@ private static void BuildParameter(ParameterBuilder builder, System.Reflection.P

if (builder.TypeReader == null)
{
builder.TypeReader = service.GetDefaultTypeReader(paramType)
?? service.GetTypeReaders(paramType)?.FirstOrDefault().Value;
builder.TypeReader = service.GetTypeReaders(paramType, false)?.FirstOrDefault().Value
?? service.GetDefaultTypeReader(paramType);
}
}

internal static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services)
{
var readers = service.GetTypeReaders(paramType);
TypeReader reader = null;
var readers = service.GetTypeReaders(paramType, true);
if (readers != null)
{
if (readers.TryGetValue(typeReaderType, out reader))
return reader;
}
foreach (var kvp in readers)
if (kvp.Key == typeReaderType)
return kvp.Value;

//We dont have a cached type reader, create one
reader = ReflectionUtils.CreateObject<TypeReader>(typeReaderType.GetTypeInfo(), service, services);
service.AddTypeReader(paramType, reader, false);
TypeReader reader = ReflectionUtils.CreateObject<TypeReader>(typeReaderType.GetTypeInfo(), service, services);
reader.IsOverride = true;
service.AddTypeReader(paramType, reader);

return reader;
}
Expand Down
5 changes: 2 additions & 3 deletions src/Discord.Net.Commands/Builders/ParameterBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private TypeReader GetReader(Type type)
if (type.GetTypeInfo().GetCustomAttribute<NamedArgumentTypeAttribute>() != null)
{
IsRemainder = true;
var reader = commands.GetTypeReaders(type)?.FirstOrDefault().Value;
var reader = commands.GetTypeReaders(type, false)?.FirstOrDefault().Value;
if (reader == null)
{
Type readerType;
Expand All @@ -80,8 +80,7 @@ private TypeReader GetReader(Type type)
return reader;
}


var readers = commands.GetTypeReaders(type);
var readers = commands.GetTypeReaders(type, false);
if (readers != null)
return readers.FirstOrDefault().Value;
else
Expand Down
123 changes: 91 additions & 32 deletions src/Discord.Net.Commands/CommandService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class CommandService : IDisposable
private readonly SemaphoreSlim _moduleLock;
private readonly ConcurrentDictionary<Type, ModuleInfo> _typedModuleDefs;
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders;
private readonly ConcurrentDictionary<Type, ConcurrentQueue<Type>> _userEntityTypeReaders;
private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders;
private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders;
private readonly HashSet<ModuleInfo> _moduleDefs;
Expand Down Expand Up @@ -77,6 +78,15 @@ public class CommandService : IDisposable
/// </summary>
public ILookup<Type, TypeReader> TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new { y.Key, y.Value })).ToLookup(x => x.Key, x => x.Value);

/// <summary>
/// Represents all entity type reader <see cref="Type" />s loaded within <see cref="CommandService"/>.
/// </summary>
/// <returns>
/// A <see cref="ILookup{TKey, TElement}"/>; the key is the object type to be read by the <see cref="TypeReader"/>,
/// while the element is the type of the <see cref="TypeReader"/> generic definition.
/// </returns>
public ILookup<Type, Type> EntityTypeReaders => _userEntityTypeReaders.SelectMany(x => x.Value.Select(y => new { x.Key, TypeReaderType = y })).ToLookup(x => x.Key, y => y.TypeReaderType);

/// <summary>
/// Initializes a new <see cref="CommandService"/> class.
/// </summary>
Expand Down Expand Up @@ -109,6 +119,7 @@ public CommandService(CommandServiceConfig config)
_moduleDefs = new HashSet<ModuleInfo>();
_map = new CommandMap(this);
_typeReaders = new ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>>();
_userEntityTypeReaders = new ConcurrentDictionary<Type, ConcurrentQueue<Type>>();

_defaultTypeReaders = new ConcurrentDictionary<Type, TypeReader>();
foreach (var type in PrimitiveParsers.SupportedTypes)
Expand Down Expand Up @@ -329,8 +340,6 @@ private bool RemoveModuleInternal(ModuleInfo module)
/// type.
/// If <typeparamref name="T" /> is a <see cref="ValueType" />, a nullable <see cref="TypeReader" /> will
/// also be added.
/// If a default <see cref="TypeReader" /> exists for <typeparamref name="T" />, a warning will be logged
/// and the default <see cref="TypeReader" /> will be replaced.
/// </summary>
/// <typeparam name="T">The object type to be read by the <see cref="TypeReader"/>.</typeparam>
/// <param name="reader">An instance of the <see cref="TypeReader" /> to be added.</param>
Expand All @@ -341,32 +350,70 @@ public void AddTypeReader<T>(TypeReader reader)
/// type.
/// If <paramref name="type" /> is a <see cref="ValueType" />, a nullable <see cref="TypeReader" /> for the
/// value type will also be added.
/// If a default <see cref="TypeReader" /> exists for <paramref name="type" />, a warning will be logged and
/// the default <see cref="TypeReader" /> will be replaced.
/// </summary>
/// <param name="type">A <see cref="Type" /> instance for the type to be read.</param>
/// <param name="reader">An instance of the <see cref="TypeReader" /> to be added.</param>
public void AddTypeReader(Type type, TypeReader reader)
{
if (_defaultTypeReaders.ContainsKey(type))
_ = _cmdLogger.WarningAsync($"The default TypeReader for {type.FullName} was replaced by {reader.GetType().FullName}." +
"To suppress this message, use AddTypeReader<T>(reader, true).");
AddTypeReader(type, reader, true);
var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary<Type, TypeReader>());
readers[reader.GetType()] = reader;

if (type.GetTypeInfo().IsValueType)
AddNullableTypeReader(type, reader);
}
/// <summary>
/// Adds a custom entity <see cref="TypeReader" /> to this <see cref="CommandService" /> for the supplied
/// object type.
/// </summary>
/// <example>
/// <para>The following example adds a custom entity reader to this <see cref="CommandService"/>.</para>
/// <code language="cs" region="AddEntityTypeReader"
/// source="..\Discord.Net.Examples\Commands\CommandService.Examples.cs" />
/// </example>
/// <typeparam name="T">The object type to be read by the <see cref="TypeReader"/>.</typeparam>
/// <param name="typeReaderGenericType">A generic type definition (with one open argument) of the <see cref="TypeReader" />.</param>
public void AddEntityTypeReader<T>(Type typeReaderGenericType)
=> AddEntityTypeReader(typeof(T), typeReaderGenericType);
/// <summary>
/// Adds a custom entity <see cref="TypeReader" /> to this <see cref="CommandService" /> for the supplied
/// object type.
/// </summary>
/// <param name="type">A <see cref="Type" /> instance for the type to be read.</param>
SubZero0 marked this conversation as resolved.
Show resolved Hide resolved
/// <param name="typeReaderGenericType">A generic type definition (with one open argument) of the <see cref="TypeReader" />.</param>
public void AddEntityTypeReader(Type type, Type typeReaderGenericType)
{
if (!typeReaderGenericType.IsGenericTypeDefinition)
throw new ArgumentException("TypeReader type must be a generic type definition.", nameof(typeReaderGenericType));
Type[] genericArgs = typeReaderGenericType.GetGenericArguments();
if (genericArgs.Length != 1)
throw new ArgumentException("TypeReader type must accept one and only one open generic argument.", nameof(typeReaderGenericType));
if (!genericArgs[0].IsGenericParameter)
throw new ArgumentException("TypeReader type must accept one and only one open generic argument.", nameof(typeReaderGenericType));
if (!genericArgs[0].GenericParameterAttributes.HasFlag(GenericParameterAttributes.ReferenceTypeConstraint))
throw new ArgumentException("TypeReader generic argument must have a reference type constraint.", nameof(typeReaderGenericType));
var readers = _userEntityTypeReaders.GetOrAdd(type, x => new ConcurrentQueue<Type>());
readers.Enqueue(typeReaderGenericType);
}
/// <summary>
/// Adds a custom <see cref="TypeReader" /> to this <see cref="CommandService" /> for the supplied object
/// type.
/// If <typeparamref name="T" /> is a <see cref="ValueType" />, a nullable <see cref="TypeReader" /> will
/// also be added.
/// </summary>
/// <example>
/// <para>The following example adds a custom entity reader to this <see cref="CommandService"/>.</para>
/// <code language="cs" region="AddEntityTypeReader2"
/// source="..\Discord.Net.Examples\Commands\CommandService.Examples.cs" />
/// </example>
/// <typeparam name="T">The object type to be read by the <see cref="TypeReader"/>.</typeparam>
/// <param name="reader">An instance of the <see cref="TypeReader" /> to be added.</param>
/// <param name="replaceDefault">
/// Defines whether the <see cref="TypeReader"/> should replace the default one for
/// <see cref="Type" /> if it exists.
/// </param>
[Obsolete("This method is deprecated. Use the method without the replaceDefault argument.")]
public void AddTypeReader<T>(TypeReader reader, bool replaceDefault)
=> AddTypeReader(typeof(T), reader, replaceDefault);
=> AddTypeReader(typeof(T), reader);
/// <summary>
/// Adds a custom <see cref="TypeReader" /> to this <see cref="CommandService" /> for the supplied object
/// type.
Expand All @@ -379,27 +426,10 @@ public void AddTypeReader<T>(TypeReader reader, bool replaceDefault)
/// Defines whether the <see cref="TypeReader"/> should replace the default one for <see cref="Type" /> if
/// it exists.
/// </param>
[Obsolete("This method is deprecated. Use the method without the replaceDefault argument.")]
public void AddTypeReader(Type type, TypeReader reader, bool replaceDefault)
{
if (replaceDefault && HasDefaultTypeReader(type))
{
_defaultTypeReaders.AddOrUpdate(type, reader, (k, v) => reader);
if (type.GetTypeInfo().IsValueType)
{
var nullableType = typeof(Nullable<>).MakeGenericType(type);
var nullableReader = NullableTypeReader.Create(type, reader);
_defaultTypeReaders.AddOrUpdate(nullableType, nullableReader, (k, v) => nullableReader);
}
}
else
{
var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary<Type, TypeReader>());
readers[reader.GetType()] = reader;
=> AddTypeReader(type, reader);

if (type.GetTypeInfo().IsValueType)
AddNullableTypeReader(type, reader);
}
}
internal bool HasDefaultTypeReader(Type type)
{
if (_defaultTypeReaders.ContainsKey(type))
Expand All @@ -408,18 +438,47 @@ internal bool HasDefaultTypeReader(Type type)
var typeInfo = type.GetTypeInfo();
if (typeInfo.IsEnum)
return true;
return _entityTypeReaders.Any(x => type == x.EntityType || typeInfo.ImplementedInterfaces.Contains(x.TypeReaderType));
return _entityTypeReaders.Any(x => type == x.EntityType || typeInfo.ImplementedInterfaces.Contains(x.EntityType));
}
internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader)
{
var readers = _typeReaders.GetOrAdd(typeof(Nullable<>).MakeGenericType(valueType), x => new ConcurrentDictionary<Type, TypeReader>());
var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader);
readers[nullableReader.GetType()] = nullableReader;
}
internal IDictionary<Type, TypeReader> GetTypeReaders(Type type)
internal IEnumerable<KeyValuePair<Type, TypeReader>> GetTypeReaders(Type type, bool includeOverride)
{
if (_typeReaders.TryGetValue(type, out var definedTypeReaders))
return definedTypeReaders;
return includeOverride ? definedTypeReaders : definedTypeReaders.Where(x => !x.Value.IsOverride);

var assignableEntityReaders = _userEntityTypeReaders.Where(x => x.Key.IsAssignableFrom(type));

int assignableTo = -1;
KeyValuePair<Type, ConcurrentQueue<Type>>? entityReaders = null;
foreach (var entityReader in assignableEntityReaders)
{
int assignables = assignableEntityReaders.Sum(x => !x.Equals(entityReader) && x.Key.IsAssignableFrom(entityReader.Key) ? 1 : 0);
if (assignableTo == -1)
{
// First time
assignableTo = assignables;
entityReaders = entityReader;
}
// Try to get the most specific type reader, i.e. IMessageChannel is assignable to IChannel, but not the inverse
else if (assignables > assignableTo)
{
assignableTo = assignables;
entityReaders = entityReader;
}
}

if (entityReaders != null)
{
var entityTypeReaderType = entityReaders.Value.Value.First();
TypeReader reader = Activator.CreateInstance(entityTypeReaderType.MakeGenericType(type)) as TypeReader;
AddTypeReader(type, reader);
return GetTypeReaders(type, false);
}
return null;
}
internal TypeReader GetDefaultTypeReader(Type type)
Expand Down Expand Up @@ -511,7 +570,7 @@ public async Task<IResult> ExecuteAsync(ICommandContext context, string input, I
await _commandExecutedEvent.InvokeAsync(Optional.Create<CommandInfo>(), context, searchResult).ConfigureAwait(false);
return searchResult;
}


var commands = searchResult.Commands;
var preconditionResults = new Dictionary<CommandMatch, PreconditionResult>();
Expand Down
4 changes: 2 additions & 2 deletions src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ async Task<object> ReadArgumentAsync(PropertyInfo prop, string arg)
var overridden = prop.GetCustomAttribute<OverrideTypeReaderAttribute>();
var reader = (overridden != null)
? ModuleClassBuilder.GetTypeReader(_commands, elemType, overridden.TypeReader, services)
: (_commands.GetDefaultTypeReader(elemType)
?? _commands.GetTypeReaders(elemType).FirstOrDefault().Value);
: (_commands.GetTypeReaders(elemType, false)?.FirstOrDefault().Value
?? _commands.GetDefaultTypeReader(elemType));

if (reader != null)
{
Expand Down
1 change: 1 addition & 0 deletions src/Discord.Net.Commands/Readers/TypeReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace Discord.Commands
/// </summary>
public abstract class TypeReader
{
internal bool IsOverride { get; set; } = false;
/// <summary>
/// Attempts to parse the <paramref name="input"/> into the desired type.
/// </summary>
Expand Down
55 changes: 55 additions & 0 deletions src/Discord.Net.Examples/Commands/CommandService.Examples.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Discord.Commands;
using JetBrains.Annotations;
using System;
using System.Threading.Tasks;

namespace Discord.Net.Examples.Commands
{
[PublicAPI]
internal class CommandServiceExamples
{
#region AddEntityTypeReader

public void AddCustomUserEntityReader(CommandService commandService)
{
commandService.AddEntityTypeReader<IUser>(typeof(MyUserTypeReader<>));
}

public class MyUserTypeReader<T> : TypeReader
where T : class, IUser
{
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services)
{
if (ulong.TryParse(input, out var id))
return ((await context.Client.GetUserAsync(id)) is T user)
? TypeReaderResult.FromSuccess(user)
: TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found.");
return TypeReaderResult.FromError(CommandError.ParseFailed, "Couldn't parse input to ulong.");
}
}

#endregion

#region AddEntityTypeReader2

public void AddCustomChannelEntityReader(CommandService commandService)
{
commandService.AddEntityTypeReader<IUser>(typeof(MyUserTypeReader<>));
}

public class MyChannelTypeReader<T> : TypeReader
where T : class, IChannel
{
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services)
{
if (ulong.TryParse(input, out var id))
return ((await context.Client.GetChannelAsync(id)) is T channel)
? TypeReaderResult.FromSuccess(channel)
: TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found.");
return TypeReaderResult.FromError(CommandError.ParseFailed, "Couldn't parse input to ulong.");
}
}

#endregion
}
}
1 change: 1 addition & 0 deletions src/Discord.Net.Examples/Discord.Net.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Discord.Net.Commands\Discord.Net.Commands.csproj" />
<ProjectReference Include="..\Discord.Net.Core\Discord.Net.Core.csproj" />
<ProjectReference Include="..\Discord.Net.WebSocket\Discord.Net.WebSocket.csproj" />
<PackageReference Include="JetBrains.Annotations" Version="2019.1.3" />
Expand Down