amb chain

By leonardo maffi
V.1.1, Jul 3 2009
keywords: programming, D language, Python, benchmark

[Go back to the article index]

Here you can find the code shown in this post:
http://www.fantascienza.net/leonardo/js/amb_chain.zip

This page of the Rosetta Code site shows a programming task that asks to emulate the 'amb' operator present in Scheme:
http://rosettacode.org/wiki/Amb

The amb operator is used to choose a word from each of the following string sets, and create a sentence of (four) words where the last character of word 1 is equal to the first character of word 2, and similarly with word 2 and word 3, as well as word 3 and word 4:
set 1: "the" "that" "a"
set 2: "frog" "elephant" "thing"
set 3: "walked" "treaded" "grows"
set 4: "slowly" "quickly"

The only one sentence that satisfies such constraints is "that thing grows slowly".

The Python solutions shown in that page aren't fully pythonic to try to partially mimic the amb operator. This is one of them:

# amb0.py
from itertools import product
sets = [
    set('the that a'.split()),
    set('frog elephant thing'.split()),
    set('walked treaded grows'.split()),
    set('slowly quickly'.split())
    ]

success = ( sentence for sentence in product(*sets)
                if all(sentence[word][-1]==sentence[word+1][0]
                       for word in range(3))
              )
print success.next()

In all the following code I'll ignore the requirements of emulate amb.

In the following code I have put the whole code needed to run (so I have added xpairwise() and select() and others that I keep hidden into some of my Python or D libraries) even if this makes the code look bigger, to simplify the running of such code by other people.

This is a bit more readable Python version:
# amb1.py
from itertools import product, tee, izip

def xpairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    # from the recipes of itertools documentation
    a, b = tee(iterable)
    next(b, None)
    return izip(a, b)

sets = """\
the that a
frog elephant thing
walked treaded grows
slowly quickly"""

sets = map(str.split, sets.splitlines())

for sentence in product(*sets):
    if all(w1[-1] == w2[0] for w1, w2 in xpairwise(sentence)):
        print sentence

As the precedent version, it works enumerating (lazily) the Cartesian product of all the word sets, using itertools.product. This is handy when the cardinality of such product is limited, as in this example, but when the size of such word sets grow (and grows their number too) this program becomes too much slow.

A way to create a faster algorithm is to pre-compute correct word-associations, to avoid trying all sub-trees of the search space.

But thinking about this problem some more shows that there's no need to store such pre-computations in a tree or graph, you can just start from the leaves of such tree and go toward the upper levels, keeping a simpler data structure. You can create all the correct partial sentences composed with just the words of the last two sets. Then you can add in front of each one of them the correct words of the third to last word set, keeping only the partial sentences that satisfy the constraints.

The only correct partial sentences with the last two word sets:

["grows slowly"]

The only two correct partial sentences with the second word set and the result of the precedent computation step:

["frog grows slowly", "thing grows slowly"]

Finally, using the first set:

["that thing grows slowly"]

As you can see the two final sets (that are string lists) are combined into a single list of strings, the the same operation is performed on the result of the precedent computation and the last line left, and so on. This is the semantics of the left-fold (reduce) built-in in Python, so the code gets even shorter and more elegant, despite being much more efficient because a very large part of the search space is being skipped:
# amb2.py
sets = """\
the that a
frog elephant thing
walked treaded grows
slowly quickly"""

sets = reversed(map(str.split, sets.splitlines()))
print reduce(lambda s1,s2: [w2+" "+w1 for w1 in s1 for w2 in s2 if w1[0]==w2[-1]], sets)

Psyco (or moving the Python code into a function) aren't able to speed up the 'amb2b.py' Python code.

Despite being fast, this algorithm isn't perfect because as the word sets grows, the memory requirements grow quickly. A way to reduce memory used is to turn this search from fully parallel (as now) to partially sequential, reducing the max size of the working set. A simpler (but less effective) way to reduce memory usage (and speed up code) is to reduce the memory used to store the partial sentences.

In Python strings (as almost everything else) are managed by reference, so if you create lists of words no word is actually copied, just their small references. This suggests me to keep the partial sentences as string lists instead of just strings:
# amb3.py
sets = """\
the that a
frog elephant thing
walked treaded grows
slowly quickly"""

