Wednesday, January 21, 2009

A performance update

I've continued playing with the LLVM. I discovered that when generating code for the normcdf and Black-Scholes functions I did not tell LLVM that the functions that were called (exp etc.) are actually pure functions. That meant that the LLVM didn't perform CSE properly.

So here are some updated numbers for computing an option prices for 10,000,000 options:

  • Pure Haskell: 8.7s
  • LLVM: 2.0s
Just as a reminder, the code for normcdf looks like this:
normcdf x = x %< 0 ?? (1 - w, w)
  where w = 1.0 - 1.0 / sqrt (2.0 * pi) * exp(-l*l / 2.0) * poly k
        k = 1.0 / (1.0 + 0.2316419 * l)
        l = abs x
        poly = horner coeff 
        coeff = [0.0,0.31938153,-0.356563782,1.781477937,-1.821255978,1.330274429] 
A noteworthy thing is that exactly the same code can be used both for the pure Haskell and the LLVM code generation; it's just a matter of overloading.

Vectors

An very cool aspect of the LLVM is that it has vector instructions. On the x86 these translate into using the SSE extensions to the processor and can speed up computations by doing things in parallel.

Again, by using overloading, the exact same code can be used to compute over vectors of numbers as with individual numbers.

So what about performance? I used four element vectors of 32 bit floating point numbers and got these results:

  • Pure Haskell: 8.7s
  • LLVM, scalar: 2.0s
  • LLVM, vector: 1.1s
  • C, gcc -O3: 2.5s
Some caveats if you try out vectors in the LLVM.
  • Only on MacOS does the LLVM package give you fast primitive functions, because that's the only platform that seems to have this as a standard.
  • The vector version of floating point comparison has not been implemented in the LLVM yet.
  • Do not use two element vectors of type 32 bit floats. This will generate code that is wrong on the x86. I sent in a bug report about this, but was told that it is a feature and not a bug. (I kid you not.) To make the code right you have to manually insert EMMS instructions.
  • The GHC FFI is broken for all operations that allocate memory for a Storable, e.g., alloca, with, withArray etc. These operations do not take the alignment into account when allocating. This means that, e.g., a vector of four floats may end up on 8 byte alignment instead of 16. This generates a segfault.
On the whole, I'm pretty happy with the LLVM performance now; about 8 times faster than ghc on this example.

[Edit:] Added point about broken FFI.

Labels: , ,

Saturday, January 10, 2009

LLVM arithmetic

So we want to compute x2-5x+6 using the Haskell LLVM bindings. It would look some something like this.
    xsq <- mul x x
    x5  <- mul 5 x
    r1  <- sub xsq x5
    r   <- add r1 6
Not very readable, it would be nicer to write
    r   <- x^2 - 5*x + 6
But, e.g., the add instruction has the (simplified) type Value a -> Value a -> CodeGenFunction r (Value a), where CodeGenFunction is the monad for generating code for a function. (BTW, the r type variable is used to keep track of the return type of the function, used by the ret instruction.)

We can't change the return type of add, but we can change the argument type.

type TValue r a = CodeGenFunction r (Value a)
add' :: TValue r a -> TValue r a -> TValue r a
add' x y = do x' <- x; y' <- y; add x' y'
Now the type fits what the Num class wants. And we can make an instance declaration.
instance (Num a) => Num (TValue r a) where
    (+) = add'
    (-) = sub'
    (*) = mul'
    fromInteger = return . valueOf . fromInteger
We are getting close, the only little thing is that the x in our original LLVM code has type Value a rather than TValue r a, but return takes care of that. So:
    let x' = return x
    r <- x'^2 - 5*x' + 6
And a quick look at the generated LLVM code (for Double) shows us that all is well.
; x in %0
 %1 = mul double %0, %0
 %2 = mul double 5.000000e+00, %0
 %3 = sub double %1, %2
 %4 = add double %3, 6.000000e+00

All kinds of numeric instances and some other goodies are available in the LLVM.Util.Arithmetic module. Here is a complete Fibonacci (again) using this.

import Data.Int
import LLVM.Core
import LLVM.ExecutionEngine
import LLVM.Util.Arithmetic

mFib :: CodeGenModule (Function (Int32 -> IO Int32))
mFib = recursiveFunction $ \ rfib n ->
    n %< 2 ? (1, rfib (n-1) + rfib (n-2))

