simulateChineseRestaurant = function(num_elements,alpha) {
allocation <- numeric(num_elements)
allocation[1] <- 1
num_cats <- 1
num_in_cats <- 1
for(i in 2:num_elements) {
u <- runif(1,0,1)
if ( u < alpha / (alpha + i - 1) ) {
# choose a new category
num_cats <- num_cats + 1
this_category <- num_cats
} else {
# choose an existing category
this_category <- sample.int(num_cats,1,num_in_cats / (alpha + i - 1))
}
allocation[i] <- this_category
}
return(allocation)
}
# 1000 replicates has a lot of Monte Carlo error
# try 1e4 or 1e5?
reps <- 20000
#reps <- 2000
#reps <- 1000
date()
## [1] "Wed Aug 31 16:08:09 2016"
tab <- table(replicate(reps,length(unique(simulateChineseRestaurant(500,0.05))))) / reps
date()
## [1] "Wed Aug 31 16:09:36 2016"
tab/sum(tab)
##
## 1 2 3 4 5
## 0.7093 0.2455 0.0403 0.0044 0.0005
plot(tab)
date()
## [1] "Wed Aug 31 16:09:36 2016"
tab <- table(replicate(reps,length(unique(simulateChineseRestaurant(500,0.5))))) / reps
date()
## [1] "Wed Aug 31 16:11:00 2016"
tab/sum(tab)
##
## 1 2 3 4 5 6 7 8 9
## 0.03775 0.13595 0.22355 0.23105 0.18125 0.10870 0.05105 0.02090 0.00665
## 10 11 12
## 0.00260 0.00050 0.00005
plot(tab)
#loading the library
path.to.libraries<-"/Users/lavila/Documents/Work/UCDavis/Winter_2016/Code/R_and_C/"
library.name<-"simulateChineseRestaurantC"
full.library.name<-paste(path.to.libraries,library.name,.Platform$dynlib.ext,sep="")
library.path<-file.path(path.to.libraries,paste(library.name,.Platform$dynlib.ext,sep=""))
if(file.exists(full.library.name)){
dyn.load(library.path)
} else {
warning(paste("file with name",full.library.name," not found. C libraries not loaded"))
}
## the wrapper function
simulateChineseRestaurantC <- function(num.elements=10,alpha=0.5) {
if (!is.numeric(num.elements))
stop("argument x must be numeric")
out <- .C("simulateChineseRestaurantC",
n=as.integer(num.elements),
x=as.double(c(alpha)),
allocation=as.integer(numeric(num.elements)))
return(out$allocation)
}
##running it
date()
## [1] "Wed Aug 31 16:11:00 2016"
tab2 <- table(replicate(reps,length(unique(simulateChineseRestaurantC(500,0.05))))) / reps
date()
## [1] "Wed Aug 31 16:11:01 2016"
tab2/sum(tab2)
##
## 1 2 3 4 5 6
## 0.71710 0.23755 0.04080 0.00420 0.00030 0.00005
plot(tab2)
date()
## [1] "Wed Aug 31 16:11:01 2016"
tab3 <- table(replicate(reps,length(unique(simulateChineseRestaurantC(500,0.5))))) / reps
date()
## [1] "Wed Aug 31 16:11:02 2016"
tab3/sum(tab3)
##
## 1 2 3 4 5 6 7 8 9
## 0.03775 0.13795 0.22095 0.23465 0.17380 0.10870 0.05580 0.02075 0.00715
## 10 11
## 0.00195 0.00055
plot(tab3)
#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <R.h>
#include <Rmath.h>
void simulateChineseRestaurantC(int *numElements, double *alpha,int *allocation)
{
allocation[0] = 1;
int num_cats = 1;
int num_in_cats = 1;
int this_category = 1;
GetRNGstate();
for (int i=2;i <= numElements[0];i++){
double u = runif(0.0,1.0);
if (u < (alpha[0] / (alpha[0]+ ((double)i) -1.0)) ){
//choose a new category
num_cats = num_cats + 1;
this_category = num_cats;
} else {
//this_category=sample(num_cats,1,num_in_cats/(alpha[0]+i-1.0));
this_category=round(runif(1,num_cats));
}
allocation[i-1]=this_category;
}
PutRNGstate();
}
microbenchmark(
simulateChineseRestaurant(500,0.5),
simulateChineseRestaurantC(500,0.5),
times=100)
## Unit: microseconds
## expr min lq mean
## simulateChineseRestaurant(500, 0.5) 2951.895 3626.143 5792.70760
## simulateChineseRestaurantC(500, 0.5) 25.918 30.056 45.90779
## median uq max neval
## 4731.9940 6760.525 44008.192 100
## 47.4635 54.275 110.473 100
It was compiled with the following command line statement
R CMD SHLIB simulateChineseRestaurantC.c