かずきのBlog@hatena

日本マイクロソフトに勤めています。XAML + C#の組み合わせをメインに、たまにASP.NETやJavaなどの.NET系以外のことも書いています。掲載内容は個人の見解であり、所属する企業を代表するものではありません。

DLINQ,XLINQみたいにIQueryableを拡張するには(下調べ)

DLINQとかXLINQってSystem.Linq.IQueryableを拡張してみるみたいね!
ってことでどうやるのか方法を探ってみる。

IQueryableを実装すると下の8個のメソッドを実装しろって言ってくる。

  1. IQueryable CreateQuery(System.Linq.Expressions.Expression expression)
  2. TResult Execute(System.Linq.Expressions.Expression expression)
  3. IEnumerator GetEnumerator()
  4. System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
  5. IQueryable CreateQuery(System.Linq.Expressions.Expression expression)
  6. Type ElementType{ get; }
  7. object Execute(System.Linq.Expressions.Expression expression)
  8. System.Linq.Expressions.Expression Expression { get; }

CreateQueryとExecuteとGetEnumeratorはGeneric版と非Generic版でかぶってるから実質5個の実装が必要らしい。


ちょいと調べた結果CreateQueryメソッドとElementTypeプロパティとExpressionプロパティとGetEnumeratorメソッドを実装すればとりあえず動きそうな雰囲気をかもし出してる。
というわけで動きをトレースするために適当な実装をでっちあげてみた。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Diagnostics;

namespace QueryableExtension
{
    class Program
    {
        static void Main(string[] args)
        {
            TraceQueryable q = new TraceQueryable();
            Console.WriteLine("before query");
            var r = from i in q
                    where i == 0
                    select i * i;
            Console.WriteLine("after query");
            foreach (var i in r)
            {
                Console.WriteLine(i);
            }
        }
    }

    class TraceQueryable : IQueryable<int>
    {
        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            Console.WriteLine(
                "TraceQueryable::CreateQuery<" + typeof(TElement) + ">(" + 
                expression + ")");
            // とりあえず自分を返す
            return (IQueryable<TElement>)this;
        }

        public TResult Execute<TResult>(Expression expression)
        {
            // なんか実装しなくても動く?
            throw new Exception("The method or operation is not implemented.");
        }
        public IEnumerator<int> GetEnumerator()
        {
            Console.WriteLine("TraceQueryable::GetEnumerator");
            
            // 適当
            yield return 0;
        }
        public Type ElementType
        {
            get 
            {
                Console.WriteLine("TraceQueryable::get_ElementType");
                return typeof(int);
            }
        }

        public Expression Expression
        {
            get 
            {
                Console.WriteLine("TraceQueryable::get_Expression");
                return Expression.Constant(this); 
            }
        }

        #region 非Genericな奴等はとりあえずスルー
        public object Execute(Expression expression)
        {
            throw new Exception("The method or operation is not implemented.");
        }
        public IQueryable CreateQuery(Expression expression)
        {
            throw new Exception("The method or operation is not implemented.");
        }
        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            throw new Exception("The method or operation is not implemented.");
        }
        #endregion
    }

}

実行すると↓みたいになった

before query
TraceQueryable::get_Expression
TraceQueryable::CreateQuery<System.Int32>(value(QueryableExtension.TraceQueryable).Where(i => (i = 0)))
TraceQueryable::get_Expression
TraceQueryable::CreateQuery<System.Int32>(value(QueryableExtension.TraceQueryable).Select(i => (i * i)))
after query
TraceQueryable::GetEnumerator
0

基本的にExpressionプロパティが呼ばれて、CreateQueryが呼ばれる。
CreateQueryの中身をもうちょっと詳しく見るために↓みたいにして実行してみた。

        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            Console.WriteLine(
                "TraceQueryable::CreateQuery<" + typeof(TElement) + ">(" + 
                expression + ")");
            Console.WriteLine("\texpression.GetType(): " +
                expression.GetType());
            Console.WriteLine("\texpression.NodeType: " +
                expression.NodeType);
            Console.WriteLine("\texpression.Type: " + expression.Type);
            Console.WriteLine("\texpression.Method: " + ((MethodCallExpression)expression).Method);
            ConstantExpression arg = (ConstantExpression)((MethodCallExpression)expression).Arguments[0];
            Console.WriteLine("\tArguments[0].NodeType: " + arg.NodeType);
            Console.WriteLine("\tArguments[0].Value: " + arg.Value);
            Console.WriteLine("\tArguments[0].Type: " + arg.Type);
            UnaryExpression arg2 = (UnaryExpression)((MethodCallExpression)expression).Arguments[1];
            Console.WriteLine("\tArguments[1].NodeType: " + arg2.NodeType);
            Console.WriteLine("\tArguments[1].Operand: " + arg2.Operand);
            Console.WriteLine("\tArguments[1].Operand.NodeType: " + arg2.Operand.NodeType);
            Console.WriteLine("\tArguments[1].Type: " + arg2.Type);

            // とりあえず自分を返す
            return (IQueryable<TElement>)this;
        }

実行すると↓になった。

before query
TraceQueryable::get_Expression
TraceQueryable::CreateQuery<System.Int32>(value(QueryableExtension.TraceQueryabl
e).Where(i => (i = 0)))
        expression.GetType(): System.Linq.Expressions.MethodCallExpression
        expression.NodeType: Call
        expression.Type: System.Linq.IQueryable`1[System.Int32]
        expression.Method: System.Linq.IQueryable`1[System.Int32] Where[Int32](S
ystem.Linq.IQueryable`1[System.Int32], System.Linq.Expressions.Expression`1[Syst
em.Linq.Func`2[System.Int32,System.Boolean]])
        Arguments[0].NodeType: Constant
        Arguments[0].Value: QueryableExtension.TraceQueryable
        Arguments[0].Type: QueryableExtension.TraceQueryable
        Arguments[1].NodeType: Quote
        Arguments[1].Operand: i => (i = 0)
        Arguments[1].Operand.NodeType: Lambda
        Arguments[1].Type: System.Linq.Expressions.Expression`1[System.Linq.Func
`2[System.Int32,System.Boolean]]
TraceQueryable::get_Expression
TraceQueryable::CreateQuery<System.Int32>(value(QueryableExtension.TraceQueryabl
e).Select(i => (i * i)))
        expression.GetType(): System.Linq.Expressions.MethodCallExpression
        expression.NodeType: Call
        expression.Type: System.Linq.IQueryable`1[System.Int32]
        expression.Method: System.Linq.IQueryable`1[System.Int32] Select[Int32,I
nt32](System.Linq.IQueryable`1[System.Int32], System.Linq.Expressions.Expression
`1[System.Linq.Func`2[System.Int32,System.Int32]])
        Arguments[0].NodeType: Constant
        Arguments[0].Value: QueryableExtension.TraceQueryable
        Arguments[0].Type: QueryableExtension.TraceQueryable
        Arguments[1].NodeType: Quote
        Arguments[1].Operand: i => (i * i)
        Arguments[1].Operand.NodeType: Lambda
        Arguments[1].Type: System.Linq.Expressions.Expression`1[System.Linq.Func
`2[System.Int32,System.Int32]]
after query
TraceQueryable::GetEnumerator
0

どうも、CreateQueryにわたってくるExpressionは、今のところMethodCallExpressionみたい。
MethodCallExpressionのArgumentsには、2つの値が詰まってる。
1つめは、Expressionプロパティで返したもの。
2つめは、whereやselectに渡されたラムダ式に該当するもの。

う〜ん、これを解析して頑張れってことか?
メンドクサソウ…