Commit e649727c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

changes

parent 385b763b
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
import deepchem as dc
import numpy as np
import sklearn
from sklearn.ensemble import RandomForestClassifier

N = 10
N = 100
n_feat = 5
n_classes = 3
n_tasks = 1
X = np.random.rand(N, n_feat)
y = np.random.randint(3, size=(N, n_tasks))
y = np.random.randint(3, size=(N,))
dataset = dc.data.NumpyDataset(X, y)

sklearn_model = RandomForestClassifier(
    class_weight="balanced", n_estimators=50)
model = dc.models.SklearnModel(sklearn_model)

# Fit models
metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

# Fit trained model
print("About to fit model")
model.fit(dataset)
model.save()

print("About to evaluate model")
train_scores = model.evaluate(dataset, [metric], [])
train_scores = model.evaluate(dataset,
    sklearn.metrics.roc_auc_score, [])

print("Train scores")
print(train_scores)