main :: IO ()
main = do
    let fib = unsafeGenerateFunction mFib
    print (fib 22)
The %< operator compares values returning a TValue r Bool. The c ? (t, e) is a conditional expression, like C's c ? t : e.

The type given to mFib is not the most general one, of course. The most general one can accept any numeric type for argument and result. Here it is:

mFib :: (Num a, Cmp a, IsConst a,
         Num b, Cmp b, IsConst b, FunctionRet b) =>
        CodeGenModule (Function (a -> IO b))
mFib = recursiveFunction $ \ rfib n ->
    n %< 2 ? (1, rfib (n-1) + rfib (n-2))
It's impossible to generate code for mFib directly; code can only be generated for monomorphic types. So a type signature is needed when generating code to force a monomorphic type.
main = do
    let fib :: Int32 -> Double
        fib = unsafeGenerateFunction mFib
        fib' :: Int16 -> Int64
        fib' = unsafeGenerateFunction mFib
    print (fib 22, fib' 22)

Another example

Let's try a more complex example. To pick something mathematical to have lots of formulae we'll do the Cumulative Distribution Function. For the precision of a Float it can be coded like this in Haskell (normal Haskell, no LLVM):
normcdf x = if x < 0 then 1 - w else w
  where w = 1 - 1 / sqrt (2 * pi) * exp(-l*l / 2) * poly k
        k = 1 / (1 + 0.2316419 * l)
        l = abs x
        poly = horner coeff 
        coeff = [0.0,0.31938153,-0.356563782,1.781477937,-1.821255978,1.330274429] 

horner coeff base = foldr1 multAdd coeff
  where multAdd x y = y*base + x
We cannot use this directly, it has type normcdf :: (Floating a, Ord a) => a -> a. The Ord context is a problem, because there are no instance of Ord for LLVM types. The comparison is the root of the problem, since it returns a Bool rather than a TValue r Bool.

It's possible to hide the Prelude and overload the comparisons, but you cannot overload the if construct. So a little rewriting is necessary regardless. Let's just bite the bullet and change the first line.

normcdf x = x %< 0 ? (1 - w, w)
And with that change, all we need to do is
mNormCDF = createFunction ExternalLinkage $ arithFunction $ normcdf

main :: IO ()
main = do
    writeFunction "CDF.bc" (mNormCDF :: CodeGenModule (Function (Float -> IO Float)))
So what happened here? Looking at normcdf it contains a things that the LLVM cannot handle, like lists. But all the list operations happen when the Haskell code runs and nothing of that is left in the LLVM code.

If you optimize and generate code for normcdf it looks like this (leaving out constants and half the code):

__fun1:
        subl    $28, %esp
        pxor    %xmm0, %xmm0
        ucomiss 32(%esp), %xmm0
        jbe     LBB1_2
        movss   32(%esp), %xmm0
        mulss   %xmm0, %xmm0
        divss   LCPI1_0, %xmm0
        movss   %xmm0, (%esp)
        call    _expf
        fstps   24(%esp)
        movss   32(%esp), %xmm0
        mulss   LCPI1_1, %xmm0
        movss   %xmm0, 8(%esp)
        movss   LCPI1_2, %xmm0
        movss   8(%esp), %xmm1
        addss   %xmm0, %xmm1
        movss   %xmm1, 8(%esp)
        movaps  %xmm0, %xmm1
        divss   8(%esp), %xmm1
        movaps  %xmm1, %xmm2
        mulss   LCPI1_3, %xmm2
        addss   LCPI1_4, %xmm2
        mulss   %xmm1, %xmm2
        addss   LCPI1_5, %xmm2
        mulss   %xmm1, %xmm2
        addss   LCPI1_6, %xmm2
        mulss   %xmm1, %xmm2
        addss   LCPI1_7, %xmm2
        mulss   %xmm1, %xmm2
        pxor    %xmm1, %xmm1
        addss   %xmm1, %xmm2
        movss   24(%esp), %xmm1
        mulss   LCPI1_8, %xmm1
        mulss   %xmm2, %xmm1
        addss   %xmm0, %xmm1
        subss   %xmm1, %xmm0
        movss   %xmm0, 20(%esp)
        flds    20(%esp)
        addl    $28, %esp
        ret
LBB1_2:
        ...
And that looks quite nice, in my opinion.

Black-Scholes

I work at a bank these days, so let's do the most famous formula in financial maths, the Black-Scholes formula for pricing options. Here's a Haskell rendition of it.
blackscholes iscall s x t r v = if iscall then call else put
  where call = s * normcdf d1 - x*exp (-r*t) * normcdf d2
        put  = x * exp (-r*t) * normcdf (-d2) - s * normcdf (-d1)
        d1 = (log(s/x) + (r+v*v/2)*t) / (v*sqrt t)
        d2 = d1 - v*sqrt t
Again, it uses an if, so it needs a little fix.
blackscholes iscall s x t r v  = iscall ? (call, put)
With that fix, code can be generated directly from blackscholes. The calls to normcdf are expanded inline, but it's easy to make some small changes to the code so that it actually does function calls.

Some figures

To test the speed of the generated code I ran blackscholes over a portfolio of 10,000,000 options and summed their value. The time excludes the time to set up the portfolio array, because that is done in Haskell. I also ran the code in pure Haskell on a list with 10,000,000 elements.
Unoptimized LLVM   17.5s
Optimized LLVM      8.2s
Pure Haskell        9.3s
The only surprise here is how well pure Haskell (with -O2) performs. This is a very good example for Haskell though, because the list gets fused away and everything is strict. A lot of the time is spent in log and exp in this code, so perhaps the similar results are not so surprising after all.

Labels: , ,

Wednesday, January 07, 2009

LLVM

The LLVM, Low Level Virtual Machine, is a really cool compiler infrastructure project with many participants. The idea is that if you want to make a new high quality compiler you just have to generate LLVM code, and then there are lots of optimizations and code generators available to get fast code.

There are different ways to generate input to the LLVM tools. You can generate a text file with LLVM code and feed it to the tools, or you can use bindings for some programming language and programmatically build the LLVM code. The original bindings from the LLVM project is for C++, but they also provide C bindings. On top of the C bindings you can easily interface to other languages; for instance O'Caml and Haskell.

There are also diffent things you can do to LLVM code you have build programmatically. You can transform it, you can write to a file, you can run an interpreter on it, or execute it with a JIT compiler.

Haskell LLVM bindings

There is a Haskell binding to the LLVM. It has two layers. You can either work on the C API level and have ample opportunity to shoot your own limbs to pieces, or you can use the high level interface which is mostly safe.

Bryan O'Sullivan did all the hard work of taking the C header files and producing the corresponding Haskell FFI files. He also made a first stab at the high level interface, which I have since change a lot (for better or for worse).

An example

Let's do an example. We'll write the LLVM code for this function
  f x y z = (x + y) * z
In Haskell this function is polymorphic, but when generating machine code we have to settle for a type. Let's pick Int32. (The Haskell Int type cannot be used in talking to LLVM; it doesn't a well defined size.) Here is how it looks:
mAddMul :: CodeGenModule (Function (Int32 -> Int32 -> Int32 -> IO Int32))
mAddMul = 
  createFunction ExternalLinkage $ \ x y z -> do
    t <- add x y
    r <- mul t z
    ret r