sets = [[[word] for word in sentence.split()] for sentence in reversed(sets.splitlines())]
print reduce(lambda s1, s2: [w2+w1 for w1 in s1 for w2 in s2 if w1[0][0]==w2[-1][-1]], sets)

As you can see now sets is a matrix of string lists, and such lists are concatenated in w2+w1 (where w2 is of length 1, and w1 grows at each iteration).

Such concatenation generates lots of garbage for the garbage collector, so to speed up the code I have tried to reduce such garbage extending the w1 list, but I have seen no significant improvement in the running times (maybe there are ways to improve the situation).

To test the performance of the various versions I have created variants of them (named amb1b, amb2b, etc) that load the words sets from a file, plus some growing word sets:

Input Set files:
sets1.txt: 4 sets, 3 or 4 words each set
sets2.txt: 5 sets, 20 words each set
sets3.txt: 5 sets, 100 words each set
sets4.txt: 5 sets, 180 words each set
sets5.txt: 5 sets, 300 words each set (used only with the faster D program).

sets3.txt produces an output of about 5.4-7 MB, sets4.txt produces an output of about 60-77 MB, that's a very large number of solutions.

Using the select() from my dlibs (that more or less emulates Python list comps using lazy arguments), plus a very simple left fold (reduce) it's easy to translate amb2.py to D1:
// amb4.d
import std.stdio, libs;
alias string str;
void main() {
  str[][] sets = [["the", "that", "a"], ["frog", "elephant", "thing"],
                  ["walked", "treaded", "grows"], ["slowly", "quickly"]];
  wamb4.dritefln(reduce((str[] s1, str[] s2) {str w1,w2; return select(w2~" "~w1, w1, s1, w2, s2, w1[0]==w2[$-1]); }, sets.reverse));
}

amb4b.d is similar to amb4.d, but takes the words sets from standard input (instead from a file name as Python "b" code).


This is a translation of 'amb3.py' to D1, and in Python the result is a worse performance (despite w2~w1 concatenates arrays, and no string is actually copied).
// amb5b.d
import std.stdio: writefln, readln;
import std.string: split;
import libs: reduce, select;
alias string str;

void main() {
  str[][][] sets;
  str line;
  while ((line = readln()) != null) {
    str[][] set;
    foreach (word; line.split())
      set ~= [word.dup];
    sets ~= set;
  }

  writefln(reduce((str[][] s1, str[][] s2) {str[] w1,w2; return select(w2~w1, w1, s1, w2, s2, w1[0][0]==w2[$-1][$-1]); }, sets.reverse));
}


All those dynamic arrays require lot of memory. An obvious attempt to reduce memory used is to use just references to the input strings instead (that are very few compared to the amount of output). Instead of pointers to strings I use indexes. The input strings aren't so many, so I can even use a ushort to index them.

This program is I/O bound, so to further speed up the code it prints using c functions (appending a /0 and a space at the end of D strings):
// amb6b.d
import std.c.stdio: printf, putchar;
import std.stdio: writefln, readln;
import std.string: split;
import libs: reduce, select;
alias ushort Index;

void main() {
    Index[][][] sets;
    string[] data;
    Index count;
    string line;
    while ((line = readln()) != null) {
        Index[][] set;
        foreach (word; line.split()) {
            data ~= word.dup;
            set ~= [count++];
        }
        sets ~= set;
    }

    auto r = reduce((Index[][] s1, Index[][] s2) {
                        Index[] w1,w2;
                        return select(w2~w1, w1, s1, w2, s2, data[w1[0]][0]==data[w2[$-1]][$-1]); },
                     sets.reverse);

    foreach (ref s; data)
        s ~= " \0";
    foreach (seq; r) {
        foreach (idx; seq)
            printf("%s", data[idx].ptr);
        putchar('\n');
    }
}


To reduce memory usage further, the sentences can be represented as linked lists of the input words. This leads to longer and less-looking clean code, that uses pointers too.

