Sunday, June 23, 2013

Implementing custom iterators with delayed execution

The spring4d users amongst you will be familar with methods like IEnumerable<T>.Where which just returns the elements of an enumerable (mostly a list) that match a certain condition passed as delegate.

Today we will take a look how that is achieved and what requirements should be met when implementing such a method. The method itself looks like this:

function Where(const predicate: TPredicate<T>): IEnumerable<T>;

You can see that it just returns an enumerable so in fact you could chain multiple Where calls and get an enumerable that only contains the elements that match every condition.

The execution of the filtering will be done delayed so as long as we don't use that enumerable in a for-in loop nothing will actually happen. In fact you can still modify the list you called Where on and it will consider these elements when starting the iteration. That means when we call Where another enumerable object will created that just saves the informations to perform the filtering: the source and the delegate.

Also the execution will be streamed. That means the elements in the original source will only iterated once and only when they are needed. So if you for example cancel the loop early it will not have iterated all elements in the source enumerable.

In the spring4d branch there is a new class called TIterator<T> that implements both IEnumerable<T> and IEnumerator<T>. Usually you have seperate classes for enumerable and enumerator but in our case the enumerable is just saving the state for the operation and passes these informations to the enumerator class because that one is doing the real work in the MoveNext method. So we can get rid of the redundant information and put them into the same class.

So we look at this:

type
  TIteratorBase<T> = class(TEnumerableBase<T>, IEnumerator)
  protected
    function GetCurrentNonGeneric: TValue; virtual; abstract;
    function IEnumerator.GetCurrent = GetCurrentNonGeneric;
  public
    function MoveNext: Boolean; virtual;
    procedure Reset; virtual;
  end;

  TIterator<T> = class(TIteratorBase<T>, IEnumerator<T>)
  private
    fThreadId: Cardinal;
  protected
    fState: Integer;
    fCurrent: T;
  protected
    function GetCurrent: T;
    function GetCurrentNonGeneric: TValue; override;
  public
    constructor Create; override;
    function Clone: TIterator<t>; virtual; abstract;
    function GetEnumerator: IEnumerator<T>; override;
  end;

Let's look at the GetEnumerator method:

function TIterator<T>.GetEnumerator: IEnumerator<T>;
var
  iterator: TIterator<T>;
begin
  if (fThreadId = TThread.CurrentThread.ThreadID) and (fState = 0) then
  begin
    fState := 1;
    Result := Self;
  end
  else
  begin
    iterator := Clone;
    iterator.fState := 1;
    Result := iterator;
  end;
end;

This ensures that you get a new instance as enumerator when necessary. In the case of the same thread and its first iteration it just returns itself (which is the most common case).

To implement the Where operation we basically just need to implement 3 methods: Create, Clone and MoveNext - let's take a look at the class for that:

type
  TWhereEnumerableIterator<T> = class(TIterator<T>)
  private
    fSource: IEnumerable<T>;
    fPredicate: TPredicate<T>;
    fEnumerator: IEnumerator<T>;
  public
    constructor Create(const source: IEnumerable<T>;
      const predicate: TPredicate<T>);
    function Clone: TIterator<T>; override;
    function MoveNext: Boolean; override;
  end;

As said above the real work is done in the MoveNext method - constructor and Clone method are a no brainer.

function TWhereEnumerableIterator<T>.MoveNext: Boolean;
var
  current: T;
begin
  Result := False;

  if fState = 1 then
  begin
    fEnumerator := fSource.GetEnumerator;
    fState := 2;
  end;

  if fState = 2 then
  begin
    while fEnumerator.MoveNext do
    begin
      current := fEnumerator.Current;
      if fPredicate(current) then
      begin
        fCurrent := current;
        Exit(True);
      end;
    end;
    fState := -1;
    fEnumerator := nil;
  end;
end;

Now this is the real interesting part. As in the requirements the source is not iterated as long as we do not iterate the filtered enumerable. When MoveNext is called for the first time we get the enumerator from the source. And it is only proceeded as much as needed. Since Delphi does not support some yield syntax and compiler generated iterator blocks like C# this source looks a bit more complicated as it would with that syntax support but I guess it is easy enough to understand and implementing these iterators is pretty easy since you only need to implement the MoveNext method (apart from a pretty much straight forward constructor and Clone method).

However just to tease you - wouldn't it be nice if we could write something like this?

function TEnumerable<T>.Where(const delegate: TPredicate<T>): IEnumerable<T>
var
  item: T;
begin
  for item in Self do
    if predicate(item) then
      Yield(item);
end;

If you are interested more about that topic I suggest reading Jon Skeets Reimplementing LINQ to Objects series.