For comparison, the LLVM code in text for for this would be:
define i32 @_fun1(i32, i32, i32) {
        %3 = add i32 %0, %1
        %4 = mul i32 %3, %2
        ret i32 %4
}
So what does the Haskell code say? The mAddMul definition is something in the CodeGenModule monad, and it generates a Function of type Int32 -> Int32 -> Int32 -> IO Int32. That last is the type of f above, except for that IO. Why the IO? The Haskell LLVM bindings forces all defined functions to return something in the IO monad, because there are no restriction on what can happen in the LLVM code; it might very well do IO. So to be on the safe side, there's always an IO on the type. If we know the function is harmless, we can use unsafePerformIO to get rid of it.

So the code does a createFunction which does what the name suggests. The ExternalLinkage argument says that this function will be available outside the module it's in, the obvious opposite being InternalLinkage. Using InternalLinkage is like saying static on the top level in C. In this examples it doesn't really matter which we pick.

The function has three arguments x y z. The last argument to createFunction should be a lambda expression with the right number of arguments, i.e., the number of arguments should agree with the type. We the use monadic syntax to generate an add, mul, and ret instruction.

The code looks like assembly code, which is the level that LLVM is at. It's a somewhat peculiar assembly code, because it's on SSA (Static Single Assignment) form. More about that later.

