# 参考サイト:http://spark.rstudio.com/

#■sparklyrパッケージインストール
#install.packages("sparklyr")

#■下記のパッケージは必要に応じて実施
#install.packages("mnormt")
#install.packages("digest")
#install.packages("openssl")
#install.packages("rlang")
#install.packages("tibble")
#devtools::install_github("rstudio/sparklyr")

#■SPARK_HOMEを確認
Sys.getenv("SPARK_HOME")
## [1] "E:\\spark-2.1.0-bin-hadoop2.7\\bin"
#■sparkインストール
#sparklyr::spark_install(version="2.1.0")

#■sparkコネクション
library(sparklyr)
sc <- spark_connect(master="local", version="2.1.0")

# nycflights13,Lahmanインストール
#install.packages(c("nycflights13", "Lahman"))

#■Using dplyr
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
iris_tbl <- copy_to(sc, iris)
flights_tbl <- copy_to(sc, nycflights13::flights, "flights")
batting_tbl <- copy_to(sc, Lahman::Batting, "batting")
src_tbls(sc)
## [1] "batting" "flights" "iris"
flights_tbl %>% filter(dep_delay == 2)
## # Source:   lazy query [?? x 19]
## # Database: spark_connection
##     year month   day dep_time sched_dep_time dep_delay arr_time
##    <int> <int> <int>    <int>          <int>     <dbl>    <int>
##  1  2013     1     1      517            515      2.00      830
##  2  2013     1     1      542            540      2.00      923
##  3  2013     1     1      702            700      2.00     1058
##  4  2013     1     1      715            713      2.00      911
##  5  2013     1     1      752            750      2.00     1025
##  6  2013     1     1      917            915      2.00     1206
##  7  2013     1     1      932            930      2.00     1219
##  8  2013     1     1     1028           1026      2.00     1350
##  9  2013     1     1     1042           1040      2.00     1325
## 10  2013     1     1     1231           1229      2.00     1523
## # ... with more rows, and 12 more variables: sched_arr_time <int>,
## #   arr_delay <dbl>, carrier <chr>, flight <int>, tailnum <chr>,
## #   origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>, hour <dbl>,
## #   minute <dbl>, time_hour <dttm>
delay <- flights_tbl %>% 
  group_by(tailnum) %>%
  summarise(count = n(), dist = mean(distance), delay = mean(arr_delay)) %>%
  filter(count > 20, dist < 2000, !is.na(delay)) %>%
  collect
## Warning: Missing values are always removed in SQL.
## Use `AVG(x, na.rm = TRUE)` to silence this warning

## Warning: Missing values are always removed in SQL.
## Use `AVG(x, na.rm = TRUE)` to silence this warning
library(ggplot2)
ggplot(delay, aes(dist, delay)) +
  geom_point(aes(size = count), alpha = 1/2) +
  geom_smooth() +
  scale_size_area(max_size = 2)
## `geom_smooth()` using method = 'gam'

#■windows functions
batting_tbl %>%
  select(playerID, yearID, teamID, G, AB:H) %>%
  arrange(playerID, yearID, teamID) %>%
  group_by(playerID) %>%
  filter(min_rank(desc(H)) <= 2 & H > 0)
## # Source:     lazy query [?? x 7]
## # Database:   spark_connection
## # Groups:     playerID
## # Ordered by: playerID, yearID, teamID
##    playerID  yearID teamID     G    AB     R     H
##    <chr>      <int> <chr>  <int> <int> <int> <int>
##  1 aaronha01   1959 ML1      154   629   116   223
##  2 aaronha01   1963 ML1      161   631   121   201
##  3 abadfe01    2012 HOU       37     7     0     1
##  4 abbated01   1905 BSN      153   610    70   170
##  5 abbated01   1904 BSN      154   579    76   148
##  6 abbeych01   1894 WAS      129   523    95   164
##  7 abbeych01   1895 WAS      132   511   102   141
##  8 abbotji01   1999 MIL       20    21     0     2
##  9 abnersh01   1992 CHA       97   208    21    58
## 10 abnersh01   1990 SDN       91   184    17    45
## # ... with more rows
#■using SQL
library(DBI)
iris_preview <- dbGetQuery(sc, "SELECT * FROM iris LIMIT 10")
iris_preview
##    Sepal_Length Sepal_Width Petal_Length Petal_Width Species
## 1           5.1         3.5          1.4         0.2  setosa
## 2           4.9         3.0          1.4         0.2  setosa
## 3           4.7         3.2          1.3         0.2  setosa
## 4           4.6         3.1          1.5         0.2  setosa
## 5           5.0         3.6          1.4         0.2  setosa
## 6           5.4         3.9          1.7         0.4  setosa
## 7           4.6         3.4          1.4         0.3  setosa
## 8           5.0         3.4          1.5         0.2  setosa
## 9           4.4         2.9          1.4         0.2  setosa
## 10          4.9         3.1          1.5         0.1  setosa
#■Machine Learning
mtcars_tbl <- copy_to(sc, mtcars)
partitions <- mtcars_tbl %>%
  filter(hp >= 100) %>%
  mutate(cyl8 = cyl == 8) %>%
  sdf_partition(training = 0.5, test = 0.5, seed = 1099)

fit <- partitions$training %>%
  ml_linear_regression(response = "mpg", features = c("wt", "cyl"))
fit
## Call: ml_linear_regression.tbl_spark(., response = "mpg", features = c("wt", "cyl"))  
## 
## Formula: mpg ~ wt + cyl
## 
## Coefficients:
## (Intercept)          wt         cyl 
##   33.499452   -2.818463   -0.923187
summary(fit)
## Call: ml_linear_regression.tbl_spark(., response = "mpg", features = c("wt", "cyl"))  
## 
## Deviance Residuals:
##    Min     1Q Median     3Q    Max 
## -1.752 -1.134 -0.499  1.296  2.282 
## 
## Coefficients:
## (Intercept)          wt         cyl 
##   33.499452   -2.818463   -0.923187 
## 
## R-Squared: 0.8274
## Root Mean Squared Error: 1.422