1 /** 2 Support the automatic implementation of test doubles via programmable mocks. 3 */ 4 module unit_threaded.mock; 5 6 import unit_threaded.from; 7 8 alias Identity(alias T) = T; 9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member)); 10 11 12 string implMixinStr(T)() { 13 import std.array: join; 14 import std.format : format; 15 import std.range : iota; 16 import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity; 17 import std.conv: text; 18 19 if(!__ctfe) return null; 20 21 string[] lines; 22 23 string getOverload(in string memberName, in int i) { 24 return `Identity!(__traits(getOverloads, T, "%s")[%s])` 25 .format(memberName, i); 26 } 27 28 foreach(memberName; __traits(allMembers, T)) { 29 30 static if(!isPrivate!(T, memberName)) { 31 32 alias member = Identity!(__traits(getMember, T, memberName)); 33 34 static if(__traits(isVirtualMethod, member)) { 35 foreach(i, overload; __traits(getOverloads, T, memberName)) { 36 37 static if(!(functionAttributes!overload & FunctionAttribute.const_) && 38 !(functionAttributes!overload & FunctionAttribute.const_)) { 39 40 enum overloadName = text(memberName, "_", i); 41 42 enum overloadString = getOverload(memberName, i); 43 lines ~= "private alias %s_parameters = Parameters!(%s);".format( 44 overloadName, overloadString); 45 lines ~= "private alias %s_returnType = ReturnType!(%s);".format( 46 overloadName, overloadString); 47 48 static if(functionAttributes!overload & FunctionAttribute.nothrow_) 49 enum tryIndent = " "; 50 else 51 enum tryIndent = ""; 52 53 static if(is(ReturnType!overload == void)) 54 enum returnDefault = ""; 55 else { 56 enum varName = overloadName ~ `_returnValues`; 57 lines ~= `%s_returnType[] %s;`.format(overloadName, varName); 58 lines ~= ""; 59 enum returnDefault = [` if(` ~ varName ~ `.length > 0) {`, 60 ` auto ret = ` ~ varName ~ `[0];`, 61 ` ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`, 62 ` return ret;`, 63 ` } else`, 64 ` return %s_returnType.init;`.format( 65 overloadName)]; 66 } 67 68 lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~ 69 typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~ 70 functionAttributesString!overload ~ ` {`; 71 72 static if(functionAttributes!overload & FunctionAttribute.nothrow_) 73 lines ~= "try {"; 74 75 lines ~= tryIndent ~ ` calledFuncs ~= "` ~ memberName ~ `";`; 76 lines ~= tryIndent ~ ` calledValues ~= tuple` ~ 77 argNamesParens(arity!overload) ~ `.to!string;`; 78 79 static if(functionAttributes!overload & FunctionAttribute.nothrow_) 80 lines ~= " } catch(Exception) {}"; 81 82 lines ~= returnDefault; 83 84 lines ~= `}`; 85 lines ~= ""; 86 } 87 } 88 } 89 } 90 } 91 92 return lines.join("\n"); 93 } 94 95 private string argNamesParens(int N) @safe pure { 96 if(!__ctfe) return null; 97 return "(" ~ argNames(N) ~ ")"; 98 } 99 100 private string argNames(int N) @safe pure { 101 import std.range; 102 import std.algorithm; 103 import std.conv; 104 105 if(!__ctfe) return null; 106 return iota(N).map!(a => "arg" ~ a.to!string).join(", "); 107 } 108 109 private string typeAndArgsParens(T...)(string prefix) { 110 import std.array; 111 import std.conv; 112 import std.format : format; 113 114 if(!__ctfe) return null; 115 116 string[] parts; 117 118 foreach(i, t; T) 119 parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i); 120 return "(" ~ parts.join(", ") ~ ")"; 121 } 122 123 private string functionAttributesString(alias F)() { 124 import std.traits: functionAttributes, FunctionAttribute; 125 import std.array: join; 126 127 if(!__ctfe) return null; 128 129 string[] parts; 130 131 const attrs = functionAttributes!F; 132 133 if(attrs & FunctionAttribute.pure_) parts ~= "pure"; 134 if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow"; 135 if(attrs & FunctionAttribute.trusted) parts ~= "@trusted"; 136 if(attrs & FunctionAttribute.safe) parts ~= "@safe"; 137 if(attrs & FunctionAttribute.nogc) parts ~= "@nogc"; 138 if(attrs & FunctionAttribute.system) parts ~= "@system"; 139 // const and immutable can't be done since the mock needs 140 // to alter state 141 // if(attrs & FunctionAttribute.const_) parts ~= "const"; 142 // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable"; 143 if(attrs & FunctionAttribute.shared_) parts ~= "shared"; 144 if(attrs & FunctionAttribute.property) parts ~= "@property"; 145 146 return parts.join(" "); 147 } 148 149 mixin template MockImplCommon() { 150 bool _verified; 151 string[] expectedFuncs; 152 string[] calledFuncs; 153 string[] expectedValues; 154 string[] calledValues; 155 156 void expect(string funcName, V...)(auto ref V values) { 157 import std.conv: to; 158 import std.typecons: tuple; 159 160 expectedFuncs ~= funcName; 161 static if(V.length > 0) 162 expectedValues ~= tuple(values).to!string; 163 else 164 expectedValues ~= ""; 165 } 166 167 void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) { 168 expect!func(values); 169 verify(file, line); 170 _verified = false; 171 } 172 173 void verify(string file = __FILE__, size_t line = __LINE__) @safe pure { 174 import std.range: repeat, take, join; 175 import std.conv: to; 176 import unit_threaded.should: fail, UnitTestException; 177 178 if(_verified) 179 fail("Mock already _verified", file, line); 180 181 _verified = true; 182 183 for(int i = 0; i < expectedFuncs.length; ++i) { 184 185 if(i >= calledFuncs.length) 186 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line); 187 188 if(expectedFuncs[i] != calledFuncs[i]) 189 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~ 190 " instead", 191 file, line); 192 193 if(expectedValues[i] != calledValues[i] && expectedValues[i] != "") 194 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i], 195 " ".repeat.take(expectedFuncs[i].length + 4).join ~ 196 "instead of the expected " ~ expectedValues[i]] , 197 file, line); 198 } 199 } 200 } 201 202 private enum isString(alias T) = is(typeof(T) == string); 203 204 /** 205 A mock object that conforms to an interface/class. 206 */ 207 struct Mock(T) { 208 209 MockAbstract _impl; 210 alias _impl this; 211 212 class MockAbstract: T { 213 import std.conv: to; 214 import std.traits: Parameters, ReturnType; 215 import std.typecons: tuple; 216 217 //static if(__traits(identifier, T) == "foobarbaz") 218 //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n"); 219 mixin(implMixinStr!T); 220 mixin MockImplCommon; 221 } 222 223 /// 224 this(int/* force constructor*/) { 225 _impl = new MockAbstract; 226 } 227 228 /// 229 ~this() pure @safe { 230 if(!_verified) verify; 231 } 232 233 /// Set the returnValue of a function to certain values. 234 void returnValue(string funcName, V...)(V values) { 235 assertFunctionIsVirtual!funcName; 236 return returnValue!(0, funcName)(values); 237 } 238 239 /** 240 This version takes overloads into account. i is the overload 241 index. e.g.: 242 --------- 243 interface Interface { void foo(int); void foo(string); } 244 auto m = mock!Interface; 245 m.returnValue!(0, "foo"); // int overload 246 m.returnValue!(1, "foo"); // string overload 247 --------- 248 */ 249 void returnValue(int i, string funcName, V...)(V values) { 250 assertFunctionIsVirtual!funcName; 251 import std.conv: text; 252 enum varName = funcName ~ text(`_`, i, `_returnValues`); 253 foreach(v; values) 254 mixin(varName ~ ` ~= v;`); 255 } 256 257 private static void assertFunctionIsVirtual(string funcName)() { 258 alias member = Identity!(__traits(getMember, T, funcName)); 259 260 static assert(__traits(isVirtualMethod, member), 261 "Cannot use returnValue on '" ~ funcName ~ "'"); 262 } 263 } 264 265 private string importsString(string module_, string[] Modules...) { 266 if(!__ctfe) return null; 267 268 auto ret = `import ` ~ module_ ~ ";\n"; 269 foreach(extraModule; Modules) { 270 ret ~= `import ` ~ extraModule ~ ";\n"; 271 } 272 return ret; 273 } 274 275 /// Helper function for creating a Mock object. 276 auto mock(T)() { 277 return Mock!T(0); 278 } 279 280 /// 281 @("mock interface positive test no params") 282 @safe pure unittest { 283 interface Foo { 284 int foo(int, string) @safe pure; 285 void bar() @safe pure; 286 } 287 288 int fun(Foo f) { 289 return 2 * f.foo(5, "foobar"); 290 } 291 292 auto m = mock!Foo; 293 m.expect!"foo"; 294 fun(m); 295 } 296 297 298 /// 299 @("mock interface positive test with params") 300 @safe pure unittest { 301 import unit_threaded.asserts; 302 303 interface Foo { 304 int foo(int, string) @safe pure; 305 void bar() @safe pure; 306 } 307 308 int fun(Foo f) { 309 return 2 * f.foo(5, "foobar"); 310 } 311 312 auto m = mock!Foo; 313 m.expect!"foo"(5, "foobar"); 314 fun(m); 315 } 316 317 318 /// 319 @("interface expectCalled") 320 @safe pure unittest { 321 interface Foo { 322 int foo(int, string) @safe pure; 323 void bar() @safe pure; 324 } 325 326 int fun(Foo f) { 327 return 2 * f.foo(5, "foobar"); 328 } 329 330 auto m = mock!Foo; 331 fun(m); 332 m.expectCalled!"foo"(5, "foobar"); 333 } 334 335 /// 336 @("interface return value") 337 @safe pure unittest { 338 339 interface Foo { 340 int timesN(int i) @safe pure; 341 } 342 343 int fun(Foo f) { 344 return f.timesN(3) * 2; 345 } 346 347 auto m = mock!Foo; 348 m.returnValue!"timesN"(42); 349 immutable res = fun(m); 350 assert(res == 84); 351 } 352 353 /// 354 @("interface return values") 355 @safe pure unittest { 356 357 interface Foo { 358 int timesN(int i) @safe pure; 359 } 360 361 int fun(Foo f) { 362 return f.timesN(3) * 2; 363 } 364 365 auto m = mock!Foo; 366 m.returnValue!"timesN"(42, 12); 367 assert(fun(m) == 84); 368 assert(fun(m) == 24); 369 assert(fun(m) == 0); 370 } 371 372 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T)) { 373 alias funcName = function_; 374 alias Values = T; 375 376 static auto values() { 377 typeof(T[0])[] ret; 378 foreach(val; T) { 379 ret ~= val; 380 } 381 return ret; 382 } 383 } 384 385 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...); 386 enum isValue(alias T) = is(typeof(T)); 387 388 389 /** 390 Version of mockStruct that accepts 0 or more values of the same 391 type. Whatever function is called on it, these values will 392 be returned one by one. The limitation is that if more than one 393 function is called on the mock, they all return the same type 394 */ 395 auto mockStruct(T...)(auto ref T returns) { 396 397 struct Mock { 398 399 MockImpl* _impl; 400 alias _impl this; 401 402 static struct MockImpl { 403 404 static if(T.length > 0) { 405 alias FirstType = typeof(returns[0]); 406 private FirstType[] _returnValues; 407 } 408 409 mixin MockImplCommon; 410 411 auto opDispatch(string funcName, V...)(auto ref V values) { 412 413 import std.conv: to; 414 import std.typecons: tuple; 415 416 calledFuncs ~= funcName; 417 calledValues ~= tuple(values).to!string; 418 419 static if(T.length > 0) { 420 421 if(_returnValues.length == 0) return typeof(_returnValues[0]).init; 422 auto ret = _returnValues[0]; 423 _returnValues = _returnValues[1..$]; 424 return ret; 425 } 426 } 427 } 428 } 429 430 Mock m; 431 m._impl = new Mock.MockImpl; 432 static if(T.length > 0) { 433 foreach(r; returns) 434 m._impl._returnValues ~= r; 435 } 436 437 return m; 438 } 439 440 /** 441 Version of mockStruct that accepts a compile-time mapping 442 of function name to return values. Each template parameter 443 must be a value of type `ReturnValues` 444 */ 445 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) { 446 447 struct Mock { 448 mixin MockImplCommon; 449 450 int[string] _retIndices; 451 452 auto opDispatch(string funcName, V...)(auto ref V values) { 453 454 import std.conv: to; 455 import std.typecons: tuple; 456 457 calledFuncs ~= funcName; 458 calledValues ~= tuple(values).to!string; 459 460 foreach(retVal; T) { 461 static if(retVal.funcName == funcName) { 462 return retVal.values[_retIndices[funcName]++]; 463 } 464 } 465 } 466 467 auto lefoofoo() { 468 return T[0].values[_retIndices["greet"]++]; 469 } 470 471 } 472 473 Mock mock; 474 475 foreach(retVal; T) { 476 mock._retIndices[retVal.funcName] = 0; 477 } 478 479 return mock; 480 } 481 482 /// 483 @("mock struct positive") 484 @safe pure unittest { 485 void fun(T)(T t) { 486 t.foobar; 487 } 488 auto m = mockStruct; 489 m.expect!"foobar"; 490 fun(m); 491 m.verify; 492 } 493 494 495 /// 496 @("mock struct values positive") 497 @safe pure unittest { 498 void fun(T)(T t) { 499 t.foobar(2, "quux"); 500 } 501 502 auto m = mockStruct; 503 m.expect!"foobar"(2, "quux"); 504 fun(m); 505 m.verify; 506 } 507 508 509 /// 510 @("struct return value") 511 @safe pure unittest { 512 513 int fun(T)(T f) { 514 return f.timesN(3) * 2; 515 } 516 517 auto m = mockStruct(42, 12); 518 assert(fun(m) == 84); 519 assert(fun(m) == 24); 520 assert(fun(m) == 0); 521 m.expectCalled!"timesN"; 522 } 523 524 /// 525 @("struct expectCalled") 526 @safe pure unittest { 527 void fun(T)(T t) { 528 t.foobar(2, "quux"); 529 } 530 531 auto m = mockStruct; 532 fun(m); 533 m.expectCalled!"foobar"(2, "quux"); 534 } 535 536 /// 537 @("mockStruct different return types for different functions") 538 @safe pure unittest { 539 auto m = mockStruct!(ReturnValues!("length", 5), 540 ReturnValues!("greet", "hello")); 541 assert(m.length == 5); 542 assert(m.greet("bar") == "hello"); 543 m.expectCalled!"length"; 544 m.expectCalled!"greet"("bar"); 545 } 546 547 /// 548 @("mockStruct different return types for different functions and multiple return values") 549 @safe pure unittest { 550 auto m = mockStruct!(ReturnValues!("length", 5, 3), 551 ReturnValues!("greet", "hello", "g'day")); 552 assert(m.length == 5); 553 m.expectCalled!"length"; 554 assert(m.length == 3); 555 m.expectCalled!"length"; 556 557 assert(m.greet("bar") == "hello"); 558 m.expectCalled!"greet"("bar"); 559 assert(m.greet("quux") == "g'day"); 560 m.expectCalled!"greet"("quux"); 561 } 562 563 564 /** 565 A mock struct that always throws. 566 */ 567 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() { 568 569 struct Mock { 570 571 R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...) 572 (auto ref V values) { 573 throw new E(funcName ~ " was called", file, line); 574 } 575 } 576 577 return Mock(); 578 } 579 580 /// 581 @("throwStruct default") 582 @safe pure unittest { 583 import std.exception: assertThrown; 584 import unit_threaded.should: UnitTestException; 585 auto m = throwStruct; 586 assertThrown!UnitTestException(m.foo); 587 assertThrown!UnitTestException(m.bar(1, "foo")); 588 }