So what can we do with this function? Well, we can generate machine code for it and call it.

main = do
    addMul <- simpleFunction mAddMul
    a <- addMul 2 3 4
    print a
In this code addMul has type Int32 -> Int32 -> Int32 -> IO Int32, so it has to be called in the IO monad. Since this is a pure function, we can make the type pure, i.e., Int32 -> Int32 -> Int32 -> Int32.
main = do
    addMul <- simpleFunction mAddMul
    let addMul' = unsafePurify addMul
    print (addMul' 2 3 4)
The unsafePurify functions is simply an extension of unsafePerformIO that drops the IO on the result of a function.

So that was pretty easy. To make a function, just specify the LLVM code using the LLVM DSEL that the Haskell bindings provides.

Fibonacci

No FP example is complete without the Fibonacci function, so here it is.
mFib :: CodeGenModule (Function (Word32 -> IO Word32))
mFib = do
    fib <- newFunction ExternalLinkage
    defineFunction fib $ \ arg -> do
        -- Create the two basic blocks.
        recurse <- newBasicBlock
        exit <- newBasicBlock

        -- Test if arg > 2
        test <- icmp IntUGT arg (2::Word32)
        condBr test recurse exit

        -- Just return 1 if not > 2
        defineBasicBlock exit
        ret (1::Word32)

        -- Recurse if > 2, using the cumbersome plus to add the results.
        defineBasicBlock recurse
        x1 <- sub arg (1::Word32)
        fibx1 <- call fib x1
        x2 <- sub arg (2::Word32)
        fibx2 <- call fib x2
        r <- add fibx1 fibx2
        ret r
    return fib
Instead of using createFunction to create the function we're using newFunction and defineFunction. The former is a shorthand for the latter two together. But splitting making the function and actually defining it means that we can refer to the function before it's been defined. We need this since fib is recursive.

Every instruction in the LLVM code belongs to a basic block. A basic block is a sequence of non-jump instructions (call is allowed in the LLVM) ending with some kind of jump. It is always entered at the top only. The top of each basic block can be thought of as a label that you can jump to, and those are the only places that you can jump to.

The code for fib starts with a test if the argument is Unsigned Greater Than 2. The condBr instruction branches to recurse if test is true otherwise to exit. To be able to refer to the two branch labels (i.e., basic blocks) before they are defined we create them with newBasicBlock and then later define them with defineBasicBlock. The defineBasicBlock simply starts a new basic block that runs to the next basic block start, or to the end of the function. The type system does not check that the basic block ends with a branch (I can't figure out how to do that without making the rest of the code more cumbersome).

In the false branch we simply return 1, and in the true branch we make the two usual recursive calls, add the results, and return the sum.

As you can see a few type annotations are necessary on constants. In my opinion they are quite annoying, because if you write anything different from ::Word32 in those annotations there will be a type error. This means that in principle the compiler has all the information, it's just too "stupid" to use it.

The performance you get from this Fibonacci function is decent, but in fact worse than GHC with -O2 gives. Even with full optimization turned on for the LLVM code it's still not as fast as GHC for this function.

[Edit: Added assembly] Here is the assembly code for Fibonacci. Note how there is only one recursive call. The other call has been transformed into a loop.

_fib:
 pushl %edi
 pushl %esi
 subl $4, %esp
 movl 16(%esp), %esi
 cmpl $2, %esi
 jbe LBB1_4
LBB1_1:
 movl $1, %edi
 .align 4,0x90
LBB1_2:
 leal -1(%esi), %eax
 movl %eax, (%esp)
 call _fib
 addl %edi, %eax
 addl $4294967294, %esi
 cmpl $2, %esi
 movl %eax, %edi
 ja LBB1_2
LBB1_3:
 addl $4, %esp
 popl %esi
 popl %edi
 ret
LBB1_4:
 movl $1, %eax
 jmp LBB1_3

Hello, World!

The code for printing "Hello, World!":
import Data.Word
import LLVM.Core
import LLVM.ExecutionEngine

bldGreet :: CodeGenModule (Function (IO ()))
bldGreet = do
    puts <- newNamedFunction ExternalLinkage "puts" :: TFunction (Ptr Word8 -> IO Word32)
    greetz <- createStringNul "Hello, World!"
    func <- createFunction ExternalLinkage $ do
      tmp <- getElementPtr greetz (0::Word32, (0::Word32, ()))
      call puts tmp -- Throw away return value.
      ret ()
    return func

main :: IO ()
main = do
    greet <- simpleFunction bldGreet
    greet
To get access to the C function puts we simply declare it and rely on the linker to link it in. The greetz variable has type pointer to array of characters. So to get a pointer to the first character we have to use the rather complicated getElementPtr instruction. See FAQ about it.

Phi instructions

Let's do the following simple C function
int f(int x)
{
  if (x < 0) x = -x;
  return (x+1);
}
Let's try to write some corresponding LLVM code:
  createFunction ExternalLinkage $ \ x -> do
    xneg <- newBasicBlock
    xpos <- newBasicBlock
    t <- icmp IntSLT x (0::Int32)
    condBr t xneg xpos

    defineBasicBlock xneg
    x' <- sub (0::Int32) x
    br xpos

    defineBasicBlock xpos
    r1 <- add ??? (1::Int32)
    ret r1
But what should we put at ???? When jumping from the condBr the value is in x, but when jumping from the negation block the value is in x'. And this is how SSA works. Every instruction puts the value in a new "register", so this situation is unavoidable. This is why SSA (and thus LLVM) form has phi instructions. This is a pseudo-instruction to tell the code generator what registers should be merged at the entry of a basic block. So the real code looks like this:
mAbs1 :: CodeGenModule (Function (Int32 -> IO Int32))
mAbs1 = 
  createFunction ExternalLinkage $ \ x -> do
    top <- getCurrentBasicBlock
    xneg <- newBasicBlock
    xpos <- newBasicBlock
    t <- icmp IntSLT x (0::Int32)
    condBr t xneg xpos

    defineBasicBlock xneg
    x' <- sub (0::Int32) x
    br xpos

    defineBasicBlock xpos
    r <- phi [(x, top), (x', xneg)]
    r1 <- add r (1::Int32)
    ret r1
The phi instruction takes a list of registers to merge, and paired up with each register is the basic block that the jump comes from. Since the first basic block in a function is created implicitely we have to get it with getCurrentBasicBlock which returns the current basic block.

If, like me, you have a perverse interest in the machine code that gets generated here is the optimized code for that function on for x86:

__fun1:
        movl    4(%esp), %eax
        movl    %eax, %ecx
        sarl    $31, %ecx
        addl    %ecx, %eax
        xorl    %ecx, %eax
        incl    %eax
        ret
Note how the conditional jump has cleverly been replaced by some non-jumping instructions. I think this code is as good as it gets.

Loops and arrays

Let's do a some simple array code, the dot product of two vectors. The function takes a length and pointers to two vectors. It sums the elementwise product of the vectors. Here's the C code:
double
dotProd(unsigned int len, double *aPtr, double *bPtr)
{
    unsigned int i;
    double s;

    s = 0;
    for (i = 0; i != len; i++)
        s += aPtr[i] * bPtr[i];
    return s;
}
The corresponding LLVM code is much more complicated and has some new twists.
import Data.Word
import Foreign.Marshal.Array
import LLVM.Core
import LLVM.ExecutionEngine

mDotProd :: CodeGenModule (Function (Word32 -> Ptr Double -> Ptr Double -> IO Double))
mDotProd =
  createFunction ExternalLinkage $ \ size aPtr bPtr -> do
    top <- getCurrentBasicBlock
    loop <- newBasicBlock
    body <- newBasicBlock
    exit <- newBasicBlock

    -- Enter loop, must use a br since control flow joins at the loop bb.
    br loop

    -- The loop control.
    defineBasicBlock loop
    i <- phi [(valueOf (0 :: Word32), top)]  -- i starts as 0, when entered from top bb
    s <- phi [(valueOf 0, top)]  -- s starts as 0, when entered from top bb
    t <- icmp IntNE i size       -- check for loop termination
    condBr t body exit

    -- Define the loop body
    defineBasicBlock body

    ap <- getElementPtr aPtr (i, ()) -- index into aPtr
    bp <- getElementPtr bPtr (i, ()) -- index into bPtr
    a <- load ap                 -- load element from a vector
    b <- load bp                 -- load element from b vector
    ab <- mul a b                -- multiply them
    s' <- add s ab               -- accumulate sum

    i' <- add i (valueOf (1 :: Word32)) -- Increment loop index

    addPhiInputs i [(i', body)]  -- Control flow reaches loop bb from body bb
    addPhiInputs s [(s', body)]
    br loop                      -- And loop

    defineBasicBlock exit
    ret (s :: Value Double)      -- Return sum

main = do
    ioDotProd <- simpleFunction mDotProd
    let dotProd a b =
         unsafePurify $
         withArrayLen a $ \ aLen aPtr ->
         withArrayLen b $ \ bLen bPtr ->
         ioDotProd (fromIntegral (aLen `min` bLen)) aPtr bPtr

    let a = [1,2,3]
        b = [4,5,6]
    print $ dotProd a b
    print $ sum $ zipWith (*) a b
First we have to set up the looping machinery. There a four basic blocks involved: the implicit basic block that is created at the start of every function, top; the top of the loop, loop; the body of the loop, body; and finally the block with the return from the function, exit.

There are two "registers", the loop index i and the running sum s that arrive from two different basic blocks at the top of the loop. When entering the loop from the first time they should be 0. That's what the phi instruction specifies. The valueOf function simply turns a constant into an LLVM value. It's worth noting that the initial values for the two variables are constant rather than registers. The control flow also reached the basic block loop from the end of body, but we don't have the names of those registers in scope yet, so we can't put them in the phi instruction. Instead, we have to use addPhiInputs to add more phi inputs later (when the registers are in scope).

The most mysterious instruction in the LLVM is getElementPtr. It simply does address arithmetic, so it really does something quite simple. But it can perform several levels of address arithmetic when addressing through multilevel arrays and structs. In can take several indicies, but since here we simply want to add the index variable to a pointer the usage is pretty simple. Doing getElementPtr aPtr (i, ()) corresponds to aPtr + i in C.

To test this function we need pointers to two vectors. The FFI function withArrayLen temporarily allocates the vector and fills it with elements from the list.

The essential part of the function looks like this in optimized x86 code:

        pxor    %xmm0, %xmm0
        xorl    %esi, %esi
        .align  4,0x90
LBB1_2:
        movsd   (%edx,%esi,8), %xmm1
        mulsd   (%ecx,%esi,8), %xmm1
        incl    %esi
        cmpl    %eax, %esi
        addsd   %xmm1, %xmm0
        jne     LBB1_2
Which is pretty good. Improving this would have to use SSD vector instructions. This is possible using the LLVM vector type, but I'll leave that for now.

Abstraction

The loop structure in dotProd is pretty common, so we would like to abstract it out for reuse. The creation of basic blocks and phi instructions is rather fiddly so it would be nice to do this once and not worry about it again.

What are the parts of the loop? Well, let's just do a simple "for" loop that loops from a lower index (inclusive) to an upper index (exclusive) and executes the loop body for each iteration. So there should be three arguments to the loop function: lower bound, upper bound and loop body. What is the loop body? Since the LLVM is using SSA the loop body can't really update the loop state variables. Instead it's like a pure functional language where you have to express it as a state transformation. So the loop body will take the old state and return a new state. It's also useful to pass the loop index to the loop body. Now when we've introduced the notion of a loop state we also need to have an initial value for the loop state as an argument to the loop function.

Let's start out easy and let the state to be updated in the loop be a single value. In dotProd it's simply the running sum (s).

forLoop low high start incr = do
    top <- getCurrentBasicBlock
    loop <- newBasicBlock
    body <- newBasicBlock
    exit <- newBasicBlock

    br loop

    defineBasicBlock loop
    i <- phi [(low, top)]
    state <- phi [(start, top)]
    t <- icmp IntNE i high
    condBr t body exit

    defineBasicBlock body

    state' <- incr i state
    i' <- add i (valueOf 1)

    body' <- getCurrentBasicBlock
    addPhiInputs i [(i', body')]
    addPhiInputs state [(state', body')]
    br loop
    defineBasicBlock exit

    return state
The low and high arguments are simply the loop bounds, start is the start value for the loop state variable, and finally incr is invoked in the loop body to get the new value for the state variable. Note that the incr can contain new basic blocks so there's no guarantee we're in the same basic block after incr has been called. That's why there is a call to getCurrentBasicBlock before adding to the phi instructions.

So the original loop in dotProd can now be written

    s <- forLoop 0 size 0 $ \ i s -> do
      ap <- getElementPtr aPtr (i, ()) -- index into aPtr
      bp <- getElementPtr bPtr (i, ()) -- index into bPtr
      a <- load ap                 -- load element from a vector
      b <- load bp                 -- load element from b vector
      ab <- mul a b                -- multiply them
      s' <- add s ab               -- accumulate sum
      return s'
So that wasn't too bad. But what if the loop needs multiple state variables? Or none? The tricky bit is handling the phi instructions since the number of instructions needed depends on how many state variables we have. So let's creat a class for types that can be state variables. This way we can use tuples for multiple state variables. The class needs two methods, the generalization of phi and the generalization of addPhiInputs.
class Phi a where
    phis :: BasicBlock -> a -> CodeGenFunction r a
    addPhis :: BasicBlock -> a -> a -> CodeGenFunction r ()
A simple instance is when we have no state variables.
instance Phi () where
    phis _ _ = return ()
    addPhis _ _ _ = return ()
We also need to handle the case with a single state variable. All LLVM values are encapsulated in the Value type, so this is the one we create an instance for.
instance (IsFirstClass a) => Phi (Value a) where
    phis bb a = do
        a' <- phi [(a, bb)]
        return a'
    addPhis bb a a' = do
        addPhiInputs a [(a', bb)]
Finally, here's the instance for pair. Other tuples can be done in the same way (or we could just use nested pairs).
instance (Phi a, Phi b) => Phi (a, b) where
    phis bb (a, b) = do
        a' <- phis bb a
        b' <- phis bb b
        return (a', b')
    addPhis bb (a, b) (a', b') = do
        addPhis bb a a'
        addPhis bb b b'
Using this new class the looping function becomes
forLoop :: forall i a r . (Phi a, Num i, IsConst i, IsInteger i, IsFirstClass i) =>
           Value i -> Value i -> a -> (Value i -> a -> CodeGenFunction r a) -> CodeGenFunction r a
forLoop low high start incr = do
    top <- getCurrentBasicBlock
    loop <- newBasicBlock
    body <- newBasicBlock
    exit <- newBasicBlock

    br loop

    defineBasicBlock loop
    i <- phi [(low, top)]
    vars <- phis top start
    t <- icmp IntNE i high
    condBr t body exit

    defineBasicBlock body

    vars' <- incr i vars
    i' <- add i (valueOf 1 :: Value i)

    body' <- getCurrentBasicBlock
    addPhis body' vars vars'
    addPhiInputs i [(i', body')]
    br loop
    defineBasicBlock exit

    return vars

File operations

The Haskell bindings provide two convenient functions - writeBitcodeToFile and readBitcodeFromFile - for writing and reading modules in the LLVM binary format.

A simple example:

import Data.Int
import LLVM.Core

mIncr :: CodeGenModule (Function (Int32 -> IO Int32))
mIncr = 
  createNamedFunction ExternalLinkage "incr" $ \ x -> do
    r <- add x (1 :: Int32)
    ret r

main = do
    m <- newModule
    defineModule m mIncr
    writeBitcodeToFile "incr.bc" m
Running this will produce the file incr.bc which can be processed with the usual LLVM tools. E.g.
$ llvm-dis < incr.bc  # to look at the LLVM code
$ opt -std-compile-opts incr.bc -f -o incrO.bc # run optimizer
$ llvm-dis < incrO.bc  # to look at the optimized LLVM code
$ llc incrO.bc # generate assembly code
$ cat incrO.s  # look at assembly code
Reading a module file is equally easy, but what can you do with a module you have read? It could contain anything. To extract things from a module there is a function getModuleValues which returns a list of name-value pairs of all externally visible functions and global variables. The values all have type ModuleValue. To convert a ModuleValue to a regular Value you have to use castModuleValue. This is a safe conversion function that makes a dynamic type test to make sure the types match (think of ModuleValue as Dynamic and castModuleValue as fromDynamic).

Here's an example:

import Data.Int
import LLVM.Core
import LLVM.ExecutionEngine

main = do
    m <- readBitcodeFromFile "incr.bc"
    ee <- createModuleProviderForExistingModule m >>= createExecutionEngine
    funcs <- getModuleValues m
    let ioincr :: Function (Int32 -> IO Int32)
        Just ioincr = lookup "incr" funcs >>= castModuleValue
        incr = unsafePurify $ generateFunction ee ioincr

    print (incr 41)
This post is getting rather long, so I'll let this be the last example for today.

Labels: , , ,