One node of such single-linked list contains the first and last char of the word (not necessary, but there's space in the padding and it may speed up code), the word index, and the pointer ro the next list node.

I have moved the delegate out of the reduce() for clarity.

Inside Node(n2p.first, '-', n2p.idx, n1p) there's '-' because in all but the original single nodes there's no need to know the last char. Here the GC is disabled because I think all structs created are kept in memory, so there's no need to collect garbage (and the max RAM used is indeed the same with and without GC), and because the performance with no GC is significantly higher (maybe lot of time is spent by the GC to follow those linked lists).
// amb7b.d
import std.c.stdio: printf, putchar;
import std.stdio: writefln, readln;
import std.string: split;
import libs: reduce, select;
import std.gc: disable;

alias ushort Index;
string[] data;

struct Node {
    Node* next;
    char first, last;
    Index idx;

    static Node* opCall(char first, char last, Index idx, Node* next) {
        Node* n = new Node;
        (*n).first = first;
        (*n).last = last;
        (*n).idx = idx;
        (*n).next = next;
        return n;
    }
}
static assert(Node.sizeof == 8);

void main() {
    disable();
    Node*[][] sets;
    Index count;
    string line;
    while ((line = readln()) != null) {
        Node*[] set;
        foreach (word; line.split()) {
            data ~= word.dup;
            set ~= Node(data[count][0], data[count][$-1], count, null);
            count++;
        }
        sets ~= set;
    }

    Node*[] combiner(Node*[] sub1, Node*[] sub2) {
        Node* n1p, n2p;
        return select(Node(n2p.first, '-', n2p.idx, n1p), n1p, sub1, n2p, sub2, n2p.last == n1p.first);
    }
    Node*[] r = reduce(&combiner, sets.reverse);

    foreach (ref s; data)
        s ~= " \0";
    foreach (n; r) {
        for ( ; n; n = n.next)
            printf("%s", data[n.idx].ptr);
        putchar('\n');
    }
}


Now an obvious performance (and memory, reducing wasted space) optimization is to allocate such nodes with a memory pool:
// amb8b.d
import std.c.stdio: printf, putchar;
import std.stdio: readln;
import std.string: split;
import libs: reduce, select, MemoryPool;
import std.gc: disable;

alias ushort Index;
string[] data;

struct Node {
    Node* next;
    char first, last;
    Index idx;
}
static assert(Node.sizeof == 8);

MemoryPool!(Node) npool;

Node* newNode(char first, char last, Index idx, Node* next) {
    Node* n = npool.newItem();
    (*n).first = first;
    (*n).last = last;
    (*n).idx = idx;
    (*n).next = next;
    return n;
}

void main() {
    disable();
    Node*[][] sets;
    Index count;
    string line;
    while ((line = readln()) != null) {
        Node*[] set;
        foreach (word; line.split()) {
            data ~= word.dup;
            set ~= newNode(data[count][0], data[count][$-1], count, null);
            count++;
        }
        sets ~= set;
    }

    Node*[] combiner(Node*[] sub1, Node*[] sub2) {
        Node* n1p, n2p;
        return select(newNode(n2p.first, '-', n2p.idx, n1p), n1p, sub1, n2p, sub2, n2p.last == n1p.first);
    }
    Node*[] r = reduce(&combiner, sets.reverse);

    foreach (ref s; data)
        s ~= " \0";
    foreach (n; r) {
        for ( ; n; n = n.next)
            printf("%s", data[n.idx].ptr);
        putchar('\n');
    }
}

Now the performance is very much dominated by the I/O (only 0.29 seconds with no printing, on a slow CPU).

So far the D code uses the high-level functions reduce() and select(), that may introduce some overhead (not much, seeing the timings of 'amb8b.d', it shows that often such high-level constructs aren't a bottleneck). reduce() probably doesn't introduce any slowdown because its loop is very short. So I've inlined the select(), and its lazy arguments (that are delegates):
// amb9b.d
import std.c.stdio: printf, putchar;
import std.stdio: readln;
import std.string: split;
import libs: reduce, MemoryPool, ArrayBuilder;
import std.gc: disable;

alias ushort Index;
string[] data;

struct Node {
    Node* next;
    char first, last;
    Index idx;
}
static assert(Node.sizeof == 8);

MemoryPool!(Node) npool;

Node* newNode(char first, char last, Index idx, Node* next) {
    Node* n = npool.newItem();
    (*n).first = first;
    (*n).last = last;
    (*n).idx = idx;
    (*n).next = next;
    return n;
}

void main() {
    disable();
    Node*[][] sets;
    Index count;
    string line;
    while ((line = readln()) != null) {
        Node*[] set;
        foreach (word; line.split()) {
            data ~= word.dup;
            set ~= newNode(data[count][0], data[count][$-1], count, null);
            count++;
        }
        sets ~= set;
    }

    Node*[] combiner(Node*[] sub1, Node*[] sub2) {
        ArrayBuilder!(Node*) array;
        foreach (n1p; sub1)
            foreach (n2p; sub2)
                if (n2p.last == n1p.first)
                    array ~= newNode(n2p.first, '-', n2p.idx, n1p);
        return array.toarray;
    }
    Node*[] r = reduce(&combiner, sets.reverse);

    foreach (ref s; data)
        s ~= " \0";
    foreach (n; r) {
        for ( ; n; n = n.next)
            printf("%s", data[n.idx].ptr);
        putchar('\n');
    }
}

The user code is longer (not counting the libs). The timings of 'amb9b.d' with no output show an improvement (0.22 seconds on 'sets4.txt'). The memory used is about 1/4, and the running time allows to try a bigger problem. So I have tried 'amb8b.d' (Index=ushort, no printing, no GC) on sets5.txt, that contains 5 sets of 300 words each, it takes 5.58 seconds of running time and 545 MB RAM. On this PC this problem can't be solved (in an acceptable amount of time) with one of the first D versions.

A version that doesn't print all the solutions may be useful anyway, because you may want to scan the solutions looking for the ones that satisfy other constraints.

Timings, > to file:
      Set     #1      #2     #3      #4     #5
  amb1b.py: 0.17   17.07      -       -
  amb2b.py: 0.16    0.16   1.28   13.95
  amb2b.py:    -       -      -   13.77      - (no GC)
  amb2b.py:    -       -   0.63    6.24      - (no printing)
  amb2b.py:    -       -      -    6.36      - (no GC, no printing)
  amb2c.py:    -       -   0.87    8.51      - (with psyco)
  amb2c.py:    -       -      -    1.03      - (with psyco, no printing)
  amb2d.py:    -       -   0.82    8.11      - (with psyco)
  amb2d.py:    -       -      -    0.66      - (with psyco, no printing)
  amb2e.py:    -       -   0.82    8.08      - (with psyco)
  amb2e.py:    -       -      -    0.63      - (with psyco, no printing)
  amb3b.py: 0.15    0.17   1.93   21.80
  amb4b.d:  0.04    0.03   0.30    3.89
  amb4b.d:     -       -      -    3.35      - (no GC)
  amb4b.d:     -       -      -    2.03      - (no printing)
  amb4b.d:     -       -      -    1.42      - (no GC, no printing)
  amb5b.d:     -       -      -    6.15
  amb5b.d:     -       -      -    3.50      - (no printing)
  amb6b.d:     -       -      -    3.67      - (Index=ushort)
  amb6b.d:     -       -      -    3.91      - (Index=uint)
  amb6b.d:     -       -      -    3.43      - (Index=ushort, no GC)
  amb6b.d:     -       -      -    1.02      - (Index=ushort, no printing, no GC)
  amb7b.d:     -       -      -    3.41      - (Index=ushort)
  amb7b.d:     -       -      -    2.87      - (Index=ushort, no GC)
  amb7b.d:     -       -      -    0.60      - (Index=ushort, no printing, no GC)
  amb8b.d:     -       -   0.26    2.58      - (Index=ushort, no GC)
  amb8b.d:     -       -   0.06    0.29      - (Index=ushort, no printing, no GC)
  amb9b.d:     -       -   0.28    2.52      - (Index=ushort, no GC)
  amb9b.d:     -       -   0.05    0.22   5.58 (Index=ushort, no printing, no GC)
  amb10b.d:    -       -      -    0.16   3.36 (Index=ushort, no printing, no GC)
  amb11.py:    -       -   0.34    3.05      - (with psyco)
  amb11.py:    -       -   0.19    1.42  23.48 (with psyco, no printing)
  amb12.py:    -       -   0.26    1.84      - (with psyco)
  amb12.py:    -       -   0.13    0.28   2.91 (with psyco, no printing)
  amb13.py:    -       -      -    0.29   2.87 (with psyco, no printing)
  amb14.py:    -       -   0.26    1.76      - (with psyco)
  amb14.py:    -       -      -       -   2.21 (with psyco, no printing)
  amb15.py:    -       -   0.26    1.88      - (with psyco)
  amb15.py:    -       -   0.12    0.24   2.43 (with psyco, no printing)
  amb16.d:     -       -   0.15    1.33      -
  amb16.d:     -       -   0.04    0.03   0.09 (no printing)

Max RAM used, MB, > to file:
     Set     #1    #2   #3    #4
  amb1b.py:   -   2.9    -     -
  amb2b.py:   -   2.4   14   105
  amb2c.py:   -     -   16   107
  amb2d.py:   -     -   16   107
  amb2e.py:   -     -   16   109
  amb2b.py:   -     -    -   105 (no GC)
  amb3b.py:   -   2.4   14   106
  amb4b.d:    -     -    -   119
  amb4b.d:    -     -    -   126 (no GC)
  amb5b.d:    -     -    -   119
  amb5b.d:    -     -    -   119 (no GC)
  amb6b.d:    -     -    -    56 (Index=ushort)
  amb6b.d:    -     -    -    79 (Index=uint)
  amb6b.d:    -     -    -    61 (Index=ushort, no GC)
  amb7b.d:    -     -    -    42 (Index=ushort)
  amb7b.d:    -     -    -    42 (Index=ushort, no GC)
  amb8b.d:    -     -    -    34 (Index=ushort, no GC)
  amb11.py:   -     -    6    15
  amb12.py:   -     -    6    15
  amb14.py:   -     -    5     6
  amb15.py:   -     -    -     6
  amb16.d:    -     -    -     1.7

D compiler used:
DMD Digital Mars D Compiler v1.042

D code compiled with:
dmd -O -release -inline

Python 2.6.2 (r262:71600, Apr 21 2009, 15:05:37) [MSC v.1500 32 bit (Intel)] on win32

Psyco 1.6.0 final

Operating system: Windows Vista Home Basic

CPU 1 core Celeron 560 at 2.13 GHz, 1 GB RAM (it's quite slower than the CPU I use in most of my other benchmarks).

Here you can find all the code shown in this post:
http://www.fantascienza.net/leonardo/js/amb_chain.zip

--------------------------

Update Lug 3 2009:

The version 'amb2b.py' gaing nothing with the Psyco JIT because the lambda is not compiled. But you just need to pull the lambda out to see a twofold speed-up. A friend of mine (Andrea) has also shown me that there's no need to invert the sets:
# amb2c.py
from sys import argv

combiner = lambda s1,s2: [w1+" "+w2 for w1 in s1 for w2 in s2 if w1[-1]==w2[0]]
import psyco; psyco.bind(combiner)
print reduce(combiner, [line.split() for line in open(argv[1])])


There are standard strategies (often used by compilers) to speed up that Python code. To speed up the code we can work on the n*m inner loop, avoiding lot of controls:
# amb2d.py
from sys import argv
from collections import defaultdict

def combiner(s1, s2):
    d = defaultdict(list)
    for w2 in s2:
        d[w2[0]].append(w2)

    result = []
    for w1 in s1:
        if w1[-1] in d:
            w1s = w1 + " "
            for w2 in d[w1[-1]]:
                result.append(w1s + w2)
    return result

import psyco; psyco.bind(combiner)
print reduce(combiner, [line.split() for line in open(argv[1])])


Avoiding memory allocations (that can also be avoided with a 'static' default dict, and calls to its clear() method) in the inner loop may help:
# amb2e.py
from sys import argv
from collections import defaultdict
from operator import itemgetter

ORDA = ord('a')
NCHARS = ord('z') - ord('a') + 1
infs = [100000] * NCHARS
sups = [-100000] * NCHARS

def combiner(s1, s2):
    for i in xrange(NCHARS):
        infs[i] = 100000
        sups[i] = -100000

    s2.sort(key=itemgetter(0))
    i = 0
    for w2 in s2:
        idx = ord(w2[0]) - ORDA
        if infs[idx] > i:
            infs[idx] = i
        if sups[idx] < i:
            sups[idx] = i
        i += 1

    result = []
    for w1 in s1:
        idx = ord(w1[-1]) - ORDA
        inf = infs[idx]
        sup = sups[idx]
        if inf <= sup:
            w1s = w1 + " "
            for i in xrange(inf, sup+1):
                result.append(w1s + s2[i])
    return result

import psyco; psyco.bind(combiner)
print reduce(combiner, [line.split() for line in open(argv[1])])


Those ideas can be applied to the D code too, but 'amb9b.d' grows single linked lists, so it's better to keep the iverted scan of the sets:
// amb10b.d
import std.c.stdio: printf, putchar;
import std.stdio: readln;
import std.string: split;
import libs: reduce, MemoryPool, ArrayBuilder, sort;
import std.gc: disable;

alias ushort Index;
string[] data;

struct Node {
    Node* next;
    char first, last;
    Index idx;
}
static assert(Node.sizeof == 8);

MemoryPool!(Node) npool;

Node* newNode(char first, char last, Index idx, Node* next) {
    Node* n = npool.newItem();
    (*n).first = first;
    (*n).last = last;
    (*n).idx = idx;
    (*n).next = next;
    return n;
}

void main() {
    disable();
    Node*[][] sets;
    Index count;
    string line;
    while ((line = readln()) != null) {
        Node*[] set;
        foreach (word; line.split()) {
            data ~= word.dup;
            set ~= newNode(data[count][0], data[count][$-1], count, null);
            count++;
        }
        sets ~= set;
    }

    Node*[] combiner(Node*[] sub1, Node*[] sub2) {
        static int['z' - 'a' + 1] infs = int.max;
        static int[infs.length] sups = int.min;
        infs[] = int.max;
        sups[] = int.min;
        sub1.sort((Node* np){ return np.first; });
        foreach (int i, n1p; sub1) {
            int idx = n1p.first - 'a';
            if (i < infs[idx])
                infs[idx] = i;
            if (i > sups[idx])
                sups[idx] = i;
        }

        ArrayBuilder!(Node*) array;
        foreach (n2p; sub2) {
            int idx = n2p.last - 'a';
            int inf = infs[idx];
            if (inf != int.max) {
                int sup = sups[idx];
                for (int i = inf; i < sup+1; i++)
                    array ~= newNode(n2p.first, '-', n2p.idx, sub1[i]);
            }
        }
        return array.toarray;
    }
    Node*[] r = reduce(&combiner, sets.reverse);

    foreach (ref s; data)
        s ~= " \0";
    foreach (n; r) {
        for ( ; n; n = n.next)
            printf("%s", data[n.idx].ptr);
        putchar('\n');
    }
}

There are many other possible algorithms to reduce both memory used and running time. A simple way to reduce memory used is to store less partial solutions, for example performing the first iteration of the reduce() in a lazy way (this is less useful if the first set of words is much bigger than the other ones), this also speeds up the program, probably because it allocates much less RAM:
# amb11.py
from sys import argv
import psyco; psyco.full()

def main(sets, subs):
    for w1 in sets[0]:
        w1s = w1 + " "
        w1_last = w1[-1]
        for w2 in subs:
            if w1_last == w2[0]:
                print w1s + w2

combiner = lambda s1,s2: [w1+" "+w2 for w1 in s1 for w2 in s2 if w1[-1]==w2[0]]
sets = [line.split() for line in open(argv[1])]
main(sets, reduce(combiner, sets[1:]))


That idea can be combined with 'amb2e.py' too, leading to a very quick program (using Psyco) (uses 95 MB RAM on 'sets5.txt'):
# amb12.py
from sys import argv
from collections import defaultdict
from operator import itemgetter

ORDA = ord('a')
NCHARS = ord('z') - ord('a') + 1
infs = [100000] * NCHARS
sups = [-100000] * NCHARS

def combiner(s1, s2):
    for i in xrange(NCHARS):
        infs[i] = 100000
        sups[i] = -100000

    s2.sort(key=itemgetter(0))
    i = 0
    for w2 in s2:
        idx = ord(w2[0]) - ORDA
        if infs[idx] > i:
            infs[idx] = i
        if sups[idx] < i:
            sups[idx] = i
        i += 1

    result = []
    for w1 in s1:
        idx = ord(w1[-1]) - ORDA
        inf = infs[idx]
        sup = sups[idx]
        if inf <= sup:
            w1s = w1 + " "
            for i in xrange(inf, sup+1):
                result.append(w1s + s2[i])
    return result


def print_solutions(s0, subs):
    d = defaultdict(list)
    for w0 in s0:
        d[w0[-1]].append(w0 + " ")

    for s2 in subs:
        if s2[0] in d:
            for w0_space in d[s2[0]]:
                print w0_space + s2

import psyco; psyco.full()
sets = [line.split() for line in open(argv[1])]
print_solutions(sets[0], reduce(combiner, sets[1:]))


'amb12.py' shows that with a better algorithm you can beat even a fast language as D. You can answer that the algorithm of 'amb12.py' can be translated to D. This is true, but it means little because Python expands my capacity to invent better algorithms, D helps me less during prototyping.

In practice I have tried to improve 'amb10.d' with the ideas of 'amb12.py' (and successive Python versions), but I have failed, there's a bug (or more than one) that I have failed to remove in an about one hour (with no debugger). So in the end Python has "won", allowing me to produce bug-free code that runs faster (but see 'amb16.d' for a different algorithm that implemented in D leads to the faster program).

There are many other ways to improve the Python code, for example the memory used can be reduced pre-processing the input, removing from the sets the words that don't join with the nearby ones, using the clean() function below. Now we are sure there are some matches, so I have removed "if" from print_solutions() and combiner():
# amb13.py
from sys import argv
from collections import defaultdict
from operator import itemgetter

ORDA = ord('a')
NCHARS = ord('z') - ord('a') + 1
infs = [100000] * NCHARS
sups = [-100000] * NCHARS

def combiner(s1, s2):
    for i in xrange(NCHARS):
        infs[i] = 100000
        sups[i] = -100000

    s2.sort(key=itemgetter(0))
    i = 0
    for w2 in s2:
        idx = ord(w2[0]) - ORDA
        if infs[idx] > i:
            infs[idx] = i
        if sups[idx] < i:
            sups[idx] = i
        i += 1

    result = []
    for w1 in s1:
        idx = ord(w1[-1]) - ORDA
        inf = infs[idx]
        sup = sups[idx]
        w1s = w1 + " "
        for i in xrange(inf, sup+1):
            result.append(w1s + s2[i])
    return result


def print_solutions(s0, subs):
    d = defaultdict(list)
    for w0 in s0:
        d[w0[-1]].append(w0 + " ")

    for s2 in subs:
        for w0_space in d[s2[0]]:
            x = w0_space + s2

def clean(sets):
    starts = [set([w[0] for w in s]) for s in sets]
    ends = [set([w[-1] for w in s]) for s in sets]
    starts1 = starts[1]
    new_sets = [[w for w in sets[0] if w[-1] in starts1]]

    for i, s in enumerate(sets[1:-1], 1):
        starts_next = starts[i+1]
        ends_prec = ends[i-1]
        new_sets.append([w for w in s if w[-1] in starts_next and w[0] in ends_prec])

    ends2 = ends[-2]
    new_sets.append([w for w in sets[-1] if w[0] in ends2])
    return new_sets

import psyco; psyco.full()
sets = clean([line.split() for line in open(argv[1])])
print_solutions(sets[0], reduce(combiner, sets[1:]))


Another way to reduce memory and speed up code is to split the computation in two parts (only 11 MB RAM used with 'test5.txt' with no printing):
# amb14.py
...
import psyco; psyco.full()
sets = clean([line.split() for line in open(argv[1])])
subs1 = reduce(combiner, sets[:len(sets) // 2])
subs2 = reduce(combiner, sets[len(sets) // 2:])
print_solutions(subs1, subs2)

In theory that idea can be generalized splitting the computation in a binary tree of pair-joinings (divide et impera general strategy of algorithm design), with the following change (hard-coded for n. sets = 5) it takes only 2.13 seconds (tests5.txt, with Psyco, no printing) (this version of the code is hard-coded, so I have not included it into the zip):
s = clean([line.split() for line in open(argv[1])])
s01 = combiner(s[0], s[1])
s23 = combiner(s[2], s[3])
s234 = combiner(s23, s[4])
print_solutions(s01, s234)

With such binary strategy the clean() function is not useful anymore, I think.

But in practice currently the function combiner(s1,s2) is designed for s2 quite smaller than s1, so as it is now it's not good when s1 and s2 have about the same size, as in the binary-tree computation case.

To be used in that binary-splitting way combiner() has to be redesigned. A possible way is to sort s2 using a bucket sort that's O(n). But a better solution seems to modify combiner() so its output is already 'sorted', for example a list of NCHARS lists, according to the first or last char (so instead of clean() the program needs to process the input to create such bucketing for each input set).

After few intermediate Python versions I don't show, I have created this one, but it's slower than 'amb14.py'. All the time is spent inside do_print(). All strings have a " " attached, so all the solutions have a trailing space (10 MB RAM used with 'test5.txt' with no printing):
# amb15.py
from sys import argv

ORDA = ord('a')
NCHARS = ord('z') - ORDA + 1

def solve(sets, by_head=True):
    len_sets = len(sets)
    assert len_sets > 0
    result = [[] for _ in xrange(NCHARS)]

    if len_sets == 1:
        idx = 0 if by_head else -2
        for w in sets[0]:
            result[ord(w[idx]) - ORDA].append(w)
    else:
        sub1 = solve(sets[: len_sets // 2], False)
        sub2 = solve(sets[len_sets // 2 :], True)

        if by_head:
            for b1, b2 in zip(sub1, sub2):
                for w1 in b1:
                    result[ord(w1[0]) - ORDA].extend(w1 + w2 for w2 in b2)
        else:
            for b1, b2 in zip(sub1, sub2):
                for w2 in b2:
                    result[ord(w2[-2]) - ORDA].extend(w1 + w2 for w1 in b1)
    return result

def do_print(sub1, sub2):
    for b1, b2 in zip(sub1, sub2):
        if b1 and b2:
            for w1 in b1:
                for w2 in b2:
                    print w1 + w2

import psyco; psyco.full()
sets = [[w + " " for w in line.split()] for line in open(argv[1])]
len_sets = len(sets)
sub1 = solve(sets[: len_sets // 2], False)
sub2 = solve(sets[len_sets // 2 :], True)
do_print(sub1, sub2)


This 'amb15.py' looks simple, so I have translated it to D (while I have not translated 'amb14.py'). The solve() can be made shorter. The input uses printf() because wtitefln is too much slow. Now the computation takes a minimal time, all the time is spent printing, so I stop here (5 MB RAM used with 'test5.txt' with no printing):
// amb16.d
import std.stdio: readln;
import std.string: split;

const int NCHARS = 'z' - 'a' + 1;

string[][] solve(string[][] sets, bool byHead=true) {
    if (!sets.length) throw new Exception("");
    auto result = new string[][](NCHARS, 0);

    if (sets.length == 1) {
        if (byHead)
            foreach (w; sets[0])
                result[w[0] - 'a'] ~= w;
        else
           foreach (w; sets[0])
                result[w[$-2] - 'a'] ~= w;
    } else {
        auto sub1 = solve(sets[0 .. sets.length / 2], false);
        auto sub2 = solve(sets[sets.length / 2 .. $], true);

        if (byHead) {
            for (int i; i < NCHARS; i++) {
                auto b1 = sub1[i];
                auto b2 = sub2[i];
                foreach (w1; b1) {
                    int idx = w1[0] - 'a';
                    int pos = result[idx].length;
                    result[idx].length = pos + b2.length;
                    foreach (j, w2; b2)
                        result[idx][pos + j] = w1 ~ w2;
                }
            }
        } else {
            for (int i; i < NCHARS; i++) {
                auto b1 = sub1[i];
                auto b2 = sub2[i];
                foreach (w2; b2) {
                    int idx = w2[$-2] - 'a';
                    int pos = result[idx].length;
                    result[idx].length = pos + b1.length;
                    foreach (j, w1; b1)
                        result[idx][pos + j] ~= w1 ~ w2;
                }
            }
        }
    }

    return result;
}

void main() {
    string[][] sets;
    string line;
    while ((line = readln()) != null) {
        string[] set;
        foreach (word; line.split())
            set ~= (word ~ ' ');
        sets ~= set;
    }

    auto sub1 = solve(sets[0 .. sets.length / 2], false);
    auto sub2 = solve(sets[sets.length / 2 .. $], true);

    for (int i; i < NCHARS; i++) {
        auto b1 = sub1[i];
        auto b2 = sub2[i];
        if (b1.length && b2.length)
            foreach (w1; b1)
                foreach (w2; b2)
                    printf("%*s%*s\n", w1.length, w1.ptr, w2.length, w2.ptr);
    }
}

[Go back to